Skip to content

Instantly share code, notes, and snippets.

View Alvtron's full-sized avatar
💭
Working at dRofus

Thomas Angeland Alvtron

💭
Working at dRofus
View GitHub Profile
@Alvtron
Alvtron / pytorch_stratified_split.md
Last active January 13, 2026 12:21
Split a PyTorch Dataset into two subsets using stratified random sampling.

Stratified dataset split in PyTorch

When working with imbalanced data for machine learning tasks in PyTorch, and simple random split might not be able to partly divide classes that are not well represented. Resulting sample splits might not portray the real-world population, leading to poor predictive peformance in the resulting model.

Therefore, I have created a simple function for conducting a stratified split with random shuffling, similar to that of StratifiedShuffleSplit from scikit-learn (https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html)

import random
import math
import torch.utils.data