Skip to content

Instantly share code, notes, and snippets.

@rkern
Created July 11, 2023 01:08
Show Gist options
  • Select an option

  • Save rkern/4cd064617a16b93554ec91704c3f3f14 to your computer and use it in GitHub Desktop.

Select an option

Save rkern/4cd064617a16b93554ec91704c3f3f14 to your computer and use it in GitHub Desktop.
Demonstrate successful worker_init_fn usage to seed the np.random PRNG
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 54,
"id": "c8dad3ae-14fc-49d0-b278-9a4141be6495",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'2.0.1+cu117'"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from functools import partial\n",
"\n",
"import numpy as np\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"torch.__version__"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "3bb48ddb-ccdb-410b-b68f-53f7cbf965ef",
"metadata": {},
"outputs": [],
"source": [
"class SemiRandom(Dataset):\n",
" def __init__(self, length):\n",
" self.length = length\n",
" \n",
" def __len__(self):\n",
" return self.length\n",
" \n",
" def __getitem__(self, idx):\n",
" return idx, np.random.random()\n",
"\n",
"\n",
"def worker_init_fn(id, split_seed: int):\n",
" # Recommended by NumPy Rng Author: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562\n",
" # Another good resource: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/\n",
" process_seed = torch.initial_seed()\n",
" # Back out the base_seed so we can use all the bits.\n",
" base_seed = process_seed - id\n",
" # TODO: split_seed seems to have no impact.\n",
" ss = np.random.SeedSequence(\n",
" [id, base_seed, split_seed]\n",
" ) # Rylan added split seed.\n",
" # More than 128 bits (4 32-bit words) would be overkill.\n",
" np_rng_seed = ss.generate_state(4)\n",
" np.random.seed(np_rng_seed)\n",
"\n",
"\n",
"torch.manual_seed(5175199322731862326)\n",
"train_dataloader = DataLoader(SemiRandom(20), batch_size=5, num_workers=4, worker_init_fn=partial(worker_init_fn, split_seed=0))\n",
"test_dataloader = DataLoader(SemiRandom(20), batch_size=5, num_workers=4, worker_init_fn=partial(worker_init_fn, split_seed=1))"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "f8066855-b5d3-4247-bc67-3f9f3be8aeae",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[tensor([0, 1, 2, 3, 4]),\n",
" tensor([0.3508, 0.8795, 0.9726, 0.7649, 0.7594], dtype=torch.float64)],\n",
" [tensor([5, 6, 7, 8, 9]),\n",
" tensor([0.3570, 0.2279, 0.7705, 0.4676, 0.5711], dtype=torch.float64)],\n",
" [tensor([10, 11, 12, 13, 14]),\n",
" tensor([0.8072, 0.2873, 0.5976, 0.5522, 0.4358], dtype=torch.float64)],\n",
" [tensor([15, 16, 17, 18, 19]),\n",
" tensor([0.3842, 0.3760, 0.7089, 0.1838, 0.4593], dtype=torch.float64)]]"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(train_dataloader)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "734f1f9e-cf2b-4b6c-8da7-ecc68d3ece0a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[tensor([0, 1, 2, 3, 4]),\n",
" tensor([0.3472, 0.4817, 0.9577, 0.1545, 0.8269], dtype=torch.float64)],\n",
" [tensor([5, 6, 7, 8, 9]),\n",
" tensor([0.3793, 0.2993, 0.7823, 0.1517, 0.2514], dtype=torch.float64)],\n",
" [tensor([10, 11, 12, 13, 14]),\n",
" tensor([0.7631, 0.2101, 0.3980, 0.7416, 0.5016], dtype=torch.float64)],\n",
" [tensor([15, 16, 17, 18, 19]),\n",
" tensor([0.9101, 0.0793, 0.7852, 0.2601, 0.1762], dtype=torch.float64)]]"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(test_dataloader)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a99ef9da-b09b-47ed-9435-87dc971421d9",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment