Created
November 14, 2025 01:26
-
-
Save shibuiwilliam/f578d0468471cc979c0110f8759e280c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| import operator | |
| import uuid | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Annotated, Any, Dict, List, Optional, TypedDict | |
| from dotenv import load_dotenv | |
| from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage | |
| from langchain_openai import AzureChatOpenAI | |
| from langgraph.checkpoint.memory import InMemorySaver | |
| from langgraph.graph import END, StateGraph | |
| load_dotenv(".envrc") | |
| # ステートの定義 | |
| class AgentState(TypedDict): | |
| """エージェントの状態を管理するクラス""" | |
| messages: Annotated[List[BaseMessage], operator.add] | |
| current_topic: str | |
| quality_score: float | |
| checkpoint_metadata: Dict[str, Any] | |
| should_rollback: bool | |
| rollback_checkpoint_id: Optional[str] | |
| state_history: List[Dict[str, Any]] # 状態履歴を追加 | |
| class ContextEngineeringAgent: | |
| """コンテキストエンジニアリングを実装したエージェント""" | |
| def __init__(self, llm: AzureChatOpenAI, enable_logging: bool = True): | |
| self.llm = llm | |
| self.checkpointer = InMemorySaver() | |
| self.quality_threshold = 0.6 | |
| self.topic_change_threshold = 0.3 | |
| self.saved_checkpoints = {} | |
| self.enable_logging = enable_logging | |
| self.state_transitions = [] # 状態遷移の記録 | |
| self.session_id = str(uuid.uuid4()) | |
| self.start_time = datetime.now() | |
| def _record_state(self, state: AgentState, node_name: str, action: str = None): | |
| """現在の状態を記録""" | |
| if not self.enable_logging: | |
| return | |
| record = { | |
| "timestamp": datetime.now().isoformat(), | |
| "node": node_name, | |
| "action": action, | |
| "quality_score": state.get("quality_score", 0), | |
| "current_topic": state.get("current_topic", ""), | |
| "should_rollback": state.get("should_rollback", False), | |
| "checkpoint_count": len(self.saved_checkpoints), | |
| "last_message": state["messages"][-1].content[:1000] if state.get("messages") else None, | |
| "last_message_size": len(state["messages"][-1].content) if state.get("messages") else 0, | |
| } | |
| self.state_transitions.append(record) | |
| # state_historyにも追加 | |
| if "state_history" not in state: | |
| state["state_history"] = [] | |
| state["state_history"].append(record) | |
| def evaluate_quality(self, state: AgentState) -> AgentState: | |
| """LLM-as-a-Judgeパターンで出力品質を評価""" | |
| self._record_state(state, "evaluate_quality", "開始") | |
| if len(state["messages"]) < 2: | |
| state["quality_score"] = 1.0 | |
| self._record_state(state, "evaluate_quality", "メッセージ不足のためスキップ") | |
| return state | |
| # 最新の会話から品質を評価 | |
| recent_messages = state["messages"][-4:] if len(state["messages"]) >= 4 else state["messages"] | |
| evaluation_prompt = f""" | |
| 以下の会話の品質を0-1のスコアで評価してください。 | |
| 評価基準: | |
| - 話の一貫性 | |
| - 情報の正確性 | |
| - 文脈の適切性 | |
| 会話: | |
| {[{"role": m.type, "content": m.content} for m in recent_messages]} | |
| スコアのみを数値で出力してください。 | |
| """ | |
| try: | |
| response = self.llm.invoke([SystemMessage(content=evaluation_prompt)]) | |
| score = float(response.content.strip()) | |
| state["quality_score"] = max(0.0, min(1.0, score)) | |
| self._record_state(state, "evaluate_quality", f"品質スコア: {state['quality_score']:.2f}") | |
| except Exception as e: | |
| state["quality_score"] = 0.7 | |
| self._record_state(state, "evaluate_quality", f"評価エラー: {str(e)}") | |
| # 品質が低い場合はロールバックフラグを立てる | |
| if state["quality_score"] < self.quality_threshold: | |
| state["should_rollback"] = True | |
| self._record_state(state, "evaluate_quality", "品質低下によりロールバックフラグ設定") | |
| return state | |
| def detect_topic_change(self, state: AgentState) -> AgentState: | |
| """話題の変化を検出して不要なコンテキストを削除""" | |
| self._record_state(state, "detect_topic_change", "開始") | |
| if len(state["messages"]) < 3: | |
| self._record_state(state, "detect_topic_change", "メッセージ不足のためスキップ") | |
| return state | |
| # 直近のメッセージから話題の変化を検出 | |
| recent_msg = state["messages"][-1].content if state["messages"] else "" | |
| previous_msgs = " ".join([m.content for m in state["messages"][-5:-1]]) if len(state["messages"]) > 1 else "" | |
| topic_analysis_prompt = f""" | |
| 以下の2つの文章の話題の類似度を0-1のスコアで評価してください。 | |
| 過去の文章: {previous_msgs[:500]} | |
| 最新の文章: {recent_msg} | |
| スコアのみを数値で出力してください。 | |
| """ | |
| try: | |
| response = self.llm.invoke([SystemMessage(content=topic_analysis_prompt)]) | |
| similarity = float(response.content.strip()) | |
| self._record_state(state, "detect_topic_change", f"話題類似度: {similarity:.2f}") | |
| # 話題が大きく変わった場合、古いコンテキストを要約または削除 | |
| if similarity < self.topic_change_threshold: | |
| self._record_state(state, "detect_topic_change", "話題変更検出 - コンテキスト圧縮開始") | |
| state = self._compress_old_context(state) | |
| except Exception as e: | |
| self._record_state(state, "detect_topic_change", f"検出エラー: {str(e)}") | |
| return state | |
| def _compress_old_context(self, state: AgentState) -> AgentState: | |
| """古いコンテキストを要約して圧縮""" | |
| self._record_state(state, "_compress_old_context", "開始") | |
| if len(state["messages"]) <= 4: | |
| self._record_state(state, "_compress_old_context", "メッセージ数が少ないためスキップ") | |
| return state | |
| # 古いメッセージを要約 | |
| old_messages = state["messages"][:-4] | |
| recent_messages = state["messages"][-4:] | |
| old_message_count = len(old_messages) | |
| summary_prompt = f""" | |
| 以下の会話履歴を簡潔に要約してください(100文字以内): | |
| {[{"role": m.type, "content": m.content} for m in old_messages]} | |
| """ | |
| try: | |
| response = self.llm.invoke([SystemMessage(content=summary_prompt)]) | |
| summary = SystemMessage(content=f"[過去の会話の要約]: {response.content}") | |
| # 要約と最近のメッセージのみを保持 | |
| state["messages"] = [summary] + recent_messages | |
| self._record_state(state, "_compress_old_context", f"成功: {old_message_count}メッセージを要約") | |
| except Exception as e: | |
| # エラー時は最近のメッセージのみ保持 | |
| state["messages"] = recent_messages | |
| self._record_state(state, "_compress_old_context", f"要約エラー: {str(e)}") | |
| return state | |
| def create_checkpoint(self, state: AgentState) -> AgentState: | |
| """重要なタイミングでチェックポイントを作成""" | |
| checkpoint_id = str(uuid.uuid4()) | |
| self._record_state(state, "create_checkpoint", f"チェックポイント作成: {checkpoint_id[:8]}") | |
| # チェックポイントのメタデータを保存 | |
| self.saved_checkpoints[checkpoint_id] = { | |
| "timestamp": datetime.now().isoformat(), | |
| "messages": state["messages"].copy(), | |
| "topic": state["current_topic"], | |
| "quality_score": state["quality_score"], | |
| } | |
| state["checkpoint_metadata"] = {"last_checkpoint_id": checkpoint_id, "created_at": datetime.now().isoformat()} | |
| # 保存するチェックポイント数を制限(メモリ管理) | |
| if len(self.saved_checkpoints) > 10: | |
| oldest_key = min(self.saved_checkpoints.keys(), key=lambda k: self.saved_checkpoints[k]["timestamp"]) | |
| del self.saved_checkpoints[oldest_key] | |
| self._record_state(state, "create_checkpoint", f"古いチェックポイント削除: {oldest_key[:8]}") | |
| return state | |
| def rollback_to_checkpoint(self, state: AgentState) -> AgentState: | |
| """チェックポイントにロールバック""" | |
| self._record_state(state, "rollback_to_checkpoint", "開始") | |
| if not state.get("should_rollback"): | |
| self._record_state(state, "rollback_to_checkpoint", "ロールバック不要") | |
| return state | |
| # ロールバック先のチェックポイントIDを決定 | |
| checkpoint_id = state.get("rollback_checkpoint_id") | |
| if not checkpoint_id and self.saved_checkpoints: | |
| # 最新の良好なチェックポイントを探す | |
| for cid in reversed(list(self.saved_checkpoints.keys())): | |
| if self.saved_checkpoints[cid]["quality_score"] >= self.quality_threshold: | |
| checkpoint_id = cid | |
| break | |
| if checkpoint_id and checkpoint_id in self.saved_checkpoints: | |
| checkpoint = self.saved_checkpoints[checkpoint_id] | |
| # 状態を復元 | |
| old_message_count = len(state["messages"]) | |
| state["messages"] = checkpoint["messages"].copy() | |
| state["current_topic"] = checkpoint["topic"] | |
| state["quality_score"] = checkpoint["quality_score"] | |
| state["should_rollback"] = False | |
| # ロールバック通知メッセージを追加 | |
| rollback_msg = SystemMessage(content="[システム]: 会話品質の低下を検出したため、前の状態に戻しました。") | |
| state["messages"].append(rollback_msg) | |
| self._record_state( | |
| state, | |
| "rollback_to_checkpoint", | |
| f"成功: {checkpoint_id[:8]}へロールバック (メッセージ数: {old_message_count} → {len(state['messages'])})", | |
| ) | |
| else: | |
| self._record_state(state, "rollback_to_checkpoint", "適切なチェックポイントが見つかりません") | |
| return state | |
| def process_message(self, state: AgentState) -> AgentState: | |
| """メッセージを処理""" | |
| self._record_state(state, "process_message", "開始") | |
| try: | |
| # LLMで応答を生成 | |
| response = self.llm.invoke(state["messages"]) | |
| state["messages"].append(response) | |
| self._record_state( | |
| state, | |
| "process_message", | |
| f"応答生成成功: {len(response.content)}文字", | |
| ) | |
| except Exception as e: | |
| self._record_state(state, "process_message", f"応答生成エラー: {str(e)}") | |
| return state | |
| def should_continue(self, state: AgentState) -> str: | |
| """処理を続行するか判断""" | |
| decision = "" | |
| if state.get("should_rollback"): | |
| decision = "rollback" | |
| elif state["quality_score"] >= 0.8: | |
| decision = "checkpoint" | |
| else: | |
| decision = "continue" | |
| self._record_state(state, "should_continue", f"判断: {decision}") | |
| return decision | |
| def export_state_history(self, output_dir: str = "./logs"): | |
| """状態履歴をファイルに出力""" | |
| output_path = Path(output_dir) | |
| output_path.mkdir(exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # 状態遷移の詳細をJSON形式で出力 | |
| transitions_file = output_path / f"state_transitions_{self.session_id[:8]}_{timestamp}.json" | |
| with open(transitions_file, "w", encoding="utf-8") as f: | |
| json.dump( | |
| { | |
| "session_id": self.session_id, | |
| "start_time": self.start_time.isoformat(), | |
| "end_time": datetime.now().isoformat(), | |
| "total_transitions": len(self.state_transitions), | |
| "quality_threshold": self.quality_threshold, | |
| "topic_change_threshold": self.topic_change_threshold, | |
| "transitions": self.state_transitions, | |
| }, | |
| f, | |
| ensure_ascii=False, | |
| indent=2, | |
| ) | |
| # サマリーレポートを生成 | |
| summary_file = output_path / f"summary_report_{self.session_id[:8]}_{timestamp}.txt" | |
| with open(summary_file, "w", encoding="utf-8") as f: | |
| f.write(self._generate_summary_report()) | |
| # チェックポイント情報を出力 | |
| checkpoints_file = output_path / f"checkpoints_{self.session_id[:8]}_{timestamp}.json" | |
| checkpoint_data = { | |
| checkpoint_id: { | |
| "timestamp": data["timestamp"], | |
| "topic": data["topic"], | |
| "quality_score": data["quality_score"], | |
| } | |
| for checkpoint_id, data in self.saved_checkpoints.items() | |
| } | |
| with open(checkpoints_file, "w", encoding="utf-8") as f: | |
| json.dump(checkpoint_data, f, ensure_ascii=False, indent=2) | |
| print("\n=== ログファイル出力完了 ===") | |
| print(f"状態遷移: {transitions_file}") | |
| print(f"サマリー: {summary_file}") | |
| print(f"チェックポイント: {checkpoints_file}") | |
| return { | |
| "transitions_file": str(transitions_file), | |
| "summary_file": str(summary_file), | |
| "checkpoints_file": str(checkpoints_file), | |
| } | |
| def _generate_summary_report(self) -> str: | |
| """サマリーレポートを生成""" | |
| duration = (datetime.now() - self.start_time).total_seconds() | |
| # 統計情報を計算 | |
| quality_scores = [ | |
| t.get("quality_score", 0) for t in self.state_transitions if t.get("quality_score") is not None | |
| ] | |
| rollback_count = sum( | |
| 1 for t in self.state_transitions if t.get("action") and "ロールバック" in t.get("action", "") | |
| ) | |
| checkpoint_count = sum(1 for t in self.state_transitions if t.get("node") == "create_checkpoint") | |
| topic_changes = sum( | |
| 1 for t in self.state_transitions if t.get("action") and "話題変更検出" in t.get("action", "") | |
| ) | |
| report = f""" | |
| ==================================== | |
| コンテキストエンジニアリング実行レポート | |
| ==================================== | |
| セッション情報: | |
| -------------- | |
| セッションID: {self.session_id} | |
| 開始時刻: {self.start_time.strftime("%Y-%m-%d %H:%M:%S")} | |
| 終了時刻: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")} | |
| 実行時間: {duration:.2f}秒 | |
| 統計情報: | |
| --------- | |
| 総状態遷移数: {len(self.state_transitions)} | |
| 品質評価回数: {len(quality_scores)} | |
| 平均品質スコア: {sum(quality_scores) / len(quality_scores):.3f} if quality_scores else 0 | |
| 最高品質スコア: {max(quality_scores):.3f} if quality_scores else 0 | |
| 最低品質スコア: {min(quality_scores):.3f} if quality_scores else 0 | |
| ロールバック回数: {rollback_count} | |
| チェックポイント作成数: {checkpoint_count} | |
| 話題変更検出回数: {topic_changes} | |
| ノード実行頻度: | |
| -------------- | |
| """ | |
| # ノード実行頻度を計算 | |
| node_counts = {} | |
| for transition in self.state_transitions: | |
| node = transition.get("node", "unknown") | |
| node_counts[node] = node_counts.get(node, 0) + 1 | |
| for node, count in sorted(node_counts.items(), key=lambda x: x[1], reverse=True): | |
| report += f" {node}: {count}回\n" | |
| # タイムラインを追加 | |
| report += "\n実行タイムライン:\n" | |
| report += "-" * 50 + "\n" | |
| for i, transition in enumerate(self.state_transitions): | |
| timestamp = transition.get("timestamp", "") | |
| node = transition.get("node", "") | |
| action = transition.get("action", "") | |
| report += f"{i + 1:3d}. [{timestamp[11:19]}] {node}: {action}\n" | |
| return report | |
| def build_graph(self) -> StateGraph: | |
| """LangGraphのワークフローを構築""" | |
| workflow = StateGraph(AgentState) | |
| # ノードの追加 | |
| workflow.add_node("evaluate_quality", self.evaluate_quality) | |
| workflow.add_node("detect_topic_change", self.detect_topic_change) | |
| workflow.add_node("process_message", self.process_message) | |
| workflow.add_node("create_checkpoint", self.create_checkpoint) | |
| workflow.add_node("rollback", self.rollback_to_checkpoint) | |
| # エッジの定義 | |
| workflow.set_entry_point("evaluate_quality") | |
| workflow.add_edge("evaluate_quality", "detect_topic_change") | |
| workflow.add_edge("detect_topic_change", "process_message") | |
| # 条件付きエッジ | |
| workflow.add_conditional_edges( | |
| "process_message", | |
| self.should_continue, | |
| {"rollback": "rollback", "checkpoint": "create_checkpoint", "continue": END}, | |
| ) | |
| workflow.add_edge("create_checkpoint", END) | |
| workflow.add_edge("rollback", "process_message") | |
| return workflow.compile(checkpointer=self.checkpointer) | |
| # 使用例 | |
| def main(): | |
| llm = AzureChatOpenAI(model="gpt-5-mini", temperature=0.7) | |
| # エージェントの作成(ロギング有効) | |
| agent = ContextEngineeringAgent(llm, enable_logging=True) | |
| app = agent.build_graph() | |
| print("=== ワークフロー構造 ===") | |
| print(app.get_graph().draw_mermaid()) | |
| print("========================\n") | |
| # 初期状態 | |
| initial_state = { | |
| "messages": [ | |
| SystemMessage(content="あなたは親切なアシスタントです。"), | |
| HumanMessage(content="Pythonのリスト内包表記について教えてください。"), | |
| ], | |
| "current_topic": "Python", | |
| "quality_score": 1.0, | |
| "checkpoint_metadata": {}, | |
| "should_rollback": False, | |
| "rollback_checkpoint_id": None, | |
| "state_history": [], | |
| } | |
| # スレッドIDでセッション管理 | |
| thread_id = str(uuid.uuid4()) | |
| config = {"configurable": {"thread_id": thread_id}} | |
| # 実行1回目 | |
| print("=== 1回目の実行 ===") | |
| result = app.invoke(initial_state, config) | |
| print(f"品質スコア: {result['quality_score']:.3f}") | |
| print(f"メッセージ数: {len(result['messages'])}") | |
| print(f"状態履歴数: {len(result.get('state_history', []))}") | |
| # 話題を変更してテスト | |
| print("\n=== 話題変更のテスト ===") | |
| result["messages"].append(HumanMessage(content="ところで、今日の天気はどうですか?")) | |
| # 再実行(話題変更の検出) | |
| result = app.invoke(result, config) | |
| print(f"話題変更後のメッセージ数: {len(result['messages'])}") | |
| print(f"状態履歴数: {len(result.get('state_history', []))}") | |
| # さらにメッセージを追加 | |
| print("\n=== 追加の対話 ===") | |
| result["messages"].append(HumanMessage(content="Pythonのデコレータについても教えてください。")) | |
| result = app.invoke(result, config) | |
| # 低品質な応答をシミュレート | |
| result["quality_score"] = 0.4 # 品質を下げる | |
| result["should_rollback"] = True | |
| result = app.invoke(result, config) | |
| print(f"最終メッセージ数: {len(result['messages'])}") | |
| print(f"最終状態履歴数: {len(result.get('state_history', []))}") | |
| # ログファイルを出力 | |
| print("\n=== ログ出力 ===") | |
| output_files = agent.export_state_history() | |
| # サマリーレポートの内容を表示 | |
| with open(output_files["summary_file"], "r", encoding="utf-8") as f: | |
| print("\n" + "=" * 50) | |
| print("生成されたサマリーレポート:") | |
| print("=" * 50) | |
| print(f.read()) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment