Keep track of AWS Athena history using Binary Search Algorithm


AWS Athena


AWS Athena is an AWS service based on Presto which is similar to Hive to query files on Hadoop. It can query files directly on AWS S3 which makes it a perfect tool for data exploration and manipulation on your data lake.


However, it could be quite costly if some of your Athena queries are running out of control e.g. duplicated runs due to data pipeline errors or some queries are just scanning large volume of data without contributing much value. It is a good idea to keep track of your Athena query running history. For example, you might want to find out which queries are the most costly ones or scanning huge amount of data last 7 days or last month. You then isolate those problematic queries and either tune their performance or remove them from your pipelines. In order to do so, we use AWS Athena APIs e.g. get_query_execution, batch_get_query_execution and get_paginator to get the metrics. But lots of queries are running each day in my company. It's pretty hard to iterate through each of them one by one given there're countless queries to scan. Binary search comes in handy to break whole history into smaller periods to filter.


AWS Athena API to pull data in AWS S3


Below is the code snippet I wrote to pull history data using Binary search and save the data into json files day by day:


# ...
def main():
    #...

    if cutoff_date:
        max_query_ids_index = binary_search_index_of_query_submission_date(get_each_execution_with_client, query_execution_ids, cutoff_date)
        query_execution_ids = query_execution_ids[:max_query_ids_index+1]

    query_ids_chunks = get_query_ids_in_chunks(query_execution_ids, 50)
    loop_start = datetime.datetime.now()
    pool = ThreadPool(2)
    final_list_of_dict_in_chunks = pool.map(
        get_each_batch_execution_with_client, query_ids_chunks)
    pool.close()
    pool.join()
    loop_end = datetime.datetime.now()

    def hours_minutes(td):
        return td.seconds//3600, (td.seconds//60) % 60

    hours, minutes = hours_minutes(loop_end-loop_start)
    logger.info(
        f"Time used to get all data is {str(hours) } hours, {str(minutes)} minutes")
    final_list_of_dict = list(
        chain.from_iterable(final_list_of_dict_in_chunks))

    if data.get("up_to_date"):
        up_to_date = datetime.datetime.strptime(
            data.get("up_to_date"), '%Y-%m-%d')
    else:
        up_to_date = now

    final_list_of_dict = list(filter(lambda d: datetime.datetime.strptime(
        d["SubmissionDateTime"], '%Y-%m-%d %H:%M:%S') >= datetime.datetime.strptime(cutoff_date, '%Y-%m-%d').replace(hour=0, minute=0, second=0, microsecond=0)+datetime.timedelta(1) and datetime.datetime.strptime(
        d["SubmissionDateTime"], '%Y-%m-%d %H:%M:%S') < up_to_date.replace(hour=0, minute=0, second=0, microsecond=0), final_list_of_dict))

    logger.info(
        f"Final queries dates ranging from {final_list_of_dict[-1]['SubmissionDateTime']} to {final_list_of_dict[0]['SubmissionDateTime']}")

    s3 = S3(data.get("dest_s3_bucket"), prefix=data.get("dest_s3_prefix"))

    date_folder_on_s3 = up_to_date-datetime.timedelta(1)

    if data.get("output_type") == "json":
        out_file = write_to_json(final_list_of_dict, "athena_history")
        s3.put(out_file, f"{str(date_folder_on_s3.year)}/{str(date_folder_on_s3.month).zfill(2)}/{str(date_folder_on_s3.day).zfill(2)}/athena_history_{str(date_folder_on_s3.year)}_{str(date_folder_on_s3.month).zfill(2)}_{str(date_folder_on_s3.day).zfill(2)}.json")
    else:
        out_file = write_to_csv(final_list_of_dict, "athena_history")
        s3.put(out_file, f"{str(date_folder_on_s3.year)}/{str(date_folder_on_s3.month).zfill(2)}/{str(date_folder_on_s3.day).zfill(2)}/athena_history_{str(date_folder_on_s3.year)}_{str(date_folder_on_s3.month).zfill(2)}_{str(date_folder_on_s3.day).zfill(2)}.csv")


def write_to_csv(final_list_of_dict, outfile):
    with open(outfile, 'w',   newline='') as f:
        w = csv.DictWriter(f, fieldnames=list(final_list_of_dict[0].keys()))
        w.writeheader()
        w.writerows(final_list_of_dict)
    return outfile


def write_to_json(final_list_of_dict, outfile):
    with open(outfile, 'a', newline='') as f:
        for idx, dic in enumerate(final_list_of_dict):
            json.dump(dic, f)
            if idx != len(final_list_of_dict)-1:
                f.write("\n")
    return outfile


def get_query_ids_in_chunks(query_ids, chunk_size):
    return [query_ids[i:i+chunk_size]
            for i in range(0, len(query_ids), chunk_size)]


@AWSRetry.backoff(tries=3, delay=3, added_exceptions=["ThrottlingException"])
def get_each_batch_execution(client, ids_chunk):
    resp = client.batch_get_query_execution(QueryExecutionIds=ids_chunk)

    executions = resp["QueryExecutions"]
    return [{
        "QueryExecutionId": execution["QueryExecutionId"],
        "Query": execution.get("Query"),
        "StatementType": execution.get("StatementType"),
        "ResultConfiguration": str(execution.get("ResultConfiguration")),
        "QueryExecutionContext": str(execution.get("QueryExecutionContext")),
        "State":   execution["Status"].get("State"),
        "StateChangeReason":   execution["Status"].get("StateChangeReason"),
        "SubmissionDateTime": execution["Status"]["SubmissionDateTime"].strftime('%Y-%m-%d %H:%M:%S'),
        "CompletionDateTime":   execution["Status"].get("CompletionDateTime").strftime('%Y-%m-%d %H:%M:%S'),
        "EngineExecutionTimeInMillis": execution.get("Statistics").get("EngineExecutionTimeInMillis"),
        "DataScannedInBytes": execution.get("Statistics").get("DataScannedInBytes"),
        "WorkGroup": execution.get("WorkGroup")
    } for execution in executions]


@AWSRetry.backoff(tries=3, delay=3, added_exceptions=["ThrottlingException"])
def get_each_execution(client, id):
    resp = client.get_query_execution(QueryExecutionId=id)
    return {
        "QueryExecutionId": resp["QueryExecution"]["QueryExecutionId"],
        "Query": resp["QueryExecution"].get("Query"),
        "StatementType": resp["QueryExecution"].get("StatementType"),
        "ResultConfiguration": str(resp["QueryExecution"].get("ResultConfiguration")),
        "QueryExecutionContext": str(resp["QueryExecution"].get("QueryExecutionContext")),
        "State":   resp["QueryExecution"]["Status"].get("State"),
        "StateChangeReason":   resp["QueryExecution"]["Status"].get("StateChangeReason"),
        "SubmissionDateTime": resp["QueryExecution"]["Status"]["SubmissionDateTime"],
        "SubmissionDateTimeString": resp["QueryExecution"]["Status"]["SubmissionDateTime"].strftime('%Y-%m-%d %H:%M:%S'),
        "CompletionDateTime":   resp["QueryExecution"]["Status"].get("CompletionDateTime").strftime('%Y-%m-%d %H:%M:%S'),
        "EngineExecutionTimeInMillis": resp["QueryExecution"].get("Statistics").get("EngineExecutionTimeInMillis"),
        "DataScannedInBytes": resp["QueryExecution"].get("Statistics").get("DataScannedInBytes"),
        "WorkGroup": resp["QueryExecution"].get("WorkGroup")
    }


def binary_search_index_of_query_submission_date(get_each_execution_with_client, query_ids, submission_date):
    left = 0
    right = len(query_ids)-1

    logger.info(f"Searching {submission_date}")

    submission_date_parsed = datetime.datetime.strptime(
        submission_date, "%Y-%m-%d").replace(tzinfo=None).replace(hour=0, minute=0, second=0, microsecond=0)

    while left <= right:
        midpoint = left + (right - left)//2
        midpoint_result_dict = get_each_execution_with_client(
            query_ids[midpoint])
        midpoint_date_parsed = midpoint_result_dict["SubmissionDateTime"].replace(
            tzinfo=None).replace(hour=0, minute=0, second=0, microsecond=0)
        if midpoint_date_parsed == submission_date_parsed:
            return midpoint
        else:
            # note the query id list is sorted in reversed order so if search date is less than midpoint date, it should be on right hand side
            if submission_date_parsed < midpoint_date_parsed:
                left = midpoint+1
            else:
                # if search date > midpoint date, it should be on left hand side of the list
                right = midpoint-1

    logger.info(
        f"Date you specified not found. Returning the index of nearest date {get_each_execution_with_client(query_ids[left-1])['SubmissionDateTime']}")
    return left-1


@AWSRetry.backoff(tries=3, delay=60, added_exceptions=["ThrottlingException"])
def get_all_execution_ids(client):
    next_token = None
    no_of_page = 0
    query_execution_ids = []

    @AWSRetry.backoff(tries=3, delay=60, added_exceptions=["ThrottlingException"])
    def iterate_paginator(response_iterator, query_execution_ids, no_of_page):

        for page in response_iterator:

            query_execution_ids.extend(page["QueryExecutionIds"])
            no_of_page = no_of_page + 1

        return page, no_of_page

    while True:
        paginator = client.get_paginator('list_query_executions')
        response_iterator = paginator.paginate(PaginationConfig={
            'MaxItems': 5000,
            'PageSize': 50,
            'StartingToken': next_token})

        page, no_of_page = iterate_paginator(
            response_iterator, query_execution_ids, no_of_page)

        try:
            next_token = page["NextToken"]
        except KeyError:
            break

        logger.info(f"Processed pages {str(no_of_page)} to get execution ids")

    return query_execution_ids

Sample result


{"QueryExecutionId": "f37eee2d-22cb-457d-8dd5-6243ba316dde", "Query": "select count(*) from a_schema.table", "StatementType": "DML", "ResultConfiguration": "{'OutputLocation': 's3://bucket/f37eee2d-22cb-457d-8dd5-6243ba316dde.csv'}", "QueryExecutionContext": "{'Database': 'schema'}", "State": "SUCCEEDED", "StateChangeReason": null, "SubmissionDateTime": "2020-07-04 21:06:27", "CompletionDateTime": "2020-07-04 21:06:31", "EngineExecutionTimeInMillis": 3802, "DataScannedInBytes": 1634006466, "WorkGroup": "primary"}
{"QueryExecutionId": "029a8baf-b9e3-4f2d-ac22-3e4de976ac50", "Query": "show partitions a_schema.table", "StatementType": "UTILITY", "ResultConfiguration": "{'OutputLocation': 's3://bucket/029a8baf-b9e3-4f2d-ac22-323de976ac50.txt'}", "QueryExecutionContext": "{'Database': 'schema'}", "State": "SUCCEEDED", "StateChangeReason": null, "SubmissionDateTime": "2020-07-04 23:59:23", "CompletionDateTime": "2020-07-04 23:59:27", "EngineExecutionTimeInMillis": 4583, "DataScannedInBytes": 0, "WorkGroup": "primary"}

Set up athena views on top of the files downloaded


After we download the data, we can set up a table and a few views:


CREATE EXTERNAL TABLE `athena_history_json`(
  `json` string)
ROW FORMAT DELIMITED
  FIELDS TERMINATED BY '\t'
STORED AS INPUTFORMAT
  'org.apache.hadoop.mapred.TextInputFormat'
OUTPUTFORMAT
  'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'
LOCATION
  's3://bucket/athena_history'
TBLPROPERTIES (
  'has_encrypted_data'='false');


CREATE OR REPLACE VIEW athena_history AS
SELECT
  "json_extract_scalar"("q"."json", '$.QueryExecutionId') "QueryExecutionId"
, "json_extract_scalar"("q"."json", '$.Query') "Query"
, "substr"("json_extract_scalar"("q"."json", '$.Query'), 1, 50) "QueryShort"
, (CASE WHEN ("substr"("json_extract_scalar"("q"."json", '$.Query'), 1, 2) = '--') THEN "trim"("split"("json_extract_scalar"("q"."json", '$.Query'), '--')[2]) ELSE '' END) "QueryTag"
, "json_extract_scalar"("q"."json", '$.StatementType') "StatementType"
, "json_extract_scalar"("q"."json", '$.ResultConfiguration.OutputLocation') "OutputLocation"
, "json_extract_scalar"("q"."json", '$.ResultConfiguration.QueryExecutionContext.Database') "Database"
, "json_extract_scalar"("q"."json", '$.State') "State"
, "json_extract_scalar"("q"."json", '$.StateChangeReason') "StateChangeReason"
, "json_extract_scalar"("q"."json", '$.SubmissionDateTime') "SubmissionDateTime"
, "json_extract_scalar"("q"."json", '$.CompletionDateTime') "CompletionDateTime"
, CAST("json_extract_scalar"("q"."json", '$.EngineExecutionTimeInMillis') AS integer) "EngineExecutionTimeInMillis"
, CAST("json_extract_scalar"("q"."json", '$.DataScannedInBytes') AS bigint) "DataScannedInBytes"
, "json_extract_scalar"("q"."json", '$.WorkGroup') "WorkGroup"
, CAST(((CAST("json_extract_scalar"("q"."json", '$.DataScannedInBytes') AS bigint) / "power"(2, 40)) * 5) AS decimal(30,20)) "QueryCost"
FROM
  a_schema.athena_history_json q ;


CREATE OR REPLACE VIEW athena_history_daily_cost AS
SELECT
  CAST(CAST("submissiondatetime" AS timestamp) AS date) "query_date"
, "sum"("querycost") "query_cost"
FROM
  a_schema.athena_history
GROUP BY CAST(CAST("submissiondatetime" AS timestamp) AS date)
ORDER BY CAST(CAST("submissiondatetime" AS timestamp) AS date) DESC ;



CREATE OR REPLACE VIEW athena_history_query_cost AS
WITH
  qc AS (
   SELECT
     (CASE WHEN ("querytag" <> '') THEN (CASE WHEN ("strpos"("querytag", ' date') > 0) THEN "split"("querytag", ' date')[1] ELSE "querytag" END) ELSE "queryshort" END) "query_type"
   , "querycost" "query_cost"
   FROM
     a_schema.athena_history
   WHERE (CAST("submissiondatetime" AS timestamp) >= "date_add"('day', -30, current_timestamp))
)
SELECT
  "query_type"
, "sum"("query_cost") "query_cost"
FROM
  qc
GROUP BY "query_type"
ORDER BY "query_cost" DESC ;

Keep track of AWS Athena history using Binary Search Algorithm
arrow_back

Previous

Depth First Search in Practice

Next

Set up Laravel Homestead Development environment
arrow_forward