Skip to content

Instantly share code, notes, and snippets.

@philerooski
Created December 3, 2025 22:54
Show Gist options
  • Select an option

  • Save philerooski/5863b5995dcc3ee00fb4821cc2c58f85 to your computer and use it in GitHub Desktop.

Select an option

Save philerooski/5863b5995dcc3ee00fb4821cc2c58f85 to your computer and use it in GitHub Desktop.
An updated version of the script used to load RDS snapshot data into Snowflake
"""
Load snapshot data from Snowflake stage into tables.
This script dynamically lists all prefixes under the specified prefix_base
from the stage, derives table names, creates tables using INFER_SCHEMA,
and logs all operations to LOAD_LOG.
See `python load_snapshot_data.py --help`
"""
import snowflake.connector
from typing import Optional
import sys
# Tables that have a schema inferred as BINARY, but ought to be VARIANT
TABLES_WITH_VARIANT_COLUMNS = {
"ASYNCH_JOB_STATUS",
"CHALLENGE_TEAM",
"DATA_ACCESS_SUBMISSION_STATUS",
"DISCUSSION_THREAD",
"DOWNLOAD_ORDER",
"EVALUATION",
"EVALUATION_SUBMISSION",
"MESSAGE_TO_USER",
"MULTIPART_UPLOAD",
"MULTIPART_UPLOAD_PART_STATE",
"OAUTH_AUTHORIZATION_CODE",
"QUIZ_RESPONSE",
"RESEARCH_PROJECT",
"STATISTICS_MONTHLY_STATUS",
"SUBSTATUS_ANNOTATIONS_BLOB",
"TABLE_STATUS",
"USER_GROUP",
"V2_WIKI_MARKDOWN",
"V2_WIKI_OWNERS",
"VERIFICATION_STATE",
}
# Tables that have actual binary/compressed data and should keep BINARY columns
TABLES_WITH_TRUE_BINARY_DATA = {
"ACCESS_REQUIREMENT_REVISION",
"ACTIVITY",
"CHALLENGE",
"DATA_ACCESS_REQUEST",
"DATA_ACCESS_SUBMISSION",
"MEMBERSHIP_INVITATION_SUBMISSION",
"MEMBERSHIP_REQUEST_SUBMISSION",
"NODE_REVISION",
"PERSONAL_ACCESS_TOKEN",
"TEAM",
"USER_PROFILE",
"VERIFICATION_SUBMISSION",
}
def derive_data_type(prefix: str, prefix_base: str) -> str:
"""
Derive data type (table name) from prefix.
Example: 'dev566/dev566.NODE/1/' -> 'NODE'
Args:
prefix: The prefix string to parse (e.g. from `s3_prefixes.txt`)
prefix_base: The base prefix used in S3 keys/stage paths (e.g. 'dev566')
Returns:
The data type/table name extracted from the prefix
"""
# Split by '<prefix_base>.' and take the second part
parts = prefix.split(f"{prefix_base}.", 1)
if len(parts) < 2:
raise ValueError(f"Invalid prefix format: {prefix}")
# Split by '/' and take the first part
data_type = parts[1].split("/")[0]
return data_type
def derive_stage_path(data_type: str, prefix_base: str) -> str:
"""
Derive stage path from data type.
Args:
data_type: The data type/table name
prefix_base: The base prefix used in S3 keys/stage paths
Returns:
Stage path string formed as '<prefix_base>.<DATA_TYPE>/1/'
"""
return f"{prefix_base}.{data_type}/1/"
def log_operation(
cursor,
prefix: str,
data_type: str,
stage_path: str,
phase: str,
status: str,
sql_text: Optional[str] = None,
error_msg: Optional[str] = None,
):
"""
Log an operation to the LOAD_LOG table.
Args:
cursor: Snowflake cursor for executing queries
prefix: The prefix being processed
data_type: The derived data type
stage_path: The stage path
phase: The current phase (START, CREATE_TABLE, COPY_DATA, ERROR)
status: Status of the operation (OK, RUN, FAIL)
sql_text: Optional SQL text being executed
error_msg: Optional error message
"""
log_sql = """
INSERT INTO LOAD_LOG (PREFIX, DATA_TYPE, STAGE_PATH, PHASE, STATUS, SQL_TEXT, ERROR_MESSAGE)
VALUES (%s, %s, %s, %s, %s, %s, %s)
"""
cursor.execute(
log_sql, (prefix, data_type, stage_path, phase, status, sql_text, error_msg)
)
def create_table_from_schema(cursor, data_type: str, stage_path: str, stage_name: str) -> str:
"""
Generate CREATE TABLE SQL using INFER_SCHEMA.
For tables in TABLES_WITH_TRUE_BINARY_DATA, converts VARIANT columns to BINARY
to preserve binary/compressed data.
Args:
cursor: Snowflake cursor for executing queries
data_type: The table name to create
stage_path: The stage path containing the data files
stage_name: The Snowflake stage name
Returns:
The SQL statement that was generated
"""
# First, infer the schema
cursor.execute(
f"""
SELECT ARRAY_AGG(OBJECT_CONSTRUCT(*))
FROM TABLE(
INFER_SCHEMA(
LOCATION => '@{stage_name}/{stage_path}',
FILE_FORMAT => 'parquet_ff'
)
)
"""
)
schema_array = cursor.fetchone()[0]
import json
if isinstance(schema_array, str):
schema_array = json.loads(schema_array)
### This was unnessecary when loading prod data ###
# Check if this table should convert BINARY to VARIANT
# convert_binary_to_variant = data_type in TABLES_WITH_VARIANT_COLUMNS
convert_binary_to_variant = False
# Build column definitions manually
column_defs = []
for column_def in schema_array:
col_name = column_def.get("COLUMN_NAME") or column_def.get("column_name")
col_type = column_def.get("TYPE") or column_def.get("type")
nullable = column_def.get("NULLABLE") or column_def.get("nullable", True)
# Convert VARIANT to BINARY for tables with true binary data
if col_type == "VARIANT" and convert_binary_to_variant:
col_type = "BINARY"
# Build column definition
null_clause = "" if nullable else " NOT NULL"
column_defs.append(f'"{col_name}" {col_type}{null_clause}')
columns_sql = ",\n ".join(column_defs)
create_sql = f"""
CREATE OR REPLACE TABLE {data_type} (
{columns_sql}
)
"""
return create_sql
def copy_data_into_table(cursor, data_type: str, stage_path: str, stage_name: str) -> str:
"""
Generate and execute COPY INTO SQL to load data from stage into table.
Args:
cursor: Snowflake cursor for executing queries
data_type: The table name to copy data into
stage_path: The stage path containing the parquet files
stage_name: The Snowflake stage name
Returns:
The SQL statement that was generated
"""
copy_sql = f"""
COPY INTO {data_type}
FROM @{stage_name}/{stage_path}
FILE_FORMAT = (TYPE = PARQUET BINARY_AS_TEXT = FALSE)
MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE
PATTERN = '^.*\\.parquet$'
"""
return copy_sql
def setup_temp_tables(cursor):
"""
Create temporary tables needed for the snapshot loading process.
Creates:
- LOAD_LOG: Logs all operations and errors
Args:
cursor: Snowflake cursor for executing queries
"""
print("Setting up logging table...")
# Create LOAD_LOG table
cursor.execute(
"""
CREATE OR REPLACE TABLE LOAD_LOG (
PREFIX STRING,
DATA_TYPE STRING,
STAGE_PATH STRING,
PHASE STRING, -- e.g. 'START', 'CREATE_TABLE', 'COPY', 'ERROR'
STATUS STRING, -- e.g. 'OK', 'RUN', 'FAILED'
SQL_TEXT STRING, -- SQL we attempted to run (if applicable)
ERROR_MESSAGE STRING, -- populated on failure
LOG_TS TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP()
)
"""
)
print(" ✓ Created LOAD_LOG table")
def list_prefixes_from_stage(cursor, prefix_base: str, stage_name: str) -> list:
"""
List all prefixes under the given prefix_base from the stage.
Queries the stage to find all directories matching the pattern:
{prefix_base}/{prefix_base}.{data_type}/1/
Args:
cursor: Snowflake cursor for executing queries
prefix_base: Base prefix to search under (e.g. 'dev566')
stage_name: The Snowflake stage name
Returns:
List of prefix strings formatted as '{prefix_base}/{prefix_base}.{data_type}/1/'
"""
print(f"Listing prefixes under {prefix_base}/ in stage...")
# List all files/directories in the stage under prefix_base
cursor.execute(
f"""
LIST @{stage_name}/{prefix_base}
"""
)
results = cursor.fetchall()
prefixes = set()
# Parse the results to extract unique prefix patterns
# Results from LIST contain columns: name, size, md5, last_modified
for row in results:
file_path = row[0] # The 'name' column
# Extract the prefix pattern: {prefix_base}/{prefix_base}.{data_type}/1/
# Example file_path: 's3://synapse-rds-snapshots-dev/test-export/dev566/dev566.NODE/1/part-00000-dcd6d72f-8ee9-400c-94ce-f1f66644c5d3-c000.gz.parquet'
# We want to extract: 'dev566/dev566.NODE/1/'
# The file_path from LIST includes the full S3 path
# Find the prefix_base in the path and extract from there
if f"/{prefix_base}/{prefix_base}." in file_path:
# Find where our prefix pattern starts
idx = file_path.find(f"/{prefix_base}/{prefix_base}.")
# Extract everything after the leading slash
relevant_path = file_path[idx + 1:]
# Split and reconstruct the prefix pattern
parts = relevant_path.split('/')
if len(parts) >= 3 and parts[2] == '1':
# Reconstruct the prefix: prefix_base/prefix_base.DATA_TYPE/1/
prefix = f"{parts[0]}/{parts[1]}/{parts[2]}/"
prefixes.add(prefix)
prefix_list = sorted(list(prefixes))
print(f" ✓ Found {len(prefix_list)} unique prefixes")
return prefix_list
def process_prefix(
cursor, prefix: str, prefix_base: str, stage_name: str
) -> Optional[bool]:
"""
Process a single prefix: derive data type, create table, and log operations.
Args:
cursor: Snowflake cursor for executing queries
prefix: The prefix to process (listed from stage)
prefix_base: Base prefix used to derive the stage path and table name
stage_name: The Snowflake stage name
Returns:
True if processing succeeded, False if an error occurred
"""
data_type = None
stage_path = None
current_phase = "INIT"
current_sql = None
try:
# Derive data type from prefix
current_phase = "PARSE_PREFIX"
# derive using the provided prefix_base
data_type = derive_data_type(prefix, prefix_base=prefix_base)
stage_path = derive_stage_path(data_type, prefix_base=prefix_base)
# Log: start processing this prefix
log_operation(cursor, prefix, data_type, stage_path, "START", "OK")
# Generate CREATE TABLE SQL
create_sql = create_table_from_schema(cursor, data_type, stage_path, stage_name)
# Log that we're about to run CREATE TABLE
log_operation(
cursor, prefix, data_type, stage_path, "CREATE_TABLE", "RUN", create_sql
)
# Execute CREATE TABLE
current_phase = "CREATE_TABLE"
current_sql = create_sql
cursor.execute(create_sql)
# Log success
log_operation(cursor, prefix, data_type, stage_path, "CREATE_TABLE", "OK")
# Generate COPY INTO SQL
copy_sql = copy_data_into_table(cursor, data_type, stage_path, stage_name)
# Log that we're about to run COPY INTO
log_operation(
cursor, prefix, data_type, stage_path, "COPY_DATA", "RUN", copy_sql
)
# Execute COPY INTO
current_phase = "COPY_DATA"
current_sql = copy_sql
cursor.execute(copy_sql)
# Log success
log_operation(cursor, prefix, data_type, stage_path, "COPY_DATA", "OK")
return True
except Exception as e:
# Log error with full details
error_msg = f"{type(e).__name__}: {str(e)}"
# Use the tracked phase for accurate error reporting
log_operation(
cursor,
prefix,
data_type if data_type else "UNKNOWN",
stage_path if stage_path else "UNKNOWN",
current_phase,
"FAIL",
sql_text=current_sql,
error_msg=error_msg,
)
print(
f"Error in {current_phase} for prefix {prefix}: {error_msg}",
file=sys.stderr,
)
# Don't re-raise - continue processing other prefixes
return False
def load_snapshot_data(prefix_base: str, stage_name: str, database=None, schema=None):
"""
Main function to load snapshot data from stage into tables.
Args:
prefix_base: Base prefix string used to derive stage paths. This value
is supplied from the CLI `--prefix-base` argument and defaults to
'dev566' there.
stage_name: The Snowflake stage name (from CLI `--stage-name` argument)
database: Optional database name to use
schema: Optional schema name to use
"""
# Connect to Snowflake
conn = snowflake.connector.connect()
cursor = conn.cursor()
try:
# Set database and schema if provided
if database:
cursor.execute(f"USE DATABASE {database}")
if schema:
cursor.execute(f"USE SCHEMA {schema}")
# Set up temporary tables (LOAD_LOG)
setup_temp_tables(cursor)
conn.commit()
# List all prefixes from the stage under prefix_base
prefixes = list_prefixes_from_stage(cursor, prefix_base, stage_name)
print(f"Processing {len(prefixes)} prefixes...")
# Track successes and failures
success_count = 0
failure_count = 0
# Process each prefix
for prefix in prefixes:
print(f"Processing prefix: {prefix}")
result = process_prefix(cursor, prefix, prefix_base, stage_name)
conn.commit() # Commit after each prefix
if result is True:
success_count += 1
elif result is False:
failure_count += 1
print(f"\n{'='*60}")
print("Processing complete!")
print(f" ✓ Successful: {success_count}")
print(f" ✗ Failed: {failure_count}")
print(f" Total: {len(prefixes)}")
print(f"{'='*60}")
except Exception as e:
print(f"Fatal error: {e}", file=sys.stderr)
conn.rollback()
raise
finally:
cursor.close()
conn.close()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Load snapshot data from Snowflake stage"
)
parser.add_argument(
"--prefix-base",
dest="prefix_base",
default="dev566",
help="Base S3 prefix (relative to the Snowflake stage) of Parquet data (default: dev566)",
)
parser.add_argument(
"--stage-name",
dest="stage_name",
default="RDS_SNAPSHOT_POC_STAGE",
help="Snowflake stage name (default: RDS_SNAPSHOT_POC_STAGE)",
)
args = parser.parse_args()
load_snapshot_data(
prefix_base=args.prefix_base,
stage_name=args.stage_name
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment