Created
November 18, 2025 09:31
-
-
Save ogrisel/8c645970dcd9478679ad1a7669f25c18 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "da11cd5d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import os\n", | |
| "os.environ[\"SCIPY_ARRAY_API\"] = \"1\"\n", | |
| "os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n", | |
| "\n", | |
| "import sklearn\n", | |
| "sklearn.set_config(array_api_dispatch=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "e64da94d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from sklearn.utils._array_api import get_namespace_and_device\n", | |
| "from scipy.stats import quantile\n", | |
| "from sklearn.utils.stats import _weighted_percentile as _weighted_percentile_sklearn\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "\n", | |
| "def _weighted_percentile_scipy(array, sample_weight, percentile_rank=10, average=False, xp=None):\n", | |
| " xp, _, device = get_namespace_and_device(array)\n", | |
| "\n", | |
| " array = xp.asarray(array, device=device)\n", | |
| " if not xp.isdtype(array.dtype, \"real floating\"):\n", | |
| " array = xp.asarray(array, dtype=xp.asarray([0.0]).dtype, device=device)\n", | |
| "\n", | |
| " p = xp.asarray(percentile_rank, dtype=array.dtype, device=device) / 100\n", | |
| " sample_weight = xp.asarray(sample_weight, dtype=array.dtype, device=device)\n", | |
| "\n", | |
| " results = quantile(\n", | |
| " array.T if array.ndim == 2 else array,\n", | |
| " p=p,\n", | |
| " axis=-1 if array.ndim == 2 else None,\n", | |
| " nan_policy=\"omit\",\n", | |
| " weights=sample_weight.T if sample_weight.ndim == 2 else sample_weight,\n", | |
| " method=\"averaged_inverted_cdf\" if average else \"inverted_cdf\",\n", | |
| " )\n", | |
| " return results" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "d314b4f1", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "rng = np.random.default_rng(42)\n", | |
| "n_samples = int(1e5)\n", | |
| "n_features = int(1e3)\n", | |
| "data_np = rng.lognormal(size=(n_samples, n_features)).astype(np.float32)\n", | |
| "weights_np = (\n", | |
| " rng.uniform(low=-1, high=5, size=data_np.shape[0]).clip(min=0).astype(np.float32)\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "33a393ab", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 7.22 s, sys: 236 ms, total: 7.45 s\n", | |
| "Wall time: 7.47 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "results_np_sklearn = _weighted_percentile_sklearn(\n", | |
| " data_np, sample_weight=weights_np, percentile_rank=[5, 25, 50, 75, 95], average=True\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "e540373f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 7.33 s, sys: 191 ms, total: 7.53 s\n", | |
| "Wall time: 7.57 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "results_np_scipy = _weighted_percentile_scipy(\n", | |
| " data_np, sample_weight=weights_np, percentile_rank=[5, 25, 50, 75, 95], average=True\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "07ae966e", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "np.testing.assert_allclose(results_np_sklearn, results_np_scipy)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "f40ad1ab", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "\n", | |
| "data_torch_mps = torch.from_numpy(data_np).to(\"mps\")\n", | |
| "weights_torch_mps = torch.from_numpy(weights_np).to(\"mps\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "dfe6ef10", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 409 ms, sys: 341 ms, total: 750 ms\n", | |
| "Wall time: 2.28 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "results_torch_sklearn = _weighted_percentile_sklearn(\n", | |
| " data_torch_mps, sample_weight=weights_torch_mps, percentile_rank=[5, 25, 50, 75, 95], average=True\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "5f8d972d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 174 ms, sys: 175 ms, total: 349 ms\n", | |
| "Wall time: 1.52 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "results_torch_scipy = _weighted_percentile_scipy(\n", | |
| " data_torch_mps, sample_weight=weights_torch_mps, percentile_rank=[5, 25, 50, 75, 95], average=True\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "6926d920", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "np.testing.assert_allclose(results_torch_sklearn.cpu(), results_torch_scipy.cpu())" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "dev", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.13.7" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment