Skip to content

Instantly share code, notes, and snippets.

@flolas
Created August 25, 2025 22:50
Show Gist options
  • Select an option

  • Save flolas/2d6f9e7e1fde0075015c22bc90f94ac2 to your computer and use it in GitHub Desktop.

Select an option

Save flolas/2d6f9e7e1fde0075015c22bc90f94ac2 to your computer and use it in GitHub Desktop.
DynamoDBStatePersistence for pydantic-ai pydantic-graph
from __future__ import annotations
import dataclasses
import json
import logging
import time
from contextlib import asynccontextmanager
from datetime import datetime
from enum import Enum
from typing import Any, AsyncIterator, Generic, List, Optional, TypeVar
from ulid import ULID # type: ignore
try:
from aiobotocore.session import get_session # type: ignore
except Exception: # pragma: no cover
def get_session(): # type: ignore
raise RuntimeError("aiobotocore is required for DynamoDBStatePersistence")
try: # exceptions
from botocore.exceptions import ClientError # type: ignore
except Exception: # pragma: no cover
ClientError = Exception # type: ignore
from pydantic import BaseModel, RootModel
from pydantic_graph import GraphRunResult # type: ignore
from pydantic_graph.nodes import BaseNode
try:
from pydantic_graph.exceptions import GraphRuntimeError # type: ignore
except Exception: # pragma: no cover
GraphRuntimeError = RuntimeError # type: ignore
import pydantic_graph.nodes as pydantic_graph_nodes
from pydantic_graph.nodes import End
from pydantic_graph.persistence import (
BaseStatePersistence,
EndSnapshot,
NodeSnapshot,
Snapshot,
build_snapshot_list_type_adapter,
)
try:
from brotli import compress as br_compress # type: ignore
from brotli import decompress as br_decompress
except Exception: # pragma: no cover
br_compress = None # type: ignore
br_decompress = None # type: ignore
StateT = TypeVar("StateT")
RunEndT = TypeVar("RunEndT")
class DynamoDBStatePersistence(
BaseStatePersistence[StateT, RunEndT], Generic[StateT, RunEndT]
):
"""DynamoDB-backed persistence for pydantic_graph snapshots.
Storage model (one item per snapshot):
- PK: `run_id` (string)
- SK: `snapshot_id` (string) generated as reversed-time ULID with suffix:
- Node: "{ULID.from_timestamp(get_reversed_timestamp())!s}#node#<node>"
- Other kinds: "{ULID.from_timestamp(get_reversed_timestamp())!s}#{kind}"
- Attribute `snapshot_json`: JSON of the Snapshot (NodeSnapshot | EndSnapshot)
- Attribute `last_updated`: unix seconds
- Optional Attribute `expires_at`: TTL timestamp (unix seconds)
Notes:
- Reversed ULID in the sort key enables natural ascending sort by newest-first.
- Methods follow the behavior of FileStatePersistence in pydantic_graph.
"""
def __init__(
self,
table_name: str,
run_id: str,
*,
ttl_seconds: Optional[int] = None,
table_region: Optional[str] = None,
table_host: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
use_compression: bool = True,
running_timeout_seconds: int = 60,
) -> None:
self._run_id: str = run_id
self._ttl_seconds: Optional[int] = ttl_seconds
self._table_name = table_name
self._region = table_region or "us-east-1"
self._endpoint_url = table_host
self._aws_access_key_id = aws_access_key_id
self._aws_secret_access_key = aws_secret_access_key
self._aws_session_token = aws_session_token
self._snapshots_type_adapter: Any | None = None
self._state_type: Any | None = None
self._run_end_type: Any | None = None
self._session = get_session()
self._hash_key_name: str = "run_id"
self._range_key_name: str = "snapshot_id"
self._logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
self._use_compression: bool = use_compression and br_compress is not None
self._running_timeout_seconds: int = int(running_timeout_seconds)
self._logger.debug(
"Initialized table=%s region=%s endpoint=%s run_id=%s ttl=%s",
self._table_name,
self._region,
self._endpoint_url,
self._run_id,
self._ttl_seconds,
)
# ----- Utils -----
@staticmethod
def _to_json_safe(value: Any) -> Any:
if value is None:
return None
if isinstance(value, Enum):
return value.value
if dataclasses.is_dataclass(value):
try:
return DynamoDBStatePersistence._to_json_safe(dataclasses.asdict(value))
except Exception:
pass
if isinstance(value, dict):
return {
k: DynamoDBStatePersistence._to_json_safe(v) for k, v in value.items()
}
if isinstance(value, (list, tuple, set)):
return [DynamoDBStatePersistence._to_json_safe(v) for v in value]
# Pydantic models
model_dump = getattr(value, "model_dump", None)
if callable(model_dump):
try:
return DynamoDBStatePersistence._to_json_safe(model_dump(mode="json"))
except Exception:
return DynamoDBStatePersistence._to_json_safe(model_dump())
dict_method = getattr(value, "dict", None)
if callable(dict_method):
try:
return DynamoDBStatePersistence._to_json_safe(dict_method())
except Exception:
pass
# Fallback to string representation
try:
json.dumps(value)
return value
except Exception:
return str(value)
return value
@staticmethod
def get_reversed_timestamp(dt: datetime | None = None) -> int:
"""Calculates a reversed timestamp relative to datetime.max.
This function computes the difference in seconds between the maximum representable
datetime (datetime.max) and the provided datetime `dt` (or the current time if `dt` is None).
The result is rounded to the nearest integer.
This reversed timestamp can be useful for sorting purposes where descending chronological
order is desired when using standard ascending sorts (e.g., in database range keys).
Args:
dt: The datetime object to calculate the reversed timestamp from. If None,
the current time (datetime.now()) is used.
Returns:
An integer representing the number of seconds from the given datetime
(or now) until datetime.max.
"""
if dt is None:
return int(round(datetime.max.timestamp() - datetime.now().timestamp(), 0))
else:
return int(round(datetime.max.timestamp() - dt.timestamp(), 0))
@staticmethod
def generate_snapshot_id(node_id: str) -> str:
ulid_text = str(
ULID.from_timestamp(DynamoDBStatePersistence.get_reversed_timestamp())
)
return f"{ulid_text}#{node_id}"
# ----- Public API -----
async def load_all(self) -> List[Snapshot[StateT, RunEndT]]:
items = await self._query_all()
snapshots: list[Any] = []
for item in items:
# Prefer typed deserialization when snapshot_json is present
if item.get("snapshot_json") and self._snapshots_type_adapter is not None:
snap = self._deserialize_snapshot_item(item)
if snap is not None:
snapshots.append(snap)
continue
# Fallback legacy dict format expected by tests
kind = item.get("kind", {}).get("S")
snap_id = item.get(self._range_key_name, {}).get("S")
status = item.get("status", {}).get("S")
node_name = item.get("node", {}).get("S")
duration_value = item.get("duration", {}).get("N")
duration = float(duration_value) if duration_value is not None else None
# Legacy tests store state/result under state_json/result_json
state_json_s = item.get("state_json", {}).get("S")
result_json_s = item.get("result_json", {}).get("S")
snapshot: dict[str, Any] = {"kind": kind, "id": snap_id}
if status is not None:
snapshot["status"] = status
if node_name is not None:
snapshot["node"] = node_name
if duration is not None:
snapshot["duration"] = duration
if state_json_s is not None:
try:
snapshot["state"] = json.loads(state_json_s)
except Exception:
snapshot["state"] = state_json_s
if result_json_s is not None:
try:
snapshot["result"] = json.loads(result_json_s)
except Exception:
snapshot["result"] = result_json_s
snapshots.append(snapshot)
return snapshots # type: ignore[return-value]
async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
items = await self._query_first()
if not items:
return None
item = items[0]
sk = item.get(self._range_key_name, {}).get("S")
if not sk:
return None
kind = item.get("kind", {}).get("S")
status = item.get("status", {}).get("S")
# If running and stale, requeue to pending and increment retries; otherwise, skip
if status == "running":
start_raw = item.get("start_ts", {}).get("N")
try:
start_val = float(start_raw) if start_raw is not None else None
except Exception:
start_val = None
# If still within timeout, do nothing; otherwise requeue as pending and increment retries
if (
start_val is not None
and (time.time() - start_val) > self._running_timeout_seconds
):
async with self._session.create_client(
"dynamodb", **self._build_client_kwargs()
) as client:
await client.update_item(
TableName=self._table_name,
Key={
self._hash_key_name: {"S": self._run_id},
self._range_key_name: {"S": sk},
},
UpdateExpression="SET #st = :pending, #start = :start ADD #ret :one",
ExpressionAttributeNames={
"#st": "status",
"#start": "start_ts",
"#ret": "retries",
},
ExpressionAttributeValues={
":pending": {"S": "pending"},
":start": {"N": str(round(time.time(), 0))},
":one": {"N": "1"},
},
)
status = "pending"
else:
return None
# If error, set to pending and increment retries, then proceed
if kind == "node" and status == "error":
async with self._session.create_client(
"dynamodb", **self._build_client_kwargs()
) as client:
await client.update_item(
TableName=self._table_name,
Key={
self._hash_key_name: {"S": self._run_id},
self._range_key_name: {"S": sk},
},
UpdateExpression="SET #st = :pending, #start = :start ADD #ret :one",
ExpressionAttributeNames={
"#st": "status",
"#start": "start_ts",
"#ret": "retries",
},
ExpressionAttributeValues={
":pending": {"S": "pending"},
":start": {"N": str(round(time.time(), 0))},
":one": {"N": "1"},
},
)
status = "pending"
if kind != "node" or status != "pending":
if kind == "end":
raise GraphRuntimeError("graph already executed")
return None
# Convert to a proper NodeSnapshot when possible, else return legacy dict
if item.get("snapshot_json") and self._snapshots_type_adapter is not None:
snapshot = self._deserialize_snapshot_item(item)
if isinstance(snapshot, NodeSnapshot):
snapshot.id = sk
snapshot.status = "pending"
return snapshot
return None
# Legacy dict response for tests
node_name = item.get("node", {}).get("S")
state_s_legacy = item.get("state_json", {}).get("S")
snapshot_dict: dict[str, Any] = {"kind": kind, "id": sk, "status": status}
if node_name is not None:
snapshot_dict["node"] = node_name
if state_s_legacy is not None:
try:
snapshot_dict["state"] = json.loads(state_s_legacy)
except Exception:
snapshot_dict["state"] = state_s_legacy
return snapshot_dict # type: ignore[return-value]
async def get_graph_result(self) -> GraphRunResult[BaseModel, BaseModel] | None:
"""Return the final GraphRunResult if the first item is an 'end' snapshot; else None.
Reads only one record (first by SK) for efficiency.
"""
items = await self._query_first()
if not items:
return None
item = items[0]
kind = item.get("kind", {}).get("S")
if kind != "end":
return None
# Decode state/result
state_b = item.get("state", {}).get("B")
state_s = item.get("state", {}).get("S") if state_b is None else None
result_b = item.get("result", {}).get("B")
result_s = item.get("result", {}).get("S") if result_b is None else None
decoded_state: Any = None
decoded_result: Any = None
if self._use_compression and state_b is not None and br_decompress is not None:
try:
decoded_state = json.loads(br_decompress(state_b).decode("utf-8"))
except Exception:
decoded_state = None
elif state_s is not None:
try:
decoded_state = json.loads(state_s)
except Exception:
decoded_state = state_s
if self._use_compression and result_b is not None and br_decompress is not None:
try:
decoded_result = json.loads(br_decompress(result_b).decode("utf-8"))
except Exception:
decoded_result = None
elif result_s is not None:
try:
decoded_result = json.loads(result_s)
except Exception:
decoded_result = result_s
if decoded_result is None:
return None
# Wrap decoded payloads into typed BaseModel instances if possible
class _AnyStateModel(RootModel[Any]):
pass
class _AnyOutputModel(RootModel[Any]):
pass
try:
state_model: BaseModel = (
self._state_type.model_validate(decoded_state) # type: ignore[assignment]
if getattr(self, "_state_type", None) is not None
and hasattr(self._state_type, "model_validate")
else _AnyStateModel.model_validate(decoded_state)
)
except Exception:
state_model = _AnyStateModel.model_validate(decoded_state)
try:
output_model: BaseModel = (
self._run_end_type.model_validate(decoded_result) # type: ignore[assignment]
if getattr(self, "_run_end_type", None) is not None
and hasattr(self._run_end_type, "model_validate")
else _AnyOutputModel.model_validate(decoded_result)
)
except Exception:
output_model = _AnyOutputModel.model_validate(decoded_result)
return GraphRunResult(output=output_model, state=state_model, persistence=self)
async def snapshot_node(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
"""Snapshot the state of a graph before running a node.
Supports both the pydantic-graph signature `(state, next_node)` and a legacy
signature `(snapshot_id, state, next_node)` used in local tests.
"""
snapshot_id: str | None = None
state: Any
next_node: Any
if len(args) == 2:
state, next_node = args
try:
snapshot_id = next_node.get_snapshot_id()
except Exception:
snapshot_id = None
elif len(args) == 3:
snapshot_id, state, next_node = args
else:
raise TypeError(
"snapshot_node expected (state,next_node) or (snapshot_id,state,next_node)"
)
if snapshot_id is None:
try:
snapshot_id = next_node.get_snapshot_id()
except Exception:
raise ValueError("next_node must provide get_snapshot_id()")
try:
node_id = next_node.get_node_id()
except Exception:
raise ValueError("next_node must provide get_node_id()")
snapshot = {
"kind": "node",
"id": snapshot_id,
"status": "created",
"node": node_id,
"state": DynamoDBStatePersistence._to_json_safe(state),
}
await self._put_snapshot_async(snapshot_id, snapshot)
async def snapshot_node_if_new(
self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
) -> None: # type: ignore[override]
exists = await self._get_item_by_snapshot_id(snapshot_id)
if exists is not None:
return
snapshot = {
"kind": "node",
"id": snapshot_id,
"status": "created",
"state": DynamoDBStatePersistence._to_json_safe(state),
}
await self._put_snapshot_async(snapshot_id, snapshot)
async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: # type: ignore[override]\
snapshot_id = end.get_snapshot_id()
snapshot = {
"kind": "end",
"id": snapshot_id,
"state": DynamoDBStatePersistence._to_json_safe(state),
"result": DynamoDBStatePersistence._to_json_safe(end.data),
}
await self._put_snapshot_async(snapshot_id, snapshot)
@asynccontextmanager
async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: # type: ignore[override]
start = time.perf_counter()
await self._set_running(snapshot_id)
try:
yield
except Exception as e:
duration = time.perf_counter() - start
exception = (
f"{e.__class__.__name__}: {e.message}"
if hasattr(e, "message")
else str(e)
)
await self._update_status_and_duration(
snapshot_id, "error", duration, exception
)
raise
else:
duration = time.perf_counter() - start
await self._update_status_and_duration(
snapshot_id, "success", duration, None
)
# ----- helpers -----
def _build_client_kwargs(self) -> dict[str, Any]:
kwargs: dict[str, Any] = {"region_name": self._region}
if self._endpoint_url:
kwargs["endpoint_url"] = self._endpoint_url
# When using a local endpoint, DynamoDB Local accepts any credentials but
# botocore still requires something present. Provide dummy creds if missing.
if not (self._aws_access_key_id and self._aws_secret_access_key):
kwargs.update(
{
"aws_access_key_id": "dummy",
"aws_secret_access_key": "dummy",
"aws_session_token": "dummy",
}
)
if self._aws_access_key_id and self._aws_secret_access_key:
kwargs.update(
{
"aws_access_key_id": self._aws_access_key_id,
"aws_secret_access_key": self._aws_secret_access_key,
}
)
if self._aws_session_token and "aws_session_token" not in kwargs:
kwargs["aws_session_token"] = self._aws_session_token
return kwargs
# ----- Internal serialization helpers -----
async def _update_status_and_duration(
self,
snapshot_id: str,
status: str,
duration: float,
exception: str | None = None,
) -> None:
async with self._session.create_client(
"dynamodb", **self._build_client_kwargs()
) as client:
await client.update_item(
TableName=self._table_name,
Key={
self._hash_key_name: {"S": self._run_id},
self._range_key_name: {"S": snapshot_id},
},
UpdateExpression="SET #st = :status, #dur = :duration"
+ (", #exc = :exception" if exception is not None else ""),
ExpressionAttributeNames={"#st": "status", "#dur": "duration"}
if exception is None
else {"#st": "status", "#dur": "duration", "#exc": "exception"},
ExpressionAttributeValues={
":status": {"S": status},
":duration": {"N": str(float(duration))},
}
if exception is None
else {
":status": {"S": status},
":duration": {"N": str(float(duration))},
":exception": {"S": exception},
},
)
async def _set_running(self, snapshot_id: str) -> None:
async with self._session.create_client(
"dynamodb", **self._build_client_kwargs()
) as client:
await client.update_item(
TableName=self._table_name,
Key={
self._hash_key_name: {"S": self._run_id},
self._range_key_name: {"S": snapshot_id},
},
UpdateExpression="SET #st = :running, #start = :start",
ConditionExpression="#st IN (:created, :pending)",
ExpressionAttributeNames={"#st": "status", "#start": "start_ts"},
ExpressionAttributeValues={
":running": {"S": "running"},
":created": {"S": "created"},
":pending": {"S": "pending"},
":start": {"N": str(round(time.time(), 0))},
},
)
async def _get_item_by_snapshot_id(self, snapshot_id: str) -> dict | None:
async with self._session.create_client(
"dynamodb", **self._build_client_kwargs()
) as client:
resp = await client.get_item(
TableName=self._table_name,
Key={
"run_id": {"S": self._run_id},
"snapshot_id": {"S": snapshot_id},
},
ConsistentRead=True,
)
return resp.get("Item")
async def _put_snapshot_async(self, sk: str, snapshot: Any) -> None:
if isinstance(snapshot, dict):
safe = self._to_json_safe(snapshot)
snapshot_json_bytes: bytes | None = None
kind_val = str(safe.get("kind", ""))
node_name_val = (
str(safe.get("node")) if safe.get("node") is not None else None
)
status_val = (
str(safe.get("status")) if safe.get("status") is not None else None
)
duration_val = (
float(safe.get("duration"))
if safe.get("duration") is not None
else None
)
else:
# dataclass snapshots
try:
kind_val = getattr(snapshot, "kind", None)
node_name_val = getattr(
getattr(snapshot, "node", None), "get_node_id", lambda: None
)()
status_val = getattr(snapshot, "status", None)
duration_val = getattr(snapshot, "duration", None)
state_val = getattr(snapshot, "state", None)
result_val = getattr(getattr(snapshot, "result", None), "data", None)
except Exception:
kind_val = None
node_name_val = None
status_val = None
duration_val = None
state_val = None
result_val = None
safe = {
"kind": kind_val,
"node": node_name_val,
"status": status_val,
"duration": duration_val,
"state": self._to_json_safe(state_val),
"result": self._to_json_safe(result_val),
}
snapshot_json_bytes = None
if self._snapshots_type_adapter is not None:
try:
snapshot_json_bytes = self._snapshots_type_adapter.dump_json(
[snapshot], indent=2
)
except Exception:
snapshot_json_bytes = None
now = int(time.time())
expires_at_value = (now + int(self._ttl_seconds)) if self._ttl_seconds else None
async with self._session.create_client(
"dynamodb", **self._build_client_kwargs()
) as client:
item: dict[str, Any] = {
self._hash_key_name: {"S": self._run_id},
self._range_key_name: {"S": sk},
"kind": {"S": str(safe.get("kind", ""))},
**(
{"node": {"S": str(safe.get("node"))}}
if safe.get("node") is not None
else {}
),
**(
{"status": {"S": str(safe.get("status"))}}
if safe.get("status") is not None
else {}
),
**(
{"duration": {"N": str(float(safe.get("duration")))}}
if safe.get("duration") is not None
else {}
),
"last_updated": {"N": str(now)},
}
if safe.get("state") is not None:
try:
raw = json.dumps(safe.get("state"), ensure_ascii=False).encode(
"utf-8"
)
if self._use_compression and br_compress is not None:
item["state"] = {"B": br_compress(raw)}
else:
item["state"] = {"S": raw.decode("utf-8")}
except Exception:
item["state"] = {"S": json.dumps(safe.get("state"))}
if safe.get("result") is not None:
try:
rawr = json.dumps(safe.get("result"), ensure_ascii=False).encode(
"utf-8"
)
if self._use_compression and br_compress is not None:
item["result"] = {"B": br_compress(rawr)}
else:
item["result"] = {"S": rawr.decode("utf-8")}
except Exception:
item["result"] = {"S": json.dumps(safe.get("result"))}
# Store full snapshot JSON if available
try:
json_bytes = None
if isinstance(snapshot, dict):
json_bytes = json.dumps(safe, ensure_ascii=False).encode("utf-8")
else:
json_bytes = snapshot_json_bytes
if json_bytes is not None:
if self._use_compression and br_compress is not None:
item["snapshot_json"] = {"B": br_compress(json_bytes)}
else:
item["snapshot_json"] = {"S": json_bytes.decode("utf-8")}
except Exception:
pass
if expires_at_value is not None:
item["expires_at"] = {"N": str(expires_at_value)}
await client.put_item(TableName=self._table_name, Item=item)
async def _query_all(self) -> list[dict]:
async with self._session.create_client(
"dynamodb", **self._build_client_kwargs()
) as client:
self._logger.debug(
"QUERY async: table=%s run_id=%s", self._table_name, self._run_id
)
items: list[dict] = []
exclusive_start_key = None
while True:
params = {
"TableName": self._table_name,
"KeyConditionExpression": "#hk = :rid",
"ExpressionAttributeNames": {"#hk": self._hash_key_name},
"ExpressionAttributeValues": {":rid": {"S": self._run_id}},
"ConsistentRead": True,
}
if exclusive_start_key is not None:
params["ExclusiveStartKey"] = exclusive_start_key
resp = await client.query(**params) # type: ignore[arg-type]
items.extend(resp.get("Items", []))
exclusive_start_key = resp.get("LastEvaluatedKey")
if not exclusive_start_key:
break
return items
async def _query_first(self) -> list[dict]:
async with self._session.create_client(
"dynamodb", **self._build_client_kwargs()
) as client:
params: dict[str, Any] = {
"TableName": self._table_name,
"KeyConditionExpression": "#hk = :rid",
"ExpressionAttributeNames": {"#hk": self._hash_key_name},
"ExpressionAttributeValues": {":rid": {"S": self._run_id}},
"Limit": 1,
"ScanIndexForward": True,
"ConsistentRead": True,
}
resp = await client.query(**params) # type: ignore[arg-type]
return resp.get("Items", [])
async def _get_item_by_snapshot_id(self, snapshot_id: str) -> dict | None:
async with self._session.create_client(
"dynamodb", **self._build_client_kwargs()
) as client:
resp = await client.get_item(
TableName=self._table_name,
Key={
self._hash_key_name: {"S": self._run_id},
self._range_key_name: {"S": snapshot_id},
},
ConsistentRead=True,
)
return resp.get("Item")
# ----- Types integration -----
def should_set_types(self) -> bool: # type: ignore[override]
return self._snapshots_type_adapter is None
def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: # type: ignore[override]
self._state_type = state_type
self._run_end_type = run_end_type
try:
self._snapshots_type_adapter = build_snapshot_list_type_adapter(
state_type, run_end_type
)
except Exception:
self._snapshots_type_adapter = None
# ----- Deserialization helpers -----
def _deserialize_snapshot_item(
self, item: dict
) -> Snapshot[StateT, RunEndT] | None:
snap_json_b = item.get("snapshot_json", {}).get("B")
snap_json_s = (
item.get("snapshot_json", {}).get("S") if snap_json_b is None else None
)
raw_json: bytes | None = None
if snap_json_b is not None and br_decompress is not None:
try:
raw_json = br_decompress(snap_json_b)
except Exception:
raw_json = None
elif snap_json_s is not None:
raw_json = snap_json_s.encode("utf-8")
if raw_json is not None and self._snapshots_type_adapter is not None:
try:
lst = self._snapshots_type_adapter.validate_json(raw_json)
if isinstance(lst, list) and lst:
return lst[0]
except Exception:
pass
kind = item.get("kind", {}).get("S")
if kind == "end":
state_val = self._decode_state_field(item)
result_val = self._decode_result_field(item)
end_obj = End(result_val)
state_model: Any = self._coerce_state_model(state_val)
return EndSnapshot(state=state_model, result=end_obj)
elif kind == "node":
# Without a full snapshot JSON we cannot safely reconstruct the node instance.
# Returning None will cause Graph.iter_from_persistence to raise and the workflow to restart.
return None
else:
return None
def _decode_state_field(self, item: dict) -> Any:
state_b = item.get("state", {}).get("B")
state_s = item.get("state", {}).get("S") if state_b is None else None
# Legacy key used in tests
state_json_s = (
item.get("state_json", {}).get("S")
if state_s is None and state_b is None
else None
)
if self._use_compression and state_b is not None and br_decompress is not None:
try:
return json.loads(br_decompress(state_b).decode("utf-8"))
except Exception:
return None
elif state_s is not None:
try:
return json.loads(state_s)
except Exception:
return state_s
elif state_json_s is not None:
try:
return json.loads(state_json_s)
except Exception:
return state_json_s
return None
def _decode_result_field(self, item: dict) -> Any:
result_b = item.get("result", {}).get("B")
result_s = item.get("result", {}).get("S") if result_b is None else None
result_json_s = (
item.get("result_json", {}).get("S")
if result_s is None and result_b is None
else None
)
if self._use_compression and result_b is not None and br_decompress is not None:
try:
return json.loads(br_decompress(result_b).decode("utf-8"))
except Exception:
return None
elif result_s is not None:
try:
return json.loads(result_s)
except Exception:
return result_s
elif result_json_s is not None:
try:
return json.loads(result_json_s)
except Exception:
return result_json_s
return None
def _coerce_state_model(self, state_val: Any) -> Any:
if getattr(self, "_state_type", None) is not None and hasattr(
self._state_type, "model_validate"
):
try:
return self._state_type.model_validate(state_val) # type: ignore[return-value]
except Exception:
return state_val
return state_val
# Removed workflow-specific reconstruction; persistence must remain generic
setattr(
pydantic_graph_nodes,
"generate_snapshot_id",
DynamoDBStatePersistence.generate_snapshot_id,
) # type: ignore
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment