Skip to content

Instantly share code, notes, and snippets.

@zzhuolun
Last active April 3, 2024 15:06
Show Gist options
  • Select an option

  • Save zzhuolun/ae2a566e6a986c44a0f73e1698ead900 to your computer and use it in GitHub Desktop.

Select an option

Save zzhuolun/ae2a566e6a986c44a0f73e1698ead900 to your computer and use it in GitHub Desktop.
Basic IO operations (cp, cp -r, ls, mkdir) that are AWS s3/local path agnostic using smart_open.
import os
import json
from typing import List, Tuple
import boto3
from smart_open import open
from botocore.exceptions import ClientError
def load_json(path_to_json: str) -> dict:
with open(path_to_json) as file:
return json.load(file)
def save_json(data: dict, path_to_json: str) -> None:
with open(path_to_json, 'w') as file:
json.dump(data, file, indent=4)
def makedirs(directory: str) -> None:
if not directory.startswith('s3://'):
os.makedirs(directory, exist_ok=True)
def uri_to_bucket_name_and_key(uri: str) -> Tuple[str, str]:
assert uri.startswith('s3://')
path = uri.split('/')
bucket_name = path[2]
key = os.path.join(*path[3:]) if len(path) > 3 else ''
return bucket_name, key
def is_s3_file(uri: str) -> bool:
"""Check if the uri is a s3 object"""
bucket_name, key = uri_to_bucket_name_and_key(uri)
s3_client = boto3.client('s3')
try:
s3_client.head_object(Bucket=bucket_name, Key=key)
return True
except ClientError:
return False
def is_s3_dir(uri: str) -> bool:
"""Check if the uri is a s3 directory"""
bucket_name, key = uri_to_bucket_name_and_key(uri)
s3_client = boto3.client('s3')
if key != '':
key = key.rstrip('/') + '/'
response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=key, Delimiter='/')
if response['KeyCount'] != 0:
return True
else:
return False
def is_file(file_path: str) -> bool:
return os.path.isfile(file_path) if not file_path.startswith('s3://') else is_s3_file(file_path)
def s3_path_exists(s3_path: str) -> bool:
"""Check if an aws s3 link exists"""
assert s3_path.startswith('s3://')
if is_s3_file(s3_path) or is_s3_dir(s3_path):
return True
else:
return False
def list_local_dir(dir_path: str) -> Tuple[List[str], List[str]]:
folders = []
filenames = []
if os.path.isdir(dir_path):
for entry in os.listdir(dir_path):
if os.path.isdir(os.path.join(dir_path, entry)):
folders.append(entry)
else:
filenames.append(entry)
return folders, filenames
def list_s3_dir(dir_uri: str) -> Tuple[List[str], List[str]]:
client = boto3.client('s3')
bucket, key = uri_to_bucket_name_and_key(dir_uri)
if key != '':
key = key.rstrip('/') + '/'
response = client.list_objects_v2(Bucket=bucket, Prefix=key, Delimiter='/')
folders = [os.path.basename(os.path.abspath(prefix['Prefix'])) for prefix in response.get('CommonPrefixes', [])]
filenames = [os.path.basename(os.path.abspath(content['Key'])) for content in response.get('Contents', [])]
assert len(folders) + len(filenames) == response['KeyCount']
return folders, filenames
def list_dir(dir_path_or_uri: str) -> Tuple[List[str], List[str]]:
"""List the folders and files under the input path/uri."""
if not dir_path_or_uri.startswith('s3://'):
return list_local_dir(dir_path_or_uri)
else:
return list_s3_dir(dir_path_or_uri)
def smart_copy(src_path: str, dst_path: str, chunk_size: int = 1024) -> None:
"""Copy file from src_path to dst_path, using smart_open to support s3 paths."""
with open(dst_path, 'wb') as file_dst:
with open(src_path, 'rb') as file_src:
while True:
chunk = file_src.read(chunk_size)
if not chunk:
break
file_dst.write(chunk)
def smart_copytree(src_dir: str, dst_dir: str) -> None:
"""S3/local agnostic copy a directory from src_dir to dst_dir."""
makedirs(dst_dir)
folders, filenames = list_dir(src_dir)
for file in filenames:
smart_copy(os.path.join(src_dir, file), os.path.join(dst_dir, file))
for folder in folders:
smart_copytree(os.path.join(src_dir, folder), os.path.join(dst_dir, folder))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment