Created
December 24, 2024 18:04
-
-
Save danksky/7f1d4928a121a7303e58ac6c45e5fcf7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import bs4 | |
| import getpass | |
| import os | |
| from langchain import hub | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain_core.documents import Document | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from typing_extensions import List, TypedDict | |
| from langchain_core.vectorstores import InMemoryVectorStore | |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
| from langgraph.graph import END, MessagesState, StateGraph | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_core.tools import tool | |
| from langchain.callbacks import LangChainTracer | |
| from langsmith import Client | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_mistralai import MistralAIEmbeddings | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| from langchain_google_vertexai import VertexAIEmbeddings | |
| from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline | |
| from langchain_community.llms.huggingface_hub import HuggingFaceHub | |
| if not os.environ.get("MISTRALAI_API_KEY"): | |
| os.environ["MISTRALAI_API_KEY"] = getpass.getpass("Enter API key for MistralAI: ") | |
| # Initialize components | |
| # Initialize LangChainTracer and set up callbacks | |
| tracer_project_name = "rag-with-history" | |
| api_url = "https://api.smith.langchain.com" | |
| api_key = "REDACTED" | |
| callbacks = [ | |
| LangChainTracer( | |
| project_name=tracer_project_name, | |
| client=Client(api_url=api_url, api_key=api_key) | |
| ) | |
| ] | |
| model = ChatOpenAI(model="gpt-4o-mini-2024-07-18", callbacks=callbacks) | |
| # print("Setting model from LLM...") | |
| # model = ChatHuggingFace(llm=llm) | |
| # Awaiting access approval from Meta... (https://huggingface.co/settings/gated-repos) (https://python.langchain.com/docs/integrations/chat/huggingface/) | |
| # llm = HuggingFaceHub( | |
| # repo_id="meta-llama/Llama-3.3-70B-Instruct", | |
| # task="text-generation", | |
| # model_kwargs={ | |
| # "max_new_tokens": 512, | |
| # "top_k": 30, | |
| # "temperature": 0.1, | |
| # "repetition_penalty": 1.03, | |
| # }, | |
| # huggingfacehub_api_token="REDACTED", | |
| # ) | |
| # model = ChatHuggingFace(llm=llm) | |
| # embeddings = MistralAIEmbeddings(model="mistral-embed") | |
| # embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
| # embeddings = OpenAIEmbeddings(model="text-embedding-3-small") | |
| embeddings = OpenAIEmbeddings(model="text-embedding-3-large") | |
| # embeddings = VertexAIEmbeddings(model="text-embedding-004") | |
| vector_store = InMemoryVectorStore(embedding=embeddings) | |
| # Load and process documents | |
| print("Loading external website...") | |
| bs4_strainer = bs4.SoupStrainer(class_=("box")) | |
| loader = WebBaseLoader( | |
| web_paths=("https://support.irembo.gov.rw/en/support/solutions/articles/47001222259-frequently-asked-questions-about-driving-licenses",), | |
| bs_kwargs={"parse_only": bs4_strainer}, | |
| ) | |
| docs = loader.load() | |
| print("Splitting external website text...") | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| all_splits = text_splitter.split_documents(docs) | |
| # Index chunks | |
| _ = vector_store.add_documents(documents=all_splits) | |
| @tool(response_format="content_and_artifact") | |
| def retrieve(query: str): | |
| """Retrieve information related to a query.""" | |
| # print("----\nretrieve invoked", query) | |
| retrieved_docs = vector_store.similarity_search(query, k=2) | |
| serialized = "\n\n".join( | |
| (f" Source: {doc.metadata}\n" f" Content: {doc.page_content}") | |
| for doc in retrieved_docs | |
| ) | |
| return serialized, retrieved_docs | |
| # Step 1: Generate an AIMessage that may include a tool-call to be sent. | |
| def query_or_respond(state: MessagesState): | |
| """Generate tool call for retrieval or respond.""" | |
| # print("----\nquery_or_respond invoked", state) | |
| llm_with_tools = model.bind_tools([retrieve]) | |
| response = llm_with_tools.invoke(state["messages"]) | |
| # MessagesState appends messages to state instead of overwriting | |
| return {"messages": [response]} | |
| # Step 2: Execute the retrieval. | |
| tools = ToolNode([retrieve]) | |
| # Step 3: Generate a response using the retrieved content. | |
| def generate(state: MessagesState): | |
| """Generate answer.""" | |
| # print("----\ngenerate invoked", state) | |
| # Get generated ToolMessages | |
| recent_tool_messages = [] | |
| for message in reversed(state["messages"]): | |
| if message.type == "tool": | |
| recent_tool_messages.append(message) | |
| else: | |
| break | |
| tool_messages = recent_tool_messages[::-1] | |
| # Format into prompt | |
| docs_content = "\n\n".join(doc.content for doc in tool_messages) | |
| system_message_content = ( | |
| "You are an assistant for question-answering tasks. " | |
| "Use the following pieces of retrieved context to answer " | |
| "the question. If you don't know the answer, say that you " | |
| "don't know. Use three sentences maximum and keep the " | |
| "answer concise." | |
| "\n\n" | |
| f"{docs_content}" | |
| ) | |
| conversation_messages = [ | |
| message | |
| for message in state["messages"] | |
| if message.type in ("human", "system") | |
| or (message.type == "ai" and not message.tool_calls) | |
| ] | |
| prompt = [SystemMessage(system_message_content)] + conversation_messages | |
| # Run | |
| response = model.invoke(prompt) | |
| return {"messages": [response]} | |
| # Build graph | |
| graph_builder = StateGraph(MessagesState) | |
| graph_builder.add_node(query_or_respond) | |
| graph_builder.add_node(tools) | |
| graph_builder.add_node(generate) | |
| graph_builder.set_entry_point("query_or_respond") | |
| graph_builder.add_conditional_edges( | |
| "query_or_respond", | |
| tools_condition, | |
| {END: END, "tools": "tools"}, | |
| ) | |
| graph_builder.add_edge("tools", "generate") | |
| graph_builder.add_edge("generate", END) | |
| # Compile graph with memory persistence | |
| memory = MemorySaver() | |
| print("Compiling graph...") | |
| graph = graph_builder.compile(checkpointer=memory) | |
| # Use the Graph for Conversations | |
| if __name__ == "__main__": | |
| # Initial conversation setup | |
| initial_message = "How much do I pay for an application for a provisional driving license, and which country is this document about?" | |
| config = {"configurable": {"thread_id": "conversation_123"}} | |
| print("Starting graph invocation...") | |
| response = graph.invoke( | |
| {"messages": [{"role": "user", "content": initial_message}]}, | |
| config=config, | |
| ) | |
| # Inspect conversation history | |
| chat_history = graph.get_state(config).values["messages"] | |
| print("\n\n ===== \n\n") | |
| # print(chat_history) | |
| for message in chat_history: | |
| if isinstance(message, HumanMessage): | |
| print(f"Human: {message.content}") | |
| elif isinstance(message, AIMessage): | |
| print(f"AI: {message.content}") | |
| elif isinstance(message, SystemMessage): | |
| print(f"System: {message.content}") | |
| elif isinstance(message, ToolMessage): | |
| print(f"Tool: {message.content}") | |
| elif hasattr(message, "artifact"): # For ToolMessage or similar | |
| artifacts = message.artifact | |
| print(f"Tool Message: ") | |
| if artifacts: | |
| for artifact in artifacts: | |
| metadata = artifact.get("metadata", {}) | |
| page_content = artifact.get("page_content", "") | |
| print(f" - Metadata: {metadata}") | |
| print(f" - Content: {page_content[:200].strip().replace(chr(10), '').replace(chr(13), '')}...") | |
| else: | |
| print(f"Unknown message type with details: {vars(message)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment