Last active
May 21, 2024 13:07
-
-
Save terrisgit/873cb0d6a095cea7cb744474b888d886 to your computer and use it in GitHub Desktop.
Langchain + AWS Athena - https://terrislinenbach.medium.com/querying-aws-athena-using-natural-language-with-langchain-and-chatgpt-e99505dad996
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Inspiration from https://repost.aws/articles/ARDEyn8B0aQLud6ZGK7yyX3Q/generative-ai-with-amazon-bedrock-and-amazon-athena and https://gist.github.com/terrisgit/873cb0d6a095cea7cb744474b888d886\n", | |
| "\n", | |
| "See also https://terrislinenbach.medium.com/querying-aws-athena-using-natural-language-with-langchain-and-chatgpt-e99505dad996\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Install dependencies: \n", | |
| "\n", | |
| "```shell\n", | |
| "pip install --upgrade pip setuptools\n", | |
| "pip install langchain langchain-openai sqlalchemy PyAthena\n", | |
| "```" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from urllib.parse import quote_plus\n", | |
| "from sqlalchemy import create_engine\n", | |
| "from langchain_community.utilities import SQLDatabase\n", | |
| "from langchain_core.output_parsers import StrOutputParser\n", | |
| "from langchain_core.runnables import RunnablePassthrough\n", | |
| "from langchain_core.prompts import ChatPromptTemplate\n", | |
| "from langchain_openai import ChatOpenAI" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "AWS_REGION = \"us-west-2\" # Change me\n", | |
| "SCHEMA_NAME = \"changeme\" # Athena calls this a database\n", | |
| "S3_STAGING_DIR = \"s3://bucket/changeme\" # Change me" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "connect_str = \"awsathena+rest://athena.{region_name}.amazonaws.com:443/{schema_name}?s3_staging_dir={s3_staging_dir}\"\n", | |
| "\n", | |
| "engine = create_engine(connect_str.format(\n", | |
| " region_name=AWS_REGION,\n", | |
| " schema_name=SCHEMA_NAME,\n", | |
| " s3_staging_dir=quote_plus(S3_STAGING_DIR)\n", | |
| "))\n", | |
| "\n", | |
| "db = SQLDatabase(engine)\n", | |
| "schema = db.get_table_info()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question.\n", | |
| "Schema: {schema}\n", | |
| "\n", | |
| "Question: {question}\n", | |
| "SQL Query:\"\"\"\n", | |
| "\n", | |
| "prompt = ChatPromptTemplate.from_template(template)\n", | |
| "model = ChatOpenAI( model_name= \"gpt-4\", temperature= 0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Modify the question below" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "question = \"Which table has the most rows?\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def get_schema(_):\n", | |
| " return db.get_table_info()\n", | |
| "\n", | |
| "sql_response = (\n", | |
| " RunnablePassthrough.assign(schema=get_schema)\n", | |
| " | prompt\n", | |
| " | model.bind(stop=[\"\\nSQLResult:\"])\n", | |
| " | StrOutputParser()\n", | |
| ")\n", | |
| "\n", | |
| "sql_response.invoke({\"question\": question})" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "template = \"\"\"Based on the table schema below, question, sql query, and sql response, output a natural language response.\n", | |
| "{schema}\n", | |
| "\n", | |
| "Question: {question}\n", | |
| "SQL Query: {query}\n", | |
| "SQL Response: {response}\"\"\"\n", | |
| "prompt_response = ChatPromptTemplate.from_template(template)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "full_chain = (\n", | |
| " RunnablePassthrough.assign(query=sql_response).assign(\n", | |
| " schema=get_schema,\n", | |
| " response=lambda x: db.run(x[\"query\"]),\n", | |
| " )\n", | |
| " | prompt_response\n", | |
| " | model\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "full_chain = (\n", | |
| " RunnablePassthrough.assign(query=sql_response).assign(\n", | |
| " schema=get_schema,\n", | |
| " response=lambda x: db.run(x[\"query\"]),\n", | |
| " )\n", | |
| " | prompt_response\n", | |
| " | model\n", | |
| ")\n", | |
| "\n", | |
| "full_chain.invoke({\"question\": question})" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.11.6" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment