Created
June 2, 2025 18:36
-
-
Save hsm207/efb26e0034d53815dea971041355d2e2 to your computer and use it in GitHub Desktop.
Pseudocode (in Python) of the SASR algorithm (see: https://arxiv.org/abs/2505.13026)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def train_SASR( | |
| D_train, # Dataset of (question x, chain-of-thought e, answer y) tuples | |
| W, # Number of warm-up (pure SFT) steps | |
| T, # Number of adaptive (SFT+GRPO) steps | |
| policy, # Initial policy π^(0)_θ (e.g., a pretrained LLM) | |
| G, # Group size for GRPO (number of sampled chains per prompt) | |
| gamma # Hyperparameter γ used in computing p_t | |
| ): | |
| """ | |
| Returns: | |
| policy # The final, fine-tuned policy π_θ after running SASR | |
| """ | |
| # 1) Warm-up stage: supervised fine-tuning (SFT) only | |
| # ---------------------------------------------------------------- | |
| # We keep running SFT updates for W steps, then record | |
| # G_warmup = ||∇_θ L_SFT(θ)||_2 at the final warm-up iteration. | |
| G_warmup = None | |
| for i in range(1, W + 1): | |
| # 1.a) Sample a minibatch of (x, e, y) from D_train | |
| batch = sample_batch(D_train) # returns a list of tuples [(x, e, y), …] | |
| # 1.b) Compute the SFT loss on that batch: | |
| # L_SFT(θ) = − E_{(x,e)∈batch} [ sum_t log π_θ(a_t | s_t) ] | |
| # and take one gradient step to update the policy parameters θ. | |
| policy = optimization_step_SFT(policy, batch) | |
| # 1.c) If this is the last warm-up step, record gradient norm: | |
| if i == W: | |
| # G_warmup = || ∇_θ L_SFT(θ) ||_2 | |
| G_warmup = compute_SFT_grad_norm(policy, batch) | |
| # After the loop, we have: | |
| # policy ← θ after W SFT updates | |
| # G_warmup ← final warm-up gradient norm | |
| # G_last_SFT ← (not yet defined; will track ∥∇ L_SFT∥ in adaptive stage) | |
| # 2) Adaptive training stage: mix SFT and GRPO per-step | |
| # ---------------------------------------------------------------- | |
| # Now we run T more steps. At each step t=1…T, we: | |
| # (a) compute p = G_last_SFT / (G_last_SFT + γ * G_warmup) | |
| # (b) sample α ∼ Uniform(0,1) | |
| # (c) if α < p: do another SFT update (and recompute G_last_SFT) | |
| # else: do one GRPO update (using G sampled chains, group into G+ / G–, etc.) | |
| # | |
| # Initially, set G_last_SFT to the warm-up gradient (so that p is between 0 and 1). | |
| G_last_SFT = G_warmup | |
| for t in range(1, T + 1): | |
| # 2.a) Compute adaptive probability p_t: | |
| # p_t = G_last_SFT / (G_last_SFT + γ * G_warmup) | |
| # (Ensure denominator ≠ 0; in practice G_warmup>0 since warm-up had gradients.) | |
| p_t = G_last_SFT / (G_last_SFT + gamma * G_warmup) | |
| # 2.b) Sample α ∼ Uniform(0,1) | |
| alpha = random_uniform_0_1() | |
| if alpha < p_t: | |
| # 2.c.i) Supervised Fine-Tuning step | |
| batch = sample_batch(D_train) # sample (x, e, y) tuples again | |
| policy = optimization_step_SFT(policy, batch) # one gradient step on L_SFT | |
| # Recompute gradient norm on this batch to update G_last_SFT: | |
| G_last_SFT = compute_SFT_grad_norm(policy, batch) | |
| else: | |
| # 2.c.ii) GRPO (Group Relative Policy Optimization) step | |
| # i) Sample one question-answer pair (we ignore the CoT here; we want the model to generate CoTs): | |
| x, _, y_true = sample_batch(D_train, batch_size=1)[0] # extract a single (x, e, y) | |
| # ii) Using the current policy π_θ, sample G full chains-of-thought ê_i for prompt x: | |
| sampled_chains = [ sample_chain(policy, x) for _ in range(G) ] | |
| # (each sampled_chain_i is a full token sequence ê_i) | |
| # iii) Extract the final answer ŷ_i from each chain ê_i: | |
| answers = [ extract_answer(chain) for chain in sampled_chains ] | |
| # iv) Compute scalar rewards R(ŷ_i) for each of the G answers: | |
| # e.g., R(ŷ_i) = 1 if ŷ_i matches y_true exactly, else 0 (or a more nuanced reward) | |
| rewards = [ compute_reward(pred=y_hat, gold=y_true) for y_hat in answers ] | |
| # v) Partition the G chains into two groups (G+ and G–) based on reward: | |
| # Let median_reward = median(rewards). Then | |
| # G_plus = { i : R(ŷ_i) ≥ median_reward } | |
| # G_minus = { i : R(ŷ_i) < median_reward } | |
| G_plus, G_minus = split_by_median_group(rewards, sampled_chains) | |
| # vi) Form the GRPO loss L_GRPO(θ) over those G chains: | |
| # L_GRPO(θ) = (1/G) ∑_{i=1}^G [ | |
| # min( r_i * A_i, clip(r_i, 1-ε, 1+ε) * A_i ) | |
| # − β * D_KL[ π_θ ‖ π_ref ] | |
| # ] | |
| # where: | |
| # r_i = π_θ(ê_i | x) / π_{θ_old}(ê_i | x) | |
| # A_i = +1 if chain i ∈ G_plus else −1 | |
| # π_ref = the frozen “reference policy” from warm-up | |
| # β, ε = GRPO hyperparameters | |
| # | |
| # and then take one gradient step on that GRPO objective. | |
| policy = optimization_step_GRPO( | |
| policy, | |
| reference_policy=policy.copy(), # π_ref is typically the policy from warm-up (or θ_old) | |
| sampled_chains=sampled_chains, | |
| rewards=rewards, | |
| G_plus=G_plus, | |
| G_minus=G_minus, | |
| gamma=gamma | |
| ) | |
| # Note: After this GRPO step, we do NOT immediately update G_last_SFT. | |
| # We will only update G_last_SFT on the next SFT step. | |
| # 3) Return the final policy after all W + T updates | |
| return policy |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment