Last active
January 16, 2025 03:43
-
-
Save codezakh/26cbf72a967010056cc2832360894f54 to your computer and use it in GitHub Desktop.
A single file implementation of a generic, modular search algorithm for natural language reasoning.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from enum import Enum | |
| from typing import ( | |
| Callable, | |
| Generic, | |
| Iterator, | |
| Optional, | |
| Protocol, | |
| Sequence, | |
| TypeVar, | |
| ) | |
| from loguru import logger | |
| from pydantic import BaseModel, Field | |
| from typing_extensions import Self, assert_never | |
| from ulid import ULID | |
| T = TypeVar("T") | |
| class Node(BaseModel): | |
| """Class defining the interface for a node in the search tree.""" | |
| reasoning_step: str | |
| is_terminal: bool = False | |
| parent: Optional[Self] = Field(default=None, repr=False) | |
| level: int = 0 | |
| ulid: ULID = Field(default_factory=ULID) | |
| def get_reasoning_trace(self) -> list[Self]: | |
| """ | |
| Get the reasoning trace represented by the node. | |
| """ | |
| lineage: list[Self] = [] | |
| current = self | |
| while current: | |
| lineage.append(current) | |
| current = current.parent | |
| return lineage | |
| class SuccessorFunction(Protocol): | |
| """Function that returns the successors of a given node.""" | |
| def __call__(self, state: Node) -> Sequence[Node]: ... | |
| class IsTerminalTestFunction(Protocol): | |
| """Function that returns whether a given node is a terminal node.""" | |
| def __call__(self, state: Node) -> bool: ... | |
| class RankingFunction(Protocol): | |
| """Function that returns a score for a given node.""" | |
| def __call__(self, state: Node) -> float: ... | |
| class SearchContainer(Protocol, Generic[T]): | |
| """Stores the nodes used during the search.""" | |
| def append(self, item: T) -> None: ... | |
| def popleft(self) -> T: ... | |
| def __bool__(self) -> bool: ... | |
| def __len__(self) -> int: ... | |
| def peek_left(self) -> Optional[T]: ... | |
| def __iter__(self) -> Iterator[T]: ... | |
| class SearchState(Enum): | |
| FRONTIER_EMPTY = "FRONTIER_EMPTY" | |
| STEP_COMPLETE = "STEP_COMPLETE" | |
| BUDGET_EXCEEDED = "BUDGET_EXCEEDED" | |
| class SearchResult(BaseModel): | |
| search_state: SearchState | |
| nodes: list[Node] | |
| class GenericSearch: | |
| """ | |
| Generic search class that can be used to implement any search algorithm. | |
| Parameters: | |
| ------------ | |
| initial_state: Node | |
| The root node of the search. For example, could contain the initial prompt. | |
| successor_fn: SuccessorFunction | |
| A function that given the current node, returns the successors of the node. | |
| For example, could return the next step in the reasoning process. | |
| check_is_terminal_fn: IsTerminalTestFunction | |
| A function that given the current node, returns whether the node is a terminal node. | |
| For example, could return whether the reasoning step has produced a final answer. | |
| container_factory: Callable[[], SearchContainer[Node]] | |
| A function that returns a new instance of the search container. | |
| For example, could return a queue or a stack. | |
| node_budget: Optional[int] | |
| The maximum number of nodes to visit. | |
| beam_width: Optional[int] | |
| The maximum number of successors to consider at each step. | |
| beam_depth: Optional[int] | |
| The maximum length of any given reasoning trace. | |
| """ | |
| def __init__( | |
| self, | |
| initial_state: Node, | |
| successor_fn: SuccessorFunction, | |
| check_is_terminal_fn: IsTerminalTestFunction, | |
| container_factory: Callable[[], SearchContainer[Node]], | |
| node_budget: Optional[int] = None, | |
| beam_width: Optional[int] = None, | |
| beam_depth: Optional[int] = None, | |
| ): | |
| self.unvisited_nodes: SearchContainer[Node] = ( | |
| container_factory() | |
| ) | |
| self.unvisited_nodes.append(initial_state) | |
| self.visited_nodes: list[Node] = [] | |
| self.level = initial_state.level | |
| self.successor_fn = successor_fn | |
| self.check_is_terminal_node = check_is_terminal_fn | |
| self.node_budget = node_budget | |
| self.beam_width = beam_width | |
| self.beam_depth = beam_depth | |
| @property | |
| def nodes(self) -> list[Node]: | |
| """ | |
| Returns every node created during the search. | |
| This includes nodes that are in the queue and nodes that have been visited. | |
| """ | |
| return self.visited_nodes + list(self.unvisited_nodes) | |
| @staticmethod | |
| def _expand_node( | |
| node: Node, | |
| successor_fn: SuccessorFunction, | |
| visited_nodes: list[Node], | |
| unvisited_nodes: SearchContainer[Node], | |
| beam_depth: Optional[int], | |
| node_budget: Optional[int], | |
| beam_width: Optional[int], | |
| ) -> Sequence[Node]: | |
| if beam_depth is not None and node.level >= beam_depth: | |
| logger.info( | |
| f"Skipping expansion. beam_depth={beam_depth} >= node.level={node.level}" | |
| ) | |
| return [] | |
| successors = successor_fn(node) | |
| if beam_width is not None and len(successors) > beam_width: | |
| successors = successors[:beam_width] | |
| if node_budget is not None: | |
| budget_remaining = node_budget - len(visited_nodes) - len(unvisited_nodes) | |
| if budget_remaining <= 0: | |
| logger.info( | |
| f"Skipping expansion. budget_used={len(visited_nodes) + len(unvisited_nodes)} >= node_budget={node_budget}" | |
| ) | |
| return [] | |
| successors = successors[:budget_remaining] | |
| return successors | |
| def expand_node( | |
| self, node: Node | |
| ) -> Sequence[Node]: | |
| return self._expand_node( | |
| node=node, | |
| successor_fn=self.successor_fn, | |
| visited_nodes=self.visited_nodes, | |
| unvisited_nodes=self.unvisited_nodes, | |
| beam_depth=self.beam_depth, | |
| node_budget=self.node_budget, | |
| beam_width=self.beam_width, | |
| ) | |
| def step(self) -> SearchState: | |
| # Check that we have no exceeded the search budget. | |
| if self.node_budget is not None: | |
| budget_expended = len(self.visited_nodes) + len(self.unvisited_nodes) | |
| budget_remaining = self.node_budget - budget_expended | |
| if budget_remaining <= 0: | |
| logger.info( | |
| f"Terminating search. Budget exceeded. budget_expended={len(self.visited_nodes)} + {len(self.unvisited_nodes)}" | |
| f" >= node_budget={self.node_budget}" | |
| ) | |
| return SearchState.BUDGET_EXCEEDED | |
| if self.unvisited_nodes: | |
| node = self.unvisited_nodes.peek_left() | |
| assert node is not None | |
| node = self.unvisited_nodes.popleft() | |
| self.visited_nodes.append(node) | |
| logger.info( | |
| f"Visiting node. node.level={node.level} plan_length={len(node.get_reasoning_trace())} queue_size={len(self.unvisited_nodes)}" | |
| ) | |
| # Check if the node is a terminal node, and if so, terminate the step. | |
| # We do not expand terminal nodes. | |
| node.is_terminal = (is_terminal := self.check_is_terminal_node(node)) | |
| if is_terminal: | |
| logger.info(f"Goal found at level {node.level}") | |
| return SearchState.STEP_COMPLETE | |
| # Expand the node and add its successors to the unvisited nodes. | |
| successors = self.expand_node(node) | |
| logger.info( | |
| f"Expanded node at level {node.level} into {len(successors)} successors" | |
| ) | |
| for successor in successors: | |
| self.unvisited_nodes.append(successor) | |
| return SearchState.STEP_COMPLETE | |
| else: | |
| return SearchState.FRONTIER_EMPTY | |
| def run(self) -> SearchResult: | |
| logger.info(f"Initial state: {self.unvisited_nodes.peek_left()}") | |
| logger.info("Running search indefinitely until node expansion limit reached.") | |
| while True: | |
| search_state = self.step() | |
| match search_state: | |
| case SearchState.BUDGET_EXCEEDED: | |
| break | |
| case SearchState.FRONTIER_EMPTY: | |
| break | |
| case SearchState.STEP_COMPLETE: | |
| pass | |
| case _: | |
| assert_never(search_state) | |
| return SearchResult( | |
| search_state=search_state, | |
| nodes=self.nodes, | |
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from enum import Enum | |
| from typing import ( | |
| Callable, | |
| Generic, | |
| Iterator, | |
| Optional, | |
| Protocol, | |
| Sequence, | |
| TypeVar, | |
| ) | |
| from loguru import logger | |
| from pydantic import BaseModel, Field | |
| from typing_extensions import Self, assert_never | |
| from ulid import ULID | |
| T = TypeVar("T") | |
| class Node(BaseModel): | |
| """Class defining the interface for a node in the search tree.""" | |
| reasoning_step: str | |
| is_terminal: bool = False | |
| parent: Optional[Self] = Field(default=None, repr=False) | |
| level: int = 0 | |
| ulid: ULID = Field(default_factory=ULID) | |
| visits: int = 0 | |
| total_value: float = 0.0 | |
| path_cost: float = 0.0 | |
| def get_reasoning_trace(self) -> list[Self]: | |
| """ | |
| Get the reasoning trace represented by the node. | |
| """ | |
| lineage: list[Self] = [] | |
| current = self | |
| while current: | |
| lineage.append(current) | |
| current = current.parent | |
| return lineage | |
| class SuccessorFunction(Protocol): | |
| """Function that returns the successors of a given node.""" | |
| def __call__(self, state: Node) -> Sequence[Node]: ... | |
| class IsTerminalTestFunction(Protocol): | |
| """Function that returns whether a given node is a terminal node.""" | |
| def __call__(self, state: Node) -> bool: ... | |
| class RankingFunction(Protocol): | |
| """Function that returns a score for a given node.""" | |
| def __call__(self, state: Node) -> float: ... | |
| class SearchContainer(Protocol, Generic[T]): | |
| """Stores the nodes used during the search.""" | |
| def append(self, item: T) -> None: ... | |
| def popleft(self) -> T: ... | |
| def __bool__(self) -> bool: ... | |
| def __len__(self) -> int: ... | |
| def peek_left(self) -> Optional[T]: ... | |
| def __iter__(self) -> Iterator[T]: ... | |
| class SearchState(Enum): | |
| FRONTIER_EMPTY = "FRONTIER_EMPTY" | |
| STEP_COMPLETE = "STEP_COMPLETE" | |
| BUDGET_EXCEEDED = "BUDGET_EXCEEDED" | |
| class SearchResult(BaseModel): | |
| search_state: SearchState | |
| nodes: list[Node] | |
| class CostFunction(Protocol): | |
| """Function that calculates the cost between two nodes.""" | |
| def __call__(self, from_state: Node, to_state: Node) -> float: ... | |
| class HeuristicFunction(Protocol): | |
| """Function that estimates the cost from a node to the goal.""" | |
| def __call__(self, state: Node) -> float: ... | |
| class GenericSearch: | |
| """ | |
| Generic search class that can be used to implement any search algorithm. | |
| Parameters: | |
| ------------ | |
| initial_state: Node | |
| The root node of the search. For example, could contain the initial prompt. | |
| successor_fn: SuccessorFunction | |
| A function that given the current node, returns the successors of the node. | |
| For example, could return the next step in the reasoning process. | |
| check_is_terminal_fn: IsTerminalTestFunction | |
| A function that given the current node, returns whether the node is a terminal node. | |
| For example, could return whether the reasoning step has produced a final answer. | |
| container_factory: Callable[[], SearchContainer[Node]] | |
| A function that returns a new instance of the search container. | |
| For example, could return a queue or a stack. | |
| node_budget: Optional[int] | |
| The maximum number of nodes to visit. | |
| beam_width: Optional[int] | |
| The maximum number of successors to consider at each step. | |
| beam_depth: Optional[int] | |
| The maximum length of any given reasoning trace. | |
| """ | |
| def __init__( | |
| self, | |
| initial_state: Node, | |
| successor_fn: SuccessorFunction, | |
| check_is_terminal_fn: IsTerminalTestFunction, | |
| container_factory: Callable[[], SearchContainer[Node]], | |
| cost_fn: Optional[CostFunction] = None, | |
| heuristic_fn: Optional[HeuristicFunction] = None, | |
| node_budget: Optional[int] = None, | |
| beam_width: Optional[int] = None, | |
| beam_depth: Optional[int] = None, | |
| ): | |
| self.unvisited_nodes: SearchContainer[Node] = ( | |
| container_factory() | |
| ) | |
| self.unvisited_nodes.append(initial_state) | |
| self.visited_nodes: list[Node] = [] | |
| self.level = initial_state.level | |
| self.successor_fn = successor_fn | |
| self.check_is_terminal_node = check_is_terminal_fn | |
| self.node_budget = node_budget | |
| self.beam_width = beam_width | |
| self.beam_depth = beam_depth | |
| self.cost_fn = cost_fn | |
| self.heuristic_fn = heuristic_fn | |
| @property | |
| def nodes(self) -> list[Node]: | |
| """ | |
| Returns every node created during the search. | |
| This includes nodes that are in the queue and nodes that have been visited. | |
| """ | |
| return self.visited_nodes + list(self.unvisited_nodes) | |
| @staticmethod | |
| def _expand_node( | |
| node: Node, | |
| successor_fn: SuccessorFunction, | |
| visited_nodes: list[Node], | |
| unvisited_nodes: SearchContainer[Node], | |
| beam_depth: Optional[int], | |
| node_budget: Optional[int], | |
| beam_width: Optional[int], | |
| cost_fn: Optional[CostFunction] = None, | |
| ) -> Sequence[Node]: | |
| if beam_depth is not None and node.level >= beam_depth: | |
| logger.info( | |
| f"Skipping expansion. beam_depth={beam_depth} >= node.level={node.level}" | |
| ) | |
| return [] | |
| successors = successor_fn(node) | |
| if cost_fn is not None: | |
| for successor in successors: | |
| edge_cost = cost_fn(node, successor) | |
| successor.path_cost = node.path_cost + edge_cost | |
| if beam_width is not None and len(successors) > beam_width: | |
| successors = successors[:beam_width] | |
| if node_budget is not None: | |
| budget_remaining = node_budget - len(visited_nodes) - len(unvisited_nodes) | |
| if budget_remaining <= 0: | |
| logger.info( | |
| f"Skipping expansion. budget_used={len(visited_nodes) + len(unvisited_nodes)} >= node_budget={node_budget}" | |
| ) | |
| return [] | |
| successors = successors[:budget_remaining] | |
| return successors | |
| def expand_node( | |
| self, node: Node | |
| ) -> Sequence[Node]: | |
| return self._expand_node( | |
| node=node, | |
| successor_fn=self.successor_fn, | |
| visited_nodes=self.visited_nodes, | |
| unvisited_nodes=self.unvisited_nodes, | |
| beam_depth=self.beam_depth, | |
| node_budget=self.node_budget, | |
| beam_width=self.beam_width, | |
| cost_fn=self.cost_fn, | |
| ) | |
| def step(self) -> SearchState: | |
| # Check that we have no exceeded the search budget. | |
| if self.node_budget is not None: | |
| budget_expended = len(self.visited_nodes) + len(self.unvisited_nodes) | |
| budget_remaining = self.node_budget - budget_expended | |
| if budget_remaining <= 0: | |
| logger.info( | |
| f"Terminating search. Budget exceeded. budget_expended={len(self.visited_nodes)} + {len(self.unvisited_nodes)}" | |
| f" >= node_budget={self.node_budget}" | |
| ) | |
| return SearchState.BUDGET_EXCEEDED | |
| if self.unvisited_nodes: | |
| node = self.unvisited_nodes.peek_left() | |
| assert node is not None | |
| node = self.unvisited_nodes.popleft() | |
| self.visited_nodes.append(node) | |
| logger.info( | |
| f"Visiting node. node.level={node.level} plan_length={len(node.get_reasoning_trace())} queue_size={len(self.unvisited_nodes)}" | |
| ) | |
| # Check if the node is a terminal node, and if so, terminate the step. | |
| # We do not expand terminal nodes. | |
| node.is_terminal = (is_terminal := self.check_is_terminal_node(node)) | |
| if is_terminal: | |
| logger.info(f"Goal found at level {node.level}") | |
| return SearchState.STEP_COMPLETE | |
| # Expand the node and add its successors to the unvisited nodes. | |
| successors = self.expand_node(node) | |
| logger.info( | |
| f"Expanded node at level {node.level} into {len(successors)} successors" | |
| ) | |
| for successor in successors: | |
| self.unvisited_nodes.append(successor) | |
| return SearchState.STEP_COMPLETE | |
| else: | |
| return SearchState.FRONTIER_EMPTY | |
| def run(self) -> SearchResult: | |
| logger.info(f"Initial state: {self.unvisited_nodes.peek_left()}") | |
| logger.info("Running search indefinitely until node expansion limit reached.") | |
| while True: | |
| search_state = self.step() | |
| match search_state: | |
| case SearchState.BUDGET_EXCEEDED: | |
| break | |
| case SearchState.FRONTIER_EMPTY: | |
| break | |
| case SearchState.STEP_COMPLETE: | |
| pass | |
| case _: | |
| assert_never(search_state) | |
| return SearchResult( | |
| search_state=search_state, | |
| nodes=self.nodes, | |
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from enum import Enum | |
| from typing import ( | |
| Callable, | |
| Generic, | |
| Iterator, | |
| Optional, | |
| Protocol, | |
| Sequence, | |
| TypeVar, | |
| ) | |
| from loguru import logger | |
| from pydantic import BaseModel, Field | |
| from typing_extensions import Self, assert_never | |
| from ulid import ULID | |
| from math import sqrt, log | |
| T = TypeVar("T") | |
| class Node(BaseModel): | |
| """Class defining the interface for a node in the search tree.""" | |
| reasoning_step: str | |
| is_terminal: bool = False | |
| parent: Optional[Self] = Field(default=None, repr=False) | |
| level: int = 0 | |
| ulid: ULID = Field(default_factory=ULID) | |
| visits: int = 0 | |
| total_value: float = 0.0 | |
| def get_reasoning_trace(self) -> list[Self]: | |
| """ | |
| Get the reasoning trace represented by the node. | |
| """ | |
| lineage: list[Self] = [] | |
| current = self | |
| while current: | |
| lineage.append(current) | |
| current = current.parent | |
| return lineage | |
| def ucb1_score(self, exploration_weight: float = sqrt(2)) -> float: | |
| """Calculate the UCB1 score for this node.""" | |
| if self.visits == 0: | |
| return float('inf') | |
| if self.parent is None: | |
| return float('inf') | |
| exploitation = self.total_value / self.visits | |
| exploration = exploration_weight * sqrt(log(self.parent.visits) / self.visits) | |
| return exploitation + exploration | |
| def backpropagate(self, value: float) -> None: | |
| """Update node statistics after a simulation.""" | |
| current = self | |
| while current: | |
| current.visits += 1 | |
| current.total_value += value | |
| current = current.parent | |
| class SuccessorFunction(Protocol): | |
| """Function that returns the successors of a given node.""" | |
| def __call__(self, state: Node) -> Sequence[Node]: ... | |
| class IsTerminalTestFunction(Protocol): | |
| """Function that returns whether a given node is a terminal node.""" | |
| def __call__(self, state: Node) -> bool: ... | |
| class RankingFunction(Protocol): | |
| """Function that returns a score for a given node.""" | |
| def __call__(self, state: Node) -> float: ... | |
| class SearchContainer(Protocol, Generic[T]): | |
| """Stores the nodes used during the search.""" | |
| def append(self, item: T) -> None: ... | |
| def popleft(self) -> T: ... | |
| def __bool__(self) -> bool: ... | |
| def __len__(self) -> int: ... | |
| def peek_left(self) -> Optional[T]: ... | |
| def __iter__(self) -> Iterator[T]: ... | |
| class SearchState(Enum): | |
| FRONTIER_EMPTY = "FRONTIER_EMPTY" | |
| STEP_COMPLETE = "STEP_COMPLETE" | |
| BUDGET_EXCEEDED = "BUDGET_EXCEEDED" | |
| class SearchResult(BaseModel): | |
| search_state: SearchState | |
| nodes: list[Node] | |
| class SimulationFunction(Protocol): | |
| """Function that simulates from a given node to get an estimated value.""" | |
| def __call__(self, state: Node) -> float: ... | |
| class GenericSearch: | |
| """ | |
| Generic search class that can be used to implement any search algorithm. | |
| Parameters: | |
| ------------ | |
| initial_state: Node | |
| The root node of the search. For example, could contain the initial prompt. | |
| successor_fn: SuccessorFunction | |
| A function that given the current node, returns the successors of the node. | |
| For example, could return the next step in the reasoning process. | |
| check_is_terminal_fn: IsTerminalTestFunction | |
| A function that given the current node, returns whether the node is a terminal node. | |
| For example, could return whether the reasoning step has produced a final answer. | |
| container_factory: Callable[[], SearchContainer[Node]] | |
| A function that returns a new instance of the search container. | |
| For example, could return a queue or a stack. | |
| simulation_fn: Optional[SimulationFunction] | |
| A function that simulates from a given node to get an estimated value. | |
| node_budget: Optional[int] | |
| The maximum number of nodes to visit. | |
| beam_width: Optional[int] | |
| The maximum number of successors to consider at each step. | |
| beam_depth: Optional[int] | |
| The maximum length of any given reasoning trace. | |
| """ | |
| def __init__( | |
| self, | |
| initial_state: Node, | |
| successor_fn: SuccessorFunction, | |
| check_is_terminal_fn: IsTerminalTestFunction, | |
| container_factory: Callable[[], SearchContainer[Node]], | |
| simulation_fn: Optional[SimulationFunction] = None, | |
| node_budget: Optional[int] = None, | |
| beam_width: Optional[int] = None, | |
| beam_depth: Optional[int] = None, | |
| ): | |
| self.unvisited_nodes: SearchContainer[Node] = ( | |
| container_factory() | |
| ) | |
| self.unvisited_nodes.append(initial_state) | |
| self.visited_nodes: list[Node] = [] | |
| self.level = initial_state.level | |
| self.successor_fn = successor_fn | |
| self.check_is_terminal_node = check_is_terminal_fn | |
| self.node_budget = node_budget | |
| self.beam_width = beam_width | |
| self.beam_depth = beam_depth | |
| self.simulation_fn = simulation_fn | |
| @property | |
| def nodes(self) -> list[Node]: | |
| """ | |
| Returns every node created during the search. | |
| This includes nodes that are in the queue and nodes that have been visited. | |
| """ | |
| return self.visited_nodes + list(self.unvisited_nodes) | |
| @staticmethod | |
| def _expand_node( | |
| node: Node, | |
| successor_fn: SuccessorFunction, | |
| visited_nodes: list[Node], | |
| unvisited_nodes: SearchContainer[Node], | |
| beam_depth: Optional[int], | |
| node_budget: Optional[int], | |
| beam_width: Optional[int], | |
| ) -> Sequence[Node]: | |
| if beam_depth is not None and node.level >= beam_depth: | |
| logger.info( | |
| f"Skipping expansion. beam_depth={beam_depth} >= node.level={node.level}" | |
| ) | |
| return [] | |
| successors = successor_fn(node) | |
| if beam_width is not None and len(successors) > beam_width: | |
| successors = successors[:beam_width] | |
| if node_budget is not None: | |
| budget_remaining = node_budget - len(visited_nodes) - len(unvisited_nodes) | |
| if budget_remaining <= 0: | |
| logger.info( | |
| f"Skipping expansion. budget_used={len(visited_nodes) + len(unvisited_nodes)} >= node_budget={node_budget}" | |
| ) | |
| return [] | |
| successors = successors[:budget_remaining] | |
| return successors | |
| def expand_node( | |
| self, node: Node | |
| ) -> Sequence[Node]: | |
| return self._expand_node( | |
| node=node, | |
| successor_fn=self.successor_fn, | |
| visited_nodes=self.visited_nodes, | |
| unvisited_nodes=self.unvisited_nodes, | |
| beam_depth=self.beam_depth, | |
| node_budget=self.node_budget, | |
| beam_width=self.beam_width, | |
| ) | |
| def step(self) -> SearchState: | |
| # Check that we have no exceeded the search budget. | |
| if self.node_budget is not None: | |
| budget_expended = len(self.visited_nodes) + len(self.unvisited_nodes) | |
| budget_remaining = self.node_budget - budget_expended | |
| if budget_remaining <= 0: | |
| logger.info( | |
| f"Terminating search. Budget exceeded. budget_expended={len(self.visited_nodes)} + {len(self.unvisited_nodes)}" | |
| f" >= node_budget={self.node_budget}" | |
| ) | |
| return SearchState.BUDGET_EXCEEDED | |
| if self.unvisited_nodes: | |
| node = self.unvisited_nodes.peek_left() | |
| assert node is not None | |
| node = self.unvisited_nodes.popleft() | |
| self.visited_nodes.append(node) | |
| logger.info( | |
| f"Visiting node. node.level={node.level} plan_length={len(node.get_reasoning_trace())} queue_size={len(self.unvisited_nodes)}" | |
| ) | |
| # Check if the node is a terminal node, and if so, terminate the step. | |
| # We do not expand terminal nodes. | |
| node.is_terminal = (is_terminal := self.check_is_terminal_node(node)) | |
| # Add simulation and backpropagation for MCTS | |
| if self.simulation_fn and not is_terminal: | |
| value = self.simulation_fn(node) | |
| node.backpropagate(value) | |
| if is_terminal: | |
| logger.info(f"Goal found at level {node.level}") | |
| return SearchState.STEP_COMPLETE | |
| # Expand the node and add its successors to the unvisited nodes. | |
| successors = self.expand_node(node) | |
| logger.info( | |
| f"Expanded node at level {node.level} into {len(successors)} successors" | |
| ) | |
| for successor in successors: | |
| self.unvisited_nodes.append(successor) | |
| return SearchState.STEP_COMPLETE | |
| else: | |
| return SearchState.FRONTIER_EMPTY | |
| def run(self) -> SearchResult: | |
| logger.info(f"Initial state: {self.unvisited_nodes.peek_left()}") | |
| logger.info("Running search indefinitely until node expansion limit reached.") | |
| while True: | |
| search_state = self.step() | |
| match search_state: | |
| case SearchState.BUDGET_EXCEEDED: | |
| break | |
| case SearchState.FRONTIER_EMPTY: | |
| break | |
| case SearchState.STEP_COMPLETE: | |
| pass | |
| case _: | |
| assert_never(search_state) | |
| return SearchResult( | |
| search_state=search_state, | |
| nodes=self.nodes, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment