Skip to content

Instantly share code, notes, and snippets.

@agcom
Last active November 2, 2025 13:47
Show Gist options
  • Select an option

  • Save agcom/963f6a494b4f20fc8437380289e5a63f to your computer and use it in GitHub Desktop.

Select an option

Save agcom/963f6a494b4f20fc8437380289e5a63f to your computer and use it in GitHub Desktop.
Creative stratified split of a dataset using scikit-learn's `train_test_split` function recursively
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-11-02T13:47:05.358868Z",
"start_time": "2025-11-02T13:47:01.931326Z"
}
},
"cell_type": "code",
"source": [
"from torchvision.datasets.mnist import MNIST\n",
"\n",
"dataset = MNIST('./dataset/', train=True, download=False)"
],
"id": "db534eb0b832929f",
"outputs": [],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-11-02T13:47:06.626458Z",
"start_time": "2025-11-02T13:47:05.368305Z"
}
},
"cell_type": "code",
"source": [
"from stratified import stratified_split\n",
"from random import Random\n",
"\n",
"from torch.utils.data.dataset import Subset\n",
"\n",
"splits = stratified_split(list(range(len(dataset))), labels=dataset.targets, n_splits=10, rng=Random(42))\n",
"splits = tuple(Subset(dataset, split) for split in splits)"
],
"id": "c98213c371f2939e",
"outputs": [],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-11-02T13:47:13.173297Z",
"start_time": "2025-11-02T13:47:06.707030Z"
}
},
"cell_type": "code",
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import plotly.express as px\n",
"\n",
"splits_labels_counts = []\n",
"for split_idx, split in enumerate(splits):\n",
"\tlabels, counts = np.unique([dataset.targets[i] for i in split.indices], return_counts=True)\n",
"\tsplits_labels_counts.append(dict(zip(labels, counts)) | {'split_index': split_idx})\n",
"\n",
"df = pd.DataFrame(splits_labels_counts)\n",
"df = df.melt(id_vars=('split_index',), value_vars=set(df.columns) - {'split_index'}, var_name='label', value_name='count')\n",
"\n",
"px.bar(\n",
"\tdf,\n",
"\tx='split_index', y='count', color='label',\n",
"\tlabels={'split_index': 'Split Index', 'count': 'Count', 'label': 'Label'}\n",
").show()"
],
"id": "980fd550b116a35c",
"outputs": [
{
"data": {
"application/vnd.plotly.v1+json": {
"data": [
{
"hovertemplate": "Label=0<br>Split Index=%{x}<br>Count=%{y}<extra></extra>",
"legendgroup": "0",
"marker": {
"color": "#636efa",
"pattern": {
"shape": ""
}
},
"name": "0",
"orientation": "v",
"showlegend": true,
"textposition": "auto",
"x": {
"dtype": "i1",
"bdata": "AAECAwQFBgcICQ=="
},
"xaxis": "x",
"y": {
"dtype": "i2",
"bdata": "UAJQAlECUAJRAlACUAJQAlECUAI="
},
"yaxis": "y",
"type": "bar"
},
{
"hovertemplate": "Label=1<br>Split Index=%{x}<br>Count=%{y}<extra></extra>",
"legendgroup": "1",
"marker": {
"color": "#EF553B",
"pattern": {
"shape": ""
}
},
"name": "1",
"orientation": "v",
"showlegend": true,
"textposition": "auto",
"x": {
"dtype": "i1",
"bdata": "AAECAwQFBgcICQ=="
},
"xaxis": "x",
"y": {
"dtype": "i2",
"bdata": "ogKiAqICogKiAqICowKiAqMCogI="
},
"yaxis": "y",
"type": "bar"
},
{
"hovertemplate": "Label=2<br>Split Index=%{x}<br>Count=%{y}<extra></extra>",
"legendgroup": "2",
"marker": {
"color": "#00cc96",
"pattern": {
"shape": ""
}
},
"name": "2",
"orientation": "v",
"showlegend": true,
"textposition": "auto",
"x": {
"dtype": "i1",
"bdata": "AAECAwQFBgcICQ=="
},
"xaxis": "x",
"y": {
"dtype": "i2",
"bdata": "VAJUAlQCVAJUAlQCVAJTAlMCVAI="
},
"yaxis": "y",
"type": "bar"
},
{
"hovertemplate": "Label=3<br>Split Index=%{x}<br>Count=%{y}<extra></extra>",
"legendgroup": "3",
"marker": {
"color": "#ab63fa",
"pattern": {
"shape": ""
}
},
"name": "3",
"orientation": "v",
"showlegend": true,
"textposition": "auto",
"x": {
"dtype": "i1",
"bdata": "AAECAwQFBgcICQ=="
},
"xaxis": "x",
"y": {
"dtype": "i2",
"bdata": "ZQJlAmUCZQJlAmUCZQJlAmYCZQI="
},
"yaxis": "y",
"type": "bar"
},
{
"hovertemplate": "Label=4<br>Split Index=%{x}<br>Count=%{y}<extra></extra>",
"legendgroup": "4",
"marker": {
"color": "#FFA15A",
"pattern": {
"shape": ""
}
},
"name": "4",
"orientation": "v",
"showlegend": true,
"textposition": "auto",
"x": {
"dtype": "i1",
"bdata": "AAECAwQFBgcICQ=="
},
"xaxis": "x",
"y": {
"dtype": "i2",
"bdata": "SAJIAkgCSAJIAkgCSQJIAkgCSQI="
},
"yaxis": "y",
"type": "bar"
},
{
"hovertemplate": "Label=5<br>Split Index=%{x}<br>Count=%{y}<extra></extra>",
"legendgroup": "5",
"marker": {
"color": "#19d3f3",
"pattern": {
"shape": ""
}
},
"name": "5",
"orientation": "v",
"showlegend": true,
"textposition": "auto",
"x": {
"dtype": "i1",
"bdata": "AAECAwQFBgcICQ=="
},
"xaxis": "x",
"y": {
"dtype": "i2",
"bdata": "HgIeAh4CHgIeAh4CHgIfAh4CHgI="
},
"yaxis": "y",
"type": "bar"
},
{
"hovertemplate": "Label=6<br>Split Index=%{x}<br>Count=%{y}<extra></extra>",
"legendgroup": "6",
"marker": {
"color": "#FF6692",
"pattern": {
"shape": ""
}
},
"name": "6",
"orientation": "v",
"showlegend": true,
"textposition": "auto",
"x": {
"dtype": "i1",
"bdata": "AAECAwQFBgcICQ=="
},
"xaxis": "x",
"y": {
"dtype": "i2",
"bdata": "UAJQAlACUAJQAlACTwJQAk8CUAI="
},
"yaxis": "y",
"type": "bar"
},
{
"hovertemplate": "Label=7<br>Split Index=%{x}<br>Count=%{y}<extra></extra>",
"legendgroup": "7",
"marker": {
"color": "#B6E880",
"pattern": {
"shape": ""
}
},
"name": "7",
"orientation": "v",
"showlegend": true,
"textposition": "auto",
"x": {
"dtype": "i1",
"bdata": "AAECAwQFBgcICQ=="
},
"xaxis": "x",
"y": {
"dtype": "i2",
"bdata": "cwJzAnICcwJyAnMCcgJyAnMCcgI="
},
"yaxis": "y",
"type": "bar"
},
{
"hovertemplate": "Label=8<br>Split Index=%{x}<br>Count=%{y}<extra></extra>",
"legendgroup": "8",
"marker": {
"color": "#FF97FF",
"pattern": {
"shape": ""
}
},
"name": "8",
"orientation": "v",
"showlegend": true,
"textposition": "auto",
"x": {
"dtype": "i1",
"bdata": "AAECAwQFBgcICQ=="
},
"xaxis": "x",
"y": {
"dtype": "i2",
"bdata": "SQJJAkkCSQJJAkkCSQJKAkkCSQI="
},
"yaxis": "y",
"type": "bar"
},
{
"hovertemplate": "Label=9<br>Split Index=%{x}<br>Count=%{y}<extra></extra>",
"legendgroup": "9",
"marker": {
"color": "#FECB52",
"pattern": {
"shape": ""
}
},
"name": "9",
"orientation": "v",
"showlegend": true,
"textposition": "auto",
"x": {
"dtype": "i1",
"bdata": "AAECAwQFBgcICQ=="
},
"xaxis": "x",
"y": {
"dtype": "i2",
"bdata": "UwJTAlMCUwJTAlMCUwJTAlICUwI="
},
"yaxis": "y",
"type": "bar"
}
],
"layout": {
"template": {
"data": {
"histogram2dcontour": [
{
"type": "histogram2dcontour",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"choropleth": [
{
"type": "choropleth",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
],
"histogram2d": [
{
"type": "histogram2d",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"heatmap": [
{
"type": "heatmap",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"contourcarpet": [
{
"type": "contourcarpet",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
],
"contour": [
{
"type": "contour",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"surface": [
{
"type": "surface",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"mesh3d": [
{
"type": "mesh3d",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
],
"scatter": [
{
"marker": {
"line": {
"color": "#283442"
}
},
"type": "scatter"
}
],
"parcoords": [
{
"type": "parcoords",
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scatterpolargl": [
{
"type": "scatterpolargl",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"bar": [
{
"error_x": {
"color": "#f2f5fa"
},
"error_y": {
"color": "#f2f5fa"
},
"marker": {
"line": {
"color": "rgb(17,17,17)",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "bar"
}
],
"scattergeo": [
{
"type": "scattergeo",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scatterpolar": [
{
"type": "scatterpolar",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"histogram": [
{
"marker": {
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "histogram"
}
],
"scattergl": [
{
"marker": {
"line": {
"color": "#283442"
}
},
"type": "scattergl"
}
],
"scatter3d": [
{
"type": "scatter3d",
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scattermap": [
{
"type": "scattermap",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scattermapbox": [
{
"type": "scattermapbox",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scatterternary": [
{
"type": "scatterternary",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scattercarpet": [
{
"type": "scattercarpet",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"carpet": [
{
"aaxis": {
"endlinecolor": "#A2B1C6",
"gridcolor": "#506784",
"linecolor": "#506784",
"minorgridcolor": "#506784",
"startlinecolor": "#A2B1C6"
},
"baxis": {
"endlinecolor": "#A2B1C6",
"gridcolor": "#506784",
"linecolor": "#506784",
"minorgridcolor": "#506784",
"startlinecolor": "#A2B1C6"
},
"type": "carpet"
}
],
"table": [
{
"cells": {
"fill": {
"color": "#506784"
},
"line": {
"color": "rgb(17,17,17)"
}
},
"header": {
"fill": {
"color": "#2a3f5f"
},
"line": {
"color": "rgb(17,17,17)"
}
},
"type": "table"
}
],
"barpolar": [
{
"marker": {
"line": {
"color": "rgb(17,17,17)",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "barpolar"
}
],
"pie": [
{
"automargin": true,
"type": "pie"
}
]
},
"layout": {
"autotypenumbers": "strict",
"colorway": [
"#636efa",
"#EF553B",
"#00cc96",
"#ab63fa",
"#FFA15A",
"#19d3f3",
"#FF6692",
"#B6E880",
"#FF97FF",
"#FECB52"
],
"font": {
"color": "#f2f5fa"
},
"hovermode": "closest",
"hoverlabel": {
"align": "left"
},
"paper_bgcolor": "rgb(17,17,17)",
"plot_bgcolor": "rgb(17,17,17)",
"polar": {
"bgcolor": "rgb(17,17,17)",
"angularaxis": {
"gridcolor": "#506784",
"linecolor": "#506784",
"ticks": ""
},
"radialaxis": {
"gridcolor": "#506784",
"linecolor": "#506784",
"ticks": ""
}
},
"ternary": {
"bgcolor": "rgb(17,17,17)",
"aaxis": {
"gridcolor": "#506784",
"linecolor": "#506784",
"ticks": ""
},
"baxis": {
"gridcolor": "#506784",
"linecolor": "#506784",
"ticks": ""
},
"caxis": {
"gridcolor": "#506784",
"linecolor": "#506784",
"ticks": ""
}
},
"coloraxis": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"colorscale": {
"sequential": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
],
"sequentialminus": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
],
"diverging": [
[
0,
"#8e0152"
],
[
0.1,
"#c51b7d"
],
[
0.2,
"#de77ae"
],
[
0.3,
"#f1b6da"
],
[
0.4,
"#fde0ef"
],
[
0.5,
"#f7f7f7"
],
[
0.6,
"#e6f5d0"
],
[
0.7,
"#b8e186"
],
[
0.8,
"#7fbc41"
],
[
0.9,
"#4d9221"
],
[
1,
"#276419"
]
]
},
"xaxis": {
"gridcolor": "#283442",
"linecolor": "#506784",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "#283442",
"automargin": true,
"zerolinewidth": 2
},
"yaxis": {
"gridcolor": "#283442",
"linecolor": "#506784",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "#283442",
"automargin": true,
"zerolinewidth": 2
},
"scene": {
"xaxis": {
"backgroundcolor": "rgb(17,17,17)",
"gridcolor": "#506784",
"linecolor": "#506784",
"showbackground": true,
"ticks": "",
"zerolinecolor": "#C8D4E3",
"gridwidth": 2
},
"yaxis": {
"backgroundcolor": "rgb(17,17,17)",
"gridcolor": "#506784",
"linecolor": "#506784",
"showbackground": true,
"ticks": "",
"zerolinecolor": "#C8D4E3",
"gridwidth": 2
},
"zaxis": {
"backgroundcolor": "rgb(17,17,17)",
"gridcolor": "#506784",
"linecolor": "#506784",
"showbackground": true,
"ticks": "",
"zerolinecolor": "#C8D4E3",
"gridwidth": 2
}
},
"shapedefaults": {
"line": {
"color": "#f2f5fa"
}
},
"annotationdefaults": {
"arrowcolor": "#f2f5fa",
"arrowhead": 0,
"arrowwidth": 1
},
"geo": {
"bgcolor": "rgb(17,17,17)",
"landcolor": "rgb(17,17,17)",
"subunitcolor": "#506784",
"showland": true,
"showlakes": true,
"lakecolor": "rgb(17,17,17)"
},
"title": {
"x": 0.05
},
"updatemenudefaults": {
"bgcolor": "#506784",
"borderwidth": 0
},
"sliderdefaults": {
"bgcolor": "#C8D4E3",
"borderwidth": 1,
"bordercolor": "rgb(17,17,17)",
"tickwidth": 0
},
"mapbox": {
"style": "dark"
}
}
},
"xaxis": {
"anchor": "y",
"domain": [
0.0,
1.0
],
"title": {
"text": "Split Index"
}
},
"yaxis": {
"anchor": "x",
"domain": [
0.0,
1.0
],
"title": {
"text": "Count"
}
},
"legend": {
"title": {
"text": "Label"
},
"tracegroupgap": 0
},
"margin": {
"t": 60
},
"barmode": "relative"
},
"config": {
"plotlyServerURL": "https://plot.ly"
}
}
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
}
],
"execution_count": 3
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"language": "python",
"display_name": "Python 3 (ipykernel)"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
import random
from random import Random
from typing import Any
from sklearn.model_selection import train_test_split
def stratified_split(*arrays, labels, n_splits: int, rnd: Random | None = None) -> tuple[Any, ...]:
if n_splits == 1:
return arrays
else:
splits = train_test_split(
*arrays, labels,
train_size=1 / n_splits,
# Found the following range from its respective error message if passed a random float.
random_state=rnd.randint(0, 4294967295) if rnd is not None else random.randint(0, 4294967295),
stratify=labels
)
labels_split, labels_rest = splits[-2:]
arrays_splits = splits[:-2]
return *arrays_splits[::2], *stratified_split(*arrays_splits[1::2], labels=labels_rest, n_splits=n_splits - 1, rnd=rnd)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment