Skip to content

Instantly share code, notes, and snippets.

@KaoruNishikawa
Last active September 9, 2025 07:46
Show Gist options
  • Select an option

  • Save KaoruNishikawa/cc6712cabc83fb3faf32a60d8190d92d to your computer and use it in GitHub Desktop.

Select an option

Save KaoruNishikawa/cc6712cabc83fb3faf32a60d8190d92d to your computer and use it in GitHub Desktop.
Download MNIST dataset without Scikit-learn, PyTorch, nor TensorFlow
import gzip
import urllib.request
from pathlib import Path
from types import SimpleNamespace
from typing import Literal
import numpy as np
def load_dataset(mnist_or_fashionmnist: Literal["mnist", "fashionmnist"]):
"""Download MNIST or Fashion MNIST dataset.
Minimum surrogate implementation of `sklearn.datasets.fetch_openml`, without
accessing to OpenML API.
Examples
--------
>>> load_dataset("mnist")
>>> sklearn.datasets.fetch_openml("mnist_784", as_frame=False)
>>> load_dataset("fashionmnist")
>>> sklearn.datasets.fetch_openml("Fashion-MNIST", as_frame=False)
"""
base_url = {
"mnist": "https://ossci-datasets.s3.amazonaws.com/mnist/",
"fashionmnist": "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/",
}
paths = {
("train", "data"): "train-images-idx3-ubyte.gz",
("train", "label"): "train-labels-idx1-ubyte.gz",
("test", "data"): "t10k-images-idx3-ubyte.gz",
("test", "label"): "t10k-labels-idx1-ubyte.gz",
}
base = Path("./dataset") / mnist_or_fashionmnist
base.mkdir(parents=True, exist_ok=True)
def load(
train_or_test: Literal["train", "test"],
data_or_label: Literal["data", "label"],
) -> np.ndarray:
path = paths[(train_or_test, data_or_label)]
url = base_url[mnist_or_fashionmnist] + path
print(f"Downloading {url} to {base / path}")
urllib.request.urlretrieve(url, base / path)
with gzip.open(base / path, "rb") as f:
f.read(16 if data_or_label == "data" else 8) # skip header bytes
content = np.frombuffer(f.read(), dtype=np.uint8)
return content.reshape(-1, 784) if data_or_label == "data" else content.astype(str)
return SimpleNamespace(
data=np.r_[load("train", "data"), load("test", "data")],
target=np.r_[load("train", "label"), load("test", "label")]
)
mnist = load_dataset("mnist")
fashion_mnist = load_dataset("fashionmnist")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment