Skip to content

Instantly share code, notes, and snippets.

@hsm207
Created June 2, 2025 18:36
Show Gist options
  • Select an option

  • Save hsm207/efb26e0034d53815dea971041355d2e2 to your computer and use it in GitHub Desktop.

Select an option

Save hsm207/efb26e0034d53815dea971041355d2e2 to your computer and use it in GitHub Desktop.
Pseudocode (in Python) of the SASR algorithm (see: https://arxiv.org/abs/2505.13026)
def train_SASR(
D_train, # Dataset of (question x, chain-of-thought e, answer y) tuples
W, # Number of warm-up (pure SFT) steps
T, # Number of adaptive (SFT+GRPO) steps
policy, # Initial policy π^(0)_θ (e.g., a pretrained LLM)
G, # Group size for GRPO (number of sampled chains per prompt)
gamma # Hyperparameter γ used in computing p_t
):
"""
Returns:
policy # The final, fine-tuned policy π_θ after running SASR
"""
# 1) Warm-up stage: supervised fine-tuning (SFT) only
# ----------------------------------------------------------------
# We keep running SFT updates for W steps, then record
# G_warmup = ||∇_θ L_SFT(θ)||_2 at the final warm-up iteration.
G_warmup = None
for i in range(1, W + 1):
# 1.a) Sample a minibatch of (x, e, y) from D_train
batch = sample_batch(D_train) # returns a list of tuples [(x, e, y), …]
# 1.b) Compute the SFT loss on that batch:
# L_SFT(θ) = − E_{(x,e)∈batch} [ sum_t log π_θ(a_t | s_t) ]
# and take one gradient step to update the policy parameters θ.
policy = optimization_step_SFT(policy, batch)
# 1.c) If this is the last warm-up step, record gradient norm:
if i == W:
# G_warmup = || ∇_θ L_SFT(θ) ||_2
G_warmup = compute_SFT_grad_norm(policy, batch)
# After the loop, we have:
# policy ← θ after W SFT updates
# G_warmup ← final warm-up gradient norm
# G_last_SFT ← (not yet defined; will track ∥∇ L_SFT∥ in adaptive stage)
# 2) Adaptive training stage: mix SFT and GRPO per-step
# ----------------------------------------------------------------
# Now we run T more steps. At each step t=1…T, we:
# (a) compute p = G_last_SFT / (G_last_SFT + γ * G_warmup)
# (b) sample α ∼ Uniform(0,1)
# (c) if α < p: do another SFT update (and recompute G_last_SFT)
# else: do one GRPO update (using G sampled chains, group into G+ / G–, etc.)
#
# Initially, set G_last_SFT to the warm-up gradient (so that p is between 0 and 1).
G_last_SFT = G_warmup
for t in range(1, T + 1):
# 2.a) Compute adaptive probability p_t:
# p_t = G_last_SFT / (G_last_SFT + γ * G_warmup)
# (Ensure denominator ≠ 0; in practice G_warmup>0 since warm-up had gradients.)
p_t = G_last_SFT / (G_last_SFT + gamma * G_warmup)
# 2.b) Sample α ∼ Uniform(0,1)
alpha = random_uniform_0_1()
if alpha < p_t:
# 2.c.i) Supervised Fine-Tuning step
batch = sample_batch(D_train) # sample (x, e, y) tuples again
policy = optimization_step_SFT(policy, batch) # one gradient step on L_SFT
# Recompute gradient norm on this batch to update G_last_SFT:
G_last_SFT = compute_SFT_grad_norm(policy, batch)
else:
# 2.c.ii) GRPO (Group Relative Policy Optimization) step
# i) Sample one question-answer pair (we ignore the CoT here; we want the model to generate CoTs):
x, _, y_true = sample_batch(D_train, batch_size=1)[0] # extract a single (x, e, y)
# ii) Using the current policy π_θ, sample G full chains-of-thought ê_i for prompt x:
sampled_chains = [ sample_chain(policy, x) for _ in range(G) ]
# (each sampled_chain_i is a full token sequence ê_i)
# iii) Extract the final answer ŷ_i from each chain ê_i:
answers = [ extract_answer(chain) for chain in sampled_chains ]
# iv) Compute scalar rewards R(ŷ_i) for each of the G answers:
# e.g., R(ŷ_i) = 1 if ŷ_i matches y_true exactly, else 0 (or a more nuanced reward)
rewards = [ compute_reward(pred=y_hat, gold=y_true) for y_hat in answers ]
# v) Partition the G chains into two groups (G+ and G–) based on reward:
# Let median_reward = median(rewards). Then
# G_plus = { i : R(ŷ_i) ≥ median_reward }
# G_minus = { i : R(ŷ_i) < median_reward }
G_plus, G_minus = split_by_median_group(rewards, sampled_chains)
# vi) Form the GRPO loss L_GRPO(θ) over those G chains:
# L_GRPO(θ) = (1/G) ∑_{i=1}^G [
# min( r_i * A_i, clip(r_i, 1-ε, 1+ε) * A_i )
# − β * D_KL[ π_θ ‖ π_ref ]
# ]
# where:
# r_i = π_θ(ê_i | x) / π_{θ_old}(ê_i | x)
# A_i = +1 if chain i ∈ G_plus else −1
# π_ref = the frozen “reference policy” from warm-up
# β, ε = GRPO hyperparameters
#
# and then take one gradient step on that GRPO objective.
policy = optimization_step_GRPO(
policy,
reference_policy=policy.copy(), # π_ref is typically the policy from warm-up (or θ_old)
sampled_chains=sampled_chains,
rewards=rewards,
G_plus=G_plus,
G_minus=G_minus,
gamma=gamma
)
# Note: After this GRPO step, we do NOT immediately update G_last_SFT.
# We will only update G_last_SFT on the next SFT step.
# 3) Return the final policy after all W + T updates
return policy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment