Created
September 17, 2025 15:51
-
-
Save AstraBert/d24f41d5328dfe4410f235c582b5c275 to your computer and use it in GitHub Desktop.
workflows_and_memory.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "authorship_tag": "ABX9TyOxLYYgUeQ/vTMTWJG7vN6Z", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/AstraBert/d24f41d5328dfe4410f235c582b5c275/workflows_and_memory.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Memory-Augmented Workflows for Customer Support\n", | |
| "\n", | |
| "LlamaIndex offers [memory](https://developers.llamaindex.ai/python/framework/module_guides/deploying/agents/memory/#remote-memory) features that allow developers to create **context-aware** AI-powered applications.\n", | |
| "\n", | |
| "In this notebook, we will explore how we can connect an external memory to our vector database." | |
| ], | |
| "metadata": { | |
| "id": "6-1jsT9foTEa" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## 1 . Install needed dependencies" | |
| ], | |
| "metadata": { | |
| "id": "CtwYERbXqpFR" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "bOHBihMCm0KW", | |
| "outputId": "45fac612-5d3b-4a58-eac6-719c39ca95dc" | |
| }, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.9/11.9 MB\u001b[0m \u001b[31m93.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.1/56.1 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m337.3/337.3 kB\u001b[0m \u001b[31m25.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m61.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.9/50.9 kB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m144.4/144.4 kB\u001b[0m \u001b[31m11.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", | |
| "ipython 7.34.0 requires jedi>=0.16, which is not installed.\u001b[0m\u001b[31m\n", | |
| "\u001b[0m" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "! pip install -q llama-index-core llama-index-vector-stores-qdrant llama-index-embeddings-openai llama-index-llms-openai llama-index-workflows" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## 2. Define the memory\n", | |
| "\n", | |
| "In order to create the memory, we will need:\n", | |
| "\n", | |
| "- An embedding model\n", | |
| "- A vector database to store long-term memory" | |
| ], | |
| "metadata": { | |
| "id": "qTJ0Xk8RrGbr" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### Vector Store and Embedding Model" | |
| ], | |
| "metadata": { | |
| "id": "rIc45fIkt7_7" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from google.colab import userdata\n", | |
| "\n", | |
| "QDRANT_API_KEY = userdata.get('QDRANT_API_KEY')\n", | |
| "QDRANT_URL = userdata.get('QDRANT_URL')\n", | |
| "QDRANT_PORT = 6333" | |
| ], | |
| "metadata": { | |
| "id": "uzNygX9__VS-" | |
| }, | |
| "execution_count": 1, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import os\n", | |
| "\n", | |
| "os.environ[\"OPENAI_API_KEY\"] = userdata.get('OPENAI_API_KEY')" | |
| ], | |
| "metadata": { | |
| "id": "QWAbOrFZtdEn" | |
| }, | |
| "execution_count": 2, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from qdrant_client import AsyncQdrantClient, QdrantClient\n", | |
| "from qdrant_client.models import VectorParams, Distance\n", | |
| "from llama_index.vector_stores.qdrant import QdrantVectorStore\n", | |
| "from llama_index.embeddings.openai import OpenAIEmbedding\n", | |
| "from llama_index.core.schema import TextNode\n", | |
| "from llama_index.core import Settings\n", | |
| "\n", | |
| "\n", | |
| "embedding_model = OpenAIEmbedding(model=\"text-embedding-3-small\", dimensions=768)\n", | |
| "Settings.embed_model = embedding_model\n", | |
| "aqd_client = AsyncQdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY, port=QDRANT_PORT)\n", | |
| "qd_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY, port=QDRANT_PORT)\n", | |
| "qd_client.create_collection(\"workflow_memory\", vectors_config={\"dense-text\": VectorParams(size=768, distance=Distance.COSINE)})\n", | |
| "vector_store = QdrantVectorStore(aclient=aqd_client, client=qd_client, collection_name=\"workflow_memory\", dense_vector_name=\"dense-text\")" | |
| ], | |
| "metadata": { | |
| "id": "cmdBPQRHt_Lh" | |
| }, | |
| "execution_count": 3, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### Memory" | |
| ], | |
| "metadata": { | |
| "id": "q4uD-W7D5Z-q" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from llama_index.core.memory import VectorMemory\n", | |
| "\n", | |
| "memory = VectorMemory.from_defaults(\n", | |
| " vector_store=vector_store,\n", | |
| " embed_model=embedding_model,\n", | |
| " retriever_kwargs={\"similarity_top_k\": 10},\n", | |
| ")\n", | |
| "\n", | |
| "def get_memory(*args, **kwargs) -> VectorMemory:\n", | |
| " return memory" | |
| ], | |
| "metadata": { | |
| "id": "GBke-7ND5fDO" | |
| }, | |
| "execution_count": 4, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Let's try to insert and retrieve some information, to see how our memory behaves" | |
| ], | |
| "metadata": { | |
| "id": "eywb3WIQ6EsU" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from llama_index.core.llms import ChatMessage\n", | |
| "\n", | |
| "messages = [\n", | |
| " ChatMessage(role=\"system\", content=\"You are a helpful AI assistant that provides concise, accurate answers.\"),\n", | |
| " ChatMessage(role=\"user\", content=\"Hey, can you explain what vector databases are?\"),\n", | |
| " ChatMessage(role=\"assistant\", content=\"Sure! A vector database stores embeddings—numeric representations of data—so that semantic search and similarity queries can be done efficiently.\"),\n", | |
| " ChatMessage(role=\"user\", content=\"Interesting. How is that different from a regular SQL database?\"),\n", | |
| " ChatMessage(role=\"assistant\", content=\"SQL databases are optimized for structured data and exact matches, while vector databases excel at approximate nearest neighbor searches on high-dimensional data.\"),\n", | |
| " ChatMessage(role=\"user\", content=\"Got it. Can you give me an example use case?\"),\n", | |
| " ChatMessage(role=\"assistant\", content=\"One common example is a recommendation system—using embeddings of products or users to find similar items.\"),\n", | |
| " ChatMessage(role=\"user\", content=\"Cool. Do they scale well with millions of records?\"),\n", | |
| " ChatMessage(role=\"assistant\", content=\"Yes, most modern vector databases like Qdrant are designed to scale to billions of vectors efficiently.\"),\n", | |
| " ChatMessage(role=\"user\", content=\"Thanks, that clears things up!\")\n", | |
| "]\n", | |
| "\n", | |
| "await memory.aput_messages(messages)" | |
| ], | |
| "metadata": { | |
| "id": "kUllxe_i6LTG" | |
| }, | |
| "execution_count": 9, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "answer = memory.get(\"What is the difference between a vector database and a regular SQL database?\")" | |
| ], | |
| "metadata": { | |
| "id": "qAPRXOa-6hrA" | |
| }, | |
| "execution_count": 10, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "answer[1]" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "r-xywiFnBWyg", | |
| "outputId": "7058c882-68f5-48b9-e28b-3e1cb5a0dc75" | |
| }, | |
| "execution_count": 11, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "ChatMessage(role=<MessageRole.ASSISTANT: 'assistant'>, additional_kwargs={}, blocks=[TextBlock(block_type='text', text='SQL databases are optimized for structured data and exact matches, while vector databases excel at approximate nearest neighbor searches on high-dimensional data.')])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 11 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Ok, let's now refresh the memory so that these experiment data do not persist within our actual workflow:" | |
| ], | |
| "metadata": { | |
| "id": "R08VKz0rBibX" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "memory.reset()" | |
| ], | |
| "metadata": { | |
| "id": "_VFEgvYCBpNJ" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Design the workflow\n", | |
| "\n", | |
| "In order for our customer management workflow to work correctly, we will need to:\n", | |
| "\n", | |
| "- Create a new session for each new user, and keep track of it in a session register\n", | |
| "- Activate the memory if we have records from a previous session\n", | |
| "- Retrieve relevant information about the user's previous inquiries, based on past interactions\n", | |
| "- Fetch details about the user's request from an API and reply based on the response\n", | |
| "\n", | |
| "In the following sections, we will create all the needed components, divided within:\n", | |
| "\n", | |
| "- Resources\n", | |
| "- Events\n", | |
| "- Workflow State\n", | |
| "- Workflow Itself" | |
| ], | |
| "metadata": { | |
| "id": "lFaG-b5MBycs" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### Resources\n", | |
| "\n", | |
| "Resources are external dependencies that we can make available to the steps of the workflow: apart from the memory we defined above, we will need an LLM and a 'mock' API client to reply to the LLM's requests." | |
| ], | |
| "metadata": { | |
| "id": "J42MvNywC8XQ" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### LLM" | |
| ], | |
| "metadata": { | |
| "id": "5HygdEbfsUff" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from llama_index.llms.openai import OpenAI\n", | |
| "\n", | |
| "llm = OpenAI(model=\"gpt-4.1-mini\")" | |
| ], | |
| "metadata": { | |
| "id": "vvyIFykGC7Ui" | |
| }, | |
| "execution_count": 5, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from pydantic import BaseModel, Field\n", | |
| "from llama_index.core.llms.structured_llm import StructuredLLM\n", | |
| "from typing import Literal\n", | |
| "\n", | |
| "class ContextRelevance(BaseModel):\n", | |
| " relevance: int = Field(ge=0,le=100, description=\"Relevance of the retrieved information, measured out of 100.\")\n", | |
| " reasons: str = Field(description=\"Brief explanation (maximum 50 words) of the evaluation\")\n", | |
| "\n", | |
| "class MemoryInformation(BaseModel):\n", | |
| " information_list: list[str] = Field(description=\"A list of information to store into memory\")\n", | |
| "\n", | |
| "class OrderRequestBody(BaseModel):\n", | |
| " order_id: str = Field(description=\"ID of the order\")\n", | |
| "\n", | |
| "class PaymentRequestBody(BaseModel):\n", | |
| " payment_id: str = Field(description=\"ID of the payment\")\n", | |
| "\n", | |
| "class ComplaintRequestBody(BaseModel):\n", | |
| " complaint_id: str = Field(description=\"ID of the complaint\")\n", | |
| "\n", | |
| "class ApiRequest(BaseModel):\n", | |
| " endpoint: Literal[\"/orders\", \"/payments\", \"/complaints\"] = Field(description=\"The endpoint to call\")\n", | |
| " request_body: OrderRequestBody | PaymentRequestBody | ComplaintRequestBody = Field(description=\"The body of the request\")\n", | |
| "\n", | |
| "llm_memory = llm.as_structured_llm(MemoryInformation)\n", | |
| "llm_api_call = llm.as_structured_llm(ApiRequest)\n", | |
| "llm_relevance = llm.as_structured_llm(ContextRelevance)\n", | |
| "\n", | |
| "def get_llm_memory(*args, **kwargs) -> StructuredLLM:\n", | |
| " return llm_memory\n", | |
| "\n", | |
| "def get_llm_api_call(*args, **kwargs) -> StructuredLLM:\n", | |
| " return llm_api_call\n", | |
| "\n", | |
| "def get_llm_relevance(*args, **kwargs) -> StructuredLLM:\n", | |
| " return llm_relevance" | |
| ], | |
| "metadata": { | |
| "id": "RZe1qC3fEGw3" | |
| }, | |
| "execution_count": 6, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### API Client\n", | |
| "\n", | |
| "> _The following implementation is a mock API client, make sure you replace it with a real one!_" | |
| ], | |
| "metadata": { | |
| "id": "_4QBgGzKEE5x" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from enum import Enum\n", | |
| "import random\n", | |
| "\n", | |
| "class OrderStatus(Enum):\n", | |
| " SHIPPED = \"shipped\"\n", | |
| " DELIVERED = \"delivered\"\n", | |
| "\n", | |
| "class PaymentStatus(Enum):\n", | |
| " PAID = \"paid\"\n", | |
| " UNPAID = \"unpaid\"\n", | |
| "\n", | |
| "class ComplaintStatus(Enum):\n", | |
| " RESOLVED = \"resolved\"\n", | |
| " UNRESOLVED = \"unresolved\"\n", | |
| "\n", | |
| "class CustomersApiClient:\n", | |
| " def __init__(self):\n", | |
| " self.endpoints_to_methods = {\n", | |
| " \"/orders\": self.get_order_details,\n", | |
| " \"/payments\": self.get_payment_details,\n", | |
| " \"/complaints\": self.get_complaint_details\n", | |
| " }\n", | |
| " def execute_query(self, query: ApiRequest):\n", | |
| " return self.endpoints_to_methods[query.endpoint](query.request_body)\n", | |
| "\n", | |
| " def get_order_details(self, order: OrderRequestBody):\n", | |
| " return {\n", | |
| " \"status\": OrderStatus.SHIPPED if random.randint(0,1) == 1 else OrderStatus.DELIVERED\n", | |
| " }\n", | |
| " def get_payment_details(self, payment: PaymentRequestBody):\n", | |
| " return {\n", | |
| " \"status\": PaymentStatus.PAID if random.randint(0,1) == 1 else PaymentStatus.UNPAID\n", | |
| " }\n", | |
| " def get_complaint_details(self, complaint: ComplaintRequestBody):\n", | |
| " return {\n", | |
| " \"status\": ComplaintStatus.RESOLVED if random.randint(0,1) == 1 else ComplaintStatus.UNRESOLVED\n", | |
| " }\n", | |
| "\n", | |
| "def get_api_client(*args, **kwargs) -> CustomersApiClient:\n", | |
| " return CustomersApiClient()" | |
| ], | |
| "metadata": { | |
| "id": "FcDatrbdGGg1" | |
| }, | |
| "execution_count": 7, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### User Sessions Manager" | |
| ], | |
| "metadata": { | |
| "id": "oerl66ShG0D3" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import uuid\n", | |
| "\n", | |
| "class SessionsManager:\n", | |
| " def __init__(self):\n", | |
| " self.users = {\n", | |
| " \"user1\": \"uuid-3-2-1\",\n", | |
| " \"user2\": \"uuid-4-5-6\",\n", | |
| " \"user3\": \"uuid-7-8-9\"\n", | |
| " }\n", | |
| " self.sessions = {}\n", | |
| " def get_user_id(self, username: str) -> str:\n", | |
| " return self.users.get(username)\n", | |
| " def get_user_session(self, username: str) -> str:\n", | |
| " session = self.sessions.get(self.get_user_id(username), \"\")\n", | |
| " if not session:\n", | |
| " session = self.create_session(username)\n", | |
| " return session\n", | |
| " def create_session(self, username: str) -> str:\n", | |
| " self.sessions[self.get_user_id(username)] = str(uuid.uuid4())\n", | |
| " return self.sessions[self.get_user_id(username)]\n", | |
| "\n", | |
| "def get_sessions_manager() -> SessionsManager:\n", | |
| " return SessionsManager()" | |
| ], | |
| "metadata": { | |
| "id": "ZTUwqsqOGzBE" | |
| }, | |
| "execution_count": 8, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### Events" | |
| ], | |
| "metadata": { | |
| "id": "uENc_X2XJ4C9" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from workflows.events import StartEvent, StopEvent, Event\n", | |
| "\n", | |
| "class InputEvent(StartEvent):\n", | |
| " username: str\n", | |
| " session_id: str\n", | |
| " message: str\n", | |
| "\n", | |
| "class MemoryEvent(Event):\n", | |
| " retrieved: list[str]\n", | |
| " put: list[str]\n", | |
| "\n", | |
| "class ProgressEvent(Event):\n", | |
| " progress: str\n", | |
| "\n", | |
| "class OutputEvent(StopEvent):\n", | |
| " message: str" | |
| ], | |
| "metadata": { | |
| "id": "KtYQfxyWGxEP" | |
| }, | |
| "execution_count": 9, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### Workflow State" | |
| ], | |
| "metadata": { | |
| "id": "o1POCncgAjnB" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "class WorkflowState(BaseModel):\n", | |
| " username: str = \"\"\n", | |
| " session_id: str = \"\"\n", | |
| " user_request: str = \"\"" | |
| ], | |
| "metadata": { | |
| "id": "cI90Sd_PAmSh" | |
| }, | |
| "execution_count": 15, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### Workflow" | |
| ], | |
| "metadata": { | |
| "id": "BwsYIvuVBqoF" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from workflows import Workflow, Context, step\n", | |
| "from llama_index.core.llms import ChatMessage\n", | |
| "from workflows.resource import Resource\n", | |
| "from typing import Annotated\n", | |
| "\n", | |
| "def info_to_chatmessage(info: list[str], session_id: str):\n", | |
| " messages = []\n", | |
| " for item in info:\n", | |
| " messages.append(ChatMessage(role=\"user\", content=item, additional_kwargs={\"session_id\": session_id}))\n", | |
| " return messages\n", | |
| "\n", | |
| "class CustomerSupportWorkflow(Workflow):\n", | |
| " @step\n", | |
| " async def handle_input(self, ev: InputEvent, ctx: Context[WorkflowState], llm_memory: Annotated[StructuredLLM, Resource(get_llm_memory)], llm_relevance: Annotated[StructuredLLM, Resource(get_llm_relevance)], memory: Annotated[VectorMemory, Resource(get_memory)]) -> MemoryEvent:\n", | |
| " async with ctx.store.edit_state() as state:\n", | |
| " state.username = ev.username\n", | |
| " state.session_id = ev.session_id\n", | |
| " state.user_request = ev.message\n", | |
| " result = memory.get(ev.message)\n", | |
| " retrieved = []\n", | |
| " put = []\n", | |
| " for res in result:\n", | |
| " if res.additional_kwargs.get(\"session_id\") == ev.session_id:\n", | |
| " retrieved.append(res.content)\n", | |
| " if len(retrieved) > 0:\n", | |
| " ctx.write_event_to_stream(\n", | |
| " ProgressEvent(progress=\"Retrieved content from memory:\\n\" + \"\\n- \".join(retrieved))\n", | |
| " )\n", | |
| " response = await llm_relevance.achat([ChatMessage(content=\"Retrieved content: \" + \"\\n\\n---\\n\\n\".join(retrieved), role=\"assistant\"), ChatMessage(content=f\"Please evaluate the retrieved content for relevance in the context of this question: '{ev.message}'\")])\n", | |
| " rel = ContextRelevance.model_validate_json(response.message.content)\n", | |
| " score = rel.relevance\n", | |
| " else:\n", | |
| " score = 0\n", | |
| " ctx.write_event_to_stream(\n", | |
| " ProgressEvent(progress=f\"Relevance of the retrieved content: {score}%\")\n", | |
| " )\n", | |
| " if score > 0:\n", | |
| " mem_res = await llm_memory.achat([ChatMessage(content=\"Retrieved content: \" + \"\\n\\n---\\n\\n\".join(retrieved), role=\"assistant\"), ChatMessage(content=f\"Based on the retrieved context, extract a list of information to store in memory from this message: '{ev.message}'\")])\n", | |
| " mem = MemoryInformation.model_validate_json(mem_res.message.content)\n", | |
| " await memory.aput_messages(info_to_chatmessage(mem.information_list, ev.session_id))\n", | |
| " put = mem.information_list\n", | |
| " else:\n", | |
| " mem_res = await llm_memory.achat([ChatMessage(content=f\"Extract a list of information to store in memory from this message: '{ev.message}'\")])\n", | |
| " mem = MemoryInformation.model_validate_json(mem_res.message.content)\n", | |
| " await memory.aput_messages(info_to_chatmessage(mem.information_list, ev.session_id))\n", | |
| " put = mem.information_list\n", | |
| " if len(put) > 0:\n", | |
| " ctx.write_event_to_stream(\n", | |
| " ProgressEvent(progress=\"Stored the following content in memory:\\n\" + \"\\n- \".join(put))\n", | |
| " )\n", | |
| " else:\n", | |
| " ctx.write_event_to_stream(\n", | |
| " ProgressEvent(progress=\"No content was stored in memory\")\n", | |
| " )\n", | |
| " return MemoryEvent(retrieved=retrieved, put=put)\n", | |
| " @step\n", | |
| " async def send_api_request(self, ev: MemoryEvent, ctx: Context[WorkflowState], llm_api_call: Annotated[StructuredLLM, Resource(get_llm_api_call)], api_client: Annotated[CustomersApiClient, Resource(get_api_client)]) -> OutputEvent:\n", | |
| " state = await ctx.store.get_state()\n", | |
| " if ev.retrieved:\n", | |
| " res = await llm_api_call.achat([ChatMessage(content=\"Retrieved content: \" + \"\\n\\n---\\n\\n\".join(ev.retrieved), role=\"assistant\"), ChatMessage(content=f\"Based on the retrieved context, please choose an endpoint for the API request (among the available '/order', '/payments', '/complaints') and generate the request body that best suit the user's request: '{state.user_request}'\")])\n", | |
| " else:\n", | |
| " res = await llm_api_call.achat([ChatMessage(content=f\"Based on the retrieved context, please choose an endpoint for the API request (among the available '/order', '/payments', '/complaints') and generate the request body that best suit the user's request: '{state.user_request}'\")])\n", | |
| " request = ApiRequest.model_validate_json(res.message.content)\n", | |
| " ctx.write_event_to_stream(\n", | |
| " ProgressEvent(progress=f\"Sending API request to https://afaboulousapp.com{request.endpoint}\")\n", | |
| " )\n", | |
| " result = api_client.execute_query(request)\n", | |
| " if request.endpoint == \"/orders\":\n", | |
| " return OutputEvent(message=f\"Hello {state.username}\\nThe status of your order is: {result['status'].value}\")\n", | |
| " elif request.endpoint == \"/payments\":\n", | |
| " return OutputEvent(message=f\"Hello {state.username}\\nThe status of your payment is: {result['status'].value}\")\n", | |
| " else:\n", | |
| " return OutputEvent(message=f\"Hello {state.username}\\nThe status of your complaint is: {result['status'].value}\")" | |
| ], | |
| "metadata": { | |
| "id": "CtsmpULlBqAX" | |
| }, | |
| "execution_count": 28, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Execute the Workflow\n", | |
| "\n", | |
| "Here we create a simple text-based interface to run the workflow" | |
| ], | |
| "metadata": { | |
| "id": "aSdyJJrEJqjh" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "session_manager = get_sessions_manager()\n", | |
| "\n", | |
| "async def run_workflow_interactive(username: str, message: str):\n", | |
| " wf = CustomerSupportWorkflow(timeout=300)\n", | |
| " session_id = session_manager.get_user_session(username=username)\n", | |
| " handler = wf.run(start_event=InputEvent(username=username, session_id=session_id, message=message))\n", | |
| " async for event in handler.stream_events():\n", | |
| " if isinstance(event, ProgressEvent):\n", | |
| " print(event.progress)\n", | |
| " result = await handler\n", | |
| " print(result.message)" | |
| ], | |
| "metadata": { | |
| "id": "D-6SOBx_KEUr" | |
| }, | |
| "execution_count": 29, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Let's run the workflow as `user1` for the first time, and ask about the status of our order with ID `order-hello-world`." | |
| ], | |
| "metadata": { | |
| "id": "2BDhHRtTK3Qj" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from rich.prompt import Prompt\n", | |
| "\n", | |
| "username = Prompt.ask(\"Enter your username\")\n", | |
| "message = Prompt.ask(\"Enter your message\")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 104 | |
| }, | |
| "id": "iTx1UH_vLSZu", | |
| "outputId": "94854c5f-f7a9-41a7-93bf-65023aa2e856" | |
| }, | |
| "execution_count": 20, | |
| "outputs": [ | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "Enter your username: " | |
| ], | |
| "text/html": [ | |
| "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Enter your username: </pre>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "user1\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "Enter your message: " | |
| ], | |
| "text/html": [ | |
| "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Enter your message: </pre>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "I would like to know the status of my order, with ID 'order-hello-world'\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "await run_workflow_interactive(username, message)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "nuDcwjE0NHgh", | |
| "outputId": "4937a430-c6cb-4af0-907c-24c9288bf42f" | |
| }, | |
| "execution_count": 30, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Relevance of the retrieved content: 0%\n", | |
| "Stored the following content in memory:\n", | |
| "User wants to know the status of their order with ID 'order-hello-world'\n", | |
| "Sending API request to https://afaboulousapp.com/orders\n", | |
| "Hello user1\n", | |
| "The status of your order is: shipped\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Perfect! Now let's re-run the workflow as `user1` and ask for the status of our order." | |
| ], | |
| "metadata": { | |
| "id": "Mxw1vzdPLQVq" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from rich.prompt import Prompt\n", | |
| "\n", | |
| "username = Prompt.ask(\"Enter your username\")\n", | |
| "message = Prompt.ask(\"Enter your message\")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 104 | |
| }, | |
| "id": "byUKQrkIO3xm", | |
| "outputId": "114003b4-8516-4fcc-ad52-5edf782dca9b" | |
| }, | |
| "execution_count": 31, | |
| "outputs": [ | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "Enter your username: " | |
| ], | |
| "text/html": [ | |
| "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Enter your username: </pre>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "user1\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "Enter your message: " | |
| ], | |
| "text/html": [ | |
| "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Enter your message: </pre>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "What is the status of my order?\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "await run_workflow_interactive(username, message)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "fJU3183QO6ak", | |
| "outputId": "137256f1-7be3-4abc-b27a-9270adfa80a5" | |
| }, | |
| "execution_count": 32, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Retrieved content from memory:\n", | |
| "User wants to know the status of their order with ID 'order-hello-world'\n", | |
| "Relevance of the retrieved content: 100%\n", | |
| "Stored the following content in memory:\n", | |
| "User asked about the status of their order.\n", | |
| "Sending API request to https://afaboulousapp.com/orders\n", | |
| "Hello user1\n", | |
| "The status of your order is: delivered\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "As you can see, the customer support workflow 'remembered' our order ID and retrieved the status for us!" | |
| ], | |
| "metadata": { | |
| "id": "9BeeGbkSPQ9n" | |
| } | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment