Skip to content

Instantly share code, notes, and snippets.

@terrisgit
Last active May 21, 2024 13:07
Show Gist options
  • Select an option

  • Save terrisgit/873cb0d6a095cea7cb744474b888d886 to your computer and use it in GitHub Desktop.

Select an option

Save terrisgit/873cb0d6a095cea7cb744474b888d886 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Inspiration from https://repost.aws/articles/ARDEyn8B0aQLud6ZGK7yyX3Q/generative-ai-with-amazon-bedrock-and-amazon-athena and https://gist.github.com/terrisgit/873cb0d6a095cea7cb744474b888d886\n",
"\n",
"See also https://terrislinenbach.medium.com/querying-aws-athena-using-natural-language-with-langchain-and-chatgpt-e99505dad996\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Install dependencies: \n",
"\n",
"```shell\n",
"pip install --upgrade pip setuptools\n",
"pip install langchain langchain-openai sqlalchemy PyAthena\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from urllib.parse import quote_plus\n",
"from sqlalchemy import create_engine\n",
"from langchain_community.utilities import SQLDatabase\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_openai import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"AWS_REGION = \"us-west-2\" # Change me\n",
"SCHEMA_NAME = \"changeme\" # Athena calls this a database\n",
"S3_STAGING_DIR = \"s3://bucket/changeme\" # Change me"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"connect_str = \"awsathena+rest://athena.{region_name}.amazonaws.com:443/{schema_name}?s3_staging_dir={s3_staging_dir}\"\n",
"\n",
"engine = create_engine(connect_str.format(\n",
" region_name=AWS_REGION,\n",
" schema_name=SCHEMA_NAME,\n",
" s3_staging_dir=quote_plus(S3_STAGING_DIR)\n",
"))\n",
"\n",
"db = SQLDatabase(engine)\n",
"schema = db.get_table_info()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question.\n",
"Schema: {schema}\n",
"\n",
"Question: {question}\n",
"SQL Query:\"\"\"\n",
"\n",
"prompt = ChatPromptTemplate.from_template(template)\n",
"model = ChatOpenAI( model_name= \"gpt-4\", temperature= 0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Modify the question below"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"question = \"Which table has the most rows?\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_schema(_):\n",
" return db.get_table_info()\n",
"\n",
"sql_response = (\n",
" RunnablePassthrough.assign(schema=get_schema)\n",
" | prompt\n",
" | model.bind(stop=[\"\\nSQLResult:\"])\n",
" | StrOutputParser()\n",
")\n",
"\n",
"sql_response.invoke({\"question\": question})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Based on the table schema below, question, sql query, and sql response, output a natural language response.\n",
"{schema}\n",
"\n",
"Question: {question}\n",
"SQL Query: {query}\n",
"SQL Response: {response}\"\"\"\n",
"prompt_response = ChatPromptTemplate.from_template(template)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"full_chain = (\n",
" RunnablePassthrough.assign(query=sql_response).assign(\n",
" schema=get_schema,\n",
" response=lambda x: db.run(x[\"query\"]),\n",
" )\n",
" | prompt_response\n",
" | model\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"full_chain = (\n",
" RunnablePassthrough.assign(query=sql_response).assign(\n",
" schema=get_schema,\n",
" response=lambda x: db.run(x[\"query\"]),\n",
" )\n",
" | prompt_response\n",
" | model\n",
")\n",
"\n",
"full_chain.invoke({\"question\": question})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment