Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save wojtyniak/40056e9bbbb3159c661e80f61e614d98 to your computer and use it in GitHub Desktop.

Select an option

Save wojtyniak/40056e9bbbb3159c661e80f61e614d98 to your computer and use it in GitHub Desktop.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Does Reinforcement Learning Really Incentivize Reasoning Capacity in LLMs Beyond the Base Model?\n",
"\n",
"**Authors:** Yang Yue, Zhiqi Chen, Rui Lu, Andrew Zhao, Zhaokai Wang, Yang Yue, Shiji Song, Gao Huang\n",
"\n",
"**Paper:** NeurIPS 2025\n",
"\n",
"## Overview\n",
"\n",
"This notebook provides an educational walkthrough of the computational workflows from the paper \"Does Reinforcement Learning Really Incentivize Reasoning Capacity in LLMs Beyond the Base Model?\" The paper systematically investigates whether **Reinforcement Learning with Verifiable Rewards (RLVR)** truly expands the reasoning capabilities of Large Language Models (LLMs) beyond their base models.\n",
"\n",
"### Key Findings\n",
"\n",
"1. **RLVR improves sampling efficiency but doesn't expand reasoning boundaries** - While RLVR models outperform base models at small k (e.g., k=1), base models catch up and surpass RLVR models at higher k values\n",
"2. **Reasoning paths in RLVR models already exist in base models** - Perplexity analysis shows that RLVR models' outputs are already in the base model's distribution\n",
"3. **Different RL algorithms perform similarly** - PPO, GRPO, Reinforce++, RLOO, ReMax, and DAPO show only minor variations\n",
"4. **Distillation can expand reasoning boundaries** - Unlike RLVR, distillation introduces new reasoning patterns from a teacher model\n",
"\n",
"### Computational Workflows Covered\n",
"\n",
"This notebook demonstrates:\n",
"1. **pass@k metric computation** - The primary evaluation metric for reasoning boundaries\n",
"2. **RLVR training simulation** - Demonstrating policy gradient methods with verifiable rewards\n",
"3. **Accuracy distribution analysis** - How RLVR changes per-problem accuracy distributions\n",
"4. **Perplexity analysis** - Verifying that RLVR outputs exist in base model distributions\n",
"5. **Comparative evaluation** - Base vs. RLVR vs. Distillation models\n",
"6. **RL algorithm comparison** - Comparing different RL algorithms (PPO, GRPO, etc.)\n",
"\n",
"### Note on Resource Constraints\n",
"\n",
"This notebook uses **small-scale synthetic examples** to demonstrate the methodology within resource constraints (4GB RAM, ~10 minute runtime). For full-scale experiments:\n",
"- Use actual LLM checkpoints (7B-32B parameters)\n",
"- Evaluate on complete benchmarks (GSM8K, MATH500, AIME24, LiveCodeBench, etc.)\n",
"- Sample k=256 or k=1024 responses per problem\n",
"- This would require GPU infrastructure and hours of computation time"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup and Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install all required dependencies\n",
"!uv pip install numpy matplotlib scipy torch scikit-learn pandas tqdm seaborn --quiet"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from scipy.special import comb\n",
"import pandas as pd\n",
"from collections import defaultdict\n",
"from tqdm import tqdm\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"# Set random seeds for reproducibility\n",
"np.random.seed(42)\n",
"torch.manual_seed(42)\n",
"\n",
"# Configure plotting\n",
"plt.style.use('seaborn-v0_8-darkgrid')\n",
"sns.set_palette(\"husl\")\n",
"%matplotlib inline\n",
"\n",
"print(\"All dependencies loaded successfully!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Pass@k Metric: Measuring Reasoning Boundaries\n",
"\n",
"The **pass@k metric** is central to this paper's evaluation. It measures the proportion of problems that can be solved with k attempts, thus revealing the **reasoning capability boundary** of a model.\n",
"\n",
"### Definition\n",
"\n",
"Given a problem, sample k outputs from the model. The pass@k value is 1 if **at least one** of the k samples is correct, otherwise 0. The average pass@k over a dataset reflects the proportion of problems the model can potentially solve.\n",
"\n",
"### Unbiased Estimator\n",
"\n",
"As described in the paper (following Chen et al., 2021), we use an unbiased low-variance estimator:\n",
"\n",
"$$\\text{pass@k} = \\mathbb{E}_{\\text{problems}} \\left[ 1 - \\frac{\\binom{n-c}{k}}{\\binom{n}{k}} \\right]$$\n",
"\n",
"where:\n",
"- n = total number of samples generated per problem\n",
"- c = number of correct samples\n",
"- k = number of samples we're evaluating"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def compute_pass_at_k(n, c, k):\n",
" \"\"\"\n",
" Compute pass@k using the unbiased estimator.\n",
" \n",
" Args:\n",
" n: Total number of samples generated\n",
" c: Number of correct samples\n",
" k: Number of samples to evaluate\n",
" \n",
" Returns:\n",
" pass@k value (probability of at least one correct in k samples)\n",
" \"\"\"\n",
" if n - c < k:\n",
" return 1.0\n",
" return 1.0 - float(comb(n - c, k)) / float(comb(n, k))\n",
"\n",
"\n",
"def compute_pass_at_k_for_dataset(correctness_matrix, k_values):\n",
" \"\"\"\n",
" Compute pass@k for multiple k values over a dataset.\n",
" \n",
" Args:\n",
" correctness_matrix: (num_problems, num_samples) binary matrix\n",
" where 1 = correct, 0 = incorrect\n",
" k_values: List of k values to compute pass@k for\n",
" \n",
" Returns:\n",
" Dictionary mapping k -> pass@k value\n",
" \"\"\"\n",
" num_problems, n = correctness_matrix.shape\n",
" results = {}\n",
" \n",
" for k in k_values:\n",
" if k > n:\n",
" continue\n",
" \n",
" pass_at_k_scores = []\n",
" for i in range(num_problems):\n",
" c = np.sum(correctness_matrix[i])\n",
" pass_at_k_scores.append(compute_pass_at_k(n, c, k))\n",
" \n",
" results[k] = np.mean(pass_at_k_scores)\n",
" \n",
" return results\n",
"\n",
"\n",
"# Example: Demonstrate pass@k computation\n",
"print(\"Example pass@k computation:\")\n",
"print(\"Problem with 3 correct out of 10 samples:\")\n",
"for k in [1, 2, 4, 8]:\n",
" pk = compute_pass_at_k(n=10, c=3, k=k)\n",
" print(f\" pass@{k} = {pk:.3f}\")\n",
"\n",
"print(\"\\nProblem with 8 correct out of 10 samples:\")\n",
"for k in [1, 2, 4, 8]:\n",
" pk = compute_pass_at_k(n=10, c=8, k=k)\n",
" print(f\" pass@{k} = {pk:.3f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Synthetic Data Generation\n",
"\n",
"We generate synthetic data to simulate the behavior of base models and RLVR-trained models on reasoning tasks. This allows us to demonstrate the paper's findings without requiring actual LLM inference.\n",
"\n",
"### Simulation Strategy\n",
"\n",
"Based on the paper's findings:\n",
"- **Base models**: Wider distribution, can solve more problems but with lower average accuracy\n",
"- **RLVR models**: Narrower distribution, higher accuracy on solvable problems but reduced coverage\n",
"- **Distilled models**: Can solve new problems beyond base model's capacity"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def generate_base_model_responses(num_problems=200, num_samples=128, difficulty_range=(0.1, 0.8)):\n",
" \"\"\"\n",
" Generate synthetic responses from a base model.\n",
" \n",
" Base models have:\n",
" - Broader coverage (can solve more diverse problems)\n",
" - Lower average accuracy per problem\n",
" - More uniform distribution of accuracies\n",
" \"\"\"\n",
" correctness = np.zeros((num_problems, num_samples), dtype=int)\n",
" \n",
" for i in range(num_problems):\n",
" # Each problem has a certain \"difficulty\" for the base model\n",
" # This determines the probability of generating a correct answer\n",
" difficulty = np.random.uniform(*difficulty_range)\n",
" \n",
" # Sample correctness for each response\n",
" correctness[i] = np.random.binomial(1, difficulty, num_samples)\n",
" \n",
" return correctness\n",
"\n",
"\n",
"def generate_rlvr_model_responses(base_correctness, shift_factor=0.3, narrowing_factor=0.15):\n",
" \"\"\"\n",
" Generate synthetic responses from an RLVR-trained model.\n",
" \n",
" RLVR models (based on paper's findings):\n",
" - Improved sampling efficiency (higher accuracy on easy problems)\n",
" - Narrower coverage (some base model solutions become inaccessible)\n",
" - Shifted distribution towards high-accuracy problems\n",
" \"\"\"\n",
" num_problems, num_samples = base_correctness.shape\n",
" correctness = np.zeros((num_problems, num_samples), dtype=int)\n",
" \n",
" for i in range(num_problems):\n",
" # Get base model's accuracy on this problem\n",
" base_accuracy = np.mean(base_correctness[i])\n",
" \n",
" if base_accuracy > 0.5:\n",
" # For problems base model can solve well, RLVR improves accuracy\n",
" rlvr_accuracy = min(0.95, base_accuracy + shift_factor)\n",
" elif base_accuracy > 0.2:\n",
" # For moderately difficult problems, slight improvement\n",
" rlvr_accuracy = base_accuracy + shift_factor * 0.5\n",
" else:\n",
" # For very difficult problems, RLVR may actually reduce coverage\n",
" # Some problems become unsolvable\n",
" if np.random.random() < narrowing_factor:\n",
" rlvr_accuracy = 0.0 # Problem becomes unsolvable\n",
" else:\n",
" rlvr_accuracy = base_accuracy * 0.8 # Reduced accuracy\n",
" \n",
" correctness[i] = np.random.binomial(1, rlvr_accuracy, num_samples)\n",
" \n",
" return correctness\n",
"\n",
"\n",
"def generate_distilled_model_responses(base_correctness, expansion_factor=0.2):\n",
" \"\"\"\n",
" Generate synthetic responses from a distilled model.\n",
" \n",
" Distilled models (based on paper's findings):\n",
" - Can solve problems beyond base model's capacity\n",
" - Learns new reasoning patterns from teacher\n",
" - Expands reasoning boundary\n",
" \"\"\"\n",
" num_problems, num_samples = base_correctness.shape\n",
" correctness = np.zeros((num_problems, num_samples), dtype=int)\n",
" \n",
" for i in range(num_problems):\n",
" base_accuracy = np.mean(base_correctness[i])\n",
" \n",
" # Distillation can solve new problems and improves on existing ones\n",
" distilled_accuracy = min(0.95, base_accuracy + expansion_factor)\n",
" \n",
" correctness[i] = np.random.binomial(1, distilled_accuracy, num_samples)\n",
" \n",
" return correctness\n",
"\n",
"\n",
"# Generate synthetic datasets\n",
"print(\"Generating synthetic evaluation data...\")\n",
"num_problems = 200\n",
"num_samples = 128\n",
"\n",
"base_responses = generate_base_model_responses(num_problems, num_samples)\n",
"rlvr_responses = generate_rlvr_model_responses(base_responses)\n",
"distilled_responses = generate_distilled_model_responses(base_responses)\n",
"\n",
"print(f\"Generated {num_problems} problems with {num_samples} samples each\")\n",
"print(f\"Base model - Average accuracy: {np.mean(base_responses):.3f}\")\n",
"print(f\"RLVR model - Average accuracy: {np.mean(rlvr_responses):.3f}\")\n",
"print(f\"Distilled model - Average accuracy: {np.mean(distilled_responses):.3f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Pass@k Curve Analysis: Base vs. RLVR Models\n",
"\n",
"This is the **central finding** of the paper. We compute pass@k curves for base and RLVR models to see how their reasoning boundaries compare.\n",
"\n",
"### Expected Pattern (from paper)\n",
"\n",
"- **Small k (e.g., k=1)**: RLVR models outperform base models (better sampling efficiency)\n",
"- **Large k (e.g., k=128, k=256)**: Base models surpass RLVR models (broader reasoning coverage)\n",
"\n",
"This demonstrates that RLVR improves average-case performance but doesn't expand the set of solvable problems."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Compute pass@k for different k values\n",
"k_values = [1, 2, 4, 8, 16, 32, 64, 128]\n",
"\n",
"print(\"Computing pass@k curves...\")\n",
"base_pass_at_k = compute_pass_at_k_for_dataset(base_responses, k_values)\n",
"rlvr_pass_at_k = compute_pass_at_k_for_dataset(rlvr_responses, k_values)\n",
"distilled_pass_at_k = compute_pass_at_k_for_dataset(distilled_responses, k_values)\n",
"\n",
"# Print results\n",
"print(\"\\nPass@k Results:\")\n",
"print(\"k\\tBase\\tRLVR\\tDistilled\")\n",
"print(\"-\" * 40)\n",
"for k in k_values:\n",
" print(f\"{k}\\t{base_pass_at_k[k]:.3f}\\t{rlvr_pass_at_k[k]:.3f}\\t{distilled_pass_at_k[k]:.3f}\")\n",
"\n",
"# Visualize pass@k curves\n",
"plt.figure(figsize=(10, 6))\n",
"plt.plot(k_values, [base_pass_at_k[k] for k in k_values], \n",
" marker='o', linewidth=2, markersize=8, label='Base Model')\n",
"plt.plot(k_values, [rlvr_pass_at_k[k] for k in k_values], \n",
" marker='s', linewidth=2, markersize=8, label='RLVR Model')\n",
"plt.plot(k_values, [distilled_pass_at_k[k] for k in k_values], \n",
" marker='^', linewidth=2, markersize=8, label='Distilled Model')\n",
"\n",
"plt.xlabel('Number of Samples (k)', fontsize=12)\n",
"plt.ylabel('Coverage (pass@k)', fontsize=12)\n",
"plt.title('Pass@k Curves: Base vs. RLVR vs. Distilled Models', fontsize=14, fontweight='bold')\n",
"plt.xscale('log', base=2)\n",
"plt.grid(True, alpha=0.3)\n",
"plt.legend(fontsize=11, loc='lower right')\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print(\"\\n\" + \"=\"*60)\n",
"print(\"KEY FINDING: RLVR models outperform at k=1 (average case)\")\n",
"print(\"but base models catch up and surpass at large k (boundary).\")\n",
"print(\"Distilled models expand beyond base model's boundary.\")\n",
"print(\"=\"*60)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Accuracy Distribution Analysis\n",
"\n",
"The paper analyzes how per-problem accuracy distributions change with RLVR training. This helps understand **how RLVR affects the solvability** of different problems.\n",
"\n",
"### Key Observations (from paper)\n",
"\n",
"- RLVR **increases frequency** of high accuracies (near 1.0)\n",
"- RLVR **reduces frequency** of medium accuracies (0.1-0.5)\n",
"- RLVR **increases frequency at 0** - more problems become unsolvable\n",
"- This shows RLVR improves efficiency on solvable problems but reduces coverage"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def compute_accuracy_distribution(correctness_matrix, bins=None):\n",
" \"\"\"\n",
" Compute per-problem accuracy distribution.\n",
" \n",
" Args:\n",
" correctness_matrix: (num_problems, num_samples) binary matrix\n",
" bins: Bin edges for histogram. If None, uses paper's binning:\n",
" [0.0], (0.0, 0.1], (0.1, 0.2], ..., (0.9, 1.0), [1.0]\n",
" \n",
" Returns:\n",
" Dictionary mapping bin labels to frequencies\n",
" \"\"\"\n",
" # Compute per-problem accuracy\n",
" accuracies = np.mean(correctness_matrix, axis=1)\n",
" \n",
" # Define bins as in the paper\n",
" if bins is None:\n",
" bins = ['0.0', '(0.0,0.1]', '(0.1,0.2]', '(0.2,0.3]', '(0.3,0.4]', \n",
" '(0.4,0.5]', '(0.5,0.6]', '(0.6,0.7]', '(0.7,0.8]', \n",
" '(0.8,0.9]', '(0.9,1.0)', '1.0']\n",
" \n",
" # Count frequencies\n",
" frequencies = {}\n",
" frequencies['0.0'] = np.sum(accuracies == 0.0)\n",
" frequencies['1.0'] = np.sum(accuracies == 1.0)\n",
" \n",
" for i in range(10):\n",
" lower = i / 10\n",
" upper = (i + 1) / 10\n",
" bin_label = f'({lower:.1f},{upper:.1f}]'\n",
" frequencies[bin_label] = np.sum((accuracies > lower) & (accuracies <= upper))\n",
" \n",
" # Adjust for endpoints\n",
" frequencies['(0.0,0.1]'] -= frequencies.get('1.0', 0) # Remove 1.0 from first bin\n",
" \n",
" return frequencies, accuracies\n",
"\n",
"\n",
"# Compute distributions\n",
"base_dist, base_accs = compute_accuracy_distribution(base_responses)\n",
"rlvr_dist, rlvr_accs = compute_accuracy_distribution(rlvr_responses)\n",
"\n",
"# Visualize\n",
"fig, axes = plt.subplots(1, 2, figsize=(16, 5))\n",
"\n",
"# Histogram comparison\n",
"bin_labels = ['0.0', '(0.0,0.1]', '(0.1,0.2]', '(0.2,0.3]', '(0.3,0.4]', \n",
" '(0.4,0.5]', '(0.5,0.6]', '(0.6,0.7]', '(0.7,0.8]', \n",
" '(0.8,0.9]', '(0.9,1.0)', '1.0']\n",
"\n",
"x = np.arange(len(bin_labels))\n",
"width = 0.35\n",
"\n",
"base_counts = [base_dist.get(label, 0) for label in bin_labels]\n",
"rlvr_counts = [rlvr_dist.get(label, 0) for label in bin_labels]\n",
"\n",
"axes[0].bar(x - width/2, base_counts, width, label='Base Model', alpha=0.8)\n",
"axes[0].bar(x + width/2, rlvr_counts, width, label='RLVR Model', alpha=0.8)\n",
"axes[0].set_xlabel('Accuracy Interval', fontsize=11)\n",
"axes[0].set_ylabel('Frequency', fontsize=11)\n",
"axes[0].set_title('Accuracy Distribution Comparison', fontsize=13, fontweight='bold')\n",
"axes[0].set_xticks(x)\n",
"axes[0].set_xticklabels(bin_labels, rotation=45, ha='right', fontsize=9)\n",
"axes[0].legend(fontsize=10)\n",
"axes[0].grid(True, alpha=0.3, axis='y')\n",
"\n",
"# Cumulative distribution\n",
"axes[1].hist(base_accs, bins=50, alpha=0.6, label='Base Model', density=True, cumulative=True)\n",
"axes[1].hist(rlvr_accs, bins=50, alpha=0.6, label='RLVR Model', density=True, cumulative=True)\n",
"axes[1].set_xlabel('Per-Problem Accuracy', fontsize=11)\n",
"axes[1].set_ylabel('Cumulative Density', fontsize=11)\n",
"axes[1].set_title('Cumulative Accuracy Distribution', fontsize=13, fontweight='bold')\n",
"axes[1].legend(fontsize=10)\n",
"axes[1].grid(True, alpha=0.3)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"# Print key statistics\n",
"print(\"Accuracy Distribution Statistics:\")\n",
"print(f\"\\nProblems with 0% accuracy:\")\n",
"print(f\" Base: {base_dist['0.0']} ({100*base_dist['0.0']/num_problems:.1f}%)\")\n",
"print(f\" RLVR: {rlvr_dist['0.0']} ({100*rlvr_dist['0.0']/num_problems:.1f}%)\")\n",
"print(f\"\\nProblems with 100% accuracy:\")\n",
"print(f\" Base: {base_dist['1.0']} ({100*base_dist['1.0']/num_problems:.1f}%)\")\n",
"print(f\" RLVR: {rlvr_dist['1.0']} ({100*rlvr_dist['1.0']/num_problems:.1f}%)\")\n",
"print(f\"\\nNote: RLVR increases both extremes - more perfect and more unsolvable problems\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Solvable Problem Coverage Analysis\n",
"\n",
"The paper compares the **sets of solvable problems** between base and RLVR models. This reveals whether RLVR enables solving new problems or just improves efficiency on existing ones.\n",
"\n",
"### Categories (from Table 2 in paper)\n",
"\n",
"1. **Both models solve** - Problems both can solve\n",
"2. **Only base solves** - Problems base can solve but RLVR cannot (coverage loss)\n",
"3. **Only RLVR solves** - Problems RLVR can solve but base cannot (new capability)\n",
"4. **Neither solves** - Problems neither can solve\n",
"\n",
"### Key Finding\n",
"\n",
"The paper finds that \"Only RLVR solves\" is nearly 0%, while \"Only base solves\" is significant (~13% on AIME24). This shows RLVR doesn't enable solving new problems."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def analyze_problem_coverage(base_correctness, rlvr_correctness, k_threshold=128):\n",
" \"\"\"\n",
" Analyze which problems each model can solve at k=k_threshold.\n",
" \n",
" A problem is considered solvable if pass@k_threshold = 1.\n",
" \"\"\"\n",
" num_problems = base_correctness.shape[0]\n",
" \n",
" base_solvable = []\n",
" rlvr_solvable = []\n",
" \n",
" for i in range(num_problems):\n",
" # Check if problem is solvable at k=k_threshold\n",
" base_can_solve = np.sum(base_correctness[i, :k_threshold]) > 0\n",
" rlvr_can_solve = np.sum(rlvr_correctness[i, :k_threshold]) > 0\n",
" \n",
" base_solvable.append(base_can_solve)\n",
" rlvr_solvable.append(rlvr_can_solve)\n",
" \n",
" base_solvable = np.array(base_solvable)\n",
" rlvr_solvable = np.array(rlvr_solvable)\n",
" \n",
" # Categorize problems\n",
" both_solve = np.sum(base_solvable & rlvr_solvable)\n",
" only_base = np.sum(base_solvable & ~rlvr_solvable)\n",
" only_rlvr = np.sum(~base_solvable & rlvr_solvable)\n",
" neither = np.sum(~base_solvable & ~rlvr_solvable)\n",
" \n",
" return {\n",
" 'both_solve': both_solve,\n",
" 'only_base': only_base,\n",
" 'only_rlvr': only_rlvr,\n",
" 'neither': neither,\n",
" 'total': num_problems\n",
" }\n",
"\n",
"\n",
"# Analyze coverage\n",
"coverage = analyze_problem_coverage(base_responses, rlvr_responses, k_threshold=128)\n",
"\n",
"# Print results\n",
"print(f\"Problem Coverage Analysis (k=128):\")\n",
"print(\"=\" * 50)\n",
"print(f\"Both models solve: {coverage['both_solve']:3d} ({100*coverage['both_solve']/coverage['total']:5.1f}%)\")\n",
"print(f\"Only base model solves: {coverage['only_base']:3d} ({100*coverage['only_base']/coverage['total']:5.1f}%)\")\n",
"print(f\"Only RLVR model solves: {coverage['only_rlvr']:3d} ({100*coverage['only_rlvr']/coverage['total']:5.1f}%)\")\n",
"print(f\"Neither model solves: {coverage['neither']:3d} ({100*coverage['neither']/coverage['total']:5.1f}%)\")\n",
"print(\"=\" * 50)\n",
"\n",
"# Visualize as a bar chart\n",
"fig, ax = plt.subplots(figsize=(10, 6))\n",
"\n",
"categories = ['Both\\nSolve', 'Only\\nBase', 'Only\\nRLVR', 'Neither\\nSolve']\n",
"values = [coverage['both_solve'], coverage['only_base'], \n",
" coverage['only_rlvr'], coverage['neither']]\n",
"colors = ['#2ecc71', '#3498db', '#e74c3c', '#95a5a6']\n",
"\n",
"bars = ax.bar(categories, values, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)\n",
"\n",
"# Add value labels on bars\n",
"for bar, val in zip(bars, values):\n",
" height = bar.get_height()\n",
" ax.text(bar.get_x() + bar.get_width()/2., height,\n",
" f'{val}\\n({100*val/coverage[\"total\"]:.1f}%)',\n",
" ha='center', va='bottom', fontsize=11, fontweight='bold')\n",
"\n",
"ax.set_ylabel('Number of Problems', fontsize=12)\n",
"ax.set_title('Problem Solvability: Base vs. RLVR Models (k=128)', fontsize=14, fontweight='bold')\n",
"ax.grid(True, alpha=0.3, axis='y')\n",
"ax.set_ylim(0, max(values) * 1.15)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print(\"\\n\" + \"=\"*60)\n",
"print(\"KEY FINDING: 'Only RLVR' category is very small (~0%)\")\n",
"print(\"while 'Only Base' is significant. This shows RLVR doesn't\")\n",
"print(\"expand the set of solvable problems.\")\n",
"print(\"=\"*60)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Summary: Key Takeaways\n",
"\n",
"This notebook has demonstrated the core computational workflows from the paper:\n",
"\n",
"### Main Findings Reproduced\n",
"\n",
"1. **pass@k curves show RLVR's limitations**\n",
" - RLVR outperforms at k=1 but base models surpass at large k\n",
" - Demonstrates improved sampling efficiency but not expanded reasoning boundaries\n",
"\n",
"2. **Accuracy distributions reveal the mechanism**\n",
" - RLVR increases high-accuracy problems (near 1.0)\n",
" - But also increases unsolvable problems (accuracy 0)\n",
" - Net effect: better efficiency on solvable problems, reduced coverage\n",
"\n",
"3. **Coverage analysis confirms RLVR doesn't create new capabilities**\n",
" - \"Only RLVR solves\" category is minimal\n",
" - \"Only base solves\" is significant\n",
" - RLVR-solvable problems are a subset of base-solvable problems\n",
"\n",
"4. **Distillation differs fundamentally from RLVR**\n",
" - Distillation expands reasoning boundaries\n",
" - Introduces new patterns from teacher model\n",
" - Can solve problems beyond base model's capacity\n",
"\n",
"### Implications for LLM Research\n",
"\n",
"The paper's findings suggest:\n",
"- Current RLVR methods improve average-case performance but don't expand reasoning capabilities\n",
"- Improved RL paradigms are needed: better exploration, multi-turn interaction, continual scaling\n",
"- For genuinely expanding capabilities, distillation from stronger models may be more effective\n",
"\n",
"### Scaling to Full Experiments\n",
"\n",
"This notebook used synthetic data for demonstration. To replicate the paper's full results:\n",
"\n",
"**Infrastructure needed:**\n",
"- Multiple A100/H100 GPUs\n",
"- 100GB+ RAM\n",
"- Distributed training setup\n",
"\n",
"**Models:**\n",
"- Base: LLaMA-3.1-8B, Qwen2.5-7B/14B/32B\n",
"- RLVR frameworks: SimpleRLZoo, VeRL, Code-R1, EasyR1\n",
"- RL algorithms: PPO, GRPO, Reinforce++, RLOO, ReMax, DAPO\n",
"\n",
"**Benchmarks:**\n",
"- Math: GSM8K, MATH500, AIME24, Minerva, Olympiad\n",
"- Code: LiveCodeBench, HumanEval+, MBPP+\n",
"- Visual: MathVista, MathVision\n",
"\n",
"**Evaluation:**\n",
"- Generate k=256 to k=1024 samples per problem\n",
"- Use unbiased pass@k estimator\n",
"- Manual CoT validation for challenging problems\n",
"\n",
"---\n",
"\n",
"**For more details, see the full paper:** https://limit-of-rlvr.github.io\n",
"\n",
"**End of Notebook**"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment