Created
December 3, 2025 22:54
-
-
Save philerooski/5863b5995dcc3ee00fb4821cc2c58f85 to your computer and use it in GitHub Desktop.
An updated version of the script used to load RDS snapshot data into Snowflake
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Load snapshot data from Snowflake stage into tables. | |
| This script dynamically lists all prefixes under the specified prefix_base | |
| from the stage, derives table names, creates tables using INFER_SCHEMA, | |
| and logs all operations to LOAD_LOG. | |
| See `python load_snapshot_data.py --help` | |
| """ | |
| import snowflake.connector | |
| from typing import Optional | |
| import sys | |
| # Tables that have a schema inferred as BINARY, but ought to be VARIANT | |
| TABLES_WITH_VARIANT_COLUMNS = { | |
| "ASYNCH_JOB_STATUS", | |
| "CHALLENGE_TEAM", | |
| "DATA_ACCESS_SUBMISSION_STATUS", | |
| "DISCUSSION_THREAD", | |
| "DOWNLOAD_ORDER", | |
| "EVALUATION", | |
| "EVALUATION_SUBMISSION", | |
| "MESSAGE_TO_USER", | |
| "MULTIPART_UPLOAD", | |
| "MULTIPART_UPLOAD_PART_STATE", | |
| "OAUTH_AUTHORIZATION_CODE", | |
| "QUIZ_RESPONSE", | |
| "RESEARCH_PROJECT", | |
| "STATISTICS_MONTHLY_STATUS", | |
| "SUBSTATUS_ANNOTATIONS_BLOB", | |
| "TABLE_STATUS", | |
| "USER_GROUP", | |
| "V2_WIKI_MARKDOWN", | |
| "V2_WIKI_OWNERS", | |
| "VERIFICATION_STATE", | |
| } | |
| # Tables that have actual binary/compressed data and should keep BINARY columns | |
| TABLES_WITH_TRUE_BINARY_DATA = { | |
| "ACCESS_REQUIREMENT_REVISION", | |
| "ACTIVITY", | |
| "CHALLENGE", | |
| "DATA_ACCESS_REQUEST", | |
| "DATA_ACCESS_SUBMISSION", | |
| "MEMBERSHIP_INVITATION_SUBMISSION", | |
| "MEMBERSHIP_REQUEST_SUBMISSION", | |
| "NODE_REVISION", | |
| "PERSONAL_ACCESS_TOKEN", | |
| "TEAM", | |
| "USER_PROFILE", | |
| "VERIFICATION_SUBMISSION", | |
| } | |
| def derive_data_type(prefix: str, prefix_base: str) -> str: | |
| """ | |
| Derive data type (table name) from prefix. | |
| Example: 'dev566/dev566.NODE/1/' -> 'NODE' | |
| Args: | |
| prefix: The prefix string to parse (e.g. from `s3_prefixes.txt`) | |
| prefix_base: The base prefix used in S3 keys/stage paths (e.g. 'dev566') | |
| Returns: | |
| The data type/table name extracted from the prefix | |
| """ | |
| # Split by '<prefix_base>.' and take the second part | |
| parts = prefix.split(f"{prefix_base}.", 1) | |
| if len(parts) < 2: | |
| raise ValueError(f"Invalid prefix format: {prefix}") | |
| # Split by '/' and take the first part | |
| data_type = parts[1].split("/")[0] | |
| return data_type | |
| def derive_stage_path(data_type: str, prefix_base: str) -> str: | |
| """ | |
| Derive stage path from data type. | |
| Args: | |
| data_type: The data type/table name | |
| prefix_base: The base prefix used in S3 keys/stage paths | |
| Returns: | |
| Stage path string formed as '<prefix_base>.<DATA_TYPE>/1/' | |
| """ | |
| return f"{prefix_base}.{data_type}/1/" | |
| def log_operation( | |
| cursor, | |
| prefix: str, | |
| data_type: str, | |
| stage_path: str, | |
| phase: str, | |
| status: str, | |
| sql_text: Optional[str] = None, | |
| error_msg: Optional[str] = None, | |
| ): | |
| """ | |
| Log an operation to the LOAD_LOG table. | |
| Args: | |
| cursor: Snowflake cursor for executing queries | |
| prefix: The prefix being processed | |
| data_type: The derived data type | |
| stage_path: The stage path | |
| phase: The current phase (START, CREATE_TABLE, COPY_DATA, ERROR) | |
| status: Status of the operation (OK, RUN, FAIL) | |
| sql_text: Optional SQL text being executed | |
| error_msg: Optional error message | |
| """ | |
| log_sql = """ | |
| INSERT INTO LOAD_LOG (PREFIX, DATA_TYPE, STAGE_PATH, PHASE, STATUS, SQL_TEXT, ERROR_MESSAGE) | |
| VALUES (%s, %s, %s, %s, %s, %s, %s) | |
| """ | |
| cursor.execute( | |
| log_sql, (prefix, data_type, stage_path, phase, status, sql_text, error_msg) | |
| ) | |
| def create_table_from_schema(cursor, data_type: str, stage_path: str, stage_name: str) -> str: | |
| """ | |
| Generate CREATE TABLE SQL using INFER_SCHEMA. | |
| For tables in TABLES_WITH_TRUE_BINARY_DATA, converts VARIANT columns to BINARY | |
| to preserve binary/compressed data. | |
| Args: | |
| cursor: Snowflake cursor for executing queries | |
| data_type: The table name to create | |
| stage_path: The stage path containing the data files | |
| stage_name: The Snowflake stage name | |
| Returns: | |
| The SQL statement that was generated | |
| """ | |
| # First, infer the schema | |
| cursor.execute( | |
| f""" | |
| SELECT ARRAY_AGG(OBJECT_CONSTRUCT(*)) | |
| FROM TABLE( | |
| INFER_SCHEMA( | |
| LOCATION => '@{stage_name}/{stage_path}', | |
| FILE_FORMAT => 'parquet_ff' | |
| ) | |
| ) | |
| """ | |
| ) | |
| schema_array = cursor.fetchone()[0] | |
| import json | |
| if isinstance(schema_array, str): | |
| schema_array = json.loads(schema_array) | |
| ### This was unnessecary when loading prod data ### | |
| # Check if this table should convert BINARY to VARIANT | |
| # convert_binary_to_variant = data_type in TABLES_WITH_VARIANT_COLUMNS | |
| convert_binary_to_variant = False | |
| # Build column definitions manually | |
| column_defs = [] | |
| for column_def in schema_array: | |
| col_name = column_def.get("COLUMN_NAME") or column_def.get("column_name") | |
| col_type = column_def.get("TYPE") or column_def.get("type") | |
| nullable = column_def.get("NULLABLE") or column_def.get("nullable", True) | |
| # Convert VARIANT to BINARY for tables with true binary data | |
| if col_type == "VARIANT" and convert_binary_to_variant: | |
| col_type = "BINARY" | |
| # Build column definition | |
| null_clause = "" if nullable else " NOT NULL" | |
| column_defs.append(f'"{col_name}" {col_type}{null_clause}') | |
| columns_sql = ",\n ".join(column_defs) | |
| create_sql = f""" | |
| CREATE OR REPLACE TABLE {data_type} ( | |
| {columns_sql} | |
| ) | |
| """ | |
| return create_sql | |
| def copy_data_into_table(cursor, data_type: str, stage_path: str, stage_name: str) -> str: | |
| """ | |
| Generate and execute COPY INTO SQL to load data from stage into table. | |
| Args: | |
| cursor: Snowflake cursor for executing queries | |
| data_type: The table name to copy data into | |
| stage_path: The stage path containing the parquet files | |
| stage_name: The Snowflake stage name | |
| Returns: | |
| The SQL statement that was generated | |
| """ | |
| copy_sql = f""" | |
| COPY INTO {data_type} | |
| FROM @{stage_name}/{stage_path} | |
| FILE_FORMAT = (TYPE = PARQUET BINARY_AS_TEXT = FALSE) | |
| MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE | |
| PATTERN = '^.*\\.parquet$' | |
| """ | |
| return copy_sql | |
| def setup_temp_tables(cursor): | |
| """ | |
| Create temporary tables needed for the snapshot loading process. | |
| Creates: | |
| - LOAD_LOG: Logs all operations and errors | |
| Args: | |
| cursor: Snowflake cursor for executing queries | |
| """ | |
| print("Setting up logging table...") | |
| # Create LOAD_LOG table | |
| cursor.execute( | |
| """ | |
| CREATE OR REPLACE TABLE LOAD_LOG ( | |
| PREFIX STRING, | |
| DATA_TYPE STRING, | |
| STAGE_PATH STRING, | |
| PHASE STRING, -- e.g. 'START', 'CREATE_TABLE', 'COPY', 'ERROR' | |
| STATUS STRING, -- e.g. 'OK', 'RUN', 'FAILED' | |
| SQL_TEXT STRING, -- SQL we attempted to run (if applicable) | |
| ERROR_MESSAGE STRING, -- populated on failure | |
| LOG_TS TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP() | |
| ) | |
| """ | |
| ) | |
| print(" ✓ Created LOAD_LOG table") | |
| def list_prefixes_from_stage(cursor, prefix_base: str, stage_name: str) -> list: | |
| """ | |
| List all prefixes under the given prefix_base from the stage. | |
| Queries the stage to find all directories matching the pattern: | |
| {prefix_base}/{prefix_base}.{data_type}/1/ | |
| Args: | |
| cursor: Snowflake cursor for executing queries | |
| prefix_base: Base prefix to search under (e.g. 'dev566') | |
| stage_name: The Snowflake stage name | |
| Returns: | |
| List of prefix strings formatted as '{prefix_base}/{prefix_base}.{data_type}/1/' | |
| """ | |
| print(f"Listing prefixes under {prefix_base}/ in stage...") | |
| # List all files/directories in the stage under prefix_base | |
| cursor.execute( | |
| f""" | |
| LIST @{stage_name}/{prefix_base} | |
| """ | |
| ) | |
| results = cursor.fetchall() | |
| prefixes = set() | |
| # Parse the results to extract unique prefix patterns | |
| # Results from LIST contain columns: name, size, md5, last_modified | |
| for row in results: | |
| file_path = row[0] # The 'name' column | |
| # Extract the prefix pattern: {prefix_base}/{prefix_base}.{data_type}/1/ | |
| # Example file_path: 's3://synapse-rds-snapshots-dev/test-export/dev566/dev566.NODE/1/part-00000-dcd6d72f-8ee9-400c-94ce-f1f66644c5d3-c000.gz.parquet' | |
| # We want to extract: 'dev566/dev566.NODE/1/' | |
| # The file_path from LIST includes the full S3 path | |
| # Find the prefix_base in the path and extract from there | |
| if f"/{prefix_base}/{prefix_base}." in file_path: | |
| # Find where our prefix pattern starts | |
| idx = file_path.find(f"/{prefix_base}/{prefix_base}.") | |
| # Extract everything after the leading slash | |
| relevant_path = file_path[idx + 1:] | |
| # Split and reconstruct the prefix pattern | |
| parts = relevant_path.split('/') | |
| if len(parts) >= 3 and parts[2] == '1': | |
| # Reconstruct the prefix: prefix_base/prefix_base.DATA_TYPE/1/ | |
| prefix = f"{parts[0]}/{parts[1]}/{parts[2]}/" | |
| prefixes.add(prefix) | |
| prefix_list = sorted(list(prefixes)) | |
| print(f" ✓ Found {len(prefix_list)} unique prefixes") | |
| return prefix_list | |
| def process_prefix( | |
| cursor, prefix: str, prefix_base: str, stage_name: str | |
| ) -> Optional[bool]: | |
| """ | |
| Process a single prefix: derive data type, create table, and log operations. | |
| Args: | |
| cursor: Snowflake cursor for executing queries | |
| prefix: The prefix to process (listed from stage) | |
| prefix_base: Base prefix used to derive the stage path and table name | |
| stage_name: The Snowflake stage name | |
| Returns: | |
| True if processing succeeded, False if an error occurred | |
| """ | |
| data_type = None | |
| stage_path = None | |
| current_phase = "INIT" | |
| current_sql = None | |
| try: | |
| # Derive data type from prefix | |
| current_phase = "PARSE_PREFIX" | |
| # derive using the provided prefix_base | |
| data_type = derive_data_type(prefix, prefix_base=prefix_base) | |
| stage_path = derive_stage_path(data_type, prefix_base=prefix_base) | |
| # Log: start processing this prefix | |
| log_operation(cursor, prefix, data_type, stage_path, "START", "OK") | |
| # Generate CREATE TABLE SQL | |
| create_sql = create_table_from_schema(cursor, data_type, stage_path, stage_name) | |
| # Log that we're about to run CREATE TABLE | |
| log_operation( | |
| cursor, prefix, data_type, stage_path, "CREATE_TABLE", "RUN", create_sql | |
| ) | |
| # Execute CREATE TABLE | |
| current_phase = "CREATE_TABLE" | |
| current_sql = create_sql | |
| cursor.execute(create_sql) | |
| # Log success | |
| log_operation(cursor, prefix, data_type, stage_path, "CREATE_TABLE", "OK") | |
| # Generate COPY INTO SQL | |
| copy_sql = copy_data_into_table(cursor, data_type, stage_path, stage_name) | |
| # Log that we're about to run COPY INTO | |
| log_operation( | |
| cursor, prefix, data_type, stage_path, "COPY_DATA", "RUN", copy_sql | |
| ) | |
| # Execute COPY INTO | |
| current_phase = "COPY_DATA" | |
| current_sql = copy_sql | |
| cursor.execute(copy_sql) | |
| # Log success | |
| log_operation(cursor, prefix, data_type, stage_path, "COPY_DATA", "OK") | |
| return True | |
| except Exception as e: | |
| # Log error with full details | |
| error_msg = f"{type(e).__name__}: {str(e)}" | |
| # Use the tracked phase for accurate error reporting | |
| log_operation( | |
| cursor, | |
| prefix, | |
| data_type if data_type else "UNKNOWN", | |
| stage_path if stage_path else "UNKNOWN", | |
| current_phase, | |
| "FAIL", | |
| sql_text=current_sql, | |
| error_msg=error_msg, | |
| ) | |
| print( | |
| f"Error in {current_phase} for prefix {prefix}: {error_msg}", | |
| file=sys.stderr, | |
| ) | |
| # Don't re-raise - continue processing other prefixes | |
| return False | |
| def load_snapshot_data(prefix_base: str, stage_name: str, database=None, schema=None): | |
| """ | |
| Main function to load snapshot data from stage into tables. | |
| Args: | |
| prefix_base: Base prefix string used to derive stage paths. This value | |
| is supplied from the CLI `--prefix-base` argument and defaults to | |
| 'dev566' there. | |
| stage_name: The Snowflake stage name (from CLI `--stage-name` argument) | |
| database: Optional database name to use | |
| schema: Optional schema name to use | |
| """ | |
| # Connect to Snowflake | |
| conn = snowflake.connector.connect() | |
| cursor = conn.cursor() | |
| try: | |
| # Set database and schema if provided | |
| if database: | |
| cursor.execute(f"USE DATABASE {database}") | |
| if schema: | |
| cursor.execute(f"USE SCHEMA {schema}") | |
| # Set up temporary tables (LOAD_LOG) | |
| setup_temp_tables(cursor) | |
| conn.commit() | |
| # List all prefixes from the stage under prefix_base | |
| prefixes = list_prefixes_from_stage(cursor, prefix_base, stage_name) | |
| print(f"Processing {len(prefixes)} prefixes...") | |
| # Track successes and failures | |
| success_count = 0 | |
| failure_count = 0 | |
| # Process each prefix | |
| for prefix in prefixes: | |
| print(f"Processing prefix: {prefix}") | |
| result = process_prefix(cursor, prefix, prefix_base, stage_name) | |
| conn.commit() # Commit after each prefix | |
| if result is True: | |
| success_count += 1 | |
| elif result is False: | |
| failure_count += 1 | |
| print(f"\n{'='*60}") | |
| print("Processing complete!") | |
| print(f" ✓ Successful: {success_count}") | |
| print(f" ✗ Failed: {failure_count}") | |
| print(f" Total: {len(prefixes)}") | |
| print(f"{'='*60}") | |
| except Exception as e: | |
| print(f"Fatal error: {e}", file=sys.stderr) | |
| conn.rollback() | |
| raise | |
| finally: | |
| cursor.close() | |
| conn.close() | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| description="Load snapshot data from Snowflake stage" | |
| ) | |
| parser.add_argument( | |
| "--prefix-base", | |
| dest="prefix_base", | |
| default="dev566", | |
| help="Base S3 prefix (relative to the Snowflake stage) of Parquet data (default: dev566)", | |
| ) | |
| parser.add_argument( | |
| "--stage-name", | |
| dest="stage_name", | |
| default="RDS_SNAPSHOT_POC_STAGE", | |
| help="Snowflake stage name (default: RDS_SNAPSHOT_POC_STAGE)", | |
| ) | |
| args = parser.parse_args() | |
| load_snapshot_data( | |
| prefix_base=args.prefix_base, | |
| stage_name=args.stage_name | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment