Skip to content

Instantly share code, notes, and snippets.

@apcamargo
Last active October 24, 2025 15:13
Show Gist options
  • Select an option

  • Save apcamargo/20b59112fc25032db47aa51a517361d3 to your computer and use it in GitHub Desktop.

Select an option

Save apcamargo/20b59112fc25032db47aa51a517361d3 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import shutil
from pathlib import Path
from typing import Literal
import pyarrow as pa
import pyarrow.parquet as pq
class ParquetShardWriter:
"""
A writer for creating multiple Parquet shard files with automatic size-based rollover.
This class manages writing large datasets to multiple Parquet files (shards), automatically
creating new shards when a size threshold is reached. It supports context management for
safe resource cleanup.
Parameters
----------
base_path : str or Path
Directory path where Parquet shard files will be written.
schema : pa.Schema
PyArrow schema defining the structure of the data to be written.
shard_size_bytes : int, default 5_368_709_120
Maximum size in bytes for each shard before rolling over to a new file.
Default is 5,368,709,120 bytes (5 GiB).
compression : {'snappy', 'gzip', 'brotli', 'zstd', 'lz4', 'none'} or dict or None, default 'zstd'
Compression codec to use. Can be a string specifying the codec for all columns,
a dict mapping column names to codecs, or None for no compression.
Default is 'zstd'.
row_group_size : int or None, default 10_000
Number of rows per row group. If None, uses PyArrow's default behavior.
Default is 10,000.
overwrite : bool, default False
If True, removes and recreates the output directory if it already exists.
If False, raises FileExistsError when the directory exists.
Default is False.
verbose : bool, default False
If True, enables print output. Default is False.
Raises
------
FileExistsError
If `base_path` already exists and `overwrite` is False.
Attributes
----------
base_path : Path
The base directory path for output files.
schema : pa.Schema
The PyArrow schema for the data.
shard_size_bytes : int
Maximum size threshold for each shard.
compression : str or dict or None
The compression codec configuration.
row_group_size : int or None
Number of rows per row group.
shard_index : int
Current shard number (incremented for each new shard).
writer : pq.ParquetWriter or None
Current active Parquet writer instance.
current_size : int
Accumulated size in bytes of the current shard.
verbose : bool
Whether to print progress messages.
Examples
--------
>>> import pyarrow as pa
>>> from pathlib import Path
>>> schema = pa.schema([("label", pa.string()), ("value", pa.uint32())])
>>> with ParquetShardWriter("output", schema, verbose=True) as writer:
... for i in range(1_000):
... batch_data = {
... "label": [f"Text {i}-{j}" for j in range(100)],
... "value": [i] * 100
... }
... writer.write_batch(batch_data)
"""
def __init__(
self,
base_path: str | Path,
schema: pa.Schema,
shard_size_bytes: int = 5_368_709_120, # 5 GiB
compression: Literal["snappy", "gzip", "brotli", "zstd", "lz4", "none"]
| dict[str, str]
| None = "zstd",
row_group_size: int | None = 10_000, # rows per row group
overwrite: bool = False,
verbose: bool = False,
) -> None:
self.base_path = Path(base_path)
self.schema = schema
self.shard_size_bytes = shard_size_bytes
self.compression = compression
self.row_group_size = row_group_size
self.shard_index = 0
self.writer: pq.ParquetWriter | None = None
self.current_size = 0
self.verbose = verbose
# Handle existing output directory
if self.base_path.exists():
if overwrite:
if self.verbose:
print(f"Removing existing directory: {self.base_path}")
shutil.rmtree(self.base_path)
else:
raise FileExistsError(
f"Output directory '{self.base_path}' already exists."
)
self.base_path.mkdir(parents=True, exist_ok=False)
def __enter__(self) -> "ParquetShardWriter":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
def _open_new_shard(self) -> None:
"""Open a new Parquet shard for writing."""
if self.writer:
self.writer.close()
shard_name = f"shard_{self.shard_index:05d}.parquet"
shard_path = self.base_path / shard_name
if self.verbose:
print(f"Opening new shard: {shard_path}")
parquet_props = pq.ParquetWriter(
str(shard_path),
self.schema,
compression=self.compression,
use_dictionary=True,
)
self.writer = parquet_props
self.current_size = 0
self.shard_index += 1
def write_batch(self, data: dict | pa.Table) -> None:
"""
Write a batch of data (as a dict of lists or a pyarrow.Table).
Automatically rolls over to a new shard when size limit is exceeded.
Parameters
----------
data : dict or pa.Table
Data to write. If dict, keys are column names and values are lists.
If pa.Table, it must match the schema provided at initialization.
Raises
------
TypeError
If data is not a dict or pyarrow.Table.
"""
if isinstance(data, dict):
table = pa.Table.from_pydict(data, schema=self.schema)
elif isinstance(data, pa.Table):
table = data
else:
raise TypeError("Data must be a dict or pyarrow.Table")
if self.writer is None:
self._open_new_shard()
batch_size = table.nbytes
if self.current_size + batch_size > self.shard_size_bytes:
self._open_new_shard()
self.writer.write_table(table, self.row_group_size)
self.current_size += batch_size
def close(self) -> None:
"""Close the current Parquet writer if open."""
if self.writer:
self.writer.close()
self.writer = None
if self.verbose:
print("Closed current shard writer.")
# Example usage
schema = pa.schema([("label", pa.string()), ("value", pa.uint32())])
with ParquetShardWriter(
"output",
schema,
verbose=True,
shard_size_bytes=5_368_709_120 // 10_000,
overwrite=True,
) as writer:
for i in range(5_000):
batch_data = {
"label": [f"Text {i}-{j}" for j in range(1_000)],
"value": [i] * 1_000,
}
writer.write_batch(batch_data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment