Created
July 11, 2023 01:08
-
-
Save rkern/4cd064617a16b93554ec91704c3f3f14 to your computer and use it in GitHub Desktop.
Demonstrate successful worker_init_fn usage to seed the np.random PRNG
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 54, | |
| "id": "c8dad3ae-14fc-49d0-b278-9a4141be6495", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "'2.0.1+cu117'" | |
| ] | |
| }, | |
| "execution_count": 54, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "from functools import partial\n", | |
| "\n", | |
| "import numpy as np\n", | |
| "import torch\n", | |
| "from torch.utils.data import Dataset, DataLoader\n", | |
| "\n", | |
| "torch.__version__" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 50, | |
| "id": "3bb48ddb-ccdb-410b-b68f-53f7cbf965ef", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class SemiRandom(Dataset):\n", | |
| " def __init__(self, length):\n", | |
| " self.length = length\n", | |
| " \n", | |
| " def __len__(self):\n", | |
| " return self.length\n", | |
| " \n", | |
| " def __getitem__(self, idx):\n", | |
| " return idx, np.random.random()\n", | |
| "\n", | |
| "\n", | |
| "def worker_init_fn(id, split_seed: int):\n", | |
| " # Recommended by NumPy Rng Author: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562\n", | |
| " # Another good resource: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/\n", | |
| " process_seed = torch.initial_seed()\n", | |
| " # Back out the base_seed so we can use all the bits.\n", | |
| " base_seed = process_seed - id\n", | |
| " # TODO: split_seed seems to have no impact.\n", | |
| " ss = np.random.SeedSequence(\n", | |
| " [id, base_seed, split_seed]\n", | |
| " ) # Rylan added split seed.\n", | |
| " # More than 128 bits (4 32-bit words) would be overkill.\n", | |
| " np_rng_seed = ss.generate_state(4)\n", | |
| " np.random.seed(np_rng_seed)\n", | |
| "\n", | |
| "\n", | |
| "torch.manual_seed(5175199322731862326)\n", | |
| "train_dataloader = DataLoader(SemiRandom(20), batch_size=5, num_workers=4, worker_init_fn=partial(worker_init_fn, split_seed=0))\n", | |
| "test_dataloader = DataLoader(SemiRandom(20), batch_size=5, num_workers=4, worker_init_fn=partial(worker_init_fn, split_seed=1))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 51, | |
| "id": "f8066855-b5d3-4247-bc67-3f9f3be8aeae", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[[tensor([0, 1, 2, 3, 4]),\n", | |
| " tensor([0.3508, 0.8795, 0.9726, 0.7649, 0.7594], dtype=torch.float64)],\n", | |
| " [tensor([5, 6, 7, 8, 9]),\n", | |
| " tensor([0.3570, 0.2279, 0.7705, 0.4676, 0.5711], dtype=torch.float64)],\n", | |
| " [tensor([10, 11, 12, 13, 14]),\n", | |
| " tensor([0.8072, 0.2873, 0.5976, 0.5522, 0.4358], dtype=torch.float64)],\n", | |
| " [tensor([15, 16, 17, 18, 19]),\n", | |
| " tensor([0.3842, 0.3760, 0.7089, 0.1838, 0.4593], dtype=torch.float64)]]" | |
| ] | |
| }, | |
| "execution_count": 51, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "list(train_dataloader)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 52, | |
| "id": "734f1f9e-cf2b-4b6c-8da7-ecc68d3ece0a", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[[tensor([0, 1, 2, 3, 4]),\n", | |
| " tensor([0.3472, 0.4817, 0.9577, 0.1545, 0.8269], dtype=torch.float64)],\n", | |
| " [tensor([5, 6, 7, 8, 9]),\n", | |
| " tensor([0.3793, 0.2993, 0.7823, 0.1517, 0.2514], dtype=torch.float64)],\n", | |
| " [tensor([10, 11, 12, 13, 14]),\n", | |
| " tensor([0.7631, 0.2101, 0.3980, 0.7416, 0.5016], dtype=torch.float64)],\n", | |
| " [tensor([15, 16, 17, 18, 19]),\n", | |
| " tensor([0.9101, 0.0793, 0.7852, 0.2601, 0.1762], dtype=torch.float64)]]" | |
| ] | |
| }, | |
| "execution_count": 52, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "list(test_dataloader)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "a99ef9da-b09b-47ed-9435-87dc971421d9", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.8.10" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment