Skip to content

Instantly share code, notes, and snippets.

@iantimmis
Last active August 5, 2025 15:04
Show Gist options
  • Select an option

  • Save iantimmis/42edf155fa1ac5ae9e226422735b24cc to your computer and use it in GitHub Desktop.

Select an option

Save iantimmis/42edf155fa1ac5ae9e226422735b24cc to your computer and use it in GitHub Desktop.
James-Stein Estimation
import numpy as np
def run_trial():
# Hidden Variables
a = np.random.rand()
b = np.random.rand()
c = np.random.rand()
# Create 3 variables sampled from normal distributions
var1 = np.random.normal(loc=a, scale=1)
var2 = np.random.normal(loc=b, scale=1)
var3 = np.random.normal(loc=c, scale=1)
# Naive Estimation
err1 = (a - var1) ** 2
err2 = (b - var2) ** 2
err3 = (c - var3) ** 2
ne_mse = np.mean([err1,err2,err3])
# James-Stein Estimation
norm_sq = var1**2 + var2**2 + var3**2
shrinkage = max(0, 1 - (1 / norm_sq))
jse1 = (a - var1 * shrinkage) ** 2
jse2 = (b - var2 * shrinkage) ** 2
jse3 = (c - var3 * shrinkage) ** 2
jse_mse = np.mean([jse1,jse2,jse3])
return ne_mse, jse_mse
def run_experiment(trial_count):
ne_errors = []
jse_errors = []
for i in range(trial_count):
ne_err, jse_err = run_trial()
ne_errors.append(ne_err)
jse_errors.append(jse_err)
mean_ne_err = np.mean(ne_errors)
mean_jse_err = np.mean(jse_errors)
print(f"Naive Est. MSE: {mean_ne_err}")
print(f"James-Stein Est. MSE: {mean_jse_err}")
jse_win = (mean_ne_err > mean_jse_err)
if jse_win:
print("Winner: James-Stein Estimate")
else:
print("Winner: Naive Estimation")
return jse_win
experiment_count = 5
trial_count = 1000
jse_wins = 0
for i in range(experiment_count):
print(f"Experiment {i+1}")
if run_experiment(trial_count):
jse_wins += 1
print("----------------------------")
print(f"James-Stein Estimation winrate = {round(jse_wins/experiment_count*100)}%")
# Experiment 1
# Naive Est. MSE: 0.9792049390277211
# James-Stein Est. MSE: 0.6422371792832063
# Winner: James-Stein Estimate
# ----------------------------
# Experiment 2
# Naive Est. MSE: 0.9729797681574638
# James-Stein Est. MSE: 0.6441794373594659
# Winner: James-Stein Estimate
# ----------------------------
# Experiment 3
# Naive Est. MSE: 1.02369053560002
# James-Stein Est. MSE: 0.6752645156636279
# Winner: James-Stein Estimate
# ----------------------------
# Experiment 4
# Naive Est. MSE: 1.0137424522022978
# James-Stein Est. MSE: 0.6696755267702634
# Winner: James-Stein Estimate
# ----------------------------
# Experiment 5
# Naive Est. MSE: 0.9881777449461806
# James-Stein Est. MSE: 0.6501366130770861
# Winner: James-Stein Estimate
# ----------------------------
# James-Stein Estimation winrate = 100%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment