Skip to content

Instantly share code, notes, and snippets.

@natolambert
Last active March 4, 2026 23:07
Show Gist options
  • Select an option

  • Save natolambert/0a6ad2e9f513d7a72b76d9e3a7b0bbb1 to your computer and use it in GitHub Desktop.

Select an option

Save natolambert/0a6ad2e9f513d7a72b76d9e3a7b0bbb1 to your computer and use it in GitHub Desktop.
OLMo 3 throughput slide plots — original script by Finbarr Timbers (https://github.com/finbarrtimbers)
"""
Generate high-resolution slide plots comparing OLMo 3 model inference throughput.
Compares GPU requirements, training time, and GPU-hours across context lengths
for OLMo 3 7B (MHA), OLMo 3 32B (GQA), and OLMo 3 7B Hybrid variants.
Original script by Finbarr Timbers (https://github.com/finbarrtimbers).
"""
import io
import math
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
matplotlib.use("Agg")
# ---------------------------------------------------------------------------
# Benchmark data
# ---------------------------------------------------------------------------
# 7B vs 32B on a single 8-GPU node
BASELINE_CSV = """\
model,data_parallel,tensor_parallel,generation_length,batch_time,MFU,MBU,actual_concurrency,tokens_per_sec
olmo3_32b,1,8,1024,37.39,1.14,44,363,1752.81
olmo3_32b,1,8,4096,168.72,1.02,46,182,1553.7
olmo3_32b,1,8,8192,402.49,0.89,45,156,1302.62
olmo3_32b,1,8,16000,765.83,0.87,48,123,1337.11
olmo3_32b,1,8,32000,1996.16,0.81,50,85,1025.97
olmo3_7b,8,1,1024,22.53,1.7,8,37.97,2908.74
olmo3_7b,8,1,4096,121.14,1.3,7.4,23.73,2163.99
olmo3_7b,8,1,8192,347.26,1.03,6,20,1509.78
olmo3_7b,8,1,16000,606.08,1.1,8,16,1689.54
olmo3_7b,8,1,32000,1642.01,0.94,7,11,1247.25
"""
# All four models (baseline + hybrid variants)
HYBRID_CSV = """\
model,data_parallel,tensor_parallel,generation_length,batch_time,MFU,MBU,actual_concurrency,tokens_per_sec
olmo3_32b,2,4,1024,37.39,1.14,44,363,1752.81
olmo3_32b,2,4,4096,168.72,1.02,46,182,1553.7
olmo3_32b,2,4,8192,402.49,0.89,45,156,1302.62
olmo3_32b,2,4,16000,765.83,0.87,48,123,1337.11
olmo3_32b,2,4,32000,1996.16,0.81,50,85,1025.97
olmo3_7b,8,1,1024,22.53,1.7,8,37.97,2908.74
olmo3_7b,8,1,4096,121.14,1.3,7.4,23.73,2163.99
olmo3_7b,8,1,8192,347.26,1.03,6,20,1509.78
olmo3_7b,8,1,16000,606.08,1.1,8,16,1689.54
olmo3_7b,8,1,32000,1642.01,0.94,7,11,1247.25
olmo3_7b_hybrid,4,2,1024,17.95,1.87,18,817.69,3483
olmo3_7b_hybrid,4,2,4096,85.81,1.68,19,237.87,3023
olmo3_7b_hybrid,4,2,8192,233.33,1.33,18,123.42,2247
olmo3_7b_hybrid,4,2,16000,705.82,1,16,62.9,1486
olmo3_7b_hybrid,4,2,32000,1398.72,1.25,10,32,1499.34
olmo3_7b_hybrid_eager,4,2,1024,48.15,0.73,7,817.69,1361
olmo3_7b_hybrid_eager,4,2,4096,183.98,0.79,9,237.87,1425
olmo3_7b_hybrid_eager,4,2,8192,384.23,0.81,11,123.42,1365
olmo3_7b_hybrid_eager,4,2,16000,855.17,0.8,13,62.9,1197
olmo3_7b_hybrid_eager,4,2,32000,3133.38,0.54,11,32,653.61
"""
PROMPT_LENGTH = 2048
BATCH_SIZE = 1024
TOTAL_EPISODES = 1_725_440
TOTAL_STEPS = TOTAL_EPISODES // BATCH_SIZE
GPUS_PER_NODE = 8
# Display names and colors for each model
MODEL_META = {
"olmo3_7b": ("OLMo 3 7B (MHA)", "#f0529c"),
"olmo3_32b": ("OLMo 3 32B (GQA)", "#105257"),
"olmo3_7b_hybrid": ("OLMo 3 7B Hybrid (w/o enforce eager)", "#b11be8"),
"olmo3_7b_hybrid_eager": ("OLMo 3 7B Hybrid (w/ enforce eager)", "#0fcb8c"),
}
# Canonical ordering for legend consistency
MODEL_ORDER = [
"olmo3_7b",
"olmo3_32b",
"olmo3_7b_hybrid",
"olmo3_7b_hybrid_eager",
]
# Context-length tick labels (generation_length + PROMPT_LENGTH -> display label)
CONTEXT_LABELS = {3072: "1K", 6144: "4K", 10240: "8K", 18048: "16K", 34048: "32K"}
# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
def load_data(csv_text: str) -> pd.DataFrame:
df = pd.read_csv(io.StringIO(csv_text))
df["generation_length"] += PROMPT_LENGTH
df["actual_concurrency"] *= df["data_parallel"]
return df
# ---------------------------------------------------------------------------
# Metric functions (applied per-row)
# ---------------------------------------------------------------------------
def gpus_needed(row: pd.Series) -> int:
return GPUS_PER_NODE * math.ceil(BATCH_SIZE / row["actual_concurrency"])
def training_time_hours(row: pd.Series) -> float:
return TOTAL_STEPS * row["batch_time"] / 3600.0
def gpu_hours_thousands(row: pd.Series) -> float:
gpus = GPUS_PER_NODE * math.ceil(BATCH_SIZE / row["actual_concurrency"])
return TOTAL_STEPS * row["batch_time"] * gpus / (3600.0 * 1000.0)
# ---------------------------------------------------------------------------
# Plotting helpers
# ---------------------------------------------------------------------------
def _plot_models(ax, df, y_fn, models):
for model_key in MODEL_ORDER:
if model_key not in models:
continue
group = df[df["model"] == model_key].sort_values("generation_length")
if group.empty:
continue
name, color = MODEL_META[model_key]
ax.plot(
group["generation_length"],
[y_fn(row) for _, row in group.iterrows()],
marker="o",
linewidth=3,
markersize=8,
label=name,
color=color,
)
def _style_axes(ax, df, ylabel, *, title=None, legend_loc="upper left", ylim=None):
if title:
ax.set_title(title, fontsize=22, fontweight="bold", pad=16)
ax.set_xlabel("Context length", fontsize=18)
ax.set_ylabel(ylabel, fontsize=18)
ticks = sorted(df["generation_length"].unique())
ax.set_xticks(ticks)
ax.set_xticklabels(
[CONTEXT_LABELS.get(int(c), f"{int(c):,}") for c in ticks], fontsize=14
)
if ylim is not None:
ax.set_ylim(0, ylim)
else:
ax.set_ylim(0, ax.get_ylim()[1] * 1.10)
ax.tick_params(axis="y", labelsize=14)
ax.grid(axis="y", linestyle="--", linewidth=0.8, alpha=0.4)
ax.set_axisbelow(True)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
for spine in ("left", "bottom"):
ax.spines[spine].set_linewidth(1.0)
ax.spines[spine].set_color("#444444")
handles, labels = ax.get_legend_handles_labels()
if handles:
ax.legend(handles, labels, loc=legend_loc, frameon=False, fontsize=13,
handlelength=2.6, handletextpad=0.8, borderaxespad=0.8)
def make_plot(df, y_fn, ylabel, title, models, filename, *, legend_loc="upper left", ylim=None):
fig, ax = plt.subplots(figsize=(12, 7))
_plot_models(ax, df, y_fn, models)
_style_axes(ax, df, ylabel, title=title, legend_loc=legend_loc, ylim=ylim)
fig.tight_layout()
fig.savefig(filename, dpi=300, bbox_inches="tight", facecolor="white")
plt.close(fig)
print(f" Saved: {filename}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
df_baseline = load_data(BASELINE_CSV)
df_hybrid = load_data(HYBRID_CSV)
base = ["olmo3_7b", "olmo3_32b"]
plus_hybrid = base + ["olmo3_7b_hybrid"]
all_models = plus_hybrid + ["olmo3_7b_hybrid_eager"]
# Slide 1-3: 7B vs 32B baseline
make_plot(df_baseline, gpus_needed,
"GPUs needed", "GPUs needed for batch concurrency",
base, "slide1.png")
make_plot(df_baseline, training_time_hours,
"Training time (hours)", "Generation time (given those GPUs)",
base, "slide2.png")
make_plot(df_baseline, gpu_hours_thousands,
"GPU hours (thousands)", "GPU hours vs context length",
base, "slide3.png")
# Slide 4-6: Add hybrid (w/o enforce eager)
make_plot(df_hybrid, gpus_needed,
"GPUs needed", "GPUs needed for batch concurrency",
plus_hybrid, "slide4.png")
make_plot(df_hybrid, training_time_hours,
"Training time (hours)", "Generation time (given those GPUs)",
plus_hybrid, "slide5.png")
gpu_hours_ylim = 105
make_plot(df_hybrid, gpu_hours_thousands,
"GPU hours (thousands)", "GPU hours vs context length",
plus_hybrid, "slide6.png", ylim=gpu_hours_ylim)
# Slide 7-8: All models including enforce eager
make_plot(df_hybrid, training_time_hours,
"Training time (hours)", "Generation time (given those GPUs)",
all_models, "slide7a.png")
make_plot(df_hybrid, gpu_hours_thousands,
"GPU hours (thousands)", "GPU hours vs context length",
all_models, "slide7b.png", ylim=gpu_hours_ylim)
make_plot(df_hybrid, gpus_needed,
"GPUs needed", "GPUs needed for batch concurrency",
all_models, "slide8.png")
print("Done! 8 slides (9 images) saved.")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment