Last active
October 28, 2025 22:20
-
-
Save tmck-code/dbb3a63751f96f97f0842a432c395899 to your computer and use it in GitHub Desktop.
AWS Utilities
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import argparse | |
| from dataclasses import dataclass | |
| import time | |
| from typing import ClassVar | |
| from datetime import datetime | |
| import boto3 | |
| import json | |
| from pygments import highlight | |
| from pygments.lexers import JsonLexer | |
| from pygments.formatters import TerminalTrueColorFormatter as Formatter | |
| from pygments.styles import get_style_by_name | |
| def ppd(d, indent=None): print(highlight(json.dumps(d, indent=indent, default=str), JsonLexer(), Formatter(style=get_style_by_name('material'))).strip()) | |
| def ppj(j, indent=2): ppd(json.loads(j), indent=indent) | |
| @dataclass | |
| class AthenaQuery: | |
| database: str | |
| work_group: str | |
| results_bucket: str | |
| query_key: str | |
| catalog: str = 'AwsDataCatalog' | |
| s3_acl_option: ClassVar[str] = 'BUCKET_OWNER_FULL_CONTROL' | |
| def __post_init__(self): | |
| self.client = boto3.client('athena') | |
| @property | |
| def output_location(self) -> str: | |
| return f's3://{self.results_bucket}/customerapi/au-identity-history/query_results/' | |
| def execute(self, query: str) -> str: | |
| return self.client.start_query_execution( | |
| QueryString=query, | |
| QueryExecutionContext={ | |
| 'Database': self.database, | |
| 'Catalog': self.catalog, | |
| }, | |
| ResultConfiguration={ | |
| 'OutputLocation': self.output_location, | |
| 'AclConfiguration': { | |
| 'S3AclOption': self.s3_acl_option, | |
| } | |
| }, | |
| WorkGroup=self.work_group | |
| )['QueryExecutionId'] | |
| def query_string(self, kwargs: dict) -> None: | |
| return open(f'{self.query_key}.sql').read().format(**kwargs) | |
| def poll_query(self, query_execution_id: str) -> dict: | |
| response = self.client.get_query_execution( | |
| QueryExecutionId=query_execution_id | |
| ) | |
| while True: | |
| response = self.client.get_query_execution(QueryExecutionId=query_execution_id) | |
| ppd({ | |
| 'timestamp': datetime.now().astimezone().isoformat(), | |
| 'QueryExecutionId': query_execution_id, | |
| 'State': response['QueryExecution']['Status']['State'] | |
| }) | |
| if response['QueryExecution']['Status']['State'] not in ['QUEUED', 'RUNNING']: | |
| break | |
| time.sleep(10) | |
| def get_query_results(self, query_execution_id: str) -> dict: | |
| return self.client.get_query_results(QueryExecutionId=query_execution_id) | |
| def _query_results_to_table(row_id: str, query_results: dict) -> list: | |
| table = [] | |
| for i, raw_row in enumerate(query_results['ResultSet']['Rows']): | |
| if i == 0: | |
| row = ['__row_id__'] | |
| else: | |
| row = [row_id] | |
| for cell in raw_row['Data']: | |
| row.append(cell['VarCharValue']) | |
| table.append(row) | |
| return table | |
| def query(self, row_id: str, query_kwargs: dict) -> dict: | |
| query_execution_id = self.execute(self.query_string(query_kwargs)) | |
| self.poll_query(query_execution_id) | |
| results = self.get_query_results(query_execution_id) | |
| table = AthenaQuery._query_results_to_table(row_id, results) | |
| return table, results | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Run Athena Queries') | |
| parser.add_argument('--database', type=str, required=True, help='Athena Database') | |
| parser.add_argument('--work-group', type=str, required=True, help='Athena Work Group') | |
| parser.add_argument('--results-bucket', type=str, required=True, help='S3 Bucket for Athena Query Results') | |
| parser.add_argument('--query-key', type=str, required=True, help='Key for the Athena Query') | |
| parser.add_argument('--queries', type=json.loads, required=True, help='Dataset IDs') | |
| parser.add_argument('--ofpath', type=str, required=True, help='Output File') | |
| return parser.parse_args().__dict__ | |
| def run(): | |
| args = parse_args() | |
| q = AthenaQuery( | |
| args['database'], | |
| args['work_group'], | |
| args['results_bucket'], | |
| args['query_key'], | |
| ) | |
| ppd({'timestamp': datetime.now().astimezone().isoformat(), 'msg': 'Starting Query'}) | |
| combined_table = [] | |
| for i, query_kwargs in enumerate(args['queries']): | |
| table, results = q.query( | |
| row_id=query_kwargs['__row_id__'], | |
| query_kwargs={ | |
| 'new_dataset_id': query_kwargs['new_dataset_id'], | |
| 'old_dataset_id': query_kwargs['old_dataset_id'], | |
| } | |
| ) | |
| ppd({ | |
| 'timestamp': datetime.now().astimezone().isoformat(), | |
| 'msg': 'received query results', | |
| 'query_kwargs': query_kwargs, | |
| 'results': results | |
| }) | |
| if i == 0: | |
| combined_table.append(table[0]) | |
| combined_table.extend(table[1:]) | |
| for row in combined_table: | |
| ppd(row) | |
| with open(args['ofpath'], 'w') as ostream: | |
| for row in combined_table: | |
| print('\t'.join(list(map(str, row))), file=ostream) | |
| if __name__ == '__main__': | |
| run() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import argparse | |
| from itertools import count | |
| import os | |
| import json | |
| import time | |
| import boto3 | |
| client = boto3.client('logs') | |
| SLEEP_TIME = 5 | |
| class STATUSES: | |
| PENDING = ['Scheduled', 'Running', 'Timeout', 'Unknown'] | |
| FINISHED = ['Cancelled', 'Complete', 'Failed'] | |
| def run(log_group_name, query, window, ofpath): | |
| print('starting query') | |
| response = client.start_query( | |
| logGroupName = log_group_name, | |
| startTime = int(time.time()-window), | |
| endTime = int(time.time()), | |
| queryString = query | |
| ) | |
| print(f'{response=}') | |
| query_id = response['queryId'] | |
| for i in count(): | |
| print(f'polling for result, {i=}, {SLEEP_TIME=}') | |
| response = client.get_query_results(queryId=query_id) | |
| print(f'{response["statistics"]=} {response["status"]=}') | |
| if (status := response['status']) in STATUSES.PENDING: | |
| print(status) | |
| time.sleep(SLEEP_TIME) | |
| elif status in STATUSES.FINISHED: | |
| break | |
| results = response['results'] | |
| print(f'writing results to file {ofpath=}') | |
| with open(os.path.basename(ofpath), 'w') as ostream: | |
| for result in results: | |
| r = {} | |
| for values in result: | |
| r[values['field']] = values['value'] | |
| r['@message'] = json.loads(r['@message']) | |
| print(json.dumps(r), file=ostream) | |
| print(f'{response["statistics"]=} {response["status"]=}') | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Process some integers.') | |
| parser.add_argument('log_group_name', type=str, help='the logGroupName') | |
| parser.add_argument('window', type=int, default=86400, help='time window of query, in seconds') | |
| parser.add_argument('ofpath', type=str, help='the output file path') | |
| parser.add_argument('query', type=str, help='the insights query to run') | |
| args = parser.parse_args().__dict__ | |
| print(args) | |
| return args | |
| if __name__ == '__main__': | |
| run(**parse_args()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| ''' | |
| This is a helper script to fetch & poll a specific AWS CloudWatch log stream. | |
| - You'll need the logGroupName, logStreamName, and the region of the log stream. | |
| - You can also provide the URL of the log stream, and the script will parse it for you. | |
| Usage: | |
| -logGroupName the logGroupName | |
| -logStreamName the logStreamName | |
| -region the region of the logStream | |
| -ofpath the output file path | |
| --poll enable poll mode (continuously wait for new logs) | |
| --tail enable tail mode (start from the last log events) | |
| --desc consume logs in reverse order (recent -> old) | |
| ''' | |
| import argparse | |
| from datetime import datetime | |
| from itertools import cycle, islice | |
| import json | |
| import os, sys | |
| import time | |
| from typing import Tuple | |
| from urllib.parse import unquote, urlparse, parse_qs | |
| import boto3 | |
| import botocore | |
| from pygments import highlight | |
| from pygments.lexers import JsonLexer | |
| from pygments.formatters import TerminalTrueColorFormatter as Formatter | |
| from pygments.styles import get_style_by_name | |
| def ppd(d, indent=2, style='material'): print(highlight(json.dumps(d, indent=indent, default=str), JsonLexer(), Formatter(style=get_style_by_name(style))).strip(), flush=True) | |
| if os.environ.get('NO_STYLE'): | |
| def ppd(d, indent=2, style='material'): print(json.dumps(d, indent=indent, default=str), flush=True) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Get all logs from a log stream, via the log group/stream names or the cloudwatch log stream URL') | |
| parser.add_argument('-logGroupName', type=str, help='the logGroupName', required=False) | |
| parser.add_argument('-logStreamName', type=str, help='the logStreamName', required=False) | |
| parser.add_argument('-logStreamURL', type=str, help='the URL of the logStream', required=False) | |
| parser.add_argument('-region', type=str, help='the region of the logStream', required=False) | |
| parser.add_argument('-ofpath', type=str, help='the output file path') | |
| parser.add_argument('--poll', action='store_true', help='enable poll mode') | |
| parser.add_argument('--tail', action='store_true', help='enable tail mode') | |
| parser.add_argument('--desc', action='store_true', help='consume logs recent -> old') | |
| args = parser.parse_args().__dict__ | |
| if args['logStreamURL']: | |
| region, log_group, log_stream = parse_cloudwatch_url(args['logStreamURL']) | |
| ppd({'url': args['logStreamURL'], 'region': region, 'log_group': log_group, 'log_stream': log_stream}, indent=None) | |
| args.update({'region': region, 'logGroupName': log_group, 'logStreamName': log_stream}) | |
| del args['logStreamURL'] | |
| ppd(args, indent=None) | |
| return args | |
| def parse_cloudwatch_url(url: str) -> Tuple[str, str]: | |
| if '$' in url: | |
| print(unquote(url.replace('$', '%')).split('/')) | |
| scheme, _, host, service, params, _lg, log_group, _le, log_stream = unquote(url.replace('$', '%')).split('/') | |
| else: | |
| print(unquote(url).split('/')) | |
| scheme, _, host, service, params, _lg, log_group, _le, log_stream = url.split('/') | |
| region = parse_qs(urlparse(url).query)['region'][0] | |
| return region, unquote(log_group), unquote(log_stream) | |
| def write_events(response, ostream): | |
| for i, event in enumerate(response['events']): | |
| if i == 0: | |
| print('\r', flush=True, end='') | |
| try: | |
| event['message'] = json.loads(event['message']) | |
| except ValueError: | |
| pass | |
| print(json.dumps(event), file=ostream, flush=True) | |
| ppd(event['message'], indent=None, ) | |
| return len(response['events']) | |
| def get_log_times(client, logGroupName, logStreamName): | |
| while True: | |
| response = client.describe_log_streams( | |
| logGroupName=logGroupName, logStreamNamePrefix=logStreamName, | |
| ) | |
| stream = response['logStreams'][0] | |
| if 'firstEventTimestamp' not in stream: | |
| time.sleep(5) | |
| continue | |
| return stream['firstEventTimestamp'], stream.get('lastEventTimestamp') | |
| def print_sleep_animation(steps: int, rotations: int = 20, total_sleep: int = 10): | |
| # account for the time.sleep() at the end of the loop, by dividing by (steps-1) | |
| delay = (total_sleep/(steps-1))/rotations | |
| for i in range(steps): | |
| for el in islice(cycle(['-', '\\', '|', '/']), rotations): | |
| print('\r' + '.'*i + el, end='', file=sys.stderr) | |
| time.sleep(delay) | |
| print('\r' + '.'*i + '↩', file=sys.stderr, end='', flush=True) | |
| # time.sleep(delay) | |
| # clear the line | |
| print('\r'+ ' '*(steps+1), end='', flush=True) | |
| # move the print cursor to the beginning of the line, getting ready to | |
| # either to print another animation, or to write the next logs | |
| print('\r', end='', flush=True) | |
| def run(logGroupName, logStreamName, ofpath, region, poll=False, tail=False, desc=True): | |
| kwargs = { | |
| 'logGroupName': logGroupName, | |
| 'logStreamName': logStreamName, | |
| # 'startFromHead': desc, | |
| 'startFromHead': True, | |
| 'limit': 1_000, | |
| } | |
| client = boto3.client('logs', region_name=region) | |
| try: | |
| startTime, endTime = get_log_times(client, logGroupName, logStreamName) | |
| if tail: | |
| kwargs['startTime'] = int(endTime)-(120*1_000) # 120 seconds ago | |
| else: | |
| kwargs['startTime'] = startTime | |
| ppd({ | |
| 'log_start': datetime.fromtimestamp(int(startTime)/1_000), | |
| 'log_end': endTime and datetime.fromtimestamp(int(endTime)/1_000), | |
| 'starting_from': datetime.fromtimestamp(int(kwargs['startTime'])/1_000) | |
| }, indent=None) | |
| except botocore.exceptions.NoCredentialsError as e: | |
| print('❌ Error: Unable to locate AWS credentials!', file=sys.stderr) | |
| sys.exit(1) | |
| ppd(kwargs, indent=None) | |
| time.sleep(1) | |
| with open(ofpath, 'w') as ostream: | |
| while True: | |
| response = client.get_log_events(**kwargs) | |
| n = write_events(response, ostream) | |
| if desc: | |
| next_token = response.get('nextBackwardToken') | |
| else: | |
| next_token = response.get('nextForwardToken') | |
| # print(kwargs.get('nextToken', ''), next_token, file=sys.stderr) | |
| if kwargs.get('nextToken', '') == response['nextForwardToken']: | |
| if poll: | |
| print_sleep_animation(steps=10, rotations=12, total_sleep=5) | |
| else: | |
| break | |
| kwargs['nextToken'] = next_token | |
| if __name__ == '__main__': | |
| run(**parse_args()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import json | |
| import sys | |
| import time | |
| from collections import namedtuple | |
| import boto3 | |
| import get_log_events | |
| cluster, status, task_type = sys.argv[1:] | |
| # instantiate ec2 client, and list all tasks for cluster | |
| client = boto3.client('ecs') | |
| print(f'listing "{status}" tasks for cluster "{cluster}"') | |
| response = client.list_tasks( | |
| cluster = cluster, | |
| desiredStatus = status, | |
| ) | |
| task_arns = response['taskArns'] | |
| FargateTask = namedtuple('FargateTask', [ | |
| 'started_at', 'task', 'task_definition', 'task_type', 'status', 'region', 'container_overrides', | |
| 'tags', | |
| 'log_stream', 'log_group', | |
| ]) | |
| def cloudwatch_log_stream(task_arn: str) -> str: | |
| "arn:aws:ecs:ap-southeast-2:732655618226:task/au-identity-history/f6b7c25e8f604797965b341bda3832ac" | |
| _arn, _aws, _ecs, _region, _account, task = task_arn.split(':') | |
| task = task.split('/', 1)[-1] | |
| cluster = task.split('/')[0] | |
| return f'/syslog/{cluster}', f'fargate/{task}' | |
| def parse_task_type(task: dict): | |
| for tag in task['tags']: | |
| if tag['key'] == 'task_type': | |
| return tag['value'] | |
| return 'unknown' | |
| def parse_task(task: dict) -> FargateTask: | |
| log_group, log_stream = cloudwatch_log_stream(task['taskArn']) | |
| return FargateTask(**{ | |
| 'started_at': task.get('startedAt', ''), | |
| 'task': task['taskArn'], | |
| 'task_definition': task['taskDefinitionArn'], | |
| 'task_type': parse_task_type(task), | |
| 'status': task['lastStatus'], | |
| 'region': task['clusterArn'].split(':')[3], | |
| 'container_overrides': {el['name']: el['value'] for el in task['overrides']['containerOverrides'][0]['environment']}, | |
| 'tags': {el['key']: el['value'] for el in task['tags']}, | |
| 'log_stream': log_stream, | |
| 'log_group': log_group, | |
| }) | |
| def print_task(task: FargateTask): | |
| return f'{task.status} | {task.started_at} | {task.task_type} | {task.container_overrides["REQUEST_ID"]} | {task.tags.get("client", ""):<30s} | {task.log_stream}' | |
| print(f'describing tasks:') | |
| print(json.dumps(task_arns, indent=2)) | |
| response = client.describe_tasks( | |
| cluster = cluster, | |
| tasks = task_arns, | |
| include = ['TAGS'], | |
| ) | |
| tasks = [] | |
| for raw_task in response['tasks']: | |
| task = parse_task(raw_task) | |
| if task.task_type != task_type: | |
| print('skipping non-activation task of type', task.task_type) | |
| continue | |
| tasks.append(task) | |
| print('\n' + '-'*50) | |
| for task in tasks: | |
| print(print_task(task)) | |
| print('\n' + '-'*50) | |
| for task in tasks: | |
| y = input(f'- {print_task(task)}\nget logs? y/n: ') | |
| if y.lower() != 'y': | |
| continue | |
| log_group, log_stream = cloudwatch_log_stream(task.task) | |
| get_log_events.run(log_group, log_stream, f'{task.container_overrides["REQUEST_ID"]}.log', task.region) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment