Last active
November 2, 2025 13:47
-
-
Save agcom/963f6a494b4f20fc8437380289e5a63f to your computer and use it in GitHub Desktop.
Creative stratified split of a dataset using scikit-learn's `train_test_split` function recursively
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2025-11-02T13:47:05.358868Z", | |
| "start_time": "2025-11-02T13:47:01.931326Z" | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "from torchvision.datasets.mnist import MNIST\n", | |
| "\n", | |
| "dataset = MNIST('./dataset/', train=True, download=False)" | |
| ], | |
| "id": "db534eb0b832929f", | |
| "outputs": [], | |
| "execution_count": 1 | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2025-11-02T13:47:06.626458Z", | |
| "start_time": "2025-11-02T13:47:05.368305Z" | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "from stratified import stratified_split\n", | |
| "from random import Random\n", | |
| "\n", | |
| "from torch.utils.data.dataset import Subset\n", | |
| "\n", | |
| "splits = stratified_split(list(range(len(dataset))), labels=dataset.targets, n_splits=10, rng=Random(42))\n", | |
| "splits = tuple(Subset(dataset, split) for split in splits)" | |
| ], | |
| "id": "c98213c371f2939e", | |
| "outputs": [], | |
| "execution_count": 2 | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2025-11-02T13:47:13.173297Z", | |
| "start_time": "2025-11-02T13:47:06.707030Z" | |
| } | |
| }, | |
| "cell_type": "code", | |
| "source": [ | |
| "import numpy as np\n", | |
| "import pandas as pd\n", | |
| "import plotly.express as px\n", | |
| "\n", | |
| "splits_labels_counts = []\n", | |
| "for split_idx, split in enumerate(splits):\n", | |
| "\tlabels, counts = np.unique([dataset.targets[i] for i in split.indices], return_counts=True)\n", | |
| "\tsplits_labels_counts.append(dict(zip(labels, counts)) | {'split_index': split_idx})\n", | |
| "\n", | |
| "df = pd.DataFrame(splits_labels_counts)\n", | |
| "df = df.melt(id_vars=('split_index',), value_vars=set(df.columns) - {'split_index'}, var_name='label', value_name='count')\n", | |
| "\n", | |
| "px.bar(\n", | |
| "\tdf,\n", | |
| "\tx='split_index', y='count', color='label',\n", | |
| "\tlabels={'split_index': 'Split Index', 'count': 'Count', 'label': 'Label'}\n", | |
| ").show()" | |
| ], | |
| "id": "980fd550b116a35c", | |
| "outputs": [ | |
| { | |
| "data": { | |
| "application/vnd.plotly.v1+json": { | |
| "data": [ | |
| { | |
| "hovertemplate": "Label=0<br>Split Index=%{x}<br>Count=%{y}<extra></extra>", | |
| "legendgroup": "0", | |
| "marker": { | |
| "color": "#636efa", | |
| "pattern": { | |
| "shape": "" | |
| } | |
| }, | |
| "name": "0", | |
| "orientation": "v", | |
| "showlegend": true, | |
| "textposition": "auto", | |
| "x": { | |
| "dtype": "i1", | |
| "bdata": "AAECAwQFBgcICQ==" | |
| }, | |
| "xaxis": "x", | |
| "y": { | |
| "dtype": "i2", | |
| "bdata": "UAJQAlECUAJRAlACUAJQAlECUAI=" | |
| }, | |
| "yaxis": "y", | |
| "type": "bar" | |
| }, | |
| { | |
| "hovertemplate": "Label=1<br>Split Index=%{x}<br>Count=%{y}<extra></extra>", | |
| "legendgroup": "1", | |
| "marker": { | |
| "color": "#EF553B", | |
| "pattern": { | |
| "shape": "" | |
| } | |
| }, | |
| "name": "1", | |
| "orientation": "v", | |
| "showlegend": true, | |
| "textposition": "auto", | |
| "x": { | |
| "dtype": "i1", | |
| "bdata": "AAECAwQFBgcICQ==" | |
| }, | |
| "xaxis": "x", | |
| "y": { | |
| "dtype": "i2", | |
| "bdata": "ogKiAqICogKiAqICowKiAqMCogI=" | |
| }, | |
| "yaxis": "y", | |
| "type": "bar" | |
| }, | |
| { | |
| "hovertemplate": "Label=2<br>Split Index=%{x}<br>Count=%{y}<extra></extra>", | |
| "legendgroup": "2", | |
| "marker": { | |
| "color": "#00cc96", | |
| "pattern": { | |
| "shape": "" | |
| } | |
| }, | |
| "name": "2", | |
| "orientation": "v", | |
| "showlegend": true, | |
| "textposition": "auto", | |
| "x": { | |
| "dtype": "i1", | |
| "bdata": "AAECAwQFBgcICQ==" | |
| }, | |
| "xaxis": "x", | |
| "y": { | |
| "dtype": "i2", | |
| "bdata": "VAJUAlQCVAJUAlQCVAJTAlMCVAI=" | |
| }, | |
| "yaxis": "y", | |
| "type": "bar" | |
| }, | |
| { | |
| "hovertemplate": "Label=3<br>Split Index=%{x}<br>Count=%{y}<extra></extra>", | |
| "legendgroup": "3", | |
| "marker": { | |
| "color": "#ab63fa", | |
| "pattern": { | |
| "shape": "" | |
| } | |
| }, | |
| "name": "3", | |
| "orientation": "v", | |
| "showlegend": true, | |
| "textposition": "auto", | |
| "x": { | |
| "dtype": "i1", | |
| "bdata": "AAECAwQFBgcICQ==" | |
| }, | |
| "xaxis": "x", | |
| "y": { | |
| "dtype": "i2", | |
| "bdata": "ZQJlAmUCZQJlAmUCZQJlAmYCZQI=" | |
| }, | |
| "yaxis": "y", | |
| "type": "bar" | |
| }, | |
| { | |
| "hovertemplate": "Label=4<br>Split Index=%{x}<br>Count=%{y}<extra></extra>", | |
| "legendgroup": "4", | |
| "marker": { | |
| "color": "#FFA15A", | |
| "pattern": { | |
| "shape": "" | |
| } | |
| }, | |
| "name": "4", | |
| "orientation": "v", | |
| "showlegend": true, | |
| "textposition": "auto", | |
| "x": { | |
| "dtype": "i1", | |
| "bdata": "AAECAwQFBgcICQ==" | |
| }, | |
| "xaxis": "x", | |
| "y": { | |
| "dtype": "i2", | |
| "bdata": "SAJIAkgCSAJIAkgCSQJIAkgCSQI=" | |
| }, | |
| "yaxis": "y", | |
| "type": "bar" | |
| }, | |
| { | |
| "hovertemplate": "Label=5<br>Split Index=%{x}<br>Count=%{y}<extra></extra>", | |
| "legendgroup": "5", | |
| "marker": { | |
| "color": "#19d3f3", | |
| "pattern": { | |
| "shape": "" | |
| } | |
| }, | |
| "name": "5", | |
| "orientation": "v", | |
| "showlegend": true, | |
| "textposition": "auto", | |
| "x": { | |
| "dtype": "i1", | |
| "bdata": "AAECAwQFBgcICQ==" | |
| }, | |
| "xaxis": "x", | |
| "y": { | |
| "dtype": "i2", | |
| "bdata": "HgIeAh4CHgIeAh4CHgIfAh4CHgI=" | |
| }, | |
| "yaxis": "y", | |
| "type": "bar" | |
| }, | |
| { | |
| "hovertemplate": "Label=6<br>Split Index=%{x}<br>Count=%{y}<extra></extra>", | |
| "legendgroup": "6", | |
| "marker": { | |
| "color": "#FF6692", | |
| "pattern": { | |
| "shape": "" | |
| } | |
| }, | |
| "name": "6", | |
| "orientation": "v", | |
| "showlegend": true, | |
| "textposition": "auto", | |
| "x": { | |
| "dtype": "i1", | |
| "bdata": "AAECAwQFBgcICQ==" | |
| }, | |
| "xaxis": "x", | |
| "y": { | |
| "dtype": "i2", | |
| "bdata": "UAJQAlACUAJQAlACTwJQAk8CUAI=" | |
| }, | |
| "yaxis": "y", | |
| "type": "bar" | |
| }, | |
| { | |
| "hovertemplate": "Label=7<br>Split Index=%{x}<br>Count=%{y}<extra></extra>", | |
| "legendgroup": "7", | |
| "marker": { | |
| "color": "#B6E880", | |
| "pattern": { | |
| "shape": "" | |
| } | |
| }, | |
| "name": "7", | |
| "orientation": "v", | |
| "showlegend": true, | |
| "textposition": "auto", | |
| "x": { | |
| "dtype": "i1", | |
| "bdata": "AAECAwQFBgcICQ==" | |
| }, | |
| "xaxis": "x", | |
| "y": { | |
| "dtype": "i2", | |
| "bdata": "cwJzAnICcwJyAnMCcgJyAnMCcgI=" | |
| }, | |
| "yaxis": "y", | |
| "type": "bar" | |
| }, | |
| { | |
| "hovertemplate": "Label=8<br>Split Index=%{x}<br>Count=%{y}<extra></extra>", | |
| "legendgroup": "8", | |
| "marker": { | |
| "color": "#FF97FF", | |
| "pattern": { | |
| "shape": "" | |
| } | |
| }, | |
| "name": "8", | |
| "orientation": "v", | |
| "showlegend": true, | |
| "textposition": "auto", | |
| "x": { | |
| "dtype": "i1", | |
| "bdata": "AAECAwQFBgcICQ==" | |
| }, | |
| "xaxis": "x", | |
| "y": { | |
| "dtype": "i2", | |
| "bdata": "SQJJAkkCSQJJAkkCSQJKAkkCSQI=" | |
| }, | |
| "yaxis": "y", | |
| "type": "bar" | |
| }, | |
| { | |
| "hovertemplate": "Label=9<br>Split Index=%{x}<br>Count=%{y}<extra></extra>", | |
| "legendgroup": "9", | |
| "marker": { | |
| "color": "#FECB52", | |
| "pattern": { | |
| "shape": "" | |
| } | |
| }, | |
| "name": "9", | |
| "orientation": "v", | |
| "showlegend": true, | |
| "textposition": "auto", | |
| "x": { | |
| "dtype": "i1", | |
| "bdata": "AAECAwQFBgcICQ==" | |
| }, | |
| "xaxis": "x", | |
| "y": { | |
| "dtype": "i2", | |
| "bdata": "UwJTAlMCUwJTAlMCUwJTAlICUwI=" | |
| }, | |
| "yaxis": "y", | |
| "type": "bar" | |
| } | |
| ], | |
| "layout": { | |
| "template": { | |
| "data": { | |
| "histogram2dcontour": [ | |
| { | |
| "type": "histogram2dcontour", | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| }, | |
| "colorscale": [ | |
| [ | |
| 0.0, | |
| "#0d0887" | |
| ], | |
| [ | |
| 0.1111111111111111, | |
| "#46039f" | |
| ], | |
| [ | |
| 0.2222222222222222, | |
| "#7201a8" | |
| ], | |
| [ | |
| 0.3333333333333333, | |
| "#9c179e" | |
| ], | |
| [ | |
| 0.4444444444444444, | |
| "#bd3786" | |
| ], | |
| [ | |
| 0.5555555555555556, | |
| "#d8576b" | |
| ], | |
| [ | |
| 0.6666666666666666, | |
| "#ed7953" | |
| ], | |
| [ | |
| 0.7777777777777778, | |
| "#fb9f3a" | |
| ], | |
| [ | |
| 0.8888888888888888, | |
| "#fdca26" | |
| ], | |
| [ | |
| 1.0, | |
| "#f0f921" | |
| ] | |
| ] | |
| } | |
| ], | |
| "choropleth": [ | |
| { | |
| "type": "choropleth", | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| } | |
| } | |
| ], | |
| "histogram2d": [ | |
| { | |
| "type": "histogram2d", | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| }, | |
| "colorscale": [ | |
| [ | |
| 0.0, | |
| "#0d0887" | |
| ], | |
| [ | |
| 0.1111111111111111, | |
| "#46039f" | |
| ], | |
| [ | |
| 0.2222222222222222, | |
| "#7201a8" | |
| ], | |
| [ | |
| 0.3333333333333333, | |
| "#9c179e" | |
| ], | |
| [ | |
| 0.4444444444444444, | |
| "#bd3786" | |
| ], | |
| [ | |
| 0.5555555555555556, | |
| "#d8576b" | |
| ], | |
| [ | |
| 0.6666666666666666, | |
| "#ed7953" | |
| ], | |
| [ | |
| 0.7777777777777778, | |
| "#fb9f3a" | |
| ], | |
| [ | |
| 0.8888888888888888, | |
| "#fdca26" | |
| ], | |
| [ | |
| 1.0, | |
| "#f0f921" | |
| ] | |
| ] | |
| } | |
| ], | |
| "heatmap": [ | |
| { | |
| "type": "heatmap", | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| }, | |
| "colorscale": [ | |
| [ | |
| 0.0, | |
| "#0d0887" | |
| ], | |
| [ | |
| 0.1111111111111111, | |
| "#46039f" | |
| ], | |
| [ | |
| 0.2222222222222222, | |
| "#7201a8" | |
| ], | |
| [ | |
| 0.3333333333333333, | |
| "#9c179e" | |
| ], | |
| [ | |
| 0.4444444444444444, | |
| "#bd3786" | |
| ], | |
| [ | |
| 0.5555555555555556, | |
| "#d8576b" | |
| ], | |
| [ | |
| 0.6666666666666666, | |
| "#ed7953" | |
| ], | |
| [ | |
| 0.7777777777777778, | |
| "#fb9f3a" | |
| ], | |
| [ | |
| 0.8888888888888888, | |
| "#fdca26" | |
| ], | |
| [ | |
| 1.0, | |
| "#f0f921" | |
| ] | |
| ] | |
| } | |
| ], | |
| "contourcarpet": [ | |
| { | |
| "type": "contourcarpet", | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| } | |
| } | |
| ], | |
| "contour": [ | |
| { | |
| "type": "contour", | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| }, | |
| "colorscale": [ | |
| [ | |
| 0.0, | |
| "#0d0887" | |
| ], | |
| [ | |
| 0.1111111111111111, | |
| "#46039f" | |
| ], | |
| [ | |
| 0.2222222222222222, | |
| "#7201a8" | |
| ], | |
| [ | |
| 0.3333333333333333, | |
| "#9c179e" | |
| ], | |
| [ | |
| 0.4444444444444444, | |
| "#bd3786" | |
| ], | |
| [ | |
| 0.5555555555555556, | |
| "#d8576b" | |
| ], | |
| [ | |
| 0.6666666666666666, | |
| "#ed7953" | |
| ], | |
| [ | |
| 0.7777777777777778, | |
| "#fb9f3a" | |
| ], | |
| [ | |
| 0.8888888888888888, | |
| "#fdca26" | |
| ], | |
| [ | |
| 1.0, | |
| "#f0f921" | |
| ] | |
| ] | |
| } | |
| ], | |
| "surface": [ | |
| { | |
| "type": "surface", | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| }, | |
| "colorscale": [ | |
| [ | |
| 0.0, | |
| "#0d0887" | |
| ], | |
| [ | |
| 0.1111111111111111, | |
| "#46039f" | |
| ], | |
| [ | |
| 0.2222222222222222, | |
| "#7201a8" | |
| ], | |
| [ | |
| 0.3333333333333333, | |
| "#9c179e" | |
| ], | |
| [ | |
| 0.4444444444444444, | |
| "#bd3786" | |
| ], | |
| [ | |
| 0.5555555555555556, | |
| "#d8576b" | |
| ], | |
| [ | |
| 0.6666666666666666, | |
| "#ed7953" | |
| ], | |
| [ | |
| 0.7777777777777778, | |
| "#fb9f3a" | |
| ], | |
| [ | |
| 0.8888888888888888, | |
| "#fdca26" | |
| ], | |
| [ | |
| 1.0, | |
| "#f0f921" | |
| ] | |
| ] | |
| } | |
| ], | |
| "mesh3d": [ | |
| { | |
| "type": "mesh3d", | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| } | |
| } | |
| ], | |
| "scatter": [ | |
| { | |
| "marker": { | |
| "line": { | |
| "color": "#283442" | |
| } | |
| }, | |
| "type": "scatter" | |
| } | |
| ], | |
| "parcoords": [ | |
| { | |
| "type": "parcoords", | |
| "line": { | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| } | |
| } | |
| } | |
| ], | |
| "scatterpolargl": [ | |
| { | |
| "type": "scatterpolargl", | |
| "marker": { | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| } | |
| } | |
| } | |
| ], | |
| "bar": [ | |
| { | |
| "error_x": { | |
| "color": "#f2f5fa" | |
| }, | |
| "error_y": { | |
| "color": "#f2f5fa" | |
| }, | |
| "marker": { | |
| "line": { | |
| "color": "rgb(17,17,17)", | |
| "width": 0.5 | |
| }, | |
| "pattern": { | |
| "fillmode": "overlay", | |
| "size": 10, | |
| "solidity": 0.2 | |
| } | |
| }, | |
| "type": "bar" | |
| } | |
| ], | |
| "scattergeo": [ | |
| { | |
| "type": "scattergeo", | |
| "marker": { | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| } | |
| } | |
| } | |
| ], | |
| "scatterpolar": [ | |
| { | |
| "type": "scatterpolar", | |
| "marker": { | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| } | |
| } | |
| } | |
| ], | |
| "histogram": [ | |
| { | |
| "marker": { | |
| "pattern": { | |
| "fillmode": "overlay", | |
| "size": 10, | |
| "solidity": 0.2 | |
| } | |
| }, | |
| "type": "histogram" | |
| } | |
| ], | |
| "scattergl": [ | |
| { | |
| "marker": { | |
| "line": { | |
| "color": "#283442" | |
| } | |
| }, | |
| "type": "scattergl" | |
| } | |
| ], | |
| "scatter3d": [ | |
| { | |
| "type": "scatter3d", | |
| "line": { | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| } | |
| }, | |
| "marker": { | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| } | |
| } | |
| } | |
| ], | |
| "scattermap": [ | |
| { | |
| "type": "scattermap", | |
| "marker": { | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| } | |
| } | |
| } | |
| ], | |
| "scattermapbox": [ | |
| { | |
| "type": "scattermapbox", | |
| "marker": { | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| } | |
| } | |
| } | |
| ], | |
| "scatterternary": [ | |
| { | |
| "type": "scatterternary", | |
| "marker": { | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| } | |
| } | |
| } | |
| ], | |
| "scattercarpet": [ | |
| { | |
| "type": "scattercarpet", | |
| "marker": { | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| } | |
| } | |
| } | |
| ], | |
| "carpet": [ | |
| { | |
| "aaxis": { | |
| "endlinecolor": "#A2B1C6", | |
| "gridcolor": "#506784", | |
| "linecolor": "#506784", | |
| "minorgridcolor": "#506784", | |
| "startlinecolor": "#A2B1C6" | |
| }, | |
| "baxis": { | |
| "endlinecolor": "#A2B1C6", | |
| "gridcolor": "#506784", | |
| "linecolor": "#506784", | |
| "minorgridcolor": "#506784", | |
| "startlinecolor": "#A2B1C6" | |
| }, | |
| "type": "carpet" | |
| } | |
| ], | |
| "table": [ | |
| { | |
| "cells": { | |
| "fill": { | |
| "color": "#506784" | |
| }, | |
| "line": { | |
| "color": "rgb(17,17,17)" | |
| } | |
| }, | |
| "header": { | |
| "fill": { | |
| "color": "#2a3f5f" | |
| }, | |
| "line": { | |
| "color": "rgb(17,17,17)" | |
| } | |
| }, | |
| "type": "table" | |
| } | |
| ], | |
| "barpolar": [ | |
| { | |
| "marker": { | |
| "line": { | |
| "color": "rgb(17,17,17)", | |
| "width": 0.5 | |
| }, | |
| "pattern": { | |
| "fillmode": "overlay", | |
| "size": 10, | |
| "solidity": 0.2 | |
| } | |
| }, | |
| "type": "barpolar" | |
| } | |
| ], | |
| "pie": [ | |
| { | |
| "automargin": true, | |
| "type": "pie" | |
| } | |
| ] | |
| }, | |
| "layout": { | |
| "autotypenumbers": "strict", | |
| "colorway": [ | |
| "#636efa", | |
| "#EF553B", | |
| "#00cc96", | |
| "#ab63fa", | |
| "#FFA15A", | |
| "#19d3f3", | |
| "#FF6692", | |
| "#B6E880", | |
| "#FF97FF", | |
| "#FECB52" | |
| ], | |
| "font": { | |
| "color": "#f2f5fa" | |
| }, | |
| "hovermode": "closest", | |
| "hoverlabel": { | |
| "align": "left" | |
| }, | |
| "paper_bgcolor": "rgb(17,17,17)", | |
| "plot_bgcolor": "rgb(17,17,17)", | |
| "polar": { | |
| "bgcolor": "rgb(17,17,17)", | |
| "angularaxis": { | |
| "gridcolor": "#506784", | |
| "linecolor": "#506784", | |
| "ticks": "" | |
| }, | |
| "radialaxis": { | |
| "gridcolor": "#506784", | |
| "linecolor": "#506784", | |
| "ticks": "" | |
| } | |
| }, | |
| "ternary": { | |
| "bgcolor": "rgb(17,17,17)", | |
| "aaxis": { | |
| "gridcolor": "#506784", | |
| "linecolor": "#506784", | |
| "ticks": "" | |
| }, | |
| "baxis": { | |
| "gridcolor": "#506784", | |
| "linecolor": "#506784", | |
| "ticks": "" | |
| }, | |
| "caxis": { | |
| "gridcolor": "#506784", | |
| "linecolor": "#506784", | |
| "ticks": "" | |
| } | |
| }, | |
| "coloraxis": { | |
| "colorbar": { | |
| "outlinewidth": 0, | |
| "ticks": "" | |
| } | |
| }, | |
| "colorscale": { | |
| "sequential": [ | |
| [ | |
| 0.0, | |
| "#0d0887" | |
| ], | |
| [ | |
| 0.1111111111111111, | |
| "#46039f" | |
| ], | |
| [ | |
| 0.2222222222222222, | |
| "#7201a8" | |
| ], | |
| [ | |
| 0.3333333333333333, | |
| "#9c179e" | |
| ], | |
| [ | |
| 0.4444444444444444, | |
| "#bd3786" | |
| ], | |
| [ | |
| 0.5555555555555556, | |
| "#d8576b" | |
| ], | |
| [ | |
| 0.6666666666666666, | |
| "#ed7953" | |
| ], | |
| [ | |
| 0.7777777777777778, | |
| "#fb9f3a" | |
| ], | |
| [ | |
| 0.8888888888888888, | |
| "#fdca26" | |
| ], | |
| [ | |
| 1.0, | |
| "#f0f921" | |
| ] | |
| ], | |
| "sequentialminus": [ | |
| [ | |
| 0.0, | |
| "#0d0887" | |
| ], | |
| [ | |
| 0.1111111111111111, | |
| "#46039f" | |
| ], | |
| [ | |
| 0.2222222222222222, | |
| "#7201a8" | |
| ], | |
| [ | |
| 0.3333333333333333, | |
| "#9c179e" | |
| ], | |
| [ | |
| 0.4444444444444444, | |
| "#bd3786" | |
| ], | |
| [ | |
| 0.5555555555555556, | |
| "#d8576b" | |
| ], | |
| [ | |
| 0.6666666666666666, | |
| "#ed7953" | |
| ], | |
| [ | |
| 0.7777777777777778, | |
| "#fb9f3a" | |
| ], | |
| [ | |
| 0.8888888888888888, | |
| "#fdca26" | |
| ], | |
| [ | |
| 1.0, | |
| "#f0f921" | |
| ] | |
| ], | |
| "diverging": [ | |
| [ | |
| 0, | |
| "#8e0152" | |
| ], | |
| [ | |
| 0.1, | |
| "#c51b7d" | |
| ], | |
| [ | |
| 0.2, | |
| "#de77ae" | |
| ], | |
| [ | |
| 0.3, | |
| "#f1b6da" | |
| ], | |
| [ | |
| 0.4, | |
| "#fde0ef" | |
| ], | |
| [ | |
| 0.5, | |
| "#f7f7f7" | |
| ], | |
| [ | |
| 0.6, | |
| "#e6f5d0" | |
| ], | |
| [ | |
| 0.7, | |
| "#b8e186" | |
| ], | |
| [ | |
| 0.8, | |
| "#7fbc41" | |
| ], | |
| [ | |
| 0.9, | |
| "#4d9221" | |
| ], | |
| [ | |
| 1, | |
| "#276419" | |
| ] | |
| ] | |
| }, | |
| "xaxis": { | |
| "gridcolor": "#283442", | |
| "linecolor": "#506784", | |
| "ticks": "", | |
| "title": { | |
| "standoff": 15 | |
| }, | |
| "zerolinecolor": "#283442", | |
| "automargin": true, | |
| "zerolinewidth": 2 | |
| }, | |
| "yaxis": { | |
| "gridcolor": "#283442", | |
| "linecolor": "#506784", | |
| "ticks": "", | |
| "title": { | |
| "standoff": 15 | |
| }, | |
| "zerolinecolor": "#283442", | |
| "automargin": true, | |
| "zerolinewidth": 2 | |
| }, | |
| "scene": { | |
| "xaxis": { | |
| "backgroundcolor": "rgb(17,17,17)", | |
| "gridcolor": "#506784", | |
| "linecolor": "#506784", | |
| "showbackground": true, | |
| "ticks": "", | |
| "zerolinecolor": "#C8D4E3", | |
| "gridwidth": 2 | |
| }, | |
| "yaxis": { | |
| "backgroundcolor": "rgb(17,17,17)", | |
| "gridcolor": "#506784", | |
| "linecolor": "#506784", | |
| "showbackground": true, | |
| "ticks": "", | |
| "zerolinecolor": "#C8D4E3", | |
| "gridwidth": 2 | |
| }, | |
| "zaxis": { | |
| "backgroundcolor": "rgb(17,17,17)", | |
| "gridcolor": "#506784", | |
| "linecolor": "#506784", | |
| "showbackground": true, | |
| "ticks": "", | |
| "zerolinecolor": "#C8D4E3", | |
| "gridwidth": 2 | |
| } | |
| }, | |
| "shapedefaults": { | |
| "line": { | |
| "color": "#f2f5fa" | |
| } | |
| }, | |
| "annotationdefaults": { | |
| "arrowcolor": "#f2f5fa", | |
| "arrowhead": 0, | |
| "arrowwidth": 1 | |
| }, | |
| "geo": { | |
| "bgcolor": "rgb(17,17,17)", | |
| "landcolor": "rgb(17,17,17)", | |
| "subunitcolor": "#506784", | |
| "showland": true, | |
| "showlakes": true, | |
| "lakecolor": "rgb(17,17,17)" | |
| }, | |
| "title": { | |
| "x": 0.05 | |
| }, | |
| "updatemenudefaults": { | |
| "bgcolor": "#506784", | |
| "borderwidth": 0 | |
| }, | |
| "sliderdefaults": { | |
| "bgcolor": "#C8D4E3", | |
| "borderwidth": 1, | |
| "bordercolor": "rgb(17,17,17)", | |
| "tickwidth": 0 | |
| }, | |
| "mapbox": { | |
| "style": "dark" | |
| } | |
| } | |
| }, | |
| "xaxis": { | |
| "anchor": "y", | |
| "domain": [ | |
| 0.0, | |
| 1.0 | |
| ], | |
| "title": { | |
| "text": "Split Index" | |
| } | |
| }, | |
| "yaxis": { | |
| "anchor": "x", | |
| "domain": [ | |
| 0.0, | |
| 1.0 | |
| ], | |
| "title": { | |
| "text": "Count" | |
| } | |
| }, | |
| "legend": { | |
| "title": { | |
| "text": "Label" | |
| }, | |
| "tracegroupgap": 0 | |
| }, | |
| "margin": { | |
| "t": 60 | |
| }, | |
| "barmode": "relative" | |
| }, | |
| "config": { | |
| "plotlyServerURL": "https://plot.ly" | |
| } | |
| } | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data", | |
| "jetTransient": { | |
| "display_id": null | |
| } | |
| } | |
| ], | |
| "execution_count": 3 | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "name": "python3", | |
| "language": "python", | |
| "display_name": "Python 3 (ipykernel)" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import random | |
| from random import Random | |
| from typing import Any | |
| from sklearn.model_selection import train_test_split | |
| def stratified_split(*arrays, labels, n_splits: int, rnd: Random | None = None) -> tuple[Any, ...]: | |
| if n_splits == 1: | |
| return arrays | |
| else: | |
| splits = train_test_split( | |
| *arrays, labels, | |
| train_size=1 / n_splits, | |
| # Found the following range from its respective error message if passed a random float. | |
| random_state=rnd.randint(0, 4294967295) if rnd is not None else random.randint(0, 4294967295), | |
| stratify=labels | |
| ) | |
| labels_split, labels_rest = splits[-2:] | |
| arrays_splits = splits[:-2] | |
| return *arrays_splits[::2], *stratified_split(*arrays_splits[1::2], labels=labels_rest, n_splits=n_splits - 1, rnd=rnd) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment