Last active
September 9, 2025 07:46
-
-
Save KaoruNishikawa/cc6712cabc83fb3faf32a60d8190d92d to your computer and use it in GitHub Desktop.
Download MNIST dataset without Scikit-learn, PyTorch, nor TensorFlow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import gzip | |
| import urllib.request | |
| from pathlib import Path | |
| from types import SimpleNamespace | |
| from typing import Literal | |
| import numpy as np | |
| def load_dataset(mnist_or_fashionmnist: Literal["mnist", "fashionmnist"]): | |
| """Download MNIST or Fashion MNIST dataset. | |
| Minimum surrogate implementation of `sklearn.datasets.fetch_openml`, without | |
| accessing to OpenML API. | |
| Examples | |
| -------- | |
| >>> load_dataset("mnist") | |
| >>> sklearn.datasets.fetch_openml("mnist_784", as_frame=False) | |
| >>> load_dataset("fashionmnist") | |
| >>> sklearn.datasets.fetch_openml("Fashion-MNIST", as_frame=False) | |
| """ | |
| base_url = { | |
| "mnist": "https://ossci-datasets.s3.amazonaws.com/mnist/", | |
| "fashionmnist": "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/", | |
| } | |
| paths = { | |
| ("train", "data"): "train-images-idx3-ubyte.gz", | |
| ("train", "label"): "train-labels-idx1-ubyte.gz", | |
| ("test", "data"): "t10k-images-idx3-ubyte.gz", | |
| ("test", "label"): "t10k-labels-idx1-ubyte.gz", | |
| } | |
| base = Path("./dataset") / mnist_or_fashionmnist | |
| base.mkdir(parents=True, exist_ok=True) | |
| def load( | |
| train_or_test: Literal["train", "test"], | |
| data_or_label: Literal["data", "label"], | |
| ) -> np.ndarray: | |
| path = paths[(train_or_test, data_or_label)] | |
| url = base_url[mnist_or_fashionmnist] + path | |
| print(f"Downloading {url} to {base / path}") | |
| urllib.request.urlretrieve(url, base / path) | |
| with gzip.open(base / path, "rb") as f: | |
| f.read(16 if data_or_label == "data" else 8) # skip header bytes | |
| content = np.frombuffer(f.read(), dtype=np.uint8) | |
| return content.reshape(-1, 784) if data_or_label == "data" else content.astype(str) | |
| return SimpleNamespace( | |
| data=np.r_[load("train", "data"), load("test", "data")], | |
| target=np.r_[load("train", "label"), load("test", "label")] | |
| ) | |
| mnist = load_dataset("mnist") | |
| fashion_mnist = load_dataset("fashionmnist") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment