Skip to content

Instantly share code, notes, and snippets.

@AlpinDale
Created August 13, 2025 07:31
Show Gist options
  • Select an option

  • Save AlpinDale/f21dbac4606cff9f23756a5addfe39a5 to your computer and use it in GitHub Desktop.

Select an option

Save AlpinDale/f21dbac4606cff9f23756a5addfe39a5 to your computer and use it in GitHub Desktop.
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>MNIST</title>
<style>
:root { --bg:#0f1116; --fg:#e6e6e6; --muted:#a3a3a3; --panel:#1b1f2a; --border:#333; }
html, body { height: 100%; }
body { margin: 0; font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial; background: var(--bg); color: var(--fg); }
.container { max-width: 1100px; margin: 0 auto; padding: 16px; }
h1 { font-size: 22px; margin: 8px 0 16px; }
.wrap { display: grid; grid-template-columns: 360px 1fr; gap: 24px; align-items: start; }
.card { background: var(--bg); border: 1px solid var(--border); border-radius: 12px; padding: 16px; }
.row { display: flex; gap: 12px; align-items: center; margin: 8px 0; flex-wrap: wrap; }
#canvas { background: #000; border-radius: 8px; box-shadow: inset 0 0 0 1px var(--border); touch-action: none; }
button { padding: 8px 12px; border-radius: 8px; border: 1px solid var(--border); background: var(--panel); color: var(--fg); cursor: pointer; }
button:active { transform: translateY(1px); }
input[type=range] { width: 160px; }
.bar { height: 14px; background: linear-gradient(90deg, #6ee7b7, #3b82f6); border-radius: 7px; }
.bar-wrap { background: var(--panel); border-radius: 7px; height: 14px; width: 100%; }
.prob-row { display: grid; grid-template-columns: 20px 1fr 44px; gap: 8px; align-items: center; margin: 6px 0; }
.muted { color: var(--muted); font-size: 12px; }
.pred { font-size: 42px; font-weight: 700; }
@media (max-width: 900px){ .wrap { grid-template-columns: 1fr; } }
</style>
<script>
const ortScript = document.createElement('script');
ortScript.src = 'https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js';
document.head.appendChild(ortScript);
</script>
</head>
<body>
<div class="container">
<h1>MNIST</h1>
<div class="wrap">
<div class="card">
<canvas id="canvas" width="320" height="320"></canvas>
<div class="row">
<button id="clearBtn">🔄 Clear</button>
<button id="predictBtn">🔮 Predict</button>
</div>
<div class="row">
<label>Brush</label>
<input id="brush" type="range" min="5" max="40" value="22" />
<label><input id="auto" type="checkbox" checked /> Auto predict</label>
</div>
<div class="muted">Tip: Draw a large white digit on the black canvas.</div>
</div>
<div class="card">
<h3 style="margin-top:0">Prediction</h3>
<div id="pred" class="pred">–</div>
<div id="probs" style="margin-top: 12px"></div>
<h3>Processed 28×28</h3>
<canvas id="preview" width="160" height="160" style="background:#000;border-radius:8px;box-shadow:inset 0 0 0 1px var(--border)"></canvas>
</div>
</div>
</div>
<script>
const modelBase64 = ``;
let session = null;
async function loadModel() {
await new Promise((resolve) => {
if (window.ort) return resolve();
ortScript.addEventListener('load', resolve, { once: true });
});
const binary = atob(modelBase64);
const bytes = new Uint8Array(binary.length);
for (let i = 0; i < binary.length; i++) bytes[i] = binary.charCodeAt(i);
session = await ort.InferenceSession.create(bytes.buffer, { executionProviders: ['wasm'] });
}
const canvas = document.getElementById('canvas');
const ctx = canvas.getContext('2d');
const preview = document.getElementById('preview');
const pctx = preview.getContext('2d');
ctx.fillStyle = '#000';
ctx.fillRect(0, 0, canvas.width, canvas.height);
let drawing = false;
let brush = 22;
const white = '#fff';
function startDraw(x, y) { drawing = true; ctx.beginPath(); ctx.moveTo(x, y); }
function drawTo(x, y) { if (!drawing) return; ctx.lineTo(x, y); ctx.strokeStyle = white; ctx.lineWidth = brush; ctx.lineCap = 'round'; ctx.lineJoin = 'round'; ctx.stroke(); }
function endDraw() { drawing = false; ctx.closePath(); if (document.getElementById('auto').checked) predict(); }
function getXY(evt) {
const rect = canvas.getBoundingClientRect();
const clientX = evt.touches ? evt.touches[0].clientX : evt.clientX;
const clientY = evt.touches ? evt.touches[0].clientY : evt.clientY;
return [clientX - rect.left, clientY - rect.top];
}
canvas.addEventListener('mousedown', e => startDraw(...getXY(e)));
canvas.addEventListener('mousemove', e => drawTo(...getXY(e)));
window.addEventListener('mouseup', endDraw);
canvas.addEventListener('touchstart', e => { e.preventDefault(); startDraw(...getXY(e)); });
canvas.addEventListener('touchmove', e => { e.preventDefault(); drawTo(...getXY(e)); });
canvas.addEventListener('touchend', e => { e.preventDefault(); endDraw(); });
document.getElementById('brush').addEventListener('input', (e) => brush = parseInt(e.target.value, 10));
document.getElementById('clearBtn').addEventListener('click', () => {
ctx.fillStyle = '#000'; ctx.fillRect(0, 0, canvas.width, canvas.height);
pctx.fillStyle = '#000'; pctx.fillRect(0, 0, preview.width, preview.height);
document.getElementById('pred').textContent = '–';
document.getElementById('probs').innerHTML = '';
});
document.getElementById('predictBtn').addEventListener('click', () => predict());
function toGrayscale(data) {
const gray = new Float32Array(data.length / 4);
for (let i = 0, j = 0; i < data.length; i += 4, j++) {
const r = data[i], g = data[i+1], b = data[i+2];
gray[j] = (0.299*r + 0.587*g + 0.114*b);
}
return gray;
}
function findBBox(gray, w, h, thr=10) {
let xMin=w, yMin=h, xMax=-1, yMax=-1; let any=false;
for (let y=0; y<h; y++) {
for (let x=0; x<w; x++) {
const v = gray[y*w + x];
if (v > thr) { any = true; if (x < xMin) xMin = x; if (x > xMax) xMax = x; if (y < yMin) yMin = y; if (y > yMax) yMax = y; }
}
}
if (!any) return null;
return [xMin, yMin, xMax+1, yMax+1];
}
function preprocess() {
const w = canvas.width, h = canvas.height;
const img = ctx.getImageData(0, 0, w, h);
const gray = toGrayscale(img.data);
const bbox = findBBox(gray, w, h, 10);
if (!bbox) return new Float32Array(28*28);
const [x0, y0, x1, y1] = bbox;
const cw = x1 - x0, ch = y1 - y0;
const tmp = document.createElement('canvas');
tmp.width = cw; tmp.height = ch; const tctx = tmp.getContext('2d');
tctx.putImageData(ctx.getImageData(x0, y0, cw, ch), 0, 0);
const scale = 20 / Math.max(cw, ch);
const nw = Math.max(1, Math.round(cw * scale));
const nh = Math.max(1, Math.round(ch * scale));
const scaled = document.createElement('canvas');
scaled.width = nw; scaled.height = nh; const sctx = scaled.getContext('2d');
sctx.imageSmoothingEnabled = true;
sctx.drawImage(tmp, 0, 0, nw, nh);
const out = document.createElement('canvas');
out.width = 28; out.height = 28; const octx = out.getContext('2d');
octx.fillStyle = '#000'; octx.fillRect(0, 0, 28, 28);
const left = Math.floor((28 - nw) / 2); const top = Math.floor((28 - nh) / 2);
octx.drawImage(scaled, left, top);
pctx.imageSmoothingEnabled = false;
pctx.drawImage(out, 0, 0, 160, 160);
const oimg = octx.getImageData(0, 0, 28, 28);
const og = toGrayscale(oimg.data);
const arr = new Float32Array(28*28);
for (let i=0; i<og.length; i++) arr[i] = og[i] / 255.0;
return arr;
}
async function predict() {
if (!session) await loadModel();
const vec = preprocess();
if (!vec) return;
const input = new ort.Tensor('float32', vec, [1, 1, 28, 28]);
const feeds = {}; feeds[session.inputNames[0]] = input;
const outputMap = await session.run(feeds);
const outName = session.outputNames[0];
const logits = outputMap[outName].data;
let max = -1e9; for (let v of logits) if (v>max) max=v;
const exps = logits.map(v => Math.exp(v - max));
const sum = exps.reduce((a,b)=>a+b,0);
const probs = exps.map(v => v/sum);
const pred = probs.indexOf(Math.max(...probs));
document.getElementById('pred').textContent = String(pred);
const probsEl = document.getElementById('probs');
probsEl.innerHTML = '';
for (let d=0; d<10; d++) {
const row = document.createElement('div'); row.className = 'prob-row';
const label = document.createElement('div'); label.textContent = String(d);
const barWrap = document.createElement('div'); barWrap.className = 'bar-wrap';
const bar = document.createElement('div'); bar.className = 'bar'; bar.style.width = (probs[d]*100).toFixed(1) + '%';
const val = document.createElement('div'); val.textContent = (probs[d]*100).toFixed(1) + '%'; val.style.textAlign = 'right';
barWrap.appendChild(bar); row.appendChild(label); row.appendChild(barWrap); row.appendChild(val);
probsEl.appendChild(row);
}
}
loadModel();
</script>
</body>
</html>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment