Last active
December 26, 2024 19:08
-
-
Save intellectronica/9b190aca94bf4372c4b08e8b016922ec to your computer and use it in GitHub Desktop.
pydantic-ai-openai-strict.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "private_outputs": true, | |
| "provenance": [], | |
| "authorship_tag": "ABX9TyP3kCA/NmPkzL7PebGZ35z1", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/intellectronica/9b190aca94bf4372c4b08e8b016922ec/pydantic-ai-openai-strict.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "tCSjwdW6sbf1" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "%pip install pydantic_ai asyncpg opentelemetry-instrumentation-asyncpg logfire[asyncpg] nest_asyncio\n", | |
| "from IPython.display import clear_output ; clear_output()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import os\n", | |
| "from google.colab import userdata\n", | |
| "\n", | |
| "for secret in ['OPENAI_API_KEY', 'LOGFIRE_TOKEN']:\n", | |
| " os.environ[secret] = userdata.get(secret)" | |
| ], | |
| "metadata": { | |
| "id": "02-od_5vsiOD" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import nest_asyncio\n", | |
| "\n", | |
| "nest_asyncio.apply()" | |
| ], | |
| "metadata": { | |
| "id": "Jk0utnZitSKG" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from openai import AsyncOpenAI\n", | |
| "import logfire\n", | |
| "\n", | |
| "openai_client = AsyncOpenAI()\n", | |
| "\n", | |
| "_ = logfire.configure(console=False)\n", | |
| "_ = logfire.instrument_openai(openai_client)\n", | |
| "_ = logfire.instrument_asyncpg()" | |
| ], | |
| "metadata": { | |
| "id": "cY3Q1t9qt_pf" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "##########################################################################\n", | |
| "# Hack to get Pydantic AI using OpenAI structured outputs in strict mode #\n", | |
| "##########################################################################\n", | |
| "\n", | |
| "from pydantic_ai.models.openai import OpenAIModel\n", | |
| "from pydantic_ai.tools import ToolDefinition\n", | |
| "from openai.types.chat import ChatCompletionToolParam\n", | |
| "\n", | |
| "class StrictOpenAIModel(OpenAIModel):\n", | |
| " \"\"\"OpenAIModel with strict mode enabled.\n", | |
| "\n", | |
| " This class can be used instead of OpenAIModel to enable strict mode\n", | |
| " for all tool calls, including any typed results.\n", | |
| " \"\"\"\n", | |
| " @staticmethod\n", | |
| " def _map_tool_definition(f: ToolDefinition) -> ChatCompletionToolParam:\n", | |
| " \"\"\"Redefinition of _map_tool_definition to enable strict mode.\n", | |
| "\n", | |
| " Function calls can use strict mode (guaraneeing adherence to the\n", | |
| " defined schema) by specifying `strict: true` for the function, and by\n", | |
| " ensuring that the parameters for the function explicitly include\n", | |
| " `additionalProperties: false`.\n", | |
| " \"\"\"\n", | |
| " tool_def = OpenAIModel._map_tool_definition(f)\n", | |
| " tool_def['function']['strict'] = True\n", | |
| " tool_def['function']['parameters']['additionalProperties'] = False\n", | |
| " return tool_def" | |
| ], | |
| "metadata": { | |
| "id": "QOXKqis10o2R" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from textwrap import wrap\n", | |
| "\n", | |
| "from pydantic import BaseModel, Field\n", | |
| "from pydantic_ai import Agent\n", | |
| "\n", | |
| "\n", | |
| "class ElaborateResponse(BaseModel):\n", | |
| " thoughts: str = Field(..., description=(\n", | |
| " \"Step-by-step thought process and reasoning. \"\n", | |
| " \"Approximately 300 tokens.\"),\n", | |
| " )\n", | |
| " answer: str = Field(\n", | |
| " ..., description=\"Final answer to the question.\",\n", | |
| " )\n", | |
| "\n", | |
| "openai_model = StrictOpenAIModel('gpt-4o', openai_client=openai_client)\n", | |
| "agent = Agent(openai_model, result_type=ElaborateResponse)\n", | |
| "result = await agent.run(\"How many 'r's are in the word strawberry?\")\n", | |
| "\n", | |
| "print('Answer:', result.data.answer)\n", | |
| "print('\\n---\\n')\n", | |
| "print('\\n'.join(wrap(result.data.thoughts)))" | |
| ], | |
| "metadata": { | |
| "id": "vYws9wUGuD-R" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Author
Nice! You can probably simplify
_map_tool_definitiontodef _map_tool_definition(f: ToolDefinition) -> ChatCompletionToolParam: tool_param = super()._map_tool_definition(f) tool_param['function']['strict'] = True return tool_param
@samuelcolvin Almost. super() can't be used like this, but we can just call the method from the class explicitly. And we need to also set additionalProperties.
This works:
@staticmethod
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionToolParam:
tool_def = OpenAIModel._map_tool_definition(f)
tool_def['function']['strict'] = True
tool_def['function']['parameters']['additionalProperties'] = False
return tool_def
Author
Gist updated. Thanks for the suggestion, this is more elegant and less duplicative.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nice! You can probably simplify
_map_tool_definitionto