Skip to content

Instantly share code, notes, and snippets.

@codezakh
Last active January 16, 2025 03:43
Show Gist options
  • Select an option

  • Save codezakh/26cbf72a967010056cc2832360894f54 to your computer and use it in GitHub Desktop.

Select an option

Save codezakh/26cbf72a967010056cc2832360894f54 to your computer and use it in GitHub Desktop.
A single file implementation of a generic, modular search algorithm for natural language reasoning.
from enum import Enum
from typing import (
Callable,
Generic,
Iterator,
Optional,
Protocol,
Sequence,
TypeVar,
)
from loguru import logger
from pydantic import BaseModel, Field
from typing_extensions import Self, assert_never
from ulid import ULID
T = TypeVar("T")
class Node(BaseModel):
"""Class defining the interface for a node in the search tree."""
reasoning_step: str
is_terminal: bool = False
parent: Optional[Self] = Field(default=None, repr=False)
level: int = 0
ulid: ULID = Field(default_factory=ULID)
def get_reasoning_trace(self) -> list[Self]:
"""
Get the reasoning trace represented by the node.
"""
lineage: list[Self] = []
current = self
while current:
lineage.append(current)
current = current.parent
return lineage
class SuccessorFunction(Protocol):
"""Function that returns the successors of a given node."""
def __call__(self, state: Node) -> Sequence[Node]: ...
class IsTerminalTestFunction(Protocol):
"""Function that returns whether a given node is a terminal node."""
def __call__(self, state: Node) -> bool: ...
class RankingFunction(Protocol):
"""Function that returns a score for a given node."""
def __call__(self, state: Node) -> float: ...
class SearchContainer(Protocol, Generic[T]):
"""Stores the nodes used during the search."""
def append(self, item: T) -> None: ...
def popleft(self) -> T: ...
def __bool__(self) -> bool: ...
def __len__(self) -> int: ...
def peek_left(self) -> Optional[T]: ...
def __iter__(self) -> Iterator[T]: ...
class SearchState(Enum):
FRONTIER_EMPTY = "FRONTIER_EMPTY"
STEP_COMPLETE = "STEP_COMPLETE"
BUDGET_EXCEEDED = "BUDGET_EXCEEDED"
class SearchResult(BaseModel):
search_state: SearchState
nodes: list[Node]
class GenericSearch:
"""
Generic search class that can be used to implement any search algorithm.
Parameters:
------------
initial_state: Node
The root node of the search. For example, could contain the initial prompt.
successor_fn: SuccessorFunction
A function that given the current node, returns the successors of the node.
For example, could return the next step in the reasoning process.
check_is_terminal_fn: IsTerminalTestFunction
A function that given the current node, returns whether the node is a terminal node.
For example, could return whether the reasoning step has produced a final answer.
container_factory: Callable[[], SearchContainer[Node]]
A function that returns a new instance of the search container.
For example, could return a queue or a stack.
node_budget: Optional[int]
The maximum number of nodes to visit.
beam_width: Optional[int]
The maximum number of successors to consider at each step.
beam_depth: Optional[int]
The maximum length of any given reasoning trace.
"""
def __init__(
self,
initial_state: Node,
successor_fn: SuccessorFunction,
check_is_terminal_fn: IsTerminalTestFunction,
container_factory: Callable[[], SearchContainer[Node]],
node_budget: Optional[int] = None,
beam_width: Optional[int] = None,
beam_depth: Optional[int] = None,
):
self.unvisited_nodes: SearchContainer[Node] = (
container_factory()
)
self.unvisited_nodes.append(initial_state)
self.visited_nodes: list[Node] = []
self.level = initial_state.level
self.successor_fn = successor_fn
self.check_is_terminal_node = check_is_terminal_fn
self.node_budget = node_budget
self.beam_width = beam_width
self.beam_depth = beam_depth
@property
def nodes(self) -> list[Node]:
"""
Returns every node created during the search.
This includes nodes that are in the queue and nodes that have been visited.
"""
return self.visited_nodes + list(self.unvisited_nodes)
@staticmethod
def _expand_node(
node: Node,
successor_fn: SuccessorFunction,
visited_nodes: list[Node],
unvisited_nodes: SearchContainer[Node],
beam_depth: Optional[int],
node_budget: Optional[int],
beam_width: Optional[int],
) -> Sequence[Node]:
if beam_depth is not None and node.level >= beam_depth:
logger.info(
f"Skipping expansion. beam_depth={beam_depth} >= node.level={node.level}"
)
return []
successors = successor_fn(node)
if beam_width is not None and len(successors) > beam_width:
successors = successors[:beam_width]
if node_budget is not None:
budget_remaining = node_budget - len(visited_nodes) - len(unvisited_nodes)
if budget_remaining <= 0:
logger.info(
f"Skipping expansion. budget_used={len(visited_nodes) + len(unvisited_nodes)} >= node_budget={node_budget}"
)
return []
successors = successors[:budget_remaining]
return successors
def expand_node(
self, node: Node
) -> Sequence[Node]:
return self._expand_node(
node=node,
successor_fn=self.successor_fn,
visited_nodes=self.visited_nodes,
unvisited_nodes=self.unvisited_nodes,
beam_depth=self.beam_depth,
node_budget=self.node_budget,
beam_width=self.beam_width,
)
def step(self) -> SearchState:
# Check that we have no exceeded the search budget.
if self.node_budget is not None:
budget_expended = len(self.visited_nodes) + len(self.unvisited_nodes)
budget_remaining = self.node_budget - budget_expended
if budget_remaining <= 0:
logger.info(
f"Terminating search. Budget exceeded. budget_expended={len(self.visited_nodes)} + {len(self.unvisited_nodes)}"
f" >= node_budget={self.node_budget}"
)
return SearchState.BUDGET_EXCEEDED
if self.unvisited_nodes:
node = self.unvisited_nodes.peek_left()
assert node is not None
node = self.unvisited_nodes.popleft()
self.visited_nodes.append(node)
logger.info(
f"Visiting node. node.level={node.level} plan_length={len(node.get_reasoning_trace())} queue_size={len(self.unvisited_nodes)}"
)
# Check if the node is a terminal node, and if so, terminate the step.
# We do not expand terminal nodes.
node.is_terminal = (is_terminal := self.check_is_terminal_node(node))
if is_terminal:
logger.info(f"Goal found at level {node.level}")
return SearchState.STEP_COMPLETE
# Expand the node and add its successors to the unvisited nodes.
successors = self.expand_node(node)
logger.info(
f"Expanded node at level {node.level} into {len(successors)} successors"
)
for successor in successors:
self.unvisited_nodes.append(successor)
return SearchState.STEP_COMPLETE
else:
return SearchState.FRONTIER_EMPTY
def run(self) -> SearchResult:
logger.info(f"Initial state: {self.unvisited_nodes.peek_left()}")
logger.info("Running search indefinitely until node expansion limit reached.")
while True:
search_state = self.step()
match search_state:
case SearchState.BUDGET_EXCEEDED:
break
case SearchState.FRONTIER_EMPTY:
break
case SearchState.STEP_COMPLETE:
pass
case _:
assert_never(search_state)
return SearchResult(
search_state=search_state,
nodes=self.nodes,
)
from enum import Enum
from typing import (
Callable,
Generic,
Iterator,
Optional,
Protocol,
Sequence,
TypeVar,
)
from loguru import logger
from pydantic import BaseModel, Field
from typing_extensions import Self, assert_never
from ulid import ULID
T = TypeVar("T")
class Node(BaseModel):
"""Class defining the interface for a node in the search tree."""
reasoning_step: str
is_terminal: bool = False
parent: Optional[Self] = Field(default=None, repr=False)
level: int = 0
ulid: ULID = Field(default_factory=ULID)
visits: int = 0
total_value: float = 0.0
path_cost: float = 0.0
def get_reasoning_trace(self) -> list[Self]:
"""
Get the reasoning trace represented by the node.
"""
lineage: list[Self] = []
current = self
while current:
lineage.append(current)
current = current.parent
return lineage
class SuccessorFunction(Protocol):
"""Function that returns the successors of a given node."""
def __call__(self, state: Node) -> Sequence[Node]: ...
class IsTerminalTestFunction(Protocol):
"""Function that returns whether a given node is a terminal node."""
def __call__(self, state: Node) -> bool: ...
class RankingFunction(Protocol):
"""Function that returns a score for a given node."""
def __call__(self, state: Node) -> float: ...
class SearchContainer(Protocol, Generic[T]):
"""Stores the nodes used during the search."""
def append(self, item: T) -> None: ...
def popleft(self) -> T: ...
def __bool__(self) -> bool: ...
def __len__(self) -> int: ...
def peek_left(self) -> Optional[T]: ...
def __iter__(self) -> Iterator[T]: ...
class SearchState(Enum):
FRONTIER_EMPTY = "FRONTIER_EMPTY"
STEP_COMPLETE = "STEP_COMPLETE"
BUDGET_EXCEEDED = "BUDGET_EXCEEDED"
class SearchResult(BaseModel):
search_state: SearchState
nodes: list[Node]
class CostFunction(Protocol):
"""Function that calculates the cost between two nodes."""
def __call__(self, from_state: Node, to_state: Node) -> float: ...
class HeuristicFunction(Protocol):
"""Function that estimates the cost from a node to the goal."""
def __call__(self, state: Node) -> float: ...
class GenericSearch:
"""
Generic search class that can be used to implement any search algorithm.
Parameters:
------------
initial_state: Node
The root node of the search. For example, could contain the initial prompt.
successor_fn: SuccessorFunction
A function that given the current node, returns the successors of the node.
For example, could return the next step in the reasoning process.
check_is_terminal_fn: IsTerminalTestFunction
A function that given the current node, returns whether the node is a terminal node.
For example, could return whether the reasoning step has produced a final answer.
container_factory: Callable[[], SearchContainer[Node]]
A function that returns a new instance of the search container.
For example, could return a queue or a stack.
node_budget: Optional[int]
The maximum number of nodes to visit.
beam_width: Optional[int]
The maximum number of successors to consider at each step.
beam_depth: Optional[int]
The maximum length of any given reasoning trace.
"""
def __init__(
self,
initial_state: Node,
successor_fn: SuccessorFunction,
check_is_terminal_fn: IsTerminalTestFunction,
container_factory: Callable[[], SearchContainer[Node]],
cost_fn: Optional[CostFunction] = None,
heuristic_fn: Optional[HeuristicFunction] = None,
node_budget: Optional[int] = None,
beam_width: Optional[int] = None,
beam_depth: Optional[int] = None,
):
self.unvisited_nodes: SearchContainer[Node] = (
container_factory()
)
self.unvisited_nodes.append(initial_state)
self.visited_nodes: list[Node] = []
self.level = initial_state.level
self.successor_fn = successor_fn
self.check_is_terminal_node = check_is_terminal_fn
self.node_budget = node_budget
self.beam_width = beam_width
self.beam_depth = beam_depth
self.cost_fn = cost_fn
self.heuristic_fn = heuristic_fn
@property
def nodes(self) -> list[Node]:
"""
Returns every node created during the search.
This includes nodes that are in the queue and nodes that have been visited.
"""
return self.visited_nodes + list(self.unvisited_nodes)
@staticmethod
def _expand_node(
node: Node,
successor_fn: SuccessorFunction,
visited_nodes: list[Node],
unvisited_nodes: SearchContainer[Node],
beam_depth: Optional[int],
node_budget: Optional[int],
beam_width: Optional[int],
cost_fn: Optional[CostFunction] = None,
) -> Sequence[Node]:
if beam_depth is not None and node.level >= beam_depth:
logger.info(
f"Skipping expansion. beam_depth={beam_depth} >= node.level={node.level}"
)
return []
successors = successor_fn(node)
if cost_fn is not None:
for successor in successors:
edge_cost = cost_fn(node, successor)
successor.path_cost = node.path_cost + edge_cost
if beam_width is not None and len(successors) > beam_width:
successors = successors[:beam_width]
if node_budget is not None:
budget_remaining = node_budget - len(visited_nodes) - len(unvisited_nodes)
if budget_remaining <= 0:
logger.info(
f"Skipping expansion. budget_used={len(visited_nodes) + len(unvisited_nodes)} >= node_budget={node_budget}"
)
return []
successors = successors[:budget_remaining]
return successors
def expand_node(
self, node: Node
) -> Sequence[Node]:
return self._expand_node(
node=node,
successor_fn=self.successor_fn,
visited_nodes=self.visited_nodes,
unvisited_nodes=self.unvisited_nodes,
beam_depth=self.beam_depth,
node_budget=self.node_budget,
beam_width=self.beam_width,
cost_fn=self.cost_fn,
)
def step(self) -> SearchState:
# Check that we have no exceeded the search budget.
if self.node_budget is not None:
budget_expended = len(self.visited_nodes) + len(self.unvisited_nodes)
budget_remaining = self.node_budget - budget_expended
if budget_remaining <= 0:
logger.info(
f"Terminating search. Budget exceeded. budget_expended={len(self.visited_nodes)} + {len(self.unvisited_nodes)}"
f" >= node_budget={self.node_budget}"
)
return SearchState.BUDGET_EXCEEDED
if self.unvisited_nodes:
node = self.unvisited_nodes.peek_left()
assert node is not None
node = self.unvisited_nodes.popleft()
self.visited_nodes.append(node)
logger.info(
f"Visiting node. node.level={node.level} plan_length={len(node.get_reasoning_trace())} queue_size={len(self.unvisited_nodes)}"
)
# Check if the node is a terminal node, and if so, terminate the step.
# We do not expand terminal nodes.
node.is_terminal = (is_terminal := self.check_is_terminal_node(node))
if is_terminal:
logger.info(f"Goal found at level {node.level}")
return SearchState.STEP_COMPLETE
# Expand the node and add its successors to the unvisited nodes.
successors = self.expand_node(node)
logger.info(
f"Expanded node at level {node.level} into {len(successors)} successors"
)
for successor in successors:
self.unvisited_nodes.append(successor)
return SearchState.STEP_COMPLETE
else:
return SearchState.FRONTIER_EMPTY
def run(self) -> SearchResult:
logger.info(f"Initial state: {self.unvisited_nodes.peek_left()}")
logger.info("Running search indefinitely until node expansion limit reached.")
while True:
search_state = self.step()
match search_state:
case SearchState.BUDGET_EXCEEDED:
break
case SearchState.FRONTIER_EMPTY:
break
case SearchState.STEP_COMPLETE:
pass
case _:
assert_never(search_state)
return SearchResult(
search_state=search_state,
nodes=self.nodes,
)
from enum import Enum
from typing import (
Callable,
Generic,
Iterator,
Optional,
Protocol,
Sequence,
TypeVar,
)
from loguru import logger
from pydantic import BaseModel, Field
from typing_extensions import Self, assert_never
from ulid import ULID
from math import sqrt, log
T = TypeVar("T")
class Node(BaseModel):
"""Class defining the interface for a node in the search tree."""
reasoning_step: str
is_terminal: bool = False
parent: Optional[Self] = Field(default=None, repr=False)
level: int = 0
ulid: ULID = Field(default_factory=ULID)
visits: int = 0
total_value: float = 0.0
def get_reasoning_trace(self) -> list[Self]:
"""
Get the reasoning trace represented by the node.
"""
lineage: list[Self] = []
current = self
while current:
lineage.append(current)
current = current.parent
return lineage
def ucb1_score(self, exploration_weight: float = sqrt(2)) -> float:
"""Calculate the UCB1 score for this node."""
if self.visits == 0:
return float('inf')
if self.parent is None:
return float('inf')
exploitation = self.total_value / self.visits
exploration = exploration_weight * sqrt(log(self.parent.visits) / self.visits)
return exploitation + exploration
def backpropagate(self, value: float) -> None:
"""Update node statistics after a simulation."""
current = self
while current:
current.visits += 1
current.total_value += value
current = current.parent
class SuccessorFunction(Protocol):
"""Function that returns the successors of a given node."""
def __call__(self, state: Node) -> Sequence[Node]: ...
class IsTerminalTestFunction(Protocol):
"""Function that returns whether a given node is a terminal node."""
def __call__(self, state: Node) -> bool: ...
class RankingFunction(Protocol):
"""Function that returns a score for a given node."""
def __call__(self, state: Node) -> float: ...
class SearchContainer(Protocol, Generic[T]):
"""Stores the nodes used during the search."""
def append(self, item: T) -> None: ...
def popleft(self) -> T: ...
def __bool__(self) -> bool: ...
def __len__(self) -> int: ...
def peek_left(self) -> Optional[T]: ...
def __iter__(self) -> Iterator[T]: ...
class SearchState(Enum):
FRONTIER_EMPTY = "FRONTIER_EMPTY"
STEP_COMPLETE = "STEP_COMPLETE"
BUDGET_EXCEEDED = "BUDGET_EXCEEDED"
class SearchResult(BaseModel):
search_state: SearchState
nodes: list[Node]
class SimulationFunction(Protocol):
"""Function that simulates from a given node to get an estimated value."""
def __call__(self, state: Node) -> float: ...
class GenericSearch:
"""
Generic search class that can be used to implement any search algorithm.
Parameters:
------------
initial_state: Node
The root node of the search. For example, could contain the initial prompt.
successor_fn: SuccessorFunction
A function that given the current node, returns the successors of the node.
For example, could return the next step in the reasoning process.
check_is_terminal_fn: IsTerminalTestFunction
A function that given the current node, returns whether the node is a terminal node.
For example, could return whether the reasoning step has produced a final answer.
container_factory: Callable[[], SearchContainer[Node]]
A function that returns a new instance of the search container.
For example, could return a queue or a stack.
simulation_fn: Optional[SimulationFunction]
A function that simulates from a given node to get an estimated value.
node_budget: Optional[int]
The maximum number of nodes to visit.
beam_width: Optional[int]
The maximum number of successors to consider at each step.
beam_depth: Optional[int]
The maximum length of any given reasoning trace.
"""
def __init__(
self,
initial_state: Node,
successor_fn: SuccessorFunction,
check_is_terminal_fn: IsTerminalTestFunction,
container_factory: Callable[[], SearchContainer[Node]],
simulation_fn: Optional[SimulationFunction] = None,
node_budget: Optional[int] = None,
beam_width: Optional[int] = None,
beam_depth: Optional[int] = None,
):
self.unvisited_nodes: SearchContainer[Node] = (
container_factory()
)
self.unvisited_nodes.append(initial_state)
self.visited_nodes: list[Node] = []
self.level = initial_state.level
self.successor_fn = successor_fn
self.check_is_terminal_node = check_is_terminal_fn
self.node_budget = node_budget
self.beam_width = beam_width
self.beam_depth = beam_depth
self.simulation_fn = simulation_fn
@property
def nodes(self) -> list[Node]:
"""
Returns every node created during the search.
This includes nodes that are in the queue and nodes that have been visited.
"""
return self.visited_nodes + list(self.unvisited_nodes)
@staticmethod
def _expand_node(
node: Node,
successor_fn: SuccessorFunction,
visited_nodes: list[Node],
unvisited_nodes: SearchContainer[Node],
beam_depth: Optional[int],
node_budget: Optional[int],
beam_width: Optional[int],
) -> Sequence[Node]:
if beam_depth is not None and node.level >= beam_depth:
logger.info(
f"Skipping expansion. beam_depth={beam_depth} >= node.level={node.level}"
)
return []
successors = successor_fn(node)
if beam_width is not None and len(successors) > beam_width:
successors = successors[:beam_width]
if node_budget is not None:
budget_remaining = node_budget - len(visited_nodes) - len(unvisited_nodes)
if budget_remaining <= 0:
logger.info(
f"Skipping expansion. budget_used={len(visited_nodes) + len(unvisited_nodes)} >= node_budget={node_budget}"
)
return []
successors = successors[:budget_remaining]
return successors
def expand_node(
self, node: Node
) -> Sequence[Node]:
return self._expand_node(
node=node,
successor_fn=self.successor_fn,
visited_nodes=self.visited_nodes,
unvisited_nodes=self.unvisited_nodes,
beam_depth=self.beam_depth,
node_budget=self.node_budget,
beam_width=self.beam_width,
)
def step(self) -> SearchState:
# Check that we have no exceeded the search budget.
if self.node_budget is not None:
budget_expended = len(self.visited_nodes) + len(self.unvisited_nodes)
budget_remaining = self.node_budget - budget_expended
if budget_remaining <= 0:
logger.info(
f"Terminating search. Budget exceeded. budget_expended={len(self.visited_nodes)} + {len(self.unvisited_nodes)}"
f" >= node_budget={self.node_budget}"
)
return SearchState.BUDGET_EXCEEDED
if self.unvisited_nodes:
node = self.unvisited_nodes.peek_left()
assert node is not None
node = self.unvisited_nodes.popleft()
self.visited_nodes.append(node)
logger.info(
f"Visiting node. node.level={node.level} plan_length={len(node.get_reasoning_trace())} queue_size={len(self.unvisited_nodes)}"
)
# Check if the node is a terminal node, and if so, terminate the step.
# We do not expand terminal nodes.
node.is_terminal = (is_terminal := self.check_is_terminal_node(node))
# Add simulation and backpropagation for MCTS
if self.simulation_fn and not is_terminal:
value = self.simulation_fn(node)
node.backpropagate(value)
if is_terminal:
logger.info(f"Goal found at level {node.level}")
return SearchState.STEP_COMPLETE
# Expand the node and add its successors to the unvisited nodes.
successors = self.expand_node(node)
logger.info(
f"Expanded node at level {node.level} into {len(successors)} successors"
)
for successor in successors:
self.unvisited_nodes.append(successor)
return SearchState.STEP_COMPLETE
else:
return SearchState.FRONTIER_EMPTY
def run(self) -> SearchResult:
logger.info(f"Initial state: {self.unvisited_nodes.peek_left()}")
logger.info("Running search indefinitely until node expansion limit reached.")
while True:
search_state = self.step()
match search_state:
case SearchState.BUDGET_EXCEEDED:
break
case SearchState.FRONTIER_EMPTY:
break
case SearchState.STEP_COMPLETE:
pass
case _:
assert_never(search_state)
return SearchResult(
search_state=search_state,
nodes=self.nodes,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment