Created
December 19, 2024 16:15
-
-
Save Taytay/eaf8fd438628138b2280ebbcae14513a to your computer and use it in GitHub Desktop.
Illustrate parallel duckdb UDFs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# %%\n", | |
| "import pyarrow as pa\n", | |
| "import duckdb\n", | |
| "from duckdb.typing import VARCHAR\n", | |
| "from duckdb.functional import PythonUDFType\n", | |
| "from time import sleep\n", | |
| "import threading\n", | |
| "\n", | |
| "import time\n", | |
| "from typing import Any\n", | |
| "\n", | |
| "\n", | |
| "num_calls = 0\n", | |
| "average_num_passed_into_each_udf_call = 0\n", | |
| "num_calls_currently_in_flight = 0\n", | |
| "max_num_calls_currently_in_flight = 0\n", | |
| "\n", | |
| "\n", | |
| "# Arrow python implementation (operates over 2048 rows at a time)\n", | |
| "def my_arrow_udf(url_arr: pa.ChunkedArray) -> pa.Array:\n", | |
| " global num_calls\n", | |
| " global average_num_passed_into_each_udf\n", | |
| " global num_calls_currently_in_flight\n", | |
| " global max_num_calls_currently_in_flight\n", | |
| "\n", | |
| " with threading.Lock():\n", | |
| " num_calls += 1\n", | |
| " num_calls_currently_in_flight += 1\n", | |
| " max_num_calls_currently_in_flight = max(max_num_calls_currently_in_flight, num_calls_currently_in_flight)\n", | |
| "\n", | |
| " urls = []\n", | |
| " for chunk in url_arr.chunks:\n", | |
| " chunk_as_list = chunk.to_pylist()\n", | |
| " urls.extend(chunk_as_list)\n", | |
| "\n", | |
| " average_num_passed_into_each_udf = (len(urls) * (num_calls - 1) + len(chunk_as_list)) / num_calls\n", | |
| "\n", | |
| " # We sleep just a bit to illustrate the parallelism\n", | |
| " # If this were not using parallelism, we'd sleep for 0.1s * 586 calls, = 58.6 seconds\n", | |
| " # but since we use parallelism, and have 10 threads, we sleep for 0.1s * 58.6 = 5.86 seconds\n", | |
| " sleep(0.1)\n", | |
| " results: list[str] = []\n", | |
| " results = [\"foo\" for _ in range(len(urls))]\n", | |
| "\n", | |
| " with threading.Lock():\n", | |
| " num_calls_currently_in_flight -= 1\n", | |
| "\n", | |
| " return pa.array(results, type=pa.string())\n", | |
| "\n", | |
| "\n", | |
| "try:\n", | |
| " duckdb.remove_function(\"my_arrow_udf\")\n", | |
| "except Exception:\n", | |
| " pass\n", | |
| "duckdb.create_function(\"my_arrow_udf\", my_arrow_udf, [VARCHAR], VARCHAR, type=PythonUDFType.ARROW)\n", | |
| "\n", | |
| "# Setup a sample table\n", | |
| "duckdb.sql(\"CREATE OR REPLACE TABLE example_table (id INTEGER, url VARCHAR)\")\n", | |
| "duckdb.sql(\n", | |
| " \"\"\"\n", | |
| " INSERT INTO example_table \n", | |
| " SELECT *, 'https://example.com' as url \n", | |
| " FROM range(1_200_000)\n", | |
| " \"\"\"\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "4660805115b2472aa3425b004b4738dd", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Arrow UDF took 6.561621904373169 seconds\n", | |
| "Number of calls: 586\n", | |
| "Number of rows in each chunk per call: 2048.0\n", | |
| "Max number of calls ever in flight: 10\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# %%\n", | |
| "# Execute and fetch the results\n", | |
| "num_calls = 0\n", | |
| "average_num_passed_into_each_udf = 0\n", | |
| "num_calls_currently_in_flight = 0\n", | |
| "max_num_calls_currently_in_flight = 0\n", | |
| "start_time: float = time.time()\n", | |
| "res_arrow: list[Any] = duckdb.sql(\n", | |
| " \"SELECT my_arrow_udf(url) FROM example_table\"\n", | |
| ").fetchall() # Will run over all rows simultaneously\n", | |
| "end_time: float = time.time()\n", | |
| "print(f\"Arrow UDF took {end_time - start_time} seconds\")\n", | |
| "print(f\"Number of calls: {num_calls}\")\n", | |
| "print(f\"Number of rows in each chunk per call: {average_num_passed_into_each_udf}\")\n", | |
| "print(f\"Max number of calls ever in flight: {max_num_calls_currently_in_flight}\")\n" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": ".venv", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.14" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment