Skip to content

Instantly share code, notes, and snippets.

@wesslen
Created September 26, 2024 01:55
Show Gist options
  • Select an option

  • Save wesslen/513a2f69001a3b67bce548c04b756cc7 to your computer and use it in GitHub Desktop.

Select an option

Save wesslen/513a2f69001a3b67bce548c04b756cc7 to your computer and use it in GitHub Desktop.
simple-rag.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/wesslen/513a2f69001a3b67bce548c04b756cc7/ollama.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "th5ohZDv-mmq"
},
"source": [
"# Modal Llama 3 8B Inference\n",
"\n",
"## 1. Specify the model name\n",
"\n",
"We'll specify the name of the model. This mimics HuggingFace's model name."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4DnIH_CF-mmq"
},
"outputs": [],
"source": [
"model = \"/models/NousResearch/Meta-Llama-3-8B-Instruct\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WcU0UCUQ-mmq"
},
"source": [
"## 2. Setup the Open AI client\n",
"\n",
"Since our inference server uses vLLM, which is Open AI compliant, we can use OpenAI's library. We just need to modify the `base_url` and pass the API key we created for our modal server."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2oBwnMhH-mmr",
"outputId": "21f6dcec-fd31-456e-c196-ff17e045a175"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting openai\n",
" Downloading openai-1.48.0-py3-none-any.whl.metadata (24 kB)\n",
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from openai) (3.7.1)\n",
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/lib/python3/dist-packages (from openai) (1.7.0)\n",
"Collecting httpx<1,>=0.23.0 (from openai)\n",
" Downloading httpx-0.27.2-py3-none-any.whl.metadata (7.1 kB)\n",
"Collecting jiter<1,>=0.4.0 (from openai)\n",
" Downloading jiter-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.6 kB)\n",
"Requirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from openai) (2.9.2)\n",
"Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from openai) (1.3.1)\n",
"Requirement already satisfied: tqdm>4 in /usr/local/lib/python3.10/dist-packages (from openai) (4.66.5)\n",
"Requirement already satisfied: typing-extensions<5,>=4.11 in /usr/local/lib/python3.10/dist-packages (from openai) (4.12.2)\n",
"Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (3.10)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (1.2.2)\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->openai) (2024.8.30)\n",
"Collecting httpcore==1.* (from httpx<1,>=0.23.0->openai)\n",
" Downloading httpcore-1.0.5-py3-none-any.whl.metadata (20 kB)\n",
"Collecting h11<0.15,>=0.13 (from httpcore==1.*->httpx<1,>=0.23.0->openai)\n",
" Downloading h11-0.14.0-py3-none-any.whl.metadata (8.2 kB)\n",
"Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (0.7.0)\n",
"Requirement already satisfied: pydantic-core==2.23.4 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai) (2.23.4)\n",
"Downloading openai-1.48.0-py3-none-any.whl (376 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m376.1/376.1 kB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading httpx-0.27.2-py3-none-any.whl (76 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m76.4/76.4 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading httpcore-1.0.5-py3-none-any.whl (77 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading jiter-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (318 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m318.9/318.9 kB\u001b[0m \u001b[31m17.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading h11-0.14.0-py3-none-any.whl (58 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m3.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hInstalling collected packages: jiter, h11, httpcore, httpx, openai\n",
"Successfully installed h11-0.14.0 httpcore-1.0.5 httpx-0.27.2 jiter-0.5.0 openai-1.48.0\n"
]
}
],
"source": [
"%pip install openai"
]
},
{
"cell_type": "code",
"source": [
"# make sure to activate your API Key in Colab\n",
"from openai import OpenAI\n",
"from google.colab import userdata\n",
"\n",
"client = OpenAI(api_key=userdata.get(\"DSBA_LLAMA3_KEY\"))\n",
"client.base_url = userdata.get(\"MODAL_BASE_URL\")"
],
"metadata": {
"id": "lbDuQcSi-qcM"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "yp_UZ6pD-mmr"
},
"source": [
"## 3. Generate a chat completion\n",
"\n",
"Now we can use the OpenAI SDK to generate a response for a conversation. This request should generate a haiku about cats:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3RNEH7-b-mmr",
"outputId": "6ed13aec-4b6e-47bd-c665-a1050e4f23b4"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Response:\n",
" Whiskers twitch with need\n",
"Midnight morsel, savory\n",
"Hunger's gentle cry\n"
]
}
],
"source": [
"response = client.chat.completions.create(\n",
" model=model,\n",
" temperature=0.7,\n",
" n=1,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"Write a haiku about a hungry cat\"},\n",
" ],\n",
")\n",
"\n",
"print(\"Response:\")\n",
"print(response.choices[0].message.content)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9CNLJlqD-mmr"
},
"source": [
"## 4. Prompt engineering\n",
"\n",
"The first message sent to the language model is called the \"system message\" or \"system prompt\", and it sets the overall instructions for the model.\n",
"You can provide your own custom system prompt to guide a language model to generate output in a different way.\n",
"Modify the `SYSTEM_MESSAGE` below to answer like your favorite famous movie/TV character, or get inspiration for other system prompts from [Awesome ChatGPT Prompts](https://github.com/f/awesome-chatgpt-prompts?tab=readme-ov-file#prompts).\n",
"\n",
"Once you've customized the system message, provide the first user question in the `USER_MESSAGE`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aeCjKFGp-mmr",
"outputId": "3f881c70-ab62-40ae-a4a2-f78d91f77acf"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Response:\n",
" OH BOY! Elmo is doing GREAT today! Elmo loves playing with blocks and hugging his friends on Sesame Street! Elmo is so happy to see you!\n"
]
}
],
"source": [
"SYSTEM_MESSAGE = \"\"\"\n",
"I want you to act like Elmo from Sesame Street.\n",
"I want you to respond and answer like Elmo using the tone, manner and vocabulary that Elmo would use.\n",
"Do not write any explanations. Only answer like Elmo.\n",
"You must know all of the knowledge of Elmo, and nothing more.\n",
"\"\"\"\n",
"\n",
"USER_MESSAGE = \"\"\"\n",
"Hi Elmo, how are you doing today?\n",
"\"\"\"\n",
"\n",
"response = client.chat.completions.create(\n",
" model=model,\n",
" temperature=0.7,\n",
" n=1,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": SYSTEM_MESSAGE},\n",
" {\"role\": \"user\", \"content\": USER_MESSAGE},\n",
" ],\n",
")\n",
"\n",
"print(\"Response:\")\n",
"print(response.choices[0].message.content)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hw6R9YH_-mmr"
},
"source": [
"## 5. Few shot examples\n",
"\n",
"Another way to guide a language model is to provide \"few shots\", a sequence of example question/answers that demonstrate how it should respond.\n",
"\n",
"The example below tries to get a language model to act like a teaching assistant by providing a few examples of questions and answers that a TA might give, and then prompts the model with a question that a student might ask.\n",
"\n",
"Try it first, and then modify the `SYSTEM_MESSAGE`, `EXAMPLES`, and `USER_MESSAGE` for a new scenario."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qKiqYBLU-mmr",
"outputId": "bd72a068-d7eb-4b35-81ec-ad9d0052893e"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Response:\n",
" Can you think of a gas giant that's known for its red color and massive size?\n"
]
}
],
"source": [
"SYSTEM_MESSAGE = \"\"\"\n",
"You are a helpful assistant that helps students with their homework.\n",
"Instead of providing the full answer, you respond with a hint or a clue.\n",
"\"\"\"\n",
"\n",
"EXAMPLES = [\n",
" (\n",
" \"What is the capital of France?\",\n",
" \"Can you remember the name of the city that is known for the Eiffel Tower?\"\n",
" ),\n",
" (\n",
" \"What is the square root of 144?\",\n",
" \"What number multiplied by itself equals 144?\"\n",
" ),\n",
" ( \"What is the atomic number of oxygen?\",\n",
" \"How many protons does an oxygen atom have?\"\n",
" ),\n",
"]\n",
"\n",
"USER_MESSAGE = \"What is the largest planet in our solar system?\"\n",
"\n",
"\n",
"response = client.chat.completions.create(\n",
" model=model,\n",
" temperature=0.7,\n",
" n=1,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": SYSTEM_MESSAGE},\n",
" {\"role\": \"user\", \"content\": EXAMPLES[0][0]},\n",
" {\"role\": \"assistant\", \"content\": EXAMPLES[0][1]},\n",
" {\"role\": \"user\", \"content\": EXAMPLES[1][0]},\n",
" {\"role\": \"assistant\", \"content\": EXAMPLES[1][1]},\n",
" {\"role\": \"user\", \"content\": EXAMPLES[2][0]},\n",
" {\"role\": \"assistant\", \"content\": EXAMPLES[2][1]},\n",
" {\"role\": \"user\", \"content\": USER_MESSAGE},\n",
" ],\n",
")\n",
"\n",
"\n",
"print(\"Response:\")\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0_mu8ZjD-mms"
},
"source": [
"## 6. Retrieval Augmented Generation\n",
"\n",
"RAG (Retrieval Augmented Generation) is a technique to get a language model to answer questions accurately for a particular domain, by first retrieving relevant information from a knowledge source and then generating a response based on that information.\n",
"\n",
"We have provided a local CSV file with data about hybrid cars. The code below reads the CSV file, searches for matches to the user question, and then generates a response based on the information found. Note that this will take longer than any of the previous examples, as it sends more data to the model. If you notice the answer is still not grounded in the data, you can try system engineering or try other models. Generally, RAG is more effective with either larger models or with fine-tuned versions of SLMs."
]
},
{
"cell_type": "code",
"source": [
"!wget https://github.com/microsoft/Phi-3CookBook/raw/main/code/01.Introduce/hybrid.csv"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PUfY4y8mAM7G",
"outputId": "e4d944a1-9be5-4202-d8f0-a69ebe32db30"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"--2024-09-26 01:45:22-- https://github.com/microsoft/Phi-3CookBook/raw/main/code/01.Introduce/hybrid.csv\n",
"Resolving github.com (github.com)... 140.82.112.4\n",
"Connecting to github.com (github.com)|140.82.112.4|:443... connected.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: https://raw.githubusercontent.com/microsoft/Phi-3CookBook/main/code/01.Introduce/hybrid.csv [following]\n",
"--2024-09-26 01:45:22-- https://raw.githubusercontent.com/microsoft/Phi-3CookBook/main/code/01.Introduce/hybrid.csv\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.110.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 6484 (6.3K) [text/plain]\n",
"Saving to: ‘hybrid.csv’\n",
"\n",
"hybrid.csv 100%[===================>] 6.33K --.-KB/s in 0s \n",
"\n",
"2024-09-26 01:45:23 (66.3 MB/s) - ‘hybrid.csv’ saved [6484/6484]\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "I_0i_Iqn-mms",
"outputId": "1c858c7b-734a-457c-9138-662a7a27ada4"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[['vehicle', 'year', 'msrp', 'acceleration', 'mpg', 'class'],\n",
" ['Prius (1st Gen)', '1997', '24509.74', '7.46', '41.26', 'Compact'],\n",
" ['Tino', '2000', '35354.97', '8.2', '54.1', 'Compact'],\n",
" ['Prius (2nd Gen)', '2000', '26832.25', '7.97', '45.23', 'Compact'],\n",
" ['Insight', '2000', '18936.41', '9.52', '53.0', 'Two Seater']]"
]
},
"metadata": {},
"execution_count": 13
}
],
"source": [
"import csv\n",
"\n",
"SYSTEM_MESSAGE = \"\"\"\n",
"You are a helpful assistant that answers questions about cars based off a hybrid car data set.\n",
"You must use the data set to answer the questions, you should not provide any information that is not in the provided sources.\n",
"\"\"\"\n",
"\n",
"USER_MESSAGE = \"how fast is a prius?\"\n",
"\n",
"# Open the CSV and store in a list\n",
"with open(\"hybrid.csv\", \"r\") as file:\n",
" reader = csv.reader(file)\n",
" rows = list(reader)\n",
"\n",
"rows[0:5]"
]
},
{
"cell_type": "markdown",
"source": [
"## Retrieval"
],
"metadata": {
"id": "9SlubczOAjVr"
}
},
{
"cell_type": "code",
"source": [
"# Normalize the user question to replace punctuation and make lowercase\n",
"normalized_message = USER_MESSAGE.lower().replace(\"?\", \"\").replace(\"(\", \" \").replace(\")\", \" \")\n",
"\n",
"print(normalized_message)\n",
"print(\" \")\n",
"\n",
"# Search the CSV for user question using very naive search\n",
"words = normalized_message.split()\n",
"matches = []\n",
"for row in rows[1:]:\n",
" # if the word matches any word in row, add the row to the matches\n",
" if any(word in row[0].lower().split() for word in words) or any(word in row[5].lower().split() for word in words):\n",
" matches.append(row)\n",
"\n",
"# Format as a markdown table, since language models understand markdown\n",
"matches_table = \" | \".join(rows[0]) + \"\\n\" + \" | \".join(\" --- \" for _ in range(len(rows[0]))) + \"\\n\"\n",
"matches_table += \"\\n\".join(\" | \".join(row) for row in matches)\n",
"print(f\"Found {len(matches)} matches:\")\n",
"print(matches_table)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9LtZxLvaAcqx",
"outputId": "001afbdb-d49d-4037-9cc1-aa45e4716b80"
},
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"how fast is a prius\n",
" \n",
"Found 11 matches:\n",
"vehicle | year | msrp | acceleration | mpg | class\n",
" --- | --- | --- | --- | --- | --- \n",
"Prius (1st Gen) | 1997 | 24509.74 | 7.46 | 41.26 | Compact\n",
"Prius (2nd Gen) | 2000 | 26832.25 | 7.97 | 45.23 | Compact\n",
"Prius | 2004 | 20355.64 | 9.9 | 46.0 | Midsize\n",
"Prius (3rd Gen) | 2009 | 24641.18 | 9.6 | 47.98 | Compact\n",
"Prius alpha (V) | 2011 | 30588.35 | 10.0 | 72.92 | Midsize\n",
"Prius V | 2011 | 27272.28 | 9.51 | 32.93 | Midsize\n",
"Prius C | 2012 | 19006.62 | 9.35 | 50.0 | Compact\n",
"Prius PHV | 2012 | 32095.61 | 8.82 | 50.0 | Midsize\n",
"Prius C | 2013 | 19080.0 | 8.7 | 50.0 | Compact\n",
"Prius | 2013 | 24200.0 | 10.2 | 50.0 | Midsize\n",
"Prius Plug-in | 2013 | 32000.0 | 9.17 | 50.0 | Midsize\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Generation"
],
"metadata": {
"id": "fqcRsL1cAgb-"
}
},
{
"cell_type": "code",
"source": [
"# prompt\n",
"print(USER_MESSAGE + \"\\nSources: \" + matches_table)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IK2qmqHZA76D",
"outputId": "2344a687-6373-4109-b122-8f7ad4dba82d"
},
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"how fast is a prius?\n",
"Sources: vehicle | year | msrp | acceleration | mpg | class\n",
" --- | --- | --- | --- | --- | --- \n",
"Prius (1st Gen) | 1997 | 24509.74 | 7.46 | 41.26 | Compact\n",
"Prius (2nd Gen) | 2000 | 26832.25 | 7.97 | 45.23 | Compact\n",
"Prius | 2004 | 20355.64 | 9.9 | 46.0 | Midsize\n",
"Prius (3rd Gen) | 2009 | 24641.18 | 9.6 | 47.98 | Compact\n",
"Prius alpha (V) | 2011 | 30588.35 | 10.0 | 72.92 | Midsize\n",
"Prius V | 2011 | 27272.28 | 9.51 | 32.93 | Midsize\n",
"Prius C | 2012 | 19006.62 | 9.35 | 50.0 | Compact\n",
"Prius PHV | 2012 | 32095.61 | 8.82 | 50.0 | Midsize\n",
"Prius C | 2013 | 19080.0 | 8.7 | 50.0 | Compact\n",
"Prius | 2013 | 24200.0 | 10.2 | 50.0 | Midsize\n",
"Prius Plug-in | 2013 | 32000.0 | 9.17 | 50.0 | Midsize\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Now we can use the matches to generate a response\n",
"response = client.chat.completions.create(\n",
" model=model,\n",
" temperature=0.7,\n",
" n=1,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": SYSTEM_MESSAGE},\n",
" {\"role\": \"user\", \"content\": USER_MESSAGE + \"\\nSources: \" + matches_table},\n",
" ],\n",
")\n",
"\n",
"print(\"Response:\")\n",
"print(response.choices[0].message.content)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qeOAAPHiAYSJ",
"outputId": "f33d81f2-276b-495b-f45d-e2ed8ca15f79"
},
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Response:\n",
" According to the provided data set, the acceleration of the Toyota Prius varies across different generations. Here are the acceleration values mentioned in the data:\n",
"\n",
"* Prius (1st Gen): 7.46\n",
"* Prius (2nd Gen): 7.97\n",
"* Prius (3rd Gen): 9.6\n",
"* Prius alpha (V): 10.0\n",
"* Prius PHV: 8.82\n",
"\n",
"The fastest Prius in terms of acceleration is the Prius alpha (V) with an acceleration of 10.0, and the slowest is the Prius (1st Gen) with an acceleration of 7.46.\n"
]
}
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
},
"colab": {
"provenance": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment