Skip to content

Instantly share code, notes, and snippets.

@tmck-code
Last active October 28, 2025 22:20
Show Gist options
  • Select an option

  • Save tmck-code/dbb3a63751f96f97f0842a432c395899 to your computer and use it in GitHub Desktop.

Select an option

Save tmck-code/dbb3a63751f96f97f0842a432c395899 to your computer and use it in GitHub Desktop.
AWS Utilities
#!/usr/bin/env python3
import argparse
from dataclasses import dataclass
import time
from typing import ClassVar
from datetime import datetime
import boto3
import json
from pygments import highlight
from pygments.lexers import JsonLexer
from pygments.formatters import TerminalTrueColorFormatter as Formatter
from pygments.styles import get_style_by_name
def ppd(d, indent=None): print(highlight(json.dumps(d, indent=indent, default=str), JsonLexer(), Formatter(style=get_style_by_name('material'))).strip())
def ppj(j, indent=2): ppd(json.loads(j), indent=indent)
@dataclass
class AthenaQuery:
database: str
work_group: str
results_bucket: str
query_key: str
catalog: str = 'AwsDataCatalog'
s3_acl_option: ClassVar[str] = 'BUCKET_OWNER_FULL_CONTROL'
def __post_init__(self):
self.client = boto3.client('athena')
@property
def output_location(self) -> str:
return f's3://{self.results_bucket}/customerapi/au-identity-history/query_results/'
def execute(self, query: str) -> str:
return self.client.start_query_execution(
QueryString=query,
QueryExecutionContext={
'Database': self.database,
'Catalog': self.catalog,
},
ResultConfiguration={
'OutputLocation': self.output_location,
'AclConfiguration': {
'S3AclOption': self.s3_acl_option,
}
},
WorkGroup=self.work_group
)['QueryExecutionId']
def query_string(self, kwargs: dict) -> None:
return open(f'{self.query_key}.sql').read().format(**kwargs)
def poll_query(self, query_execution_id: str) -> dict:
response = self.client.get_query_execution(
QueryExecutionId=query_execution_id
)
while True:
response = self.client.get_query_execution(QueryExecutionId=query_execution_id)
ppd({
'timestamp': datetime.now().astimezone().isoformat(),
'QueryExecutionId': query_execution_id,
'State': response['QueryExecution']['Status']['State']
})
if response['QueryExecution']['Status']['State'] not in ['QUEUED', 'RUNNING']:
break
time.sleep(10)
def get_query_results(self, query_execution_id: str) -> dict:
return self.client.get_query_results(QueryExecutionId=query_execution_id)
def _query_results_to_table(row_id: str, query_results: dict) -> list:
table = []
for i, raw_row in enumerate(query_results['ResultSet']['Rows']):
if i == 0:
row = ['__row_id__']
else:
row = [row_id]
for cell in raw_row['Data']:
row.append(cell['VarCharValue'])
table.append(row)
return table
def query(self, row_id: str, query_kwargs: dict) -> dict:
query_execution_id = self.execute(self.query_string(query_kwargs))
self.poll_query(query_execution_id)
results = self.get_query_results(query_execution_id)
table = AthenaQuery._query_results_to_table(row_id, results)
return table, results
def parse_args():
parser = argparse.ArgumentParser(description='Run Athena Queries')
parser.add_argument('--database', type=str, required=True, help='Athena Database')
parser.add_argument('--work-group', type=str, required=True, help='Athena Work Group')
parser.add_argument('--results-bucket', type=str, required=True, help='S3 Bucket for Athena Query Results')
parser.add_argument('--query-key', type=str, required=True, help='Key for the Athena Query')
parser.add_argument('--queries', type=json.loads, required=True, help='Dataset IDs')
parser.add_argument('--ofpath', type=str, required=True, help='Output File')
return parser.parse_args().__dict__
def run():
args = parse_args()
q = AthenaQuery(
args['database'],
args['work_group'],
args['results_bucket'],
args['query_key'],
)
ppd({'timestamp': datetime.now().astimezone().isoformat(), 'msg': 'Starting Query'})
combined_table = []
for i, query_kwargs in enumerate(args['queries']):
table, results = q.query(
row_id=query_kwargs['__row_id__'],
query_kwargs={
'new_dataset_id': query_kwargs['new_dataset_id'],
'old_dataset_id': query_kwargs['old_dataset_id'],
}
)
ppd({
'timestamp': datetime.now().astimezone().isoformat(),
'msg': 'received query results',
'query_kwargs': query_kwargs,
'results': results
})
if i == 0:
combined_table.append(table[0])
combined_table.extend(table[1:])
for row in combined_table:
ppd(row)
with open(args['ofpath'], 'w') as ostream:
for row in combined_table:
print('\t'.join(list(map(str, row))), file=ostream)
if __name__ == '__main__':
run()
#!/usr/bin/env python3
import argparse
from itertools import count
import os
import json
import time
import boto3
client = boto3.client('logs')
SLEEP_TIME = 5
class STATUSES:
PENDING = ['Scheduled', 'Running', 'Timeout', 'Unknown']
FINISHED = ['Cancelled', 'Complete', 'Failed']
def run(log_group_name, query, window, ofpath):
print('starting query')
response = client.start_query(
logGroupName = log_group_name,
startTime = int(time.time()-window),
endTime = int(time.time()),
queryString = query
)
print(f'{response=}')
query_id = response['queryId']
for i in count():
print(f'polling for result, {i=}, {SLEEP_TIME=}')
response = client.get_query_results(queryId=query_id)
print(f'{response["statistics"]=} {response["status"]=}')
if (status := response['status']) in STATUSES.PENDING:
print(status)
time.sleep(SLEEP_TIME)
elif status in STATUSES.FINISHED:
break
results = response['results']
print(f'writing results to file {ofpath=}')
with open(os.path.basename(ofpath), 'w') as ostream:
for result in results:
r = {}
for values in result:
r[values['field']] = values['value']
r['@message'] = json.loads(r['@message'])
print(json.dumps(r), file=ostream)
print(f'{response["statistics"]=} {response["status"]=}')
def parse_args():
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('log_group_name', type=str, help='the logGroupName')
parser.add_argument('window', type=int, default=86400, help='time window of query, in seconds')
parser.add_argument('ofpath', type=str, help='the output file path')
parser.add_argument('query', type=str, help='the insights query to run')
args = parser.parse_args().__dict__
print(args)
return args
if __name__ == '__main__':
run(**parse_args())
#!/usr/bin/env python3
'''
This is a helper script to fetch & poll a specific AWS CloudWatch log stream.
- You'll need the logGroupName, logStreamName, and the region of the log stream.
- You can also provide the URL of the log stream, and the script will parse it for you.
Usage:
-logGroupName the logGroupName
-logStreamName the logStreamName
-region the region of the logStream
-ofpath the output file path
--poll enable poll mode (continuously wait for new logs)
--tail enable tail mode (start from the last log events)
--desc consume logs in reverse order (recent -> old)
'''
import argparse
from datetime import datetime
from itertools import cycle, islice
import json
import os, sys
import time
from typing import Tuple
from urllib.parse import unquote, urlparse, parse_qs
import boto3
import botocore
from pygments import highlight
from pygments.lexers import JsonLexer
from pygments.formatters import TerminalTrueColorFormatter as Formatter
from pygments.styles import get_style_by_name
def ppd(d, indent=2, style='material'): print(highlight(json.dumps(d, indent=indent, default=str), JsonLexer(), Formatter(style=get_style_by_name(style))).strip(), flush=True)
if os.environ.get('NO_STYLE'):
def ppd(d, indent=2, style='material'): print(json.dumps(d, indent=indent, default=str), flush=True)
def parse_args():
parser = argparse.ArgumentParser(description='Get all logs from a log stream, via the log group/stream names or the cloudwatch log stream URL')
parser.add_argument('-logGroupName', type=str, help='the logGroupName', required=False)
parser.add_argument('-logStreamName', type=str, help='the logStreamName', required=False)
parser.add_argument('-logStreamURL', type=str, help='the URL of the logStream', required=False)
parser.add_argument('-region', type=str, help='the region of the logStream', required=False)
parser.add_argument('-ofpath', type=str, help='the output file path')
parser.add_argument('--poll', action='store_true', help='enable poll mode')
parser.add_argument('--tail', action='store_true', help='enable tail mode')
parser.add_argument('--desc', action='store_true', help='consume logs recent -> old')
args = parser.parse_args().__dict__
if args['logStreamURL']:
region, log_group, log_stream = parse_cloudwatch_url(args['logStreamURL'])
ppd({'url': args['logStreamURL'], 'region': region, 'log_group': log_group, 'log_stream': log_stream}, indent=None)
args.update({'region': region, 'logGroupName': log_group, 'logStreamName': log_stream})
del args['logStreamURL']
ppd(args, indent=None)
return args
def parse_cloudwatch_url(url: str) -> Tuple[str, str]:
if '$' in url:
print(unquote(url.replace('$', '%')).split('/'))
scheme, _, host, service, params, _lg, log_group, _le, log_stream = unquote(url.replace('$', '%')).split('/')
else:
print(unquote(url).split('/'))
scheme, _, host, service, params, _lg, log_group, _le, log_stream = url.split('/')
region = parse_qs(urlparse(url).query)['region'][0]
return region, unquote(log_group), unquote(log_stream)
def write_events(response, ostream):
for i, event in enumerate(response['events']):
if i == 0:
print('\r', flush=True, end='')
try:
event['message'] = json.loads(event['message'])
except ValueError:
pass
print(json.dumps(event), file=ostream, flush=True)
ppd(event['message'], indent=None, )
return len(response['events'])
def get_log_times(client, logGroupName, logStreamName):
while True:
response = client.describe_log_streams(
logGroupName=logGroupName, logStreamNamePrefix=logStreamName,
)
stream = response['logStreams'][0]
if 'firstEventTimestamp' not in stream:
time.sleep(5)
continue
return stream['firstEventTimestamp'], stream.get('lastEventTimestamp')
def print_sleep_animation(steps: int, rotations: int = 20, total_sleep: int = 10):
# account for the time.sleep() at the end of the loop, by dividing by (steps-1)
delay = (total_sleep/(steps-1))/rotations
for i in range(steps):
for el in islice(cycle(['-', '\\', '|', '/']), rotations):
print('\r' + '.'*i + el, end='', file=sys.stderr)
time.sleep(delay)
print('\r' + '.'*i + '↩', file=sys.stderr, end='', flush=True)
# time.sleep(delay)
# clear the line
print('\r'+ ' '*(steps+1), end='', flush=True)
# move the print cursor to the beginning of the line, getting ready to
# either to print another animation, or to write the next logs
print('\r', end='', flush=True)
def run(logGroupName, logStreamName, ofpath, region, poll=False, tail=False, desc=True):
kwargs = {
'logGroupName': logGroupName,
'logStreamName': logStreamName,
# 'startFromHead': desc,
'startFromHead': True,
'limit': 1_000,
}
client = boto3.client('logs', region_name=region)
try:
startTime, endTime = get_log_times(client, logGroupName, logStreamName)
if tail:
kwargs['startTime'] = int(endTime)-(120*1_000) # 120 seconds ago
else:
kwargs['startTime'] = startTime
ppd({
'log_start': datetime.fromtimestamp(int(startTime)/1_000),
'log_end': endTime and datetime.fromtimestamp(int(endTime)/1_000),
'starting_from': datetime.fromtimestamp(int(kwargs['startTime'])/1_000)
}, indent=None)
except botocore.exceptions.NoCredentialsError as e:
print('❌ Error: Unable to locate AWS credentials!', file=sys.stderr)
sys.exit(1)
ppd(kwargs, indent=None)
time.sleep(1)
with open(ofpath, 'w') as ostream:
while True:
response = client.get_log_events(**kwargs)
n = write_events(response, ostream)
if desc:
next_token = response.get('nextBackwardToken')
else:
next_token = response.get('nextForwardToken')
# print(kwargs.get('nextToken', ''), next_token, file=sys.stderr)
if kwargs.get('nextToken', '') == response['nextForwardToken']:
if poll:
print_sleep_animation(steps=10, rotations=12, total_sleep=5)
else:
break
kwargs['nextToken'] = next_token
if __name__ == '__main__':
run(**parse_args())
#!/usr/bin/env python3
import json
import sys
import time
from collections import namedtuple
import boto3
import get_log_events
cluster, status, task_type = sys.argv[1:]
# instantiate ec2 client, and list all tasks for cluster
client = boto3.client('ecs')
print(f'listing "{status}" tasks for cluster "{cluster}"')
response = client.list_tasks(
cluster = cluster,
desiredStatus = status,
)
task_arns = response['taskArns']
FargateTask = namedtuple('FargateTask', [
'started_at', 'task', 'task_definition', 'task_type', 'status', 'region', 'container_overrides',
'tags',
'log_stream', 'log_group',
])
def cloudwatch_log_stream(task_arn: str) -> str:
"arn:aws:ecs:ap-southeast-2:732655618226:task/au-identity-history/f6b7c25e8f604797965b341bda3832ac"
_arn, _aws, _ecs, _region, _account, task = task_arn.split(':')
task = task.split('/', 1)[-1]
cluster = task.split('/')[0]
return f'/syslog/{cluster}', f'fargate/{task}'
def parse_task_type(task: dict):
for tag in task['tags']:
if tag['key'] == 'task_type':
return tag['value']
return 'unknown'
def parse_task(task: dict) -> FargateTask:
log_group, log_stream = cloudwatch_log_stream(task['taskArn'])
return FargateTask(**{
'started_at': task.get('startedAt', ''),
'task': task['taskArn'],
'task_definition': task['taskDefinitionArn'],
'task_type': parse_task_type(task),
'status': task['lastStatus'],
'region': task['clusterArn'].split(':')[3],
'container_overrides': {el['name']: el['value'] for el in task['overrides']['containerOverrides'][0]['environment']},
'tags': {el['key']: el['value'] for el in task['tags']},
'log_stream': log_stream,
'log_group': log_group,
})
def print_task(task: FargateTask):
return f'{task.status} | {task.started_at} | {task.task_type} | {task.container_overrides["REQUEST_ID"]} | {task.tags.get("client", ""):<30s} | {task.log_stream}'
print(f'describing tasks:')
print(json.dumps(task_arns, indent=2))
response = client.describe_tasks(
cluster = cluster,
tasks = task_arns,
include = ['TAGS'],
)
tasks = []
for raw_task in response['tasks']:
task = parse_task(raw_task)
if task.task_type != task_type:
print('skipping non-activation task of type', task.task_type)
continue
tasks.append(task)
print('\n' + '-'*50)
for task in tasks:
print(print_task(task))
print('\n' + '-'*50)
for task in tasks:
y = input(f'- {print_task(task)}\nget logs? y/n: ')
if y.lower() != 'y':
continue
log_group, log_stream = cloudwatch_log_stream(task.task)
get_log_events.run(log_group, log_stream, f'{task.container_overrides["REQUEST_ID"]}.log', task.region)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment