Skip to content

Instantly share code, notes, and snippets.

@evanthebouncy
Created January 23, 2026 15:09
Show Gist options
  • Select an option

  • Save evanthebouncy/d60ea04a95594c1ab77916f4b8e6d0cc to your computer and use it in GitHub Desktop.

Select an option

Save evanthebouncy/d60ea04a95594c1ab77916f4b8e6d0cc to your computer and use it in GitHub Desktop.
manually play the grid game
# play_gridworld_terminal.py
# Terminal-playable GridWorld + trajectory tracking (only dependency: numpy)
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
@dataclass(frozen=True)
class StepResult:
obs: int
reward: float
terminated: bool
truncated: bool
info: Dict
class SimpleGridWorld:
"""
Layout (3x3 with blocked center):
A B C
D X E
F G H (H is terminal)
Actions: u/d/l/r
Start: A
Rewards:
- (E, d) -> H gives +1.0
- (G, r) -> H gives +1.0
"""
ACTIONS = {"u": 0, "d": 1, "l": 2, "r": 3}
UP, DOWN, LEFT, RIGHT = 0, 1, 2, 3
def __init__(self, max_steps: int = 200, seed: Optional[int] = None):
self.max_steps = int(max_steps)
self.nrows, self.ncols = 3, 3
self.blocked = (1, 1)
self.start_state = "A"
self.terminal = "H"
self.pos_to_name = {
(0, 0): "A", (0, 1): "B", (0, 2): "C",
(1, 0): "D", (1, 2): "E",
(2, 0): "F", (2, 1): "G", (2, 2): "H",
}
self.name_to_pos = {v: k for k, v in self.pos_to_name.items()}
# Rewards only on these edges (otherwise 0.0)
self.reward_edges = {
("E", self.DOWN): ("H", 1.0),
("G", self.RIGHT): ("H", 1.0),
}
self.rng = np.random.default_rng(seed)
self.agent = self.start_state
self.steps = 0
def reset(self) -> Tuple[str, Dict]:
"""Return (state_name, info). Start is always A."""
self.steps = 0
self.agent = self.start_state
return self.agent, {"state_name": self.agent}
def step(self, action_int: int) -> StepResult:
if action_int not in (0, 1, 2, 3):
raise ValueError("action must be 0(up),1(down),2(left),3(right)")
self.steps += 1
s = self.agent
s_next = self._move(s, action_int)
reward = 0.0
key = (s, action_int)
if key in self.reward_edges:
target, rew = self.reward_edges[key]
if s_next == target:
reward = float(rew)
self.agent = s_next
terminated = (self.agent == self.terminal)
truncated = (self.steps >= self.max_steps)
return StepResult(
obs=0, # unused in terminal version; kept for consistency
reward=reward,
terminated=terminated,
truncated=truncated,
info={"state_name": self.agent, "steps": self.steps},
)
def render(self) -> str:
lines = []
for r in range(self.nrows):
row = []
for c in range(self.ncols):
if (r, c) == self.blocked:
cell = "###"
else:
name = self.pos_to_name[(r, c)]
cell = f"[{name}]" if name == self.agent else f" {name} "
row.append(cell)
lines.append(" ".join(row))
return "\n".join(lines)
def _move(self, state_name: str, action: int) -> str:
if state_name == self.terminal:
return state_name # absorbing
r, c = self.name_to_pos[state_name]
dr, dc = {
self.UP: (-1, 0),
self.DOWN: (1, 0),
self.LEFT: (0, -1),
self.RIGHT: (0, 1),
}[action]
nr, nc = r + dr, c + dc
# bump rules
if not (0 <= nr < self.nrows and 0 <= nc < self.ncols):
return state_name
if (nr, nc) == self.blocked:
return state_name
return self.pos_to_name[(nr, nc)]
def play_terminal():
env = SimpleGridWorld(max_steps=200)
env.reset()
# Trajectory as list of tuples: (s, a, r, s')
traj: List[Tuple[str, str, float, str]] = []
print("SimpleGridWorld terminal")
print("Commands: u/d/l/r to move, q to quit, reset to restart, help to show this.")
print()
print(env.render())
print("Trajectory:", traj)
print()
while True:
cmd = input("move> ").strip().lower()
if cmd in ("q", "quit", "exit"):
print("bye")
break
if cmd in ("h", "help", "?"):
print("Commands: u/d/l/r, q, reset, help")
continue
if cmd in ("reset", "rset"):
env.reset()
traj.clear()
print()
print(env.render())
print("Trajectory:", traj)
print()
continue
if cmd not in ("u", "d", "l", "r"):
print("Invalid. Use u/d/l/r (or help).")
continue
s = env.agent
a_int = env.ACTIONS[cmd]
res = env.step(a_int)
s2 = env.agent
traj.append((s, cmd, res.reward, s2))
print()
print(env.render())
print("Trajectory:", traj)
print(f"Last: {(s, cmd, res.reward, s2)}")
print()
if res.reward != 0.0:
print("🎉 reward!")
print()
if res.terminated:
print("✅ Reached terminal state H. Type 'reset' to play again, or 'q' to quit.")
print()
elif res.truncated:
print("⏱️ Max steps reached. Type 'reset' to play again, or 'q' to quit.")
print()
if __name__ == "__main__":
play_terminal()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment