Last active
March 13, 2026 18:36
-
-
Save jiayuasu/a7ca7047f6f390e4f8c27332d76beca6 to your computer and use it in GitHub Desktop.
How to reproduce the SedonaDB spatial join results
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Reproduce SedonaDB spatial join benchmark from the blog post. | |
| Dataset: SpatialBench SF1 from HuggingFace | |
| https://huggingface.co/datasets/apache-sedona/spatialbench/tree/main/v0.1.0/sf1 | |
| Requirements: | |
| pip install sedonadb huggingface_hub | |
| Usage: | |
| python repro_sedonadb.py | |
| This script downloads the trip and zone tables from HuggingFace, | |
| then runs Q10 and Q11 with a 512 MB memory limit. | |
| """ | |
| import os | |
| import shutil | |
| import signal | |
| import time | |
| from huggingface_hub import hf_hub_download | |
| import sedonadb | |
| REPO_ID = "apache-sedona/spatialbench" | |
| HF_BASE = "v0.1.0/sf1" | |
| DATA_DIR = "data/sf1" | |
| SPILL_DIR = "sedona-spill" | |
| MEMORY_LIMIT = "512m" | |
| # Files to download (only trip and zone are needed for Q10 & Q11) | |
| FILES = { | |
| "trip": ["trip.1.parquet", "trip.2.parquet"], | |
| "zone": [f"zone.{i}.parquet" for i in range(1, 7)], | |
| } | |
| QUERIES = { | |
| "q10": """ | |
| SELECT | |
| z.z_zonekey, z.z_name AS pickup_zone, | |
| AVG(t.t_dropofftime - t.t_pickuptime) AS avg_duration, | |
| AVG(t.t_distance) AS avg_distance, | |
| COUNT(t.t_tripkey) AS num_trips | |
| FROM zone z | |
| LEFT JOIN trip t ON ST_Within(ST_GeomFromWKB(t.t_pickuploc), ST_GeomFromWKB(z.z_boundary)) | |
| GROUP BY z.z_zonekey, z.z_name | |
| ORDER BY avg_duration DESC NULLS LAST, z.z_zonekey ASC | |
| """, | |
| "q11": """ | |
| SELECT COUNT(*) AS cross_zone_trip_count | |
| FROM trip t | |
| JOIN zone pickup_zone | |
| ON ST_Within(ST_GeomFromWKB(t.t_pickuploc), ST_GeomFromWKB(pickup_zone.z_boundary)) | |
| JOIN zone dropoff_zone | |
| ON ST_Within(ST_GeomFromWKB(t.t_dropoffloc), ST_GeomFromWKB(dropoff_zone.z_boundary)) | |
| WHERE pickup_zone.z_zonekey != dropoff_zone.z_zonekey | |
| """, | |
| } | |
| class QueryTimeout(Exception): | |
| pass | |
| def _timeout_handler(signum, frame): | |
| raise QueryTimeout("Query exceeded timeout") | |
| def download_data(): | |
| """Download parquet files from HuggingFace if not already present.""" | |
| for table, files in FILES.items(): | |
| table_dir = os.path.join(DATA_DIR, table) | |
| os.makedirs(table_dir, exist_ok=True) | |
| for fname in files: | |
| local_path = os.path.join(table_dir, fname) | |
| if os.path.exists(local_path): | |
| print(f" {table}/{fname} — already downloaded") | |
| continue | |
| print(f" Downloading {table}/{fname}...", end=" ", flush=True) | |
| downloaded = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=f"{HF_BASE}/{table}/{fname}", | |
| repo_type="dataset", | |
| local_dir=DATA_DIR, | |
| ) | |
| # hf_hub_download may place files in a nested structure; copy to flat layout | |
| if not os.path.exists(local_path): | |
| os.rename(downloaded, local_path) | |
| print("done") | |
| def run_query(sd, query_name, query_sql): | |
| print(f" {query_name}...", end=" ", flush=True) | |
| old_handler = signal.signal(signal.SIGALRM, _timeout_handler) | |
| signal.alarm(300) | |
| try: | |
| t0 = time.perf_counter() | |
| pdf = sd.sql(query_sql).to_pandas() | |
| elapsed = time.perf_counter() - t0 | |
| signal.alarm(0) | |
| print(f"{elapsed:.2f}s ({len(pdf):,} rows)") | |
| return True, elapsed | |
| except QueryTimeout: | |
| elapsed = time.perf_counter() - t0 | |
| print(f"TIMEOUT after {elapsed:.1f}s") | |
| return False, elapsed | |
| except Exception as e: | |
| signal.alarm(0) | |
| elapsed = time.perf_counter() - t0 | |
| print(f"FAILED after {elapsed:.1f}s: {e}") | |
| return False, elapsed | |
| finally: | |
| signal.signal(signal.SIGALRM, old_handler) | |
| def main(): | |
| print("SedonaDB SpatialBench SF1 — Q10 & Q11") | |
| print(f" Memory limit: {MEMORY_LIMIT}") | |
| print() | |
| # Step 1: Download data | |
| print("Downloading data from HuggingFace...") | |
| download_data() | |
| print() | |
| # Step 2: Set up SedonaDB | |
| if os.path.exists(SPILL_DIR): | |
| shutil.rmtree(SPILL_DIR) | |
| os.makedirs(SPILL_DIR, exist_ok=True) | |
| sd = sedonadb.connect() | |
| sd.options.memory_limit = MEMORY_LIMIT | |
| sd.options.memory_pool_type = "fair" | |
| sd.options.temp_dir = SPILL_DIR | |
| sd.sql("SET datafusion.execution.spill_compression = 'lz4_frame'").execute() | |
| # Step 3: Register tables | |
| print("Registering tables...") | |
| for table in FILES: | |
| path = os.path.join(DATA_DIR, table) | |
| sd.read_parquet(os.path.abspath(path)).to_view(table, overwrite=True) | |
| print(f" {table}") | |
| print() | |
| # Step 4: Run queries | |
| print(f"Running queries (memory_limit = {MEMORY_LIMIT}):") | |
| for qname, qsql in QUERIES.items(): | |
| run_query(sd, qname, qsql) | |
| # Report spill | |
| spill_size = sum( | |
| os.path.getsize(os.path.join(root, f)) | |
| for root, _, files in os.walk(SPILL_DIR) | |
| for f in files | |
| ) | |
| print(f"\nSpill to disk: {spill_size / (1024*1024):.1f} MB") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment