Skip to content

Instantly share code, notes, and snippets.

@danksky
Created December 24, 2024 18:04
Show Gist options
  • Select an option

  • Save danksky/7f1d4928a121a7303e58ac6c45e5fcf7 to your computer and use it in GitHub Desktop.

Select an option

Save danksky/7f1d4928a121a7303e58ac6c45e5fcf7 to your computer and use it in GitHub Desktop.
import bs4
import getpass
import os
from langchain import hub
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from typing_extensions import List, TypedDict
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.tools import tool
from langchain.callbacks import LangChainTracer
from langsmith import Client
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_mistralai import MistralAIEmbeddings
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_google_vertexai import VertexAIEmbeddings
from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
from langchain_community.llms.huggingface_hub import HuggingFaceHub
if not os.environ.get("MISTRALAI_API_KEY"):
os.environ["MISTRALAI_API_KEY"] = getpass.getpass("Enter API key for MistralAI: ")
# Initialize components
# Initialize LangChainTracer and set up callbacks
tracer_project_name = "rag-with-history"
api_url = "https://api.smith.langchain.com"
api_key = "REDACTED"
callbacks = [
LangChainTracer(
project_name=tracer_project_name,
client=Client(api_url=api_url, api_key=api_key)
)
]
model = ChatOpenAI(model="gpt-4o-mini-2024-07-18", callbacks=callbacks)
# print("Setting model from LLM...")
# model = ChatHuggingFace(llm=llm)
# Awaiting access approval from Meta... (https://huggingface.co/settings/gated-repos) (https://python.langchain.com/docs/integrations/chat/huggingface/)
# llm = HuggingFaceHub(
# repo_id="meta-llama/Llama-3.3-70B-Instruct",
# task="text-generation",
# model_kwargs={
# "max_new_tokens": 512,
# "top_k": 30,
# "temperature": 0.1,
# "repetition_penalty": 1.03,
# },
# huggingfacehub_api_token="REDACTED",
# )
# model = ChatHuggingFace(llm=llm)
# embeddings = MistralAIEmbeddings(model="mistral-embed")
# embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
# embeddings = VertexAIEmbeddings(model="text-embedding-004")
vector_store = InMemoryVectorStore(embedding=embeddings)
# Load and process documents
print("Loading external website...")
bs4_strainer = bs4.SoupStrainer(class_=("box"))
loader = WebBaseLoader(
web_paths=("https://support.irembo.gov.rw/en/support/solutions/articles/47001222259-frequently-asked-questions-about-driving-licenses",),
bs_kwargs={"parse_only": bs4_strainer},
)
docs = loader.load()
print("Splitting external website text...")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
all_splits = text_splitter.split_documents(docs)
# Index chunks
_ = vector_store.add_documents(documents=all_splits)
@tool(response_format="content_and_artifact")
def retrieve(query: str):
"""Retrieve information related to a query."""
# print("----\nretrieve invoked", query)
retrieved_docs = vector_store.similarity_search(query, k=2)
serialized = "\n\n".join(
(f" Source: {doc.metadata}\n" f" Content: {doc.page_content}")
for doc in retrieved_docs
)
return serialized, retrieved_docs
# Step 1: Generate an AIMessage that may include a tool-call to be sent.
def query_or_respond(state: MessagesState):
"""Generate tool call for retrieval or respond."""
# print("----\nquery_or_respond invoked", state)
llm_with_tools = model.bind_tools([retrieve])
response = llm_with_tools.invoke(state["messages"])
# MessagesState appends messages to state instead of overwriting
return {"messages": [response]}
# Step 2: Execute the retrieval.
tools = ToolNode([retrieve])
# Step 3: Generate a response using the retrieved content.
def generate(state: MessagesState):
"""Generate answer."""
# print("----\ngenerate invoked", state)
# Get generated ToolMessages
recent_tool_messages = []
for message in reversed(state["messages"]):
if message.type == "tool":
recent_tool_messages.append(message)
else:
break
tool_messages = recent_tool_messages[::-1]
# Format into prompt
docs_content = "\n\n".join(doc.content for doc in tool_messages)
system_message_content = (
"You are an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know. Use three sentences maximum and keep the "
"answer concise."
"\n\n"
f"{docs_content}"
)
conversation_messages = [
message
for message in state["messages"]
if message.type in ("human", "system")
or (message.type == "ai" and not message.tool_calls)
]
prompt = [SystemMessage(system_message_content)] + conversation_messages
# Run
response = model.invoke(prompt)
return {"messages": [response]}
# Build graph
graph_builder = StateGraph(MessagesState)
graph_builder.add_node(query_or_respond)
graph_builder.add_node(tools)
graph_builder.add_node(generate)
graph_builder.set_entry_point("query_or_respond")
graph_builder.add_conditional_edges(
"query_or_respond",
tools_condition,
{END: END, "tools": "tools"},
)
graph_builder.add_edge("tools", "generate")
graph_builder.add_edge("generate", END)
# Compile graph with memory persistence
memory = MemorySaver()
print("Compiling graph...")
graph = graph_builder.compile(checkpointer=memory)
# Use the Graph for Conversations
if __name__ == "__main__":
# Initial conversation setup
initial_message = "How much do I pay for an application for a provisional driving license, and which country is this document about?"
config = {"configurable": {"thread_id": "conversation_123"}}
print("Starting graph invocation...")
response = graph.invoke(
{"messages": [{"role": "user", "content": initial_message}]},
config=config,
)
# Inspect conversation history
chat_history = graph.get_state(config).values["messages"]
print("\n\n ===== \n\n")
# print(chat_history)
for message in chat_history:
if isinstance(message, HumanMessage):
print(f"Human: {message.content}")
elif isinstance(message, AIMessage):
print(f"AI: {message.content}")
elif isinstance(message, SystemMessage):
print(f"System: {message.content}")
elif isinstance(message, ToolMessage):
print(f"Tool: {message.content}")
elif hasattr(message, "artifact"): # For ToolMessage or similar
artifacts = message.artifact
print(f"Tool Message: ")
if artifacts:
for artifact in artifacts:
metadata = artifact.get("metadata", {})
page_content = artifact.get("page_content", "")
print(f" - Metadata: {metadata}")
print(f" - Content: {page_content[:200].strip().replace(chr(10), '').replace(chr(13), '')}...")
else:
print(f"Unknown message type with details: {vars(message)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment