Created
January 23, 2026 15:09
-
-
Save evanthebouncy/d60ea04a95594c1ab77916f4b8e6d0cc to your computer and use it in GitHub Desktop.
manually play the grid game
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # play_gridworld_terminal.py | |
| # Terminal-playable GridWorld + trajectory tracking (only dependency: numpy) | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional, Tuple | |
| import numpy as np | |
| @dataclass(frozen=True) | |
| class StepResult: | |
| obs: int | |
| reward: float | |
| terminated: bool | |
| truncated: bool | |
| info: Dict | |
| class SimpleGridWorld: | |
| """ | |
| Layout (3x3 with blocked center): | |
| A B C | |
| D X E | |
| F G H (H is terminal) | |
| Actions: u/d/l/r | |
| Start: A | |
| Rewards: | |
| - (E, d) -> H gives +1.0 | |
| - (G, r) -> H gives +1.0 | |
| """ | |
| ACTIONS = {"u": 0, "d": 1, "l": 2, "r": 3} | |
| UP, DOWN, LEFT, RIGHT = 0, 1, 2, 3 | |
| def __init__(self, max_steps: int = 200, seed: Optional[int] = None): | |
| self.max_steps = int(max_steps) | |
| self.nrows, self.ncols = 3, 3 | |
| self.blocked = (1, 1) | |
| self.start_state = "A" | |
| self.terminal = "H" | |
| self.pos_to_name = { | |
| (0, 0): "A", (0, 1): "B", (0, 2): "C", | |
| (1, 0): "D", (1, 2): "E", | |
| (2, 0): "F", (2, 1): "G", (2, 2): "H", | |
| } | |
| self.name_to_pos = {v: k for k, v in self.pos_to_name.items()} | |
| # Rewards only on these edges (otherwise 0.0) | |
| self.reward_edges = { | |
| ("E", self.DOWN): ("H", 1.0), | |
| ("G", self.RIGHT): ("H", 1.0), | |
| } | |
| self.rng = np.random.default_rng(seed) | |
| self.agent = self.start_state | |
| self.steps = 0 | |
| def reset(self) -> Tuple[str, Dict]: | |
| """Return (state_name, info). Start is always A.""" | |
| self.steps = 0 | |
| self.agent = self.start_state | |
| return self.agent, {"state_name": self.agent} | |
| def step(self, action_int: int) -> StepResult: | |
| if action_int not in (0, 1, 2, 3): | |
| raise ValueError("action must be 0(up),1(down),2(left),3(right)") | |
| self.steps += 1 | |
| s = self.agent | |
| s_next = self._move(s, action_int) | |
| reward = 0.0 | |
| key = (s, action_int) | |
| if key in self.reward_edges: | |
| target, rew = self.reward_edges[key] | |
| if s_next == target: | |
| reward = float(rew) | |
| self.agent = s_next | |
| terminated = (self.agent == self.terminal) | |
| truncated = (self.steps >= self.max_steps) | |
| return StepResult( | |
| obs=0, # unused in terminal version; kept for consistency | |
| reward=reward, | |
| terminated=terminated, | |
| truncated=truncated, | |
| info={"state_name": self.agent, "steps": self.steps}, | |
| ) | |
| def render(self) -> str: | |
| lines = [] | |
| for r in range(self.nrows): | |
| row = [] | |
| for c in range(self.ncols): | |
| if (r, c) == self.blocked: | |
| cell = "###" | |
| else: | |
| name = self.pos_to_name[(r, c)] | |
| cell = f"[{name}]" if name == self.agent else f" {name} " | |
| row.append(cell) | |
| lines.append(" ".join(row)) | |
| return "\n".join(lines) | |
| def _move(self, state_name: str, action: int) -> str: | |
| if state_name == self.terminal: | |
| return state_name # absorbing | |
| r, c = self.name_to_pos[state_name] | |
| dr, dc = { | |
| self.UP: (-1, 0), | |
| self.DOWN: (1, 0), | |
| self.LEFT: (0, -1), | |
| self.RIGHT: (0, 1), | |
| }[action] | |
| nr, nc = r + dr, c + dc | |
| # bump rules | |
| if not (0 <= nr < self.nrows and 0 <= nc < self.ncols): | |
| return state_name | |
| if (nr, nc) == self.blocked: | |
| return state_name | |
| return self.pos_to_name[(nr, nc)] | |
| def play_terminal(): | |
| env = SimpleGridWorld(max_steps=200) | |
| env.reset() | |
| # Trajectory as list of tuples: (s, a, r, s') | |
| traj: List[Tuple[str, str, float, str]] = [] | |
| print("SimpleGridWorld terminal") | |
| print("Commands: u/d/l/r to move, q to quit, reset to restart, help to show this.") | |
| print() | |
| print(env.render()) | |
| print("Trajectory:", traj) | |
| print() | |
| while True: | |
| cmd = input("move> ").strip().lower() | |
| if cmd in ("q", "quit", "exit"): | |
| print("bye") | |
| break | |
| if cmd in ("h", "help", "?"): | |
| print("Commands: u/d/l/r, q, reset, help") | |
| continue | |
| if cmd in ("reset", "rset"): | |
| env.reset() | |
| traj.clear() | |
| print() | |
| print(env.render()) | |
| print("Trajectory:", traj) | |
| print() | |
| continue | |
| if cmd not in ("u", "d", "l", "r"): | |
| print("Invalid. Use u/d/l/r (or help).") | |
| continue | |
| s = env.agent | |
| a_int = env.ACTIONS[cmd] | |
| res = env.step(a_int) | |
| s2 = env.agent | |
| traj.append((s, cmd, res.reward, s2)) | |
| print() | |
| print(env.render()) | |
| print("Trajectory:", traj) | |
| print(f"Last: {(s, cmd, res.reward, s2)}") | |
| print() | |
| if res.reward != 0.0: | |
| print("🎉 reward!") | |
| print() | |
| if res.terminated: | |
| print("✅ Reached terminal state H. Type 'reset' to play again, or 'q' to quit.") | |
| print() | |
| elif res.truncated: | |
| print("⏱️ Max steps reached. Type 'reset' to play again, or 'q' to quit.") | |
| print() | |
| if __name__ == "__main__": | |
| play_terminal() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment