Skip to content

Instantly share code, notes, and snippets.

@mrjsj
Created October 30, 2025 20:44
Show Gist options
  • Select an option

  • Save mrjsj/dc56aaa9cc6213270faa0d16957cdb05 to your computer and use it in GitHub Desktop.

Select an option

Save mrjsj/dc56aaa9cc6213270faa0d16957cdb05 to your computer and use it in GitHub Desktop.
PySpark gap-free identity column
import pyspark.sql.functions as F
from pyspark.sql import DataFrame
from pyspark.sql.types import LongType, StructField, StructType
from pyspark.sql.window import Window
def identity_column_with_zip_with_index(df: DataFrame, offset: int = 0) -> DataFrame:
rdd_with_index = df.rdd.zipWithIndex()
# All columns in df are stored in the _1 column as a Struct, and _2 contains the index
new_schema = StructType(
[
StructField("_1", df.schema),
StructField("_2", LongType(), False),
]
)
# Important to provide the schema - otherwise Spark needs to infer the schema of the rdd
df = (
rdd_with_index.toDF(schema=new_schema)
.select(
F.col("_1.*"),
(F.col("_2") + F.lit(offset)).alias("identity_column")
)
)
return df
def identity_column_with_monotonically_increasing_id(df: DataFrame, offset: int = 0) -> DataFrame:
df = df.withColumn(
"identity_column",
F.monotonically_increasing_id() + F.lit(offset)
)
return df
def identity_column_with_row_number(df: DataFrame, offset: int = 0) -> DataFrame:
window = Window().orderBy(F.lit(1))
df = df.withColumn(
"identity_column",
F.row_number().over(window) + F.lit(offset)
)
return df
def identity_column_with_partitioned_row_number(df: DataFrame, offset: int = 0) -> DataFrame:
original_columns = df.columns
# assign partition id and a per-partition row number
df = df.withColumn(
"partition_id",
F.spark_partition_id()
)
# add row number in each partition
partition_window = Window().partitionBy("partition_id").orderBy(F.lit(1))
df = df.withColumn(
"row_num",
F.row_number().over(partition_window)
)
# compute how many rows are in each partition
partition_counts = (
df.groupBy("partition_id")
.agg(F.max(F.col("row_num")).alias("partition_count"))
)
# running cumulative sum of partition counts ordered by partition_id
cumulative_window = (
Window.partitionBy(F.lit(1))
.orderBy("partition_id")
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
)
partition_counts = partition_counts.withColumn(
"cumulative_count", F.sum("partition_count").over(cumulative_window)
)
# previous cumulative count (to offset row numbers from previous partitions)
prev_cumulative_window = Window().partitionBy(F.lit(1)).orderBy("partition_id")
partition_counts = partition_counts.withColumn(
"prev_cumulative_count",
F.coalesce(F.lag(F.col("cumulative_count")).over(prev_cumulative_window), F.lit(0)),
)
# join back to original rows and compute identity column
df = df.join(partition_counts, on="partition_id", how="left")
df = df.withColumn(
"identity_column",
F.col("row_num") + F.col("prev_cumulative_count") + F.lit(offset),
)
# return original columns plus the identity column
df = df.select("identity_column", *original_columns)
return df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment