Last active
October 24, 2025 15:13
-
-
Save apcamargo/20b59112fc25032db47aa51a517361d3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| import shutil | |
| from pathlib import Path | |
| from typing import Literal | |
| import pyarrow as pa | |
| import pyarrow.parquet as pq | |
| class ParquetShardWriter: | |
| """ | |
| A writer for creating multiple Parquet shard files with automatic size-based rollover. | |
| This class manages writing large datasets to multiple Parquet files (shards), automatically | |
| creating new shards when a size threshold is reached. It supports context management for | |
| safe resource cleanup. | |
| Parameters | |
| ---------- | |
| base_path : str or Path | |
| Directory path where Parquet shard files will be written. | |
| schema : pa.Schema | |
| PyArrow schema defining the structure of the data to be written. | |
| shard_size_bytes : int, default 5_368_709_120 | |
| Maximum size in bytes for each shard before rolling over to a new file. | |
| Default is 5,368,709,120 bytes (5 GiB). | |
| compression : {'snappy', 'gzip', 'brotli', 'zstd', 'lz4', 'none'} or dict or None, default 'zstd' | |
| Compression codec to use. Can be a string specifying the codec for all columns, | |
| a dict mapping column names to codecs, or None for no compression. | |
| Default is 'zstd'. | |
| row_group_size : int or None, default 10_000 | |
| Number of rows per row group. If None, uses PyArrow's default behavior. | |
| Default is 10,000. | |
| overwrite : bool, default False | |
| If True, removes and recreates the output directory if it already exists. | |
| If False, raises FileExistsError when the directory exists. | |
| Default is False. | |
| verbose : bool, default False | |
| If True, enables print output. Default is False. | |
| Raises | |
| ------ | |
| FileExistsError | |
| If `base_path` already exists and `overwrite` is False. | |
| Attributes | |
| ---------- | |
| base_path : Path | |
| The base directory path for output files. | |
| schema : pa.Schema | |
| The PyArrow schema for the data. | |
| shard_size_bytes : int | |
| Maximum size threshold for each shard. | |
| compression : str or dict or None | |
| The compression codec configuration. | |
| row_group_size : int or None | |
| Number of rows per row group. | |
| shard_index : int | |
| Current shard number (incremented for each new shard). | |
| writer : pq.ParquetWriter or None | |
| Current active Parquet writer instance. | |
| current_size : int | |
| Accumulated size in bytes of the current shard. | |
| verbose : bool | |
| Whether to print progress messages. | |
| Examples | |
| -------- | |
| >>> import pyarrow as pa | |
| >>> from pathlib import Path | |
| >>> schema = pa.schema([("label", pa.string()), ("value", pa.uint32())]) | |
| >>> with ParquetShardWriter("output", schema, verbose=True) as writer: | |
| ... for i in range(1_000): | |
| ... batch_data = { | |
| ... "label": [f"Text {i}-{j}" for j in range(100)], | |
| ... "value": [i] * 100 | |
| ... } | |
| ... writer.write_batch(batch_data) | |
| """ | |
| def __init__( | |
| self, | |
| base_path: str | Path, | |
| schema: pa.Schema, | |
| shard_size_bytes: int = 5_368_709_120, # 5 GiB | |
| compression: Literal["snappy", "gzip", "brotli", "zstd", "lz4", "none"] | |
| | dict[str, str] | |
| | None = "zstd", | |
| row_group_size: int | None = 10_000, # rows per row group | |
| overwrite: bool = False, | |
| verbose: bool = False, | |
| ) -> None: | |
| self.base_path = Path(base_path) | |
| self.schema = schema | |
| self.shard_size_bytes = shard_size_bytes | |
| self.compression = compression | |
| self.row_group_size = row_group_size | |
| self.shard_index = 0 | |
| self.writer: pq.ParquetWriter | None = None | |
| self.current_size = 0 | |
| self.verbose = verbose | |
| # Handle existing output directory | |
| if self.base_path.exists(): | |
| if overwrite: | |
| if self.verbose: | |
| print(f"Removing existing directory: {self.base_path}") | |
| shutil.rmtree(self.base_path) | |
| else: | |
| raise FileExistsError( | |
| f"Output directory '{self.base_path}' already exists." | |
| ) | |
| self.base_path.mkdir(parents=True, exist_ok=False) | |
| def __enter__(self) -> "ParquetShardWriter": | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb) -> None: | |
| self.close() | |
| def _open_new_shard(self) -> None: | |
| """Open a new Parquet shard for writing.""" | |
| if self.writer: | |
| self.writer.close() | |
| shard_name = f"shard_{self.shard_index:05d}.parquet" | |
| shard_path = self.base_path / shard_name | |
| if self.verbose: | |
| print(f"Opening new shard: {shard_path}") | |
| parquet_props = pq.ParquetWriter( | |
| str(shard_path), | |
| self.schema, | |
| compression=self.compression, | |
| use_dictionary=True, | |
| ) | |
| self.writer = parquet_props | |
| self.current_size = 0 | |
| self.shard_index += 1 | |
| def write_batch(self, data: dict | pa.Table) -> None: | |
| """ | |
| Write a batch of data (as a dict of lists or a pyarrow.Table). | |
| Automatically rolls over to a new shard when size limit is exceeded. | |
| Parameters | |
| ---------- | |
| data : dict or pa.Table | |
| Data to write. If dict, keys are column names and values are lists. | |
| If pa.Table, it must match the schema provided at initialization. | |
| Raises | |
| ------ | |
| TypeError | |
| If data is not a dict or pyarrow.Table. | |
| """ | |
| if isinstance(data, dict): | |
| table = pa.Table.from_pydict(data, schema=self.schema) | |
| elif isinstance(data, pa.Table): | |
| table = data | |
| else: | |
| raise TypeError("Data must be a dict or pyarrow.Table") | |
| if self.writer is None: | |
| self._open_new_shard() | |
| batch_size = table.nbytes | |
| if self.current_size + batch_size > self.shard_size_bytes: | |
| self._open_new_shard() | |
| self.writer.write_table(table, self.row_group_size) | |
| self.current_size += batch_size | |
| def close(self) -> None: | |
| """Close the current Parquet writer if open.""" | |
| if self.writer: | |
| self.writer.close() | |
| self.writer = None | |
| if self.verbose: | |
| print("Closed current shard writer.") | |
| # Example usage | |
| schema = pa.schema([("label", pa.string()), ("value", pa.uint32())]) | |
| with ParquetShardWriter( | |
| "output", | |
| schema, | |
| verbose=True, | |
| shard_size_bytes=5_368_709_120 // 10_000, | |
| overwrite=True, | |
| ) as writer: | |
| for i in range(5_000): | |
| batch_data = { | |
| "label": [f"Text {i}-{j}" for j in range(1_000)], | |
| "value": [i] * 1_000, | |
| } | |
| writer.write_batch(batch_data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment