Skip to content

Instantly share code, notes, and snippets.

@ogrisel
Created November 18, 2025 09:31
Show Gist options
  • Select an option

  • Save ogrisel/8c645970dcd9478679ad1a7669f25c18 to your computer and use it in GitHub Desktop.

Select an option

Save ogrisel/8c645970dcd9478679ad1a7669f25c18 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "da11cd5d",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"SCIPY_ARRAY_API\"] = \"1\"\n",
"os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n",
"\n",
"import sklearn\n",
"sklearn.set_config(array_api_dispatch=True)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e64da94d",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.utils._array_api import get_namespace_and_device\n",
"from scipy.stats import quantile\n",
"from sklearn.utils.stats import _weighted_percentile as _weighted_percentile_sklearn\n",
"import numpy as np\n",
"\n",
"\n",
"def _weighted_percentile_scipy(array, sample_weight, percentile_rank=10, average=False, xp=None):\n",
" xp, _, device = get_namespace_and_device(array)\n",
"\n",
" array = xp.asarray(array, device=device)\n",
" if not xp.isdtype(array.dtype, \"real floating\"):\n",
" array = xp.asarray(array, dtype=xp.asarray([0.0]).dtype, device=device)\n",
"\n",
" p = xp.asarray(percentile_rank, dtype=array.dtype, device=device) / 100\n",
" sample_weight = xp.asarray(sample_weight, dtype=array.dtype, device=device)\n",
"\n",
" results = quantile(\n",
" array.T if array.ndim == 2 else array,\n",
" p=p,\n",
" axis=-1 if array.ndim == 2 else None,\n",
" nan_policy=\"omit\",\n",
" weights=sample_weight.T if sample_weight.ndim == 2 else sample_weight,\n",
" method=\"averaged_inverted_cdf\" if average else \"inverted_cdf\",\n",
" )\n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d314b4f1",
"metadata": {},
"outputs": [],
"source": [
"rng = np.random.default_rng(42)\n",
"n_samples = int(1e5)\n",
"n_features = int(1e3)\n",
"data_np = rng.lognormal(size=(n_samples, n_features)).astype(np.float32)\n",
"weights_np = (\n",
" rng.uniform(low=-1, high=5, size=data_np.shape[0]).clip(min=0).astype(np.float32)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "33a393ab",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 7.22 s, sys: 236 ms, total: 7.45 s\n",
"Wall time: 7.47 s\n"
]
}
],
"source": [
"%%time\n",
"results_np_sklearn = _weighted_percentile_sklearn(\n",
" data_np, sample_weight=weights_np, percentile_rank=[5, 25, 50, 75, 95], average=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e540373f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 7.33 s, sys: 191 ms, total: 7.53 s\n",
"Wall time: 7.57 s\n"
]
}
],
"source": [
"%%time\n",
"results_np_scipy = _weighted_percentile_scipy(\n",
" data_np, sample_weight=weights_np, percentile_rank=[5, 25, 50, 75, 95], average=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "07ae966e",
"metadata": {},
"outputs": [],
"source": [
"np.testing.assert_allclose(results_np_sklearn, results_np_scipy)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f40ad1ab",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"data_torch_mps = torch.from_numpy(data_np).to(\"mps\")\n",
"weights_torch_mps = torch.from_numpy(weights_np).to(\"mps\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "dfe6ef10",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 409 ms, sys: 341 ms, total: 750 ms\n",
"Wall time: 2.28 s\n"
]
}
],
"source": [
"%%time\n",
"results_torch_sklearn = _weighted_percentile_sklearn(\n",
" data_torch_mps, sample_weight=weights_torch_mps, percentile_rank=[5, 25, 50, 75, 95], average=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "5f8d972d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 174 ms, sys: 175 ms, total: 349 ms\n",
"Wall time: 1.52 s\n"
]
}
],
"source": [
"%%time\n",
"results_torch_scipy = _weighted_percentile_scipy(\n",
" data_torch_mps, sample_weight=weights_torch_mps, percentile_rank=[5, 25, 50, 75, 95], average=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "6926d920",
"metadata": {},
"outputs": [],
"source": [
"np.testing.assert_allclose(results_torch_sklearn.cpu(), results_torch_scipy.cpu())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment