Last active
April 3, 2024 15:06
-
-
Save zzhuolun/ae2a566e6a986c44a0f73e1698ead900 to your computer and use it in GitHub Desktop.
Basic IO operations (cp, cp -r, ls, mkdir) that are AWS s3/local path agnostic using smart_open.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import json | |
| from typing import List, Tuple | |
| import boto3 | |
| from smart_open import open | |
| from botocore.exceptions import ClientError | |
| def load_json(path_to_json: str) -> dict: | |
| with open(path_to_json) as file: | |
| return json.load(file) | |
| def save_json(data: dict, path_to_json: str) -> None: | |
| with open(path_to_json, 'w') as file: | |
| json.dump(data, file, indent=4) | |
| def makedirs(directory: str) -> None: | |
| if not directory.startswith('s3://'): | |
| os.makedirs(directory, exist_ok=True) | |
| def uri_to_bucket_name_and_key(uri: str) -> Tuple[str, str]: | |
| assert uri.startswith('s3://') | |
| path = uri.split('/') | |
| bucket_name = path[2] | |
| key = os.path.join(*path[3:]) if len(path) > 3 else '' | |
| return bucket_name, key | |
| def is_s3_file(uri: str) -> bool: | |
| """Check if the uri is a s3 object""" | |
| bucket_name, key = uri_to_bucket_name_and_key(uri) | |
| s3_client = boto3.client('s3') | |
| try: | |
| s3_client.head_object(Bucket=bucket_name, Key=key) | |
| return True | |
| except ClientError: | |
| return False | |
| def is_s3_dir(uri: str) -> bool: | |
| """Check if the uri is a s3 directory""" | |
| bucket_name, key = uri_to_bucket_name_and_key(uri) | |
| s3_client = boto3.client('s3') | |
| if key != '': | |
| key = key.rstrip('/') + '/' | |
| response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=key, Delimiter='/') | |
| if response['KeyCount'] != 0: | |
| return True | |
| else: | |
| return False | |
| def is_file(file_path: str) -> bool: | |
| return os.path.isfile(file_path) if not file_path.startswith('s3://') else is_s3_file(file_path) | |
| def s3_path_exists(s3_path: str) -> bool: | |
| """Check if an aws s3 link exists""" | |
| assert s3_path.startswith('s3://') | |
| if is_s3_file(s3_path) or is_s3_dir(s3_path): | |
| return True | |
| else: | |
| return False | |
| def list_local_dir(dir_path: str) -> Tuple[List[str], List[str]]: | |
| folders = [] | |
| filenames = [] | |
| if os.path.isdir(dir_path): | |
| for entry in os.listdir(dir_path): | |
| if os.path.isdir(os.path.join(dir_path, entry)): | |
| folders.append(entry) | |
| else: | |
| filenames.append(entry) | |
| return folders, filenames | |
| def list_s3_dir(dir_uri: str) -> Tuple[List[str], List[str]]: | |
| client = boto3.client('s3') | |
| bucket, key = uri_to_bucket_name_and_key(dir_uri) | |
| if key != '': | |
| key = key.rstrip('/') + '/' | |
| response = client.list_objects_v2(Bucket=bucket, Prefix=key, Delimiter='/') | |
| folders = [os.path.basename(os.path.abspath(prefix['Prefix'])) for prefix in response.get('CommonPrefixes', [])] | |
| filenames = [os.path.basename(os.path.abspath(content['Key'])) for content in response.get('Contents', [])] | |
| assert len(folders) + len(filenames) == response['KeyCount'] | |
| return folders, filenames | |
| def list_dir(dir_path_or_uri: str) -> Tuple[List[str], List[str]]: | |
| """List the folders and files under the input path/uri.""" | |
| if not dir_path_or_uri.startswith('s3://'): | |
| return list_local_dir(dir_path_or_uri) | |
| else: | |
| return list_s3_dir(dir_path_or_uri) | |
| def smart_copy(src_path: str, dst_path: str, chunk_size: int = 1024) -> None: | |
| """Copy file from src_path to dst_path, using smart_open to support s3 paths.""" | |
| with open(dst_path, 'wb') as file_dst: | |
| with open(src_path, 'rb') as file_src: | |
| while True: | |
| chunk = file_src.read(chunk_size) | |
| if not chunk: | |
| break | |
| file_dst.write(chunk) | |
| def smart_copytree(src_dir: str, dst_dir: str) -> None: | |
| """S3/local agnostic copy a directory from src_dir to dst_dir.""" | |
| makedirs(dst_dir) | |
| folders, filenames = list_dir(src_dir) | |
| for file in filenames: | |
| smart_copy(os.path.join(src_dir, file), os.path.join(dst_dir, file)) | |
| for folder in folders: | |
| smart_copytree(os.path.join(src_dir, folder), os.path.join(dst_dir, folder)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment