Skip to content

Instantly share code, notes, and snippets.

@DSamuelHodge
Created February 14, 2025 17:35
Show Gist options
  • Select an option

  • Save DSamuelHodge/8339005e9eaf8e79799f8c711aaceefa to your computer and use it in GitHub Desktop.

Select an option

Save DSamuelHodge/8339005e9eaf8e79799f8c711aaceefa to your computer and use it in GitHub Desktop.
GRPO equation explained in Manim
from manim import *
class GRPOExplanation(MovingCameraScene):
def construct(self):
# Title
title = Text("DeepSeek-R1 Reinforcement Learning", font_size=36)
subtitle = Text("Group Relative Policy Optimization (GRPO)", font_size=24)
title_group = VGroup(title, subtitle).arrange(DOWN, buff=0.5)
title_group.to_edge(UP, buff=1)
# Main GRPO equation
grpo = MathTex(
"J_{GRPO}(\\theta)", "=",
"E_{q\\sim P(Q);\\{o_i\\}_{i=1}^G\\sim\\pi_{\\theta_{old}}(O|q)}",
"[\\hat{\\mathcal{G}}]"
)
# Second line with the expanded G-hat term
ghat = MathTex(
"\\hat{\\mathcal{G}} = \\frac{1}{G}", "\\sum_{i=1}^G",
"\\left[\\min\\left(\\frac{\\pi_\\theta(o_i|q)}{\\pi_{\\theta_{old}}(o_i|q)}A_i,",
"\\text{clip}\\left(\\frac{\\pi_\\theta(o_i|q)}{\\pi_{\\theta_{old}}(o_i|q)},1-\\varepsilon,1+\\varepsilon\\right)A_i\\right)\\right]",
"-\\beta D_{KL}(\\pi_\\theta||\\pi_{ref})"
)
# Position equations
grpo.next_to(title_group, DOWN, buff=1)
ghat.next_to(grpo, DOWN, buff=1)
# Save the default camera frame state
self.camera.frame.save_state()
# Initial display
self.play(Write(title_group))
self.wait()
self.play(Write(grpo))
self.play(Write(ghat))
self.wait()
# Create components dictionary
components = {
"objective": {
"box": SurroundingRectangle(grpo[0], buff=.1, color=BLUE),
"equation": grpo[0],
"explanation": Text("GRPO Objective Function\nOptimizes policy parameters", font_size=24)
},
"expectation": {
"box": SurroundingRectangle(grpo[2], buff=.1, color=GREEN),
"equation": grpo[2],
"explanation": Text("Expectation over queries and responses\nSamples from current policy", font_size=24)
},
"sum": {
"box": SurroundingRectangle(ghat[0:2], buff=.1, color=YELLOW),
"equation": VGroup(ghat[0], ghat[1]),
"explanation": Text("Average over group responses\nAggregates multiple outputs", font_size=24)
},
"policy_ratio": {
"box": SurroundingRectangle(ghat[2], buff=.1, color=RED),
"equation": ghat[2],
"explanation": Text("Policy ratio comparing new vs old\nMeasures policy update magnitude", font_size=24)
},
"clipping": {
"box": SurroundingRectangle(ghat[3], buff=.1, color=PURPLE),
"equation": ghat[3],
"explanation": Text("Clipping to limit policy updates\nPrevents too large changes", font_size=24)
},
"kl_div": {
"box": SurroundingRectangle(ghat[4], buff=.1, color=ORANGE),
"equation": ghat[4],
"explanation": Text("KL divergence regularization\nMaintains proximity to base model", font_size=24)
}
}
# Position explanations
for key in components:
components[key]["explanation"].next_to(
components[key]["box"], RIGHT, buff=1
)
# Animation sequence
for key, component in components.items():
# Fade out everything except title
self.play(
FadeOut(grpo),
FadeOut(ghat),
run_time=0.5
)
# Create a group for current component and center it
current_group = VGroup(
component["equation"],
component["explanation"]
)
current_center = current_group.get_center()
# Pan camera to center on current component
self.play(
self.camera.frame.animate.move_to(current_center),
run_time=1
)
# Show current component
self.play(
FadeIn(component["equation"]),
run_time=0.5
)
# Show box and explanation
self.play(
Create(component["box"]),
Write(component["explanation"]),
run_time=1
)
# Wait for reading
self.wait(2)
# Clear current component
self.play(
Uncreate(component["box"]),
Unwrite(component["explanation"]),
FadeOut(component["equation"]),
run_time=0.8
)
# Return camera to original position
self.play(
Restore(self.camera.frame),
run_time=1
)
# Show full equation again
self.play(
FadeIn(grpo),
FadeIn(ghat),
run_time=0.5
)
self.wait(0.5)
# Final pause showing complete equation
self.wait(2)
# Clean up
self.play(
Unwrite(grpo),
Unwrite(ghat),
Unwrite(title_group)
)
self.wait()
if __name__ == "__main__":
scene = GRPOExplanation()
scene.render()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment