Skip to content

Instantly share code, notes, and snippets.

@salekh
Last active October 16, 2025 10:07
Show Gist options
  • Select an option

  • Save salekh/a62cf0f41845f628bf4a6553f5fe825c to your computer and use it in GitHub Desktop.

Select an option

Save salekh/a62cf0f41845f628bf4a6553f5fe825c to your computer and use it in GitHub Desktop.
Implement a Hybrid RAG with Vertex AI Vector Search + Graph Search using Spanner Graph
import os
from dotenv import load_dotenv
# --- Updated Imports for Google Gemini & Vertex AI ---
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_google_vertexai import VertexAIVectorSearch
from langchain_community.graphs import SpannerGraph
from langchain.chains import SpannerGraphQAChain
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document
# --- 1. SETUP: LOAD ENVIRONMENT AND INITIALIZE MODELS ---
# Load environment variables from a .env file for security and configuration.
load_dotenv()
# --- Updated Check for Google API Keys and Project Info ---
if "GOOGLE_API_KEY" not in os.environ:
raise ValueError("GOOGLE_API_KEY is not set in your environment variables.")
GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID")
GCP_REGION = os.environ.get("GCP_REGION")
VERTEX_AI_INDEX_ID = os.environ.get("VERTEX_AI_INDEX_ID")
VERTEX_AI_BUCKET_NAME = os.environ.get("VERTEX_AI_BUCKET_NAME")
if not all([GCP_PROJECT_ID, GCP_REGION, VERTEX_AI_INDEX_ID, VERTEX_AI_BUCKET_NAME]):
raise ValueError(
"To use Vertex AI Vector Search, please set GCP_PROJECT_ID, GCP_REGION, "
"VERTEX_AI_INDEX_ID, and VERTEX_AI_BUCKET_NAME in your .env file."
)
# --- Initialize Google Gemini Models ---
# Using Gemini 2.5 Flash for generation and the recommended text embedding model.
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash-preview-09-2025", temperature=0)
embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
# --- 2. BUILD THE RETRIEVAL SUB-CHAINS (As per the diagram) ---
# --- Sub-chain A: Vertex AI Vector Search Chain ---
# This part is updated to use Google Cloud's managed Vertex AI Vector Search.
sample_documents = [
Document(
page_content="Project Titan was a highly classified initiative focused on developing next-generation AI. The lead scientist was Dr. Aris Thorne. The project resulted in the 'Prometheus Core' technology.",
metadata={"source": "doc_titan_summary"}
),
Document(
page_content="Dr. Aris Thorne, a specialist in neural networks, previously worked at Cybersystems Inc. before leading Project Titan. His work is foundational to modern AI ethics.",
metadata={"source": "doc_thorne_bio"}
),
Document(
page_content="Cybersystems Inc. is a technology conglomerate known for its work in robotics and AI. They developed the 'Helios' learning algorithm, which was a precursor to many modern systems.",
metadata={"source": "doc_cybersystems_info"}
),
]
# Initialize Vertex AI Vector Search.
# The `from_documents` method will create embeddings and upload them to your index.
# This is a one-time setup operation for these sample docs.
print("Initializing Vertex AI Vector Search and adding documents...")
vector_store = VertexAIVectorSearch.from_documents(
documents=sample_documents,
embedding=embeddings,
project_id=GCP_PROJECT_ID,
region=GCP_REGION,
index_name=VERTEX_AI_INDEX_ID,
staging_bucket=f"gs://{VERTEX_AI_BUCKET_NAME}",
)
print("Vertex AI Vector Search initialized.")
vector_retriever = vector_store.as_retriever(search_kwargs={"k": 2})
def format_docs(docs: list[Document]) -> str:
"""Helper function to combine document contents into a single string."""
return "\n\n".join(doc.page_content for doc in docs)
# This LCEL chain implements the `async_run_vector_search` logic from the pseudocode.
vector_search_chain = vector_retriever | RunnableLambda(format_docs)
# --- Sub-chain B: Graph Search Chain ---
# This section implements the `async_run_graph_search` logic from the pseudocode.
# It's wrapped in a try/except to ensure the script can run even without Spanner access.
try:
# IMPORTANT: Replace with your actual Spanner instance and database IDs in a .env file
INSTANCE_ID = os.environ.get("SPANNER_INSTANCE_ID", "your-spanner-instance")
DATABASE_ID = os.environ.get("SPANNER_DATABASE_ID", "your-spanner-db")
graph = SpannerGraph(instance_id=INSTANCE_ID, database_id=DATABASE_ID)
# This chain converts a user's question into a graph query, executes it, and returns a summary.
# We still use the Gemini LLM to interpret the question for the graph.
graph_search_chain = SpannerGraphQAChain.from_llm(llm=llm, graph=graph, verbose=True)
except Exception as e:
print("---")
print("WARNING: SpannerGraph could not be initialized. Graph search will be disabled.")
print(f"Error: {e}")
print("To enable it, set SPANNER_INSTANCE_ID and SPANNER_DATABASE_ID in your .env file.")
print("Using a fallback for demonstration.")
print("---")
# If Spanner isn't available, this dummy lambda provides a default response.
graph_search_chain = RunnableLambda(lambda x: "Graph search is not available.")
# --- 3. BUILD THE FINAL HYBRID CHAIN ---
# This `RunnableParallel` step corresponds to the "Parallel Fan-Out" in the diagram.
# It runs both chains at the same time.
hybrid_retriever = RunnableParallel(
vector_context=vector_search_chain,
graph_context=graph_search_chain,
question=RunnablePassthrough() # Pass the original question through
)
# This prompt template corresponds to the "Final Prompt" and "LLM Synthesizer" steps.
final_prompt_template = """
You are an expert AI assistant. Your task is to synthesize a final answer based ONLY on the provided context from two sources: a vector search (for semantic information) and a graph search (for structured facts).
Do not use any prior knowledge. If the answer is not in the context, state that you cannot answer.
### CONTEXT FROM VECTOR SEARCH:
{vector_context}
### CONTEXT FROM GRAPH SEARCH:
{graph_context}
### USER'S QUESTION:
{question}
### FINAL ANSWER:
"""
final_prompt = ChatPromptTemplate.from_template(final_prompt_template)
# This is the complete, end-to-end chain, piping all the steps together.
# It perfectly matches the logic from your flow diagram and pseudocode.
custom_hybrid_chain = (
hybrid_retriever
| final_prompt
| llm
| StrOutputParser()
)
# --- 4. RUN THE CHAIN ---
if __name__ == "__main__":
print("--- Custom Hybrid RAG Chain (using Gemini and Vertex AI Vector Search) ---")
# Example 1: A question best answered by vector search (semantic context)
question1 = "What was the outcome of Project Titan?"
print(f"\n[Question]: {question1}")
response1 = custom_hybrid_chain.invoke(question1)
print(f"[Answer]: {response1}")
# Example 2: A question best answered by graph search (relational facts)
question2 = "Where did Dr. Aris Thorne work before Project Titan?"
print(f"\n[Question]: {question2}")
response2 = custom_hybrid_chain.invoke(question2)
print(f"[Answer]: {response2}")
# Example 3: A hybrid question needing both sources.
question3 = "Tell me about Dr. Aris Thorne and the project he led."
print(f"\n[Question]: {question3}")
response3 = custom_hybrid_chain.invoke(question3)
print(f"[Answer]: {response3}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment