Skip to content

Instantly share code, notes, and snippets.

@yberreby
Created October 28, 2025 21:34
Show Gist options
  • Select an option

  • Save yberreby/da29befcb935fe4a7cb0e572f4bf2dd5 to your computer and use it in GitHub Desktop.

Select an option

Save yberreby/da29befcb935fe4a7cb0e572f4bf2dd5 to your computer and use it in GitHub Desktop.
"""Clean, fast single-layer binary network for n-bit parity.
Fully vectorized and JIT-compiled global k=1 weight flipping.
"""
import jax
import jax.numpy as jnp
from functools import partial
from time import time
from tqdm import tqdm
# Configuration
N_BITS = 16
N_HIDDEN = 1024
N_STEPS = 500
SEED = 42
LOG_INTERVAL = 10
def generate_all_patterns(n_bits):
"""Generate all 2^n_bits parity patterns."""
n_patterns = 2**n_bits
patterns = []
labels = []
for i in range(n_patterns):
bits = [(i >> j) & 1 for j in range(n_bits)]
pattern = jnp.array([2 * b - 1 for b in bits], dtype=jnp.float32)
parity = sum(bits) % 2
label = jnp.array([2 * parity - 1], dtype=jnp.float32)
patterns.append(pattern)
labels.append(label)
return jnp.stack(patterns), jnp.stack(labels)
def init_network(n_bits, n_hidden, key):
"""Initialize binary weights {-1, +1}, no biases."""
key1, key2 = jax.random.split(key)
W1 = jax.random.choice(key1, jnp.array([-1.0, 1.0]), shape=(n_hidden, n_bits))
W2 = jax.random.choice(key2, jnp.array([-1.0, 1.0]), shape=(1, n_hidden))
return W1, W2
@partial(jax.jit, static_argnums=())
def forward(W1, W2, X):
"""Forward pass: X -> hidden -> output."""
h1 = jnp.where(X @ W1.T >= 0, 1.0, -1.0) # [batch, n_hidden]
output = jnp.where(h1 @ W2.T >= 0, 1.0, -1.0) # [batch, 1]
return h1, output
@partial(jax.jit, static_argnums=())
def backward(W1, W2, X, h1, target, output):
"""Backprop to compute gradients."""
B = X.shape[0]
delta2 = target - output # [batch, 1]
g_W2 = delta2.T @ h1 / B # [1, n_hidden]
delta1 = delta2 @ W2 # [batch, n_hidden]
g_W1 = delta1.T @ X / B # [n_hidden, n_bits]
return g_W1, g_W2
@partial(jax.jit, static_argnums=())
def update_step(W1, W2, X, target):
"""Vectorized global k=1: flip single best weight across entire network."""
# Forward pass
h1, output = forward(W1, W2, X)
# Backward pass
g_W1, g_W2 = backward(W1, W2, X, h1, target, output)
# Compute benefits: |gradient| when gradient disagrees with weight sign
W1_flat = W1.flatten()
g_W1_flat = g_W1.flatten()
benefits1 = jnp.abs(g_W1_flat) * (jnp.sign(g_W1_flat) != W1_flat)
W2_flat = W2.flatten()
g_W2_flat = g_W2.flatten()
benefits2 = jnp.abs(g_W2_flat) * (jnp.sign(g_W2_flat) != W2_flat)
# Find global best weight to flip
all_benefits = jnp.concatenate([benefits1, benefits2])
best_idx = jnp.argmax(all_benefits)
# Create flip masks (exactly one True per layer, depending on best_idx)
n_params_W1 = W1.size
mask_W1 = jnp.arange(W1.size) == best_idx
mask_W2 = jnp.arange(W2.size) == (best_idx - n_params_W1)
# Apply flips: multiply by -1 where mask is True
W1_flat = W1_flat * jnp.where(mask_W1, -1.0, 1.0)
W2_flat = W2_flat * jnp.where(mask_W2, -1.0, 1.0)
return W1_flat.reshape(W1.shape), W2_flat.reshape(W2.shape)
def train(n_bits, n_hidden, n_steps, seed):
"""Train network on n-bit parity."""
# Generate and shuffle data
X_all, y_all = generate_all_patterns(n_bits)
key = jax.random.PRNGKey(seed)
perm = jax.random.permutation(key, len(X_all))
X_all, y_all = X_all[perm], y_all[perm]
# Initialize network
W1, W2 = init_network(n_bits, n_hidden, key)
n_params = W1.size + W2.size
print(f"\n{'=' * 70}")
print(f"Training {n_bits}-bit parity: [{n_bits}, {n_hidden}, 1]")
print(f"Patterns: {len(X_all):,} | Parameters: {n_params:,}")
print(f"{'=' * 70}\n")
# Warmup JIT
W1, W2 = update_step(W1, W2, X_all, y_all)
jax.block_until_ready(W1)
# Train
start = time()
pbar = tqdm(range(n_steps), desc=f"{n_bits}b/{n_hidden}h")
for step in pbar:
W1, W2 = update_step(W1, W2, X_all, y_all)
if step % LOG_INTERVAL == 0 or step == n_steps - 1:
jax.block_until_ready(W1)
_, output = forward(W1, W2, X_all)
acc = float((output == y_all).mean())
mse = float(jnp.mean((output - y_all) ** 2))
elapsed = time() - start
steps_per_sec = (step + 1) / elapsed
pbar.set_postfix(
{
"acc": f"{acc:.4f}",
"mse": f"{mse:.4f}",
"step/s": f"{steps_per_sec:.1f}",
}
)
if acc >= 1.0:
pbar.close()
print(f"\n✓ Converged to 100% at step {step}")
print(f" Time: {elapsed:.2f}s ({steps_per_sec:.1f} steps/s)")
return step, elapsed
pbar.close()
print(f"\n✗ Did not converge in {n_steps} steps")
return n_steps, time() - start
if __name__ == "__main__":
train(N_BITS, N_HIDDEN, N_STEPS, SEED)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment