Created
February 14, 2025 17:35
-
-
Save DSamuelHodge/8339005e9eaf8e79799f8c711aaceefa to your computer and use it in GitHub Desktop.
GRPO equation explained in Manim
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from manim import * | |
| class GRPOExplanation(MovingCameraScene): | |
| def construct(self): | |
| # Title | |
| title = Text("DeepSeek-R1 Reinforcement Learning", font_size=36) | |
| subtitle = Text("Group Relative Policy Optimization (GRPO)", font_size=24) | |
| title_group = VGroup(title, subtitle).arrange(DOWN, buff=0.5) | |
| title_group.to_edge(UP, buff=1) | |
| # Main GRPO equation | |
| grpo = MathTex( | |
| "J_{GRPO}(\\theta)", "=", | |
| "E_{q\\sim P(Q);\\{o_i\\}_{i=1}^G\\sim\\pi_{\\theta_{old}}(O|q)}", | |
| "[\\hat{\\mathcal{G}}]" | |
| ) | |
| # Second line with the expanded G-hat term | |
| ghat = MathTex( | |
| "\\hat{\\mathcal{G}} = \\frac{1}{G}", "\\sum_{i=1}^G", | |
| "\\left[\\min\\left(\\frac{\\pi_\\theta(o_i|q)}{\\pi_{\\theta_{old}}(o_i|q)}A_i,", | |
| "\\text{clip}\\left(\\frac{\\pi_\\theta(o_i|q)}{\\pi_{\\theta_{old}}(o_i|q)},1-\\varepsilon,1+\\varepsilon\\right)A_i\\right)\\right]", | |
| "-\\beta D_{KL}(\\pi_\\theta||\\pi_{ref})" | |
| ) | |
| # Position equations | |
| grpo.next_to(title_group, DOWN, buff=1) | |
| ghat.next_to(grpo, DOWN, buff=1) | |
| # Save the default camera frame state | |
| self.camera.frame.save_state() | |
| # Initial display | |
| self.play(Write(title_group)) | |
| self.wait() | |
| self.play(Write(grpo)) | |
| self.play(Write(ghat)) | |
| self.wait() | |
| # Create components dictionary | |
| components = { | |
| "objective": { | |
| "box": SurroundingRectangle(grpo[0], buff=.1, color=BLUE), | |
| "equation": grpo[0], | |
| "explanation": Text("GRPO Objective Function\nOptimizes policy parameters", font_size=24) | |
| }, | |
| "expectation": { | |
| "box": SurroundingRectangle(grpo[2], buff=.1, color=GREEN), | |
| "equation": grpo[2], | |
| "explanation": Text("Expectation over queries and responses\nSamples from current policy", font_size=24) | |
| }, | |
| "sum": { | |
| "box": SurroundingRectangle(ghat[0:2], buff=.1, color=YELLOW), | |
| "equation": VGroup(ghat[0], ghat[1]), | |
| "explanation": Text("Average over group responses\nAggregates multiple outputs", font_size=24) | |
| }, | |
| "policy_ratio": { | |
| "box": SurroundingRectangle(ghat[2], buff=.1, color=RED), | |
| "equation": ghat[2], | |
| "explanation": Text("Policy ratio comparing new vs old\nMeasures policy update magnitude", font_size=24) | |
| }, | |
| "clipping": { | |
| "box": SurroundingRectangle(ghat[3], buff=.1, color=PURPLE), | |
| "equation": ghat[3], | |
| "explanation": Text("Clipping to limit policy updates\nPrevents too large changes", font_size=24) | |
| }, | |
| "kl_div": { | |
| "box": SurroundingRectangle(ghat[4], buff=.1, color=ORANGE), | |
| "equation": ghat[4], | |
| "explanation": Text("KL divergence regularization\nMaintains proximity to base model", font_size=24) | |
| } | |
| } | |
| # Position explanations | |
| for key in components: | |
| components[key]["explanation"].next_to( | |
| components[key]["box"], RIGHT, buff=1 | |
| ) | |
| # Animation sequence | |
| for key, component in components.items(): | |
| # Fade out everything except title | |
| self.play( | |
| FadeOut(grpo), | |
| FadeOut(ghat), | |
| run_time=0.5 | |
| ) | |
| # Create a group for current component and center it | |
| current_group = VGroup( | |
| component["equation"], | |
| component["explanation"] | |
| ) | |
| current_center = current_group.get_center() | |
| # Pan camera to center on current component | |
| self.play( | |
| self.camera.frame.animate.move_to(current_center), | |
| run_time=1 | |
| ) | |
| # Show current component | |
| self.play( | |
| FadeIn(component["equation"]), | |
| run_time=0.5 | |
| ) | |
| # Show box and explanation | |
| self.play( | |
| Create(component["box"]), | |
| Write(component["explanation"]), | |
| run_time=1 | |
| ) | |
| # Wait for reading | |
| self.wait(2) | |
| # Clear current component | |
| self.play( | |
| Uncreate(component["box"]), | |
| Unwrite(component["explanation"]), | |
| FadeOut(component["equation"]), | |
| run_time=0.8 | |
| ) | |
| # Return camera to original position | |
| self.play( | |
| Restore(self.camera.frame), | |
| run_time=1 | |
| ) | |
| # Show full equation again | |
| self.play( | |
| FadeIn(grpo), | |
| FadeIn(ghat), | |
| run_time=0.5 | |
| ) | |
| self.wait(0.5) | |
| # Final pause showing complete equation | |
| self.wait(2) | |
| # Clean up | |
| self.play( | |
| Unwrite(grpo), | |
| Unwrite(ghat), | |
| Unwrite(title_group) | |
| ) | |
| self.wait() | |
| if __name__ == "__main__": | |
| scene = GRPOExplanation() | |
| scene.render() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment