Skip to content

Instantly share code, notes, and snippets.

@philerooski
Created November 20, 2025 23:32
Show Gist options
  • Select an option

  • Save philerooski/a740b25f066f1ad205344637160aa969 to your computer and use it in GitHub Desktop.

Select an option

Save philerooski/a740b25f066f1ad205344637160aa969 to your computer and use it in GitHub Desktop.
"""
Load snapshot data from Snowflake stage into tables.
This script processes prefixes from PREFIX_LIST table, derives table names,
creates tables using INFER_SCHEMA, and logs all operations to LOAD_LOG.
See `--help` for optional parameter `--only-affected`
"""
import snowflake.connector
from typing import Optional
import sys
# Tables that had BINARY columns and need to be recreated with VARIANT
# Tables that had BINARY columns but actually contain text/JSON data that should be VARIANT
# (Excludes tables with true binary data that need to remain BINARY)
TABLES_WITH_BINARY_COLUMNS = {
"ASYNCH_JOB_STATUS",
"CHALLENGE_TEAM",
"DATA_ACCESS_SUBMISSION_STATUS",
"DISCUSSION_THREAD",
"DOWNLOAD_ORDER",
"EVALUATION",
"EVALUATION_SUBMISSION",
"MESSAGE_TO_USER",
"MULTIPART_UPLOAD",
"MULTIPART_UPLOAD_PART_STATE",
"OAUTH_AUTHORIZATION_CODE",
"QUIZ_RESPONSE",
"RESEARCH_PROJECT",
"STATISTICS_MONTHLY_STATUS",
"SUBSTATUS_ANNOTATIONS_BLOB",
"TABLE_STATUS",
"USER_GROUP",
"V2_WIKI_MARKDOWN",
"V2_WIKI_OWNERS",
"VERIFICATION_STATE",
}
# Tables that have actual binary/compressed data and should keep BINARY columns
TABLES_WITH_TRUE_BINARY_DATA = {
"ACCESS_REQUIREMENT_REVISION",
"ACTIVITY",
"CHALLENGE",
"DATA_ACCESS_REQUEST",
"DATA_ACCESS_SUBMISSION",
"MEMBERSHIP_INVITATION_SUBMISSION",
"MEMBERSHIP_REQUEST_SUBMISSION",
"NODE_REVISION",
"PERSONAL_ACCESS_TOKEN",
"TEAM",
"USER_PROFILE",
"VERIFICATION_SUBMISSION",
}
# Combined set of all affected tables
ALL_AFFECTED_TABLES = TABLES_WITH_BINARY_COLUMNS | TABLES_WITH_TRUE_BINARY_DATA
def derive_data_type(prefix: str) -> str:
"""
Derive data type (table name) from prefix.
Example: 'dev566/dev566.NODE/1/' -> 'NODE'
Args:
prefix: The prefix string to parse
Returns:
The data type extracted from the prefix
"""
# Split by 'dev566.' and take the second part
parts = prefix.split("dev566.", 1)
if len(parts) < 2:
raise ValueError(f"Invalid prefix format: {prefix}")
# Split by '/' and take the first part
data_type = parts[1].split("/")[0]
return data_type
def derive_stage_path(data_type: str) -> str:
"""
Derive stage path from data type.
Args:
data_type: The data type/table name
Returns:
Stage path: dev566.<DATA_TYPE>/1/
"""
return f"dev566.{data_type}/1/"
def log_operation(
cursor,
prefix: str,
data_type: str,
stage_path: str,
phase: str,
status: str,
sql_text: Optional[str] = None,
error_msg: Optional[str] = None,
):
"""
Log an operation to the LOAD_LOG table.
Args:
cursor: Snowflake cursor for executing queries
prefix: The prefix being processed
data_type: The derived data type
stage_path: The stage path
phase: The current phase (START, CREATE_TABLE, COPY_DATA, ERROR)
status: Status of the operation (OK, RUN, FAIL)
sql_text: Optional SQL text being executed
error_msg: Optional error message
"""
log_sql = """
INSERT INTO LOAD_LOG (PREFIX, DATA_TYPE, STAGE_PATH, PHASE, STATUS, SQL_TEXT, ERROR_MESSAGE)
VALUES (%s, %s, %s, %s, %s, %s, %s)
"""
cursor.execute(
log_sql, (prefix, data_type, stage_path, phase, status, sql_text, error_msg)
)
def create_table_from_schema(cursor, data_type: str, stage_path: str) -> str:
"""
Generate CREATE TABLE SQL using INFER_SCHEMA, replacing BINARY with VARIANT.
Args:
cursor: Snowflake cursor for executing queries
data_type: The table name to create
stage_path: The stage path containing the data files
Returns:
The SQL statement that was generated
"""
# First, infer the schema
cursor.execute(
f"""
SELECT ARRAY_AGG(OBJECT_CONSTRUCT(*))
FROM TABLE(
INFER_SCHEMA(
LOCATION => '@RDS_SNAPSHOT_POC_STAGE/{stage_path}',
FILE_FORMAT => 'parquet_ff'
)
)
"""
)
schema_array = cursor.fetchone()[0]
# Replace BINARY with VARIANT in the schema (only for tables that don't have true binary data)
import json
if isinstance(schema_array, str):
schema_array = json.loads(schema_array)
# Check if this table should keep BINARY columns or convert to VARIANT
keep_binary = data_type in TABLES_WITH_TRUE_BINARY_DATA
# Build column definitions manually
column_defs = []
for column_def in schema_array:
col_name = column_def.get("COLUMN_NAME") or column_def.get("column_name")
col_type = column_def.get("TYPE") or column_def.get("type")
nullable = column_def.get("NULLABLE") or column_def.get("nullable", True)
# Replace BINARY with VARIANT only if this table doesn't have true binary data
if col_type == "BINARY" and not keep_binary:
col_type = "VARIANT"
# Build column definition
null_clause = "" if nullable else " NOT NULL"
column_defs.append(f'"{col_name}" {col_type}{null_clause}')
columns_sql = ",\n ".join(column_defs)
create_sql = f"""
CREATE OR REPLACE TABLE {data_type} (
{columns_sql}
)
"""
return create_sql
def copy_data_into_table(cursor, data_type: str, stage_path: str) -> str:
"""
Generate and execute COPY INTO SQL to load data from stage into table.
Args:
cursor: Snowflake cursor for executing queries
data_type: The table name to copy data into
stage_path: The stage path containing the parquet files
Returns:
The SQL statement that was generated
"""
copy_sql = f"""
COPY INTO {data_type}
FROM @RDS_SNAPSHOT_POC_STAGE/{stage_path}
FILE_FORMAT = (TYPE = PARQUET BINARY_AS_TEXT = FALSE)
MATCH_BY_COLUMN_NAME = CASE_INSENSITIVE
PATTERN = '^.*\\.parquet$'
"""
return copy_sql
def setup_temp_tables(cursor):
"""
Create temporary tables needed for the snapshot loading process.
Creates:
- LOAD_LOG: Logs all operations and errors
- PREFIX_LIST: Stores the list of S3 prefixes to process
Args:
cursor: Snowflake cursor for executing queries
"""
print("Setting up temporary tables...")
# Create LOAD_LOG table
cursor.execute(
"""
CREATE OR REPLACE TABLE LOAD_LOG (
PREFIX STRING,
DATA_TYPE STRING,
STAGE_PATH STRING,
PHASE STRING, -- e.g. 'START', 'CREATE_TABLE', 'COPY', 'ERROR'
STATUS STRING, -- e.g. 'OK', 'RUN', 'FAILED'
SQL_TEXT STRING, -- SQL we attempted to run (if applicable)
ERROR_MESSAGE STRING, -- populated on failure
LOG_TS TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP()
)
"""
)
print(" ✓ Created LOAD_LOG table")
# Create PREFIX_LIST table
cursor.execute(
"""
CREATE OR REPLACE TEMP TABLE PREFIX_LIST (PREFIX STRING)
"""
)
print(" ✓ Created PREFIX_LIST table")
# Load prefixes from stage
cursor.execute(
"""
COPY INTO PREFIX_LIST
FROM @RDS_SNAPSHOT_POC_STAGE/s3_prefixes.txt
FILE_FORMAT = (
TYPE = CSV
FIELD_DELIMITER = NONE
SKIP_HEADER = 0
)
"""
)
# Get count of loaded prefixes
cursor.execute("SELECT COUNT(*) FROM PREFIX_LIST")
prefix_count = cursor.fetchone()[0]
print(f" ✓ Loaded {prefix_count} prefixes from s3_prefixes.txt")
def process_prefix(
cursor, prefix: str, only_affected_tables: bool = False
) -> Optional[bool]:
"""
Process a single prefix: derive data type, create table, and log operations.
Args:
cursor: Snowflake cursor for executing queries
prefix: The prefix to process
only_affected_tables: If True, only process tables in ALL_AFFECTED_TABLES
Returns:
True if processing succeeded, False if an error occurred, None if skipped
"""
data_type = None
stage_path = None
current_phase = "INIT"
current_sql = None
try:
# Derive data type from prefix
current_phase = "PARSE_PREFIX"
data_type = derive_data_type(prefix)
stage_path = derive_stage_path(data_type)
# Skip if only processing affected tables and this isn't one of them
if only_affected_tables and data_type not in ALL_AFFECTED_TABLES:
print(f" ⊘ Skipping {data_type} (not in affected tables list)")
return None
# Log: start processing this prefix
log_operation(cursor, prefix, data_type, stage_path, "START", "OK")
# Generate CREATE TABLE SQL
create_sql = create_table_from_schema(cursor, data_type, stage_path)
# Log that we're about to run CREATE TABLE
log_operation(
cursor, prefix, data_type, stage_path, "CREATE_TABLE", "RUN", create_sql
)
# Execute CREATE TABLE
current_phase = "CREATE_TABLE"
current_sql = create_sql
cursor.execute(create_sql)
# Log success
log_operation(cursor, prefix, data_type, stage_path, "CREATE_TABLE", "OK")
# Generate COPY INTO SQL
copy_sql = copy_data_into_table(cursor, data_type, stage_path)
# Log that we're about to run COPY INTO
log_operation(
cursor, prefix, data_type, stage_path, "COPY_DATA", "RUN", copy_sql
)
# Execute COPY INTO
current_phase = "COPY_DATA"
current_sql = copy_sql
cursor.execute(copy_sql)
# Log success
log_operation(cursor, prefix, data_type, stage_path, "COPY_DATA", "OK")
return True
except Exception as e:
# Log error with full details
error_msg = f"{type(e).__name__}: {str(e)}"
# Use the tracked phase for accurate error reporting
log_operation(
cursor,
prefix,
data_type if data_type else "UNKNOWN",
stage_path if stage_path else "UNKNOWN",
current_phase,
"FAIL",
sql_text=current_sql,
error_msg=error_msg,
)
print(
f"Error in {current_phase} for prefix {prefix}: {error_msg}",
file=sys.stderr,
)
# Don't re-raise - continue processing other prefixes
return False
def load_snapshot_data(database=None, schema=None, only_affected_tables=False):
"""
Main function to load snapshot data from stage into tables.
Args:
database: Optional database name to use
schema: Optional schema name to use
only_affected_tables: If True, only process tables that had BINARY columns
"""
# Connect to Snowflake
conn = snowflake.connector.connect()
cursor = conn.cursor()
try:
# Set database and schema if provided
if database:
cursor.execute(f"USE DATABASE {database}")
if schema:
cursor.execute(f"USE SCHEMA {schema}")
# Set up temporary tables (LOAD_LOG and PREFIX_LIST)
setup_temp_tables(cursor)
conn.commit()
# Fetch all distinct prefixes
cursor.execute(
"""
SELECT DISTINCT PREFIX
FROM PREFIX_LIST
WHERE PREFIX IS NOT NULL
ORDER BY PREFIX
"""
)
prefixes = cursor.fetchall()
if only_affected_tables:
print(
f"\n⚠️ AFFECTED TABLES MODE: Only processing {len(ALL_AFFECTED_TABLES)} tables that had BINARY columns"
)
print(
f" - {len(TABLES_WITH_BINARY_COLUMNS)} tables will convert BINARY → VARIANT (text data)"
)
print(
f" - {len(TABLES_WITH_TRUE_BINARY_DATA)} tables will keep BINARY columns (binary data)"
)
print(
f"Processing {len(prefixes)} prefixes (will skip non-affected tables)...\n"
)
else:
print(f"Processing {len(prefixes)} prefixes...")
# Track successes, failures, and skips
success_count = 0
failure_count = 0
skipped_count = 0
# Process each prefix
for (prefix,) in prefixes:
print(f"Processing prefix: {prefix}")
result = process_prefix(cursor, prefix, only_affected_tables)
conn.commit() # Commit after each prefix
if result is True:
success_count += 1
elif result is False:
failure_count += 1
elif result is None:
skipped_count += 1
print(f"\n{'='*60}")
print("Processing complete!")
print(f" ✓ Successful: {success_count}")
print(f" ✗ Failed: {failure_count}")
if skipped_count > 0:
print(f" ⊘ Skipped: {skipped_count}")
print(f" Total: {len(prefixes)}")
print(f"{'='*60}")
except Exception as e:
print(f"Fatal error: {e}", file=sys.stderr)
conn.rollback()
raise
finally:
cursor.close()
conn.close()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Load snapshot data from Snowflake stage"
)
parser.add_argument(
"--only-affected",
action="store_true",
help=f"Only process tables that had BINARY columns ({len(ALL_AFFECTED_TABLES)} tables: "
f"{len(TABLES_WITH_BINARY_COLUMNS)} convert to VARIANT, "
f"{len(TABLES_WITH_TRUE_BINARY_DATA)} keep BINARY)",
)
args = parser.parse_args()
load_snapshot_data(only_affected_tables=args.only_affected)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment