Skip to content

Instantly share code, notes, and snippets.

@jiayuasu
Last active March 13, 2026 18:36
Show Gist options
  • Select an option

  • Save jiayuasu/a7ca7047f6f390e4f8c27332d76beca6 to your computer and use it in GitHub Desktop.

Select an option

Save jiayuasu/a7ca7047f6f390e4f8c27332d76beca6 to your computer and use it in GitHub Desktop.
How to reproduce the SedonaDB spatial join results
#!/usr/bin/env python3
"""
Reproduce SedonaDB spatial join benchmark from the blog post.
Dataset: SpatialBench SF1 from HuggingFace
https://huggingface.co/datasets/apache-sedona/spatialbench/tree/main/v0.1.0/sf1
Requirements:
pip install sedonadb huggingface_hub
Usage:
python repro_sedonadb.py
This script downloads the trip and zone tables from HuggingFace,
then runs Q10 and Q11 with a 512 MB memory limit.
"""
import os
import shutil
import signal
import time
from huggingface_hub import hf_hub_download
import sedonadb
REPO_ID = "apache-sedona/spatialbench"
HF_BASE = "v0.1.0/sf1"
DATA_DIR = "data/sf1"
SPILL_DIR = "sedona-spill"
MEMORY_LIMIT = "512m"
# Files to download (only trip and zone are needed for Q10 & Q11)
FILES = {
"trip": ["trip.1.parquet", "trip.2.parquet"],
"zone": [f"zone.{i}.parquet" for i in range(1, 7)],
}
QUERIES = {
"q10": """
SELECT
z.z_zonekey, z.z_name AS pickup_zone,
AVG(t.t_dropofftime - t.t_pickuptime) AS avg_duration,
AVG(t.t_distance) AS avg_distance,
COUNT(t.t_tripkey) AS num_trips
FROM zone z
LEFT JOIN trip t ON ST_Within(ST_GeomFromWKB(t.t_pickuploc), ST_GeomFromWKB(z.z_boundary))
GROUP BY z.z_zonekey, z.z_name
ORDER BY avg_duration DESC NULLS LAST, z.z_zonekey ASC
""",
"q11": """
SELECT COUNT(*) AS cross_zone_trip_count
FROM trip t
JOIN zone pickup_zone
ON ST_Within(ST_GeomFromWKB(t.t_pickuploc), ST_GeomFromWKB(pickup_zone.z_boundary))
JOIN zone dropoff_zone
ON ST_Within(ST_GeomFromWKB(t.t_dropoffloc), ST_GeomFromWKB(dropoff_zone.z_boundary))
WHERE pickup_zone.z_zonekey != dropoff_zone.z_zonekey
""",
}
class QueryTimeout(Exception):
pass
def _timeout_handler(signum, frame):
raise QueryTimeout("Query exceeded timeout")
def download_data():
"""Download parquet files from HuggingFace if not already present."""
for table, files in FILES.items():
table_dir = os.path.join(DATA_DIR, table)
os.makedirs(table_dir, exist_ok=True)
for fname in files:
local_path = os.path.join(table_dir, fname)
if os.path.exists(local_path):
print(f" {table}/{fname} — already downloaded")
continue
print(f" Downloading {table}/{fname}...", end=" ", flush=True)
downloaded = hf_hub_download(
repo_id=REPO_ID,
filename=f"{HF_BASE}/{table}/{fname}",
repo_type="dataset",
local_dir=DATA_DIR,
)
# hf_hub_download may place files in a nested structure; copy to flat layout
if not os.path.exists(local_path):
os.rename(downloaded, local_path)
print("done")
def run_query(sd, query_name, query_sql):
print(f" {query_name}...", end=" ", flush=True)
old_handler = signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(300)
try:
t0 = time.perf_counter()
pdf = sd.sql(query_sql).to_pandas()
elapsed = time.perf_counter() - t0
signal.alarm(0)
print(f"{elapsed:.2f}s ({len(pdf):,} rows)")
return True, elapsed
except QueryTimeout:
elapsed = time.perf_counter() - t0
print(f"TIMEOUT after {elapsed:.1f}s")
return False, elapsed
except Exception as e:
signal.alarm(0)
elapsed = time.perf_counter() - t0
print(f"FAILED after {elapsed:.1f}s: {e}")
return False, elapsed
finally:
signal.signal(signal.SIGALRM, old_handler)
def main():
print("SedonaDB SpatialBench SF1 — Q10 & Q11")
print(f" Memory limit: {MEMORY_LIMIT}")
print()
# Step 1: Download data
print("Downloading data from HuggingFace...")
download_data()
print()
# Step 2: Set up SedonaDB
if os.path.exists(SPILL_DIR):
shutil.rmtree(SPILL_DIR)
os.makedirs(SPILL_DIR, exist_ok=True)
sd = sedonadb.connect()
sd.options.memory_limit = MEMORY_LIMIT
sd.options.memory_pool_type = "fair"
sd.options.temp_dir = SPILL_DIR
sd.sql("SET datafusion.execution.spill_compression = 'lz4_frame'").execute()
# Step 3: Register tables
print("Registering tables...")
for table in FILES:
path = os.path.join(DATA_DIR, table)
sd.read_parquet(os.path.abspath(path)).to_view(table, overwrite=True)
print(f" {table}")
print()
# Step 4: Run queries
print(f"Running queries (memory_limit = {MEMORY_LIMIT}):")
for qname, qsql in QUERIES.items():
run_query(sd, qname, qsql)
# Report spill
spill_size = sum(
os.path.getsize(os.path.join(root, f))
for root, _, files in os.walk(SPILL_DIR)
for f in files
)
print(f"\nSpill to disk: {spill_size / (1024*1024):.1f} MB")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment