Created
August 25, 2025 22:50
-
-
Save flolas/2d6f9e7e1fde0075015c22bc90f94ac2 to your computer and use it in GitHub Desktop.
DynamoDBStatePersistence for pydantic-ai pydantic-graph
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| import dataclasses | |
| import json | |
| import logging | |
| import time | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime | |
| from enum import Enum | |
| from typing import Any, AsyncIterator, Generic, List, Optional, TypeVar | |
| from ulid import ULID # type: ignore | |
| try: | |
| from aiobotocore.session import get_session # type: ignore | |
| except Exception: # pragma: no cover | |
| def get_session(): # type: ignore | |
| raise RuntimeError("aiobotocore is required for DynamoDBStatePersistence") | |
| try: # exceptions | |
| from botocore.exceptions import ClientError # type: ignore | |
| except Exception: # pragma: no cover | |
| ClientError = Exception # type: ignore | |
| from pydantic import BaseModel, RootModel | |
| from pydantic_graph import GraphRunResult # type: ignore | |
| from pydantic_graph.nodes import BaseNode | |
| try: | |
| from pydantic_graph.exceptions import GraphRuntimeError # type: ignore | |
| except Exception: # pragma: no cover | |
| GraphRuntimeError = RuntimeError # type: ignore | |
| import pydantic_graph.nodes as pydantic_graph_nodes | |
| from pydantic_graph.nodes import End | |
| from pydantic_graph.persistence import ( | |
| BaseStatePersistence, | |
| EndSnapshot, | |
| NodeSnapshot, | |
| Snapshot, | |
| build_snapshot_list_type_adapter, | |
| ) | |
| try: | |
| from brotli import compress as br_compress # type: ignore | |
| from brotli import decompress as br_decompress | |
| except Exception: # pragma: no cover | |
| br_compress = None # type: ignore | |
| br_decompress = None # type: ignore | |
| StateT = TypeVar("StateT") | |
| RunEndT = TypeVar("RunEndT") | |
| class DynamoDBStatePersistence( | |
| BaseStatePersistence[StateT, RunEndT], Generic[StateT, RunEndT] | |
| ): | |
| """DynamoDB-backed persistence for pydantic_graph snapshots. | |
| Storage model (one item per snapshot): | |
| - PK: `run_id` (string) | |
| - SK: `snapshot_id` (string) generated as reversed-time ULID with suffix: | |
| - Node: "{ULID.from_timestamp(get_reversed_timestamp())!s}#node#<node>" | |
| - Other kinds: "{ULID.from_timestamp(get_reversed_timestamp())!s}#{kind}" | |
| - Attribute `snapshot_json`: JSON of the Snapshot (NodeSnapshot | EndSnapshot) | |
| - Attribute `last_updated`: unix seconds | |
| - Optional Attribute `expires_at`: TTL timestamp (unix seconds) | |
| Notes: | |
| - Reversed ULID in the sort key enables natural ascending sort by newest-first. | |
| - Methods follow the behavior of FileStatePersistence in pydantic_graph. | |
| """ | |
| def __init__( | |
| self, | |
| table_name: str, | |
| run_id: str, | |
| *, | |
| ttl_seconds: Optional[int] = None, | |
| table_region: Optional[str] = None, | |
| table_host: Optional[str] = None, | |
| aws_access_key_id: Optional[str] = None, | |
| aws_secret_access_key: Optional[str] = None, | |
| aws_session_token: Optional[str] = None, | |
| use_compression: bool = True, | |
| running_timeout_seconds: int = 60, | |
| ) -> None: | |
| self._run_id: str = run_id | |
| self._ttl_seconds: Optional[int] = ttl_seconds | |
| self._table_name = table_name | |
| self._region = table_region or "us-east-1" | |
| self._endpoint_url = table_host | |
| self._aws_access_key_id = aws_access_key_id | |
| self._aws_secret_access_key = aws_secret_access_key | |
| self._aws_session_token = aws_session_token | |
| self._snapshots_type_adapter: Any | None = None | |
| self._state_type: Any | None = None | |
| self._run_end_type: Any | None = None | |
| self._session = get_session() | |
| self._hash_key_name: str = "run_id" | |
| self._range_key_name: str = "snapshot_id" | |
| self._logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}") | |
| self._use_compression: bool = use_compression and br_compress is not None | |
| self._running_timeout_seconds: int = int(running_timeout_seconds) | |
| self._logger.debug( | |
| "Initialized table=%s region=%s endpoint=%s run_id=%s ttl=%s", | |
| self._table_name, | |
| self._region, | |
| self._endpoint_url, | |
| self._run_id, | |
| self._ttl_seconds, | |
| ) | |
| # ----- Utils ----- | |
| @staticmethod | |
| def _to_json_safe(value: Any) -> Any: | |
| if value is None: | |
| return None | |
| if isinstance(value, Enum): | |
| return value.value | |
| if dataclasses.is_dataclass(value): | |
| try: | |
| return DynamoDBStatePersistence._to_json_safe(dataclasses.asdict(value)) | |
| except Exception: | |
| pass | |
| if isinstance(value, dict): | |
| return { | |
| k: DynamoDBStatePersistence._to_json_safe(v) for k, v in value.items() | |
| } | |
| if isinstance(value, (list, tuple, set)): | |
| return [DynamoDBStatePersistence._to_json_safe(v) for v in value] | |
| # Pydantic models | |
| model_dump = getattr(value, "model_dump", None) | |
| if callable(model_dump): | |
| try: | |
| return DynamoDBStatePersistence._to_json_safe(model_dump(mode="json")) | |
| except Exception: | |
| return DynamoDBStatePersistence._to_json_safe(model_dump()) | |
| dict_method = getattr(value, "dict", None) | |
| if callable(dict_method): | |
| try: | |
| return DynamoDBStatePersistence._to_json_safe(dict_method()) | |
| except Exception: | |
| pass | |
| # Fallback to string representation | |
| try: | |
| json.dumps(value) | |
| return value | |
| except Exception: | |
| return str(value) | |
| return value | |
| @staticmethod | |
| def get_reversed_timestamp(dt: datetime | None = None) -> int: | |
| """Calculates a reversed timestamp relative to datetime.max. | |
| This function computes the difference in seconds between the maximum representable | |
| datetime (datetime.max) and the provided datetime `dt` (or the current time if `dt` is None). | |
| The result is rounded to the nearest integer. | |
| This reversed timestamp can be useful for sorting purposes where descending chronological | |
| order is desired when using standard ascending sorts (e.g., in database range keys). | |
| Args: | |
| dt: The datetime object to calculate the reversed timestamp from. If None, | |
| the current time (datetime.now()) is used. | |
| Returns: | |
| An integer representing the number of seconds from the given datetime | |
| (or now) until datetime.max. | |
| """ | |
| if dt is None: | |
| return int(round(datetime.max.timestamp() - datetime.now().timestamp(), 0)) | |
| else: | |
| return int(round(datetime.max.timestamp() - dt.timestamp(), 0)) | |
| @staticmethod | |
| def generate_snapshot_id(node_id: str) -> str: | |
| ulid_text = str( | |
| ULID.from_timestamp(DynamoDBStatePersistence.get_reversed_timestamp()) | |
| ) | |
| return f"{ulid_text}#{node_id}" | |
| # ----- Public API ----- | |
| async def load_all(self) -> List[Snapshot[StateT, RunEndT]]: | |
| items = await self._query_all() | |
| snapshots: list[Any] = [] | |
| for item in items: | |
| # Prefer typed deserialization when snapshot_json is present | |
| if item.get("snapshot_json") and self._snapshots_type_adapter is not None: | |
| snap = self._deserialize_snapshot_item(item) | |
| if snap is not None: | |
| snapshots.append(snap) | |
| continue | |
| # Fallback legacy dict format expected by tests | |
| kind = item.get("kind", {}).get("S") | |
| snap_id = item.get(self._range_key_name, {}).get("S") | |
| status = item.get("status", {}).get("S") | |
| node_name = item.get("node", {}).get("S") | |
| duration_value = item.get("duration", {}).get("N") | |
| duration = float(duration_value) if duration_value is not None else None | |
| # Legacy tests store state/result under state_json/result_json | |
| state_json_s = item.get("state_json", {}).get("S") | |
| result_json_s = item.get("result_json", {}).get("S") | |
| snapshot: dict[str, Any] = {"kind": kind, "id": snap_id} | |
| if status is not None: | |
| snapshot["status"] = status | |
| if node_name is not None: | |
| snapshot["node"] = node_name | |
| if duration is not None: | |
| snapshot["duration"] = duration | |
| if state_json_s is not None: | |
| try: | |
| snapshot["state"] = json.loads(state_json_s) | |
| except Exception: | |
| snapshot["state"] = state_json_s | |
| if result_json_s is not None: | |
| try: | |
| snapshot["result"] = json.loads(result_json_s) | |
| except Exception: | |
| snapshot["result"] = result_json_s | |
| snapshots.append(snapshot) | |
| return snapshots # type: ignore[return-value] | |
| async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: | |
| items = await self._query_first() | |
| if not items: | |
| return None | |
| item = items[0] | |
| sk = item.get(self._range_key_name, {}).get("S") | |
| if not sk: | |
| return None | |
| kind = item.get("kind", {}).get("S") | |
| status = item.get("status", {}).get("S") | |
| # If running and stale, requeue to pending and increment retries; otherwise, skip | |
| if status == "running": | |
| start_raw = item.get("start_ts", {}).get("N") | |
| try: | |
| start_val = float(start_raw) if start_raw is not None else None | |
| except Exception: | |
| start_val = None | |
| # If still within timeout, do nothing; otherwise requeue as pending and increment retries | |
| if ( | |
| start_val is not None | |
| and (time.time() - start_val) > self._running_timeout_seconds | |
| ): | |
| async with self._session.create_client( | |
| "dynamodb", **self._build_client_kwargs() | |
| ) as client: | |
| await client.update_item( | |
| TableName=self._table_name, | |
| Key={ | |
| self._hash_key_name: {"S": self._run_id}, | |
| self._range_key_name: {"S": sk}, | |
| }, | |
| UpdateExpression="SET #st = :pending, #start = :start ADD #ret :one", | |
| ExpressionAttributeNames={ | |
| "#st": "status", | |
| "#start": "start_ts", | |
| "#ret": "retries", | |
| }, | |
| ExpressionAttributeValues={ | |
| ":pending": {"S": "pending"}, | |
| ":start": {"N": str(round(time.time(), 0))}, | |
| ":one": {"N": "1"}, | |
| }, | |
| ) | |
| status = "pending" | |
| else: | |
| return None | |
| # If error, set to pending and increment retries, then proceed | |
| if kind == "node" and status == "error": | |
| async with self._session.create_client( | |
| "dynamodb", **self._build_client_kwargs() | |
| ) as client: | |
| await client.update_item( | |
| TableName=self._table_name, | |
| Key={ | |
| self._hash_key_name: {"S": self._run_id}, | |
| self._range_key_name: {"S": sk}, | |
| }, | |
| UpdateExpression="SET #st = :pending, #start = :start ADD #ret :one", | |
| ExpressionAttributeNames={ | |
| "#st": "status", | |
| "#start": "start_ts", | |
| "#ret": "retries", | |
| }, | |
| ExpressionAttributeValues={ | |
| ":pending": {"S": "pending"}, | |
| ":start": {"N": str(round(time.time(), 0))}, | |
| ":one": {"N": "1"}, | |
| }, | |
| ) | |
| status = "pending" | |
| if kind != "node" or status != "pending": | |
| if kind == "end": | |
| raise GraphRuntimeError("graph already executed") | |
| return None | |
| # Convert to a proper NodeSnapshot when possible, else return legacy dict | |
| if item.get("snapshot_json") and self._snapshots_type_adapter is not None: | |
| snapshot = self._deserialize_snapshot_item(item) | |
| if isinstance(snapshot, NodeSnapshot): | |
| snapshot.id = sk | |
| snapshot.status = "pending" | |
| return snapshot | |
| return None | |
| # Legacy dict response for tests | |
| node_name = item.get("node", {}).get("S") | |
| state_s_legacy = item.get("state_json", {}).get("S") | |
| snapshot_dict: dict[str, Any] = {"kind": kind, "id": sk, "status": status} | |
| if node_name is not None: | |
| snapshot_dict["node"] = node_name | |
| if state_s_legacy is not None: | |
| try: | |
| snapshot_dict["state"] = json.loads(state_s_legacy) | |
| except Exception: | |
| snapshot_dict["state"] = state_s_legacy | |
| return snapshot_dict # type: ignore[return-value] | |
| async def get_graph_result(self) -> GraphRunResult[BaseModel, BaseModel] | None: | |
| """Return the final GraphRunResult if the first item is an 'end' snapshot; else None. | |
| Reads only one record (first by SK) for efficiency. | |
| """ | |
| items = await self._query_first() | |
| if not items: | |
| return None | |
| item = items[0] | |
| kind = item.get("kind", {}).get("S") | |
| if kind != "end": | |
| return None | |
| # Decode state/result | |
| state_b = item.get("state", {}).get("B") | |
| state_s = item.get("state", {}).get("S") if state_b is None else None | |
| result_b = item.get("result", {}).get("B") | |
| result_s = item.get("result", {}).get("S") if result_b is None else None | |
| decoded_state: Any = None | |
| decoded_result: Any = None | |
| if self._use_compression and state_b is not None and br_decompress is not None: | |
| try: | |
| decoded_state = json.loads(br_decompress(state_b).decode("utf-8")) | |
| except Exception: | |
| decoded_state = None | |
| elif state_s is not None: | |
| try: | |
| decoded_state = json.loads(state_s) | |
| except Exception: | |
| decoded_state = state_s | |
| if self._use_compression and result_b is not None and br_decompress is not None: | |
| try: | |
| decoded_result = json.loads(br_decompress(result_b).decode("utf-8")) | |
| except Exception: | |
| decoded_result = None | |
| elif result_s is not None: | |
| try: | |
| decoded_result = json.loads(result_s) | |
| except Exception: | |
| decoded_result = result_s | |
| if decoded_result is None: | |
| return None | |
| # Wrap decoded payloads into typed BaseModel instances if possible | |
| class _AnyStateModel(RootModel[Any]): | |
| pass | |
| class _AnyOutputModel(RootModel[Any]): | |
| pass | |
| try: | |
| state_model: BaseModel = ( | |
| self._state_type.model_validate(decoded_state) # type: ignore[assignment] | |
| if getattr(self, "_state_type", None) is not None | |
| and hasattr(self._state_type, "model_validate") | |
| else _AnyStateModel.model_validate(decoded_state) | |
| ) | |
| except Exception: | |
| state_model = _AnyStateModel.model_validate(decoded_state) | |
| try: | |
| output_model: BaseModel = ( | |
| self._run_end_type.model_validate(decoded_result) # type: ignore[assignment] | |
| if getattr(self, "_run_end_type", None) is not None | |
| and hasattr(self._run_end_type, "model_validate") | |
| else _AnyOutputModel.model_validate(decoded_result) | |
| ) | |
| except Exception: | |
| output_model = _AnyOutputModel.model_validate(decoded_result) | |
| return GraphRunResult(output=output_model, state=state_model, persistence=self) | |
| async def snapshot_node(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override] | |
| """Snapshot the state of a graph before running a node. | |
| Supports both the pydantic-graph signature `(state, next_node)` and a legacy | |
| signature `(snapshot_id, state, next_node)` used in local tests. | |
| """ | |
| snapshot_id: str | None = None | |
| state: Any | |
| next_node: Any | |
| if len(args) == 2: | |
| state, next_node = args | |
| try: | |
| snapshot_id = next_node.get_snapshot_id() | |
| except Exception: | |
| snapshot_id = None | |
| elif len(args) == 3: | |
| snapshot_id, state, next_node = args | |
| else: | |
| raise TypeError( | |
| "snapshot_node expected (state,next_node) or (snapshot_id,state,next_node)" | |
| ) | |
| if snapshot_id is None: | |
| try: | |
| snapshot_id = next_node.get_snapshot_id() | |
| except Exception: | |
| raise ValueError("next_node must provide get_snapshot_id()") | |
| try: | |
| node_id = next_node.get_node_id() | |
| except Exception: | |
| raise ValueError("next_node must provide get_node_id()") | |
| snapshot = { | |
| "kind": "node", | |
| "id": snapshot_id, | |
| "status": "created", | |
| "node": node_id, | |
| "state": DynamoDBStatePersistence._to_json_safe(state), | |
| } | |
| await self._put_snapshot_async(snapshot_id, snapshot) | |
| async def snapshot_node_if_new( | |
| self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT] | |
| ) -> None: # type: ignore[override] | |
| exists = await self._get_item_by_snapshot_id(snapshot_id) | |
| if exists is not None: | |
| return | |
| snapshot = { | |
| "kind": "node", | |
| "id": snapshot_id, | |
| "status": "created", | |
| "state": DynamoDBStatePersistence._to_json_safe(state), | |
| } | |
| await self._put_snapshot_async(snapshot_id, snapshot) | |
| async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: # type: ignore[override]\ | |
| snapshot_id = end.get_snapshot_id() | |
| snapshot = { | |
| "kind": "end", | |
| "id": snapshot_id, | |
| "state": DynamoDBStatePersistence._to_json_safe(state), | |
| "result": DynamoDBStatePersistence._to_json_safe(end.data), | |
| } | |
| await self._put_snapshot_async(snapshot_id, snapshot) | |
| @asynccontextmanager | |
| async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: # type: ignore[override] | |
| start = time.perf_counter() | |
| await self._set_running(snapshot_id) | |
| try: | |
| yield | |
| except Exception as e: | |
| duration = time.perf_counter() - start | |
| exception = ( | |
| f"{e.__class__.__name__}: {e.message}" | |
| if hasattr(e, "message") | |
| else str(e) | |
| ) | |
| await self._update_status_and_duration( | |
| snapshot_id, "error", duration, exception | |
| ) | |
| raise | |
| else: | |
| duration = time.perf_counter() - start | |
| await self._update_status_and_duration( | |
| snapshot_id, "success", duration, None | |
| ) | |
| # ----- helpers ----- | |
| def _build_client_kwargs(self) -> dict[str, Any]: | |
| kwargs: dict[str, Any] = {"region_name": self._region} | |
| if self._endpoint_url: | |
| kwargs["endpoint_url"] = self._endpoint_url | |
| # When using a local endpoint, DynamoDB Local accepts any credentials but | |
| # botocore still requires something present. Provide dummy creds if missing. | |
| if not (self._aws_access_key_id and self._aws_secret_access_key): | |
| kwargs.update( | |
| { | |
| "aws_access_key_id": "dummy", | |
| "aws_secret_access_key": "dummy", | |
| "aws_session_token": "dummy", | |
| } | |
| ) | |
| if self._aws_access_key_id and self._aws_secret_access_key: | |
| kwargs.update( | |
| { | |
| "aws_access_key_id": self._aws_access_key_id, | |
| "aws_secret_access_key": self._aws_secret_access_key, | |
| } | |
| ) | |
| if self._aws_session_token and "aws_session_token" not in kwargs: | |
| kwargs["aws_session_token"] = self._aws_session_token | |
| return kwargs | |
| # ----- Internal serialization helpers ----- | |
| async def _update_status_and_duration( | |
| self, | |
| snapshot_id: str, | |
| status: str, | |
| duration: float, | |
| exception: str | None = None, | |
| ) -> None: | |
| async with self._session.create_client( | |
| "dynamodb", **self._build_client_kwargs() | |
| ) as client: | |
| await client.update_item( | |
| TableName=self._table_name, | |
| Key={ | |
| self._hash_key_name: {"S": self._run_id}, | |
| self._range_key_name: {"S": snapshot_id}, | |
| }, | |
| UpdateExpression="SET #st = :status, #dur = :duration" | |
| + (", #exc = :exception" if exception is not None else ""), | |
| ExpressionAttributeNames={"#st": "status", "#dur": "duration"} | |
| if exception is None | |
| else {"#st": "status", "#dur": "duration", "#exc": "exception"}, | |
| ExpressionAttributeValues={ | |
| ":status": {"S": status}, | |
| ":duration": {"N": str(float(duration))}, | |
| } | |
| if exception is None | |
| else { | |
| ":status": {"S": status}, | |
| ":duration": {"N": str(float(duration))}, | |
| ":exception": {"S": exception}, | |
| }, | |
| ) | |
| async def _set_running(self, snapshot_id: str) -> None: | |
| async with self._session.create_client( | |
| "dynamodb", **self._build_client_kwargs() | |
| ) as client: | |
| await client.update_item( | |
| TableName=self._table_name, | |
| Key={ | |
| self._hash_key_name: {"S": self._run_id}, | |
| self._range_key_name: {"S": snapshot_id}, | |
| }, | |
| UpdateExpression="SET #st = :running, #start = :start", | |
| ConditionExpression="#st IN (:created, :pending)", | |
| ExpressionAttributeNames={"#st": "status", "#start": "start_ts"}, | |
| ExpressionAttributeValues={ | |
| ":running": {"S": "running"}, | |
| ":created": {"S": "created"}, | |
| ":pending": {"S": "pending"}, | |
| ":start": {"N": str(round(time.time(), 0))}, | |
| }, | |
| ) | |
| async def _get_item_by_snapshot_id(self, snapshot_id: str) -> dict | None: | |
| async with self._session.create_client( | |
| "dynamodb", **self._build_client_kwargs() | |
| ) as client: | |
| resp = await client.get_item( | |
| TableName=self._table_name, | |
| Key={ | |
| "run_id": {"S": self._run_id}, | |
| "snapshot_id": {"S": snapshot_id}, | |
| }, | |
| ConsistentRead=True, | |
| ) | |
| return resp.get("Item") | |
| async def _put_snapshot_async(self, sk: str, snapshot: Any) -> None: | |
| if isinstance(snapshot, dict): | |
| safe = self._to_json_safe(snapshot) | |
| snapshot_json_bytes: bytes | None = None | |
| kind_val = str(safe.get("kind", "")) | |
| node_name_val = ( | |
| str(safe.get("node")) if safe.get("node") is not None else None | |
| ) | |
| status_val = ( | |
| str(safe.get("status")) if safe.get("status") is not None else None | |
| ) | |
| duration_val = ( | |
| float(safe.get("duration")) | |
| if safe.get("duration") is not None | |
| else None | |
| ) | |
| else: | |
| # dataclass snapshots | |
| try: | |
| kind_val = getattr(snapshot, "kind", None) | |
| node_name_val = getattr( | |
| getattr(snapshot, "node", None), "get_node_id", lambda: None | |
| )() | |
| status_val = getattr(snapshot, "status", None) | |
| duration_val = getattr(snapshot, "duration", None) | |
| state_val = getattr(snapshot, "state", None) | |
| result_val = getattr(getattr(snapshot, "result", None), "data", None) | |
| except Exception: | |
| kind_val = None | |
| node_name_val = None | |
| status_val = None | |
| duration_val = None | |
| state_val = None | |
| result_val = None | |
| safe = { | |
| "kind": kind_val, | |
| "node": node_name_val, | |
| "status": status_val, | |
| "duration": duration_val, | |
| "state": self._to_json_safe(state_val), | |
| "result": self._to_json_safe(result_val), | |
| } | |
| snapshot_json_bytes = None | |
| if self._snapshots_type_adapter is not None: | |
| try: | |
| snapshot_json_bytes = self._snapshots_type_adapter.dump_json( | |
| [snapshot], indent=2 | |
| ) | |
| except Exception: | |
| snapshot_json_bytes = None | |
| now = int(time.time()) | |
| expires_at_value = (now + int(self._ttl_seconds)) if self._ttl_seconds else None | |
| async with self._session.create_client( | |
| "dynamodb", **self._build_client_kwargs() | |
| ) as client: | |
| item: dict[str, Any] = { | |
| self._hash_key_name: {"S": self._run_id}, | |
| self._range_key_name: {"S": sk}, | |
| "kind": {"S": str(safe.get("kind", ""))}, | |
| **( | |
| {"node": {"S": str(safe.get("node"))}} | |
| if safe.get("node") is not None | |
| else {} | |
| ), | |
| **( | |
| {"status": {"S": str(safe.get("status"))}} | |
| if safe.get("status") is not None | |
| else {} | |
| ), | |
| **( | |
| {"duration": {"N": str(float(safe.get("duration")))}} | |
| if safe.get("duration") is not None | |
| else {} | |
| ), | |
| "last_updated": {"N": str(now)}, | |
| } | |
| if safe.get("state") is not None: | |
| try: | |
| raw = json.dumps(safe.get("state"), ensure_ascii=False).encode( | |
| "utf-8" | |
| ) | |
| if self._use_compression and br_compress is not None: | |
| item["state"] = {"B": br_compress(raw)} | |
| else: | |
| item["state"] = {"S": raw.decode("utf-8")} | |
| except Exception: | |
| item["state"] = {"S": json.dumps(safe.get("state"))} | |
| if safe.get("result") is not None: | |
| try: | |
| rawr = json.dumps(safe.get("result"), ensure_ascii=False).encode( | |
| "utf-8" | |
| ) | |
| if self._use_compression and br_compress is not None: | |
| item["result"] = {"B": br_compress(rawr)} | |
| else: | |
| item["result"] = {"S": rawr.decode("utf-8")} | |
| except Exception: | |
| item["result"] = {"S": json.dumps(safe.get("result"))} | |
| # Store full snapshot JSON if available | |
| try: | |
| json_bytes = None | |
| if isinstance(snapshot, dict): | |
| json_bytes = json.dumps(safe, ensure_ascii=False).encode("utf-8") | |
| else: | |
| json_bytes = snapshot_json_bytes | |
| if json_bytes is not None: | |
| if self._use_compression and br_compress is not None: | |
| item["snapshot_json"] = {"B": br_compress(json_bytes)} | |
| else: | |
| item["snapshot_json"] = {"S": json_bytes.decode("utf-8")} | |
| except Exception: | |
| pass | |
| if expires_at_value is not None: | |
| item["expires_at"] = {"N": str(expires_at_value)} | |
| await client.put_item(TableName=self._table_name, Item=item) | |
| async def _query_all(self) -> list[dict]: | |
| async with self._session.create_client( | |
| "dynamodb", **self._build_client_kwargs() | |
| ) as client: | |
| self._logger.debug( | |
| "QUERY async: table=%s run_id=%s", self._table_name, self._run_id | |
| ) | |
| items: list[dict] = [] | |
| exclusive_start_key = None | |
| while True: | |
| params = { | |
| "TableName": self._table_name, | |
| "KeyConditionExpression": "#hk = :rid", | |
| "ExpressionAttributeNames": {"#hk": self._hash_key_name}, | |
| "ExpressionAttributeValues": {":rid": {"S": self._run_id}}, | |
| "ConsistentRead": True, | |
| } | |
| if exclusive_start_key is not None: | |
| params["ExclusiveStartKey"] = exclusive_start_key | |
| resp = await client.query(**params) # type: ignore[arg-type] | |
| items.extend(resp.get("Items", [])) | |
| exclusive_start_key = resp.get("LastEvaluatedKey") | |
| if not exclusive_start_key: | |
| break | |
| return items | |
| async def _query_first(self) -> list[dict]: | |
| async with self._session.create_client( | |
| "dynamodb", **self._build_client_kwargs() | |
| ) as client: | |
| params: dict[str, Any] = { | |
| "TableName": self._table_name, | |
| "KeyConditionExpression": "#hk = :rid", | |
| "ExpressionAttributeNames": {"#hk": self._hash_key_name}, | |
| "ExpressionAttributeValues": {":rid": {"S": self._run_id}}, | |
| "Limit": 1, | |
| "ScanIndexForward": True, | |
| "ConsistentRead": True, | |
| } | |
| resp = await client.query(**params) # type: ignore[arg-type] | |
| return resp.get("Items", []) | |
| async def _get_item_by_snapshot_id(self, snapshot_id: str) -> dict | None: | |
| async with self._session.create_client( | |
| "dynamodb", **self._build_client_kwargs() | |
| ) as client: | |
| resp = await client.get_item( | |
| TableName=self._table_name, | |
| Key={ | |
| self._hash_key_name: {"S": self._run_id}, | |
| self._range_key_name: {"S": snapshot_id}, | |
| }, | |
| ConsistentRead=True, | |
| ) | |
| return resp.get("Item") | |
| # ----- Types integration ----- | |
| def should_set_types(self) -> bool: # type: ignore[override] | |
| return self._snapshots_type_adapter is None | |
| def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: # type: ignore[override] | |
| self._state_type = state_type | |
| self._run_end_type = run_end_type | |
| try: | |
| self._snapshots_type_adapter = build_snapshot_list_type_adapter( | |
| state_type, run_end_type | |
| ) | |
| except Exception: | |
| self._snapshots_type_adapter = None | |
| # ----- Deserialization helpers ----- | |
| def _deserialize_snapshot_item( | |
| self, item: dict | |
| ) -> Snapshot[StateT, RunEndT] | None: | |
| snap_json_b = item.get("snapshot_json", {}).get("B") | |
| snap_json_s = ( | |
| item.get("snapshot_json", {}).get("S") if snap_json_b is None else None | |
| ) | |
| raw_json: bytes | None = None | |
| if snap_json_b is not None and br_decompress is not None: | |
| try: | |
| raw_json = br_decompress(snap_json_b) | |
| except Exception: | |
| raw_json = None | |
| elif snap_json_s is not None: | |
| raw_json = snap_json_s.encode("utf-8") | |
| if raw_json is not None and self._snapshots_type_adapter is not None: | |
| try: | |
| lst = self._snapshots_type_adapter.validate_json(raw_json) | |
| if isinstance(lst, list) and lst: | |
| return lst[0] | |
| except Exception: | |
| pass | |
| kind = item.get("kind", {}).get("S") | |
| if kind == "end": | |
| state_val = self._decode_state_field(item) | |
| result_val = self._decode_result_field(item) | |
| end_obj = End(result_val) | |
| state_model: Any = self._coerce_state_model(state_val) | |
| return EndSnapshot(state=state_model, result=end_obj) | |
| elif kind == "node": | |
| # Without a full snapshot JSON we cannot safely reconstruct the node instance. | |
| # Returning None will cause Graph.iter_from_persistence to raise and the workflow to restart. | |
| return None | |
| else: | |
| return None | |
| def _decode_state_field(self, item: dict) -> Any: | |
| state_b = item.get("state", {}).get("B") | |
| state_s = item.get("state", {}).get("S") if state_b is None else None | |
| # Legacy key used in tests | |
| state_json_s = ( | |
| item.get("state_json", {}).get("S") | |
| if state_s is None and state_b is None | |
| else None | |
| ) | |
| if self._use_compression and state_b is not None and br_decompress is not None: | |
| try: | |
| return json.loads(br_decompress(state_b).decode("utf-8")) | |
| except Exception: | |
| return None | |
| elif state_s is not None: | |
| try: | |
| return json.loads(state_s) | |
| except Exception: | |
| return state_s | |
| elif state_json_s is not None: | |
| try: | |
| return json.loads(state_json_s) | |
| except Exception: | |
| return state_json_s | |
| return None | |
| def _decode_result_field(self, item: dict) -> Any: | |
| result_b = item.get("result", {}).get("B") | |
| result_s = item.get("result", {}).get("S") if result_b is None else None | |
| result_json_s = ( | |
| item.get("result_json", {}).get("S") | |
| if result_s is None and result_b is None | |
| else None | |
| ) | |
| if self._use_compression and result_b is not None and br_decompress is not None: | |
| try: | |
| return json.loads(br_decompress(result_b).decode("utf-8")) | |
| except Exception: | |
| return None | |
| elif result_s is not None: | |
| try: | |
| return json.loads(result_s) | |
| except Exception: | |
| return result_s | |
| elif result_json_s is not None: | |
| try: | |
| return json.loads(result_json_s) | |
| except Exception: | |
| return result_json_s | |
| return None | |
| def _coerce_state_model(self, state_val: Any) -> Any: | |
| if getattr(self, "_state_type", None) is not None and hasattr( | |
| self._state_type, "model_validate" | |
| ): | |
| try: | |
| return self._state_type.model_validate(state_val) # type: ignore[return-value] | |
| except Exception: | |
| return state_val | |
| return state_val | |
| # Removed workflow-specific reconstruction; persistence must remain generic | |
| setattr( | |
| pydantic_graph_nodes, | |
| "generate_snapshot_id", | |
| DynamoDBStatePersistence.generate_snapshot_id, | |
| ) # type: ignore |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment