Created
June 5, 2025 22:58
-
-
Save Hironsan/eaa59f28aacbbd6a5436f3b51fcdea11 to your computer and use it in GitHub Desktop.
fastapi_human_in_the_loop
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.12" | |
| # dependencies = [ | |
| # "fastapi>=0.115.12", | |
| # "langchain>=0.3.25", | |
| # "langgraph>=0.4.8", | |
| # "uvicorn>=0.34.3", | |
| # ] | |
| # /// | |
| from typing import Any, Dict, List, TypedDict | |
| from uuid import uuid4 | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from langchain.schema.runnable.config import RunnableConfig | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.constants import END, START | |
| from langgraph.graph import StateGraph | |
| from langgraph.types import Command, interrupt | |
| from pydantic import BaseModel | |
| class State(TypedDict): | |
| """The graph state.""" | |
| topic: str | |
| plan: str | |
| feedback: str | |
| final_report: str | |
| def generate_plan(state: State): | |
| feedback = state.get("feedback", "") | |
| return {"plan": f"This is a plan. {feedback}"} | |
| def human_feedback(state: State): | |
| feedback = interrupt("Approve the plan (true) or give me feedback:") | |
| if isinstance(feedback, bool) and feedback is True: | |
| return Command(goto="write_report") | |
| if isinstance(feedback, str): | |
| return Command(goto="generate_plan", update={"feedback": feedback}) | |
| raise TypeError(f"Interrupt value of type {type(feedback)} is not supported.") | |
| def write_report(state: State): | |
| return {"final_report": "Here is the final report."} | |
| graph_builder = StateGraph(State) | |
| graph_builder.add_node("generate_plan", generate_plan) | |
| graph_builder.add_node("human_feedback", human_feedback) | |
| graph_builder.add_node("write_report", write_report) | |
| graph_builder.add_edge(START, "generate_plan") | |
| graph_builder.add_edge("generate_plan", "human_feedback") | |
| graph_builder.add_edge("write_report", END) | |
| checkpointer = MemorySaver() | |
| agent = graph_builder.compile(checkpointer=checkpointer) | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| sessions: Dict[str, Any] = {} | |
| class Message(BaseModel): | |
| content: str | |
| type: str | |
| class ConversationState(BaseModel): | |
| messages: List[Message] | |
| class HumanMessageRequest(BaseModel): | |
| message: str | |
| @app.post("/start_conversation") | |
| async def start_conversation(): | |
| thread_id = str(uuid4()) | |
| sessions[thread_id] = {"messages": []} | |
| return {"thread_id": thread_id} | |
| @app.post("/send_message/{thread_id}") | |
| async def send_message(thread_id: str, request: HumanMessageRequest): | |
| if thread_id not in sessions: | |
| raise HTTPException(status_code=404, detail="Conversation not found") | |
| config = RunnableConfig(configurable={"thread_id": thread_id}) | |
| result = agent.invoke({"topic": request.message}, config) | |
| sessions[thread_id]["messages"].extend( | |
| [ | |
| Message(content=request.message, type="human"), | |
| Message(content=result["plan"], type="ai"), | |
| ] | |
| ) | |
| return ConversationState(**sessions[thread_id]) | |
| @app.post("/approve/{thread_id}") | |
| async def approve(thread_id: str): | |
| if thread_id not in sessions: | |
| raise HTTPException(status_code=404, detail="Conversation not found") | |
| config = RunnableConfig(configurable={"thread_id": thread_id}) | |
| result = agent.invoke(Command(resume=True), config) | |
| sessions[thread_id]["messages"].append( | |
| Message(content=result["final_report"], type="ai") | |
| ) | |
| return ConversationState(**sessions[thread_id]) | |
| @app.post("/feedback/{thread_id}") | |
| async def feedback(thread_id: str, request: HumanMessageRequest): | |
| if thread_id not in sessions: | |
| raise HTTPException(status_code=404, detail="Conversation not found") | |
| config = RunnableConfig(configurable={"thread_id": thread_id}) | |
| result = agent.invoke(Command(resume=request.message), config) | |
| sessions[thread_id]["messages"].extend( | |
| [ | |
| Message(content=request.message, type="human"), | |
| Message(content=result["plan"], type="ai"), | |
| ] | |
| ) | |
| return ConversationState(**sessions[thread_id]) | |
| @app.get("/conversation_state/{thread_id}") | |
| async def get_conversation_state(thread_id: str): | |
| if thread_id not in sessions: | |
| raise HTTPException(status_code=404, detail="Conversation not found") | |
| return ConversationState(**sessions[thread_id]) | |
| @app.post("/deep_research/{thread_id}") | |
| async def handle_message(thread_id: str, request: HumanMessageRequest): | |
| if thread_id not in sessions: | |
| raise HTTPException(status_code=404, detail="Conversation not found") | |
| config = RunnableConfig(configurable={"thread_id": thread_id}) | |
| next_node = agent.get_state(config).next | |
| if not next_node: | |
| result = agent.invoke({"topic": request.message}, config) | |
| sessions[thread_id]["messages"].extend( | |
| [ | |
| Message(content=request.message, type="human"), | |
| Message(content=result["plan"], type="ai"), | |
| ] | |
| ) | |
| elif next_node == ("human_feedback",): | |
| if request.message.lower() == "true": | |
| result = agent.invoke(Command(resume=True), config) | |
| sessions[thread_id]["messages"].append( | |
| Message(content=result["final_report"], type="ai") | |
| ) | |
| else: | |
| result = agent.invoke(Command(resume=request.message), config) | |
| sessions[thread_id]["messages"].extend( | |
| [ | |
| Message(content=request.message, type="human"), | |
| Message(content=result["plan"], type="ai"), | |
| ] | |
| ) | |
| else: | |
| raise HTTPException(status_code=400, detail="Invalid state transition") | |
| return ConversationState(**sessions[thread_id]) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment