Keep track of AWS Athena history using Binary Search Algorithm
AWS Athena
AWS Athena is an AWS service based on Presto which is similar to Hive to query files on Hadoop. It can query files directly on AWS S3 which makes it a perfect tool for data exploration and manipulation on your data lake.
However, it could be quite costly if some of your Athena queries are running out of control e.g. duplicated runs due to data pipeline errors or some queries are just scanning large volume of data without contributing much value. It is a good idea to keep track of your Athena query running history. For example, you might want to find out which queries are the most costly ones or scanning huge amount of data last 7 days or last month. You then isolate those problematic queries and either tune their performance or remove them from your pipelines. In order to do so, we use AWS Athena APIs e.g. get_query_execution
, batch_get_query_execution
and get_paginator
to get the metrics. But lots of queries are running each day in my company. It's pretty hard to iterate through each of them one by one given there're countless queries to scan. Binary search comes in handy to break whole history into smaller periods to filter.
AWS Athena API to pull data in AWS S3
Below is the code snippet I wrote to pull history data using Binary search and save the data into json files day by day:
# ...
def main():
#...
if cutoff_date:
max_query_ids_index = binary_search_index_of_query_submission_date(get_each_execution_with_client, query_execution_ids, cutoff_date)
query_execution_ids = query_execution_ids[:max_query_ids_index+1]
query_ids_chunks = get_query_ids_in_chunks(query_execution_ids, 50)
loop_start = datetime.datetime.now()
pool = ThreadPool(2)
final_list_of_dict_in_chunks = pool.map(
get_each_batch_execution_with_client, query_ids_chunks)
pool.close()
pool.join()
loop_end = datetime.datetime.now()
def hours_minutes(td):
return td.seconds//3600, (td.seconds//60) % 60
hours, minutes = hours_minutes(loop_end-loop_start)
logger.info(
f"Time used to get all data is {str(hours) } hours, {str(minutes)} minutes")
final_list_of_dict = list(
chain.from_iterable(final_list_of_dict_in_chunks))
if data.get("up_to_date"):
up_to_date = datetime.datetime.strptime(
data.get("up_to_date"), '%Y-%m-%d')
else:
up_to_date = now
final_list_of_dict = list(filter(lambda d: datetime.datetime.strptime(
d["SubmissionDateTime"], '%Y-%m-%d %H:%M:%S') >= datetime.datetime.strptime(cutoff_date, '%Y-%m-%d').replace(hour=0, minute=0, second=0, microsecond=0)+datetime.timedelta(1) and datetime.datetime.strptime(
d["SubmissionDateTime"], '%Y-%m-%d %H:%M:%S') < up_to_date.replace(hour=0, minute=0, second=0, microsecond=0), final_list_of_dict))
logger.info(
f"Final queries dates ranging from {final_list_of_dict[-1]['SubmissionDateTime']} to {final_list_of_dict[0]['SubmissionDateTime']}")
s3 = S3(data.get("dest_s3_bucket"), prefix=data.get("dest_s3_prefix"))
date_folder_on_s3 = up_to_date-datetime.timedelta(1)
if data.get("output_type") == "json":
out_file = write_to_json(final_list_of_dict, "athena_history")
s3.put(out_file, f"{str(date_folder_on_s3.year)}/{str(date_folder_on_s3.month).zfill(2)}/{str(date_folder_on_s3.day).zfill(2)}/athena_history_{str(date_folder_on_s3.year)}_{str(date_folder_on_s3.month).zfill(2)}_{str(date_folder_on_s3.day).zfill(2)}.json")
else:
out_file = write_to_csv(final_list_of_dict, "athena_history")
s3.put(out_file, f"{str(date_folder_on_s3.year)}/{str(date_folder_on_s3.month).zfill(2)}/{str(date_folder_on_s3.day).zfill(2)}/athena_history_{str(date_folder_on_s3.year)}_{str(date_folder_on_s3.month).zfill(2)}_{str(date_folder_on_s3.day).zfill(2)}.csv")
def write_to_csv(final_list_of_dict, outfile):
with open(outfile, 'w', newline='') as f:
w = csv.DictWriter(f, fieldnames=list(final_list_of_dict[0].keys()))
w.writeheader()
w.writerows(final_list_of_dict)
return outfile
def write_to_json(final_list_of_dict, outfile):
with open(outfile, 'a', newline='') as f:
for idx, dic in enumerate(final_list_of_dict):
json.dump(dic, f)
if idx != len(final_list_of_dict)-1:
f.write("\n")
return outfile
def get_query_ids_in_chunks(query_ids, chunk_size):
return [query_ids[i:i+chunk_size]
for i in range(0, len(query_ids), chunk_size)]
@AWSRetry.backoff(tries=3, delay=3, added_exceptions=["ThrottlingException"])
def get_each_batch_execution(client, ids_chunk):
resp = client.batch_get_query_execution(QueryExecutionIds=ids_chunk)
executions = resp["QueryExecutions"]
return [{
"QueryExecutionId": execution["QueryExecutionId"],
"Query": execution.get("Query"),
"StatementType": execution.get("StatementType"),
"ResultConfiguration": str(execution.get("ResultConfiguration")),
"QueryExecutionContext": str(execution.get("QueryExecutionContext")),
"State": execution["Status"].get("State"),
"StateChangeReason": execution["Status"].get("StateChangeReason"),
"SubmissionDateTime": execution["Status"]["SubmissionDateTime"].strftime('%Y-%m-%d %H:%M:%S'),
"CompletionDateTime": execution["Status"].get("CompletionDateTime").strftime('%Y-%m-%d %H:%M:%S'),
"EngineExecutionTimeInMillis": execution.get("Statistics").get("EngineExecutionTimeInMillis"),
"DataScannedInBytes": execution.get("Statistics").get("DataScannedInBytes"),
"WorkGroup": execution.get("WorkGroup")
} for execution in executions]
@AWSRetry.backoff(tries=3, delay=3, added_exceptions=["ThrottlingException"])
def get_each_execution(client, id):
resp = client.get_query_execution(QueryExecutionId=id)
return {
"QueryExecutionId": resp["QueryExecution"]["QueryExecutionId"],
"Query": resp["QueryExecution"].get("Query"),
"StatementType": resp["QueryExecution"].get("StatementType"),
"ResultConfiguration": str(resp["QueryExecution"].get("ResultConfiguration")),
"QueryExecutionContext": str(resp["QueryExecution"].get("QueryExecutionContext")),
"State": resp["QueryExecution"]["Status"].get("State"),
"StateChangeReason": resp["QueryExecution"]["Status"].get("StateChangeReason"),
"SubmissionDateTime": resp["QueryExecution"]["Status"]["SubmissionDateTime"],
"SubmissionDateTimeString": resp["QueryExecution"]["Status"]["SubmissionDateTime"].strftime('%Y-%m-%d %H:%M:%S'),
"CompletionDateTime": resp["QueryExecution"]["Status"].get("CompletionDateTime").strftime('%Y-%m-%d %H:%M:%S'),
"EngineExecutionTimeInMillis": resp["QueryExecution"].get("Statistics").get("EngineExecutionTimeInMillis"),
"DataScannedInBytes": resp["QueryExecution"].get("Statistics").get("DataScannedInBytes"),
"WorkGroup": resp["QueryExecution"].get("WorkGroup")
}
def binary_search_index_of_query_submission_date(get_each_execution_with_client, query_ids, submission_date):
left = 0
right = len(query_ids)-1
logger.info(f"Searching {submission_date}")
submission_date_parsed = datetime.datetime.strptime(
submission_date, "%Y-%m-%d").replace(tzinfo=None).replace(hour=0, minute=0, second=0, microsecond=0)
while left <= right:
midpoint = left + (right - left)//2
midpoint_result_dict = get_each_execution_with_client(
query_ids[midpoint])
midpoint_date_parsed = midpoint_result_dict["SubmissionDateTime"].replace(
tzinfo=None).replace(hour=0, minute=0, second=0, microsecond=0)
if midpoint_date_parsed == submission_date_parsed:
return midpoint
else:
# note the query id list is sorted in reversed order so if search date is less than midpoint date, it should be on right hand side
if submission_date_parsed < midpoint_date_parsed:
left = midpoint+1
else:
# if search date > midpoint date, it should be on left hand side of the list
right = midpoint-1
logger.info(
f"Date you specified not found. Returning the index of nearest date {get_each_execution_with_client(query_ids[left-1])['SubmissionDateTime']}")
return left-1
@AWSRetry.backoff(tries=3, delay=60, added_exceptions=["ThrottlingException"])
def get_all_execution_ids(client):
next_token = None
no_of_page = 0
query_execution_ids = []
@AWSRetry.backoff(tries=3, delay=60, added_exceptions=["ThrottlingException"])
def iterate_paginator(response_iterator, query_execution_ids, no_of_page):
for page in response_iterator:
query_execution_ids.extend(page["QueryExecutionIds"])
no_of_page = no_of_page + 1
return page, no_of_page
while True:
paginator = client.get_paginator('list_query_executions')
response_iterator = paginator.paginate(PaginationConfig={
'MaxItems': 5000,
'PageSize': 50,
'StartingToken': next_token})
page, no_of_page = iterate_paginator(
response_iterator, query_execution_ids, no_of_page)
try:
next_token = page["NextToken"]
except KeyError:
break
logger.info(f"Processed pages {str(no_of_page)} to get execution ids")
return query_execution_ids
Sample result
{"QueryExecutionId": "f37eee2d-22cb-457d-8dd5-6243ba316dde", "Query": "select count(*) from a_schema.table", "StatementType": "DML", "ResultConfiguration": "{'OutputLocation': 's3://bucket/f37eee2d-22cb-457d-8dd5-6243ba316dde.csv'}", "QueryExecutionContext": "{'Database': 'schema'}", "State": "SUCCEEDED", "StateChangeReason": null, "SubmissionDateTime": "2020-07-04 21:06:27", "CompletionDateTime": "2020-07-04 21:06:31", "EngineExecutionTimeInMillis": 3802, "DataScannedInBytes": 1634006466, "WorkGroup": "primary"}
{"QueryExecutionId": "029a8baf-b9e3-4f2d-ac22-3e4de976ac50", "Query": "show partitions a_schema.table", "StatementType": "UTILITY", "ResultConfiguration": "{'OutputLocation': 's3://bucket/029a8baf-b9e3-4f2d-ac22-323de976ac50.txt'}", "QueryExecutionContext": "{'Database': 'schema'}", "State": "SUCCEEDED", "StateChangeReason": null, "SubmissionDateTime": "2020-07-04 23:59:23", "CompletionDateTime": "2020-07-04 23:59:27", "EngineExecutionTimeInMillis": 4583, "DataScannedInBytes": 0, "WorkGroup": "primary"}
Set up athena views on top of the files downloaded
After we download the data, we can set up a table and a few views:
CREATE EXTERNAL TABLE `athena_history_json`(
`json` string)
ROW FORMAT DELIMITED
FIELDS TERMINATED BY '\t'
STORED AS INPUTFORMAT
'org.apache.hadoop.mapred.TextInputFormat'
OUTPUTFORMAT
'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'
LOCATION
's3://bucket/athena_history'
TBLPROPERTIES (
'has_encrypted_data'='false');
CREATE OR REPLACE VIEW athena_history AS
SELECT
"json_extract_scalar"("q"."json", '$.QueryExecutionId') "QueryExecutionId"
, "json_extract_scalar"("q"."json", '$.Query') "Query"
, "substr"("json_extract_scalar"("q"."json", '$.Query'), 1, 50) "QueryShort"
, (CASE WHEN ("substr"("json_extract_scalar"("q"."json", '$.Query'), 1, 2) = '--') THEN "trim"("split"("json_extract_scalar"("q"."json", '$.Query'), '--')[2]) ELSE '' END) "QueryTag"
, "json_extract_scalar"("q"."json", '$.StatementType') "StatementType"
, "json_extract_scalar"("q"."json", '$.ResultConfiguration.OutputLocation') "OutputLocation"
, "json_extract_scalar"("q"."json", '$.ResultConfiguration.QueryExecutionContext.Database') "Database"
, "json_extract_scalar"("q"."json", '$.State') "State"
, "json_extract_scalar"("q"."json", '$.StateChangeReason') "StateChangeReason"
, "json_extract_scalar"("q"."json", '$.SubmissionDateTime') "SubmissionDateTime"
, "json_extract_scalar"("q"."json", '$.CompletionDateTime') "CompletionDateTime"
, CAST("json_extract_scalar"("q"."json", '$.EngineExecutionTimeInMillis') AS integer) "EngineExecutionTimeInMillis"
, CAST("json_extract_scalar"("q"."json", '$.DataScannedInBytes') AS bigint) "DataScannedInBytes"
, "json_extract_scalar"("q"."json", '$.WorkGroup') "WorkGroup"
, CAST(((CAST("json_extract_scalar"("q"."json", '$.DataScannedInBytes') AS bigint) / "power"(2, 40)) * 5) AS decimal(30,20)) "QueryCost"
FROM
a_schema.athena_history_json q ;
CREATE OR REPLACE VIEW athena_history_daily_cost AS
SELECT
CAST(CAST("submissiondatetime" AS timestamp) AS date) "query_date"
, "sum"("querycost") "query_cost"
FROM
a_schema.athena_history
GROUP BY CAST(CAST("submissiondatetime" AS timestamp) AS date)
ORDER BY CAST(CAST("submissiondatetime" AS timestamp) AS date) DESC ;
CREATE OR REPLACE VIEW athena_history_query_cost AS
WITH
qc AS (
SELECT
(CASE WHEN ("querytag" <> '') THEN (CASE WHEN ("strpos"("querytag", ' date') > 0) THEN "split"("querytag", ' date')[1] ELSE "querytag" END) ELSE "queryshort" END) "query_type"
, "querycost" "query_cost"
FROM
a_schema.athena_history
WHERE (CAST("submissiondatetime" AS timestamp) >= "date_add"('day', -30, current_timestamp))
)
SELECT
"query_type"
, "sum"("query_cost") "query_cost"
FROM
qc
GROUP BY "query_type"
ORDER BY "query_cost" DESC ;