Skip to content

Instantly share code, notes, and snippets.

@sgbaird
Created May 28, 2025 21:00
Show Gist options
  • Select an option

  • Save sgbaird/aa140b187e72491a278d98039253f94e to your computer and use it in GitHub Desktop.

Select an option

Save sgbaird/aa140b187e72491a278d98039253f94e to your computer and use it in GitHub Desktop.
ax-mwe-for-sergio-saving.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/sgbaird/aa140b187e72491a278d98039253f94e/ax-mwe-for-sergio-saving.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "57xih43KFdOY",
"outputId": "d3538b65-581d-4abd-c15f-1dbe0aad8ba4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting ax-platform==0.4.3\n",
" Downloading ax_platform-0.4.3-py3-none-any.whl.metadata (11 kB)\n",
"Collecting botorch==0.12.0 (from ax-platform==0.4.3)\n",
" Downloading botorch-0.12.0-py3-none-any.whl.metadata (11 kB)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from ax-platform==0.4.3) (3.1.6)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from ax-platform==0.4.3) (2.2.2)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.11/dist-packages (from ax-platform==0.4.3) (1.15.3)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (from ax-platform==0.4.3) (1.6.1)\n",
"Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from ax-platform==0.4.3) (7.7.1)\n",
"Requirement already satisfied: plotly>=5.12.0 in /usr/local/lib/python3.11/dist-packages (from ax-platform==0.4.3) (5.24.1)\n",
"Collecting pyre-extensions (from ax-platform==0.4.3)\n",
" Downloading pyre_extensions-0.0.32-py3-none-any.whl.metadata (4.0 kB)\n",
"Requirement already satisfied: multipledispatch in /usr/local/lib/python3.11/dist-packages (from botorch==0.12.0->ax-platform==0.4.3) (1.0.0)\n",
"Requirement already satisfied: mpmath<=1.3,>=0.19 in /usr/local/lib/python3.11/dist-packages (from botorch==0.12.0->ax-platform==0.4.3) (1.3.0)\n",
"Requirement already satisfied: torch>=2.0.1 in /usr/local/lib/python3.11/dist-packages (from botorch==0.12.0->ax-platform==0.4.3) (2.6.0+cu124)\n",
"Collecting pyro-ppl>=1.8.4 (from botorch==0.12.0->ax-platform==0.4.3)\n",
" Downloading pyro_ppl-1.9.1-py3-none-any.whl.metadata (7.8 kB)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.11/dist-packages (from botorch==0.12.0->ax-platform==0.4.3) (4.13.2)\n",
"Collecting gpytorch==1.13 (from botorch==0.12.0->ax-platform==0.4.3)\n",
" Downloading gpytorch-1.13-py3-none-any.whl.metadata (8.0 kB)\n",
"Collecting linear-operator==0.5.3 (from botorch==0.12.0->ax-platform==0.4.3)\n",
" Downloading linear_operator-0.5.3-py3-none-any.whl.metadata (15 kB)\n",
"Collecting jaxtyping==0.2.19 (from gpytorch==1.13->botorch==0.12.0->ax-platform==0.4.3)\n",
" Downloading jaxtyping-0.2.19-py3-none-any.whl.metadata (5.7 kB)\n",
"Requirement already satisfied: numpy>=1.20.0 in /usr/local/lib/python3.11/dist-packages (from jaxtyping==0.2.19->gpytorch==1.13->botorch==0.12.0->ax-platform==0.4.3) (2.0.2)\n",
"Requirement already satisfied: typeguard>=2.13.3 in /usr/local/lib/python3.11/dist-packages (from jaxtyping==0.2.19->gpytorch==1.13->botorch==0.12.0->ax-platform==0.4.3) (4.4.2)\n",
"Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.11/dist-packages (from plotly>=5.12.0->ax-platform==0.4.3) (9.1.2)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from plotly>=5.12.0->ax-platform==0.4.3) (24.2)\n",
"Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->ax-platform==0.4.3) (6.17.1)\n",
"Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->ax-platform==0.4.3) (0.2.0)\n",
"Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->ax-platform==0.4.3) (5.7.1)\n",
"Requirement already satisfied: widgetsnbextension~=3.6.0 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->ax-platform==0.4.3) (3.6.10)\n",
"Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->ax-platform==0.4.3) (7.34.0)\n",
"Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->ax-platform==0.4.3) (3.0.15)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->ax-platform==0.4.3) (3.0.2)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas->ax-platform==0.4.3) (2.9.0.post0)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->ax-platform==0.4.3) (2025.2)\n",
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->ax-platform==0.4.3) (2025.2)\n",
"Collecting typing-inspect (from pyre-extensions->ax-platform==0.4.3)\n",
" Downloading typing_inspect-0.9.0-py3-none-any.whl.metadata (1.5 kB)\n",
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn->ax-platform==0.4.3) (1.5.0)\n",
"Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn->ax-platform==0.4.3) (3.6.0)\n",
"Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (1.8.0)\n",
"Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (6.1.12)\n",
"Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (0.1.7)\n",
"Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (1.6.0)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (5.9.5)\n",
"Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (24.0.1)\n",
"Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (6.4.2)\n",
"Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (75.2.0)\n",
"Collecting jedi>=0.16 (from ipython>=4.0.0->ipywidgets->ax-platform==0.4.3)\n",
" Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)\n",
"Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (4.4.2)\n",
"Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (0.7.5)\n",
"Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (3.0.51)\n",
"Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (2.19.1)\n",
"Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (0.2.0)\n",
"Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (4.9.0)\n",
"Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.11/dist-packages (from pyro-ppl>=1.8.4->botorch==0.12.0->ax-platform==0.4.3) (3.4.0)\n",
"Collecting pyro-api>=0.1.1 (from pyro-ppl>=1.8.4->botorch==0.12.0->ax-platform==0.4.3)\n",
" Downloading pyro_api-0.1.2-py3-none-any.whl.metadata (2.5 kB)\n",
"Requirement already satisfied: tqdm>=4.36 in /usr/local/lib/python3.11/dist-packages (from pyro-ppl>=1.8.4->botorch==0.12.0->ax-platform==0.4.3) (4.67.1)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas->ax-platform==0.4.3) (1.17.0)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3) (3.18.0)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3) (3.4.2)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3) (2025.3.2)\n",
"Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n",
" Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n",
" Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n",
" Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n",
" Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n",
" Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n",
" Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n",
" Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n",
" Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n",
" Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3) (0.6.2)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3) (2.21.5)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3) (12.4.127)\n",
"Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3)\n",
" Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3) (3.2.0)\n",
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.1->botorch==0.12.0->ax-platform==0.4.3) (1.13.1)\n",
"Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.11/dist-packages (from widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (6.5.7)\n",
"Collecting mypy-extensions>=0.3.0 (from typing-inspect->pyre-extensions->ax-platform==0.4.3)\n",
" Downloading mypy_extensions-1.1.0-py3-none-any.whl.metadata (1.1 kB)\n",
"Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (0.8.4)\n",
"Requirement already satisfied: jupyter-core>=4.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (5.7.2)\n",
"Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (23.1.0)\n",
"Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (5.10.4)\n",
"Requirement already satisfied: nbconvert>=5 in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (7.16.6)\n",
"Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (1.8.3)\n",
"Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.18.1)\n",
"Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.22.0)\n",
"Requirement already satisfied: nbclassic>=0.4.7 in /usr/local/lib/python3.11/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (1.3.1)\n",
"Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.11/dist-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (0.7.0)\n",
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0->ipywidgets->ax-platform==0.4.3) (0.2.13)\n",
"Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core>=4.6.0->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ax-platform==0.4.3) (4.3.8)\n",
"Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.2.4)\n",
"Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (4.13.4)\n",
"Requirement already satisfied: bleach!=5.0.0 in /usr/local/lib/python3.11/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (6.2.0)\n",
"Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.7.1)\n",
"Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.3.0)\n",
"Requirement already satisfied: mistune<4,>=2.0.3 in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (3.1.3)\n",
"Requirement already satisfied: nbclient>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.10.2)\n",
"Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (1.5.1)\n",
"Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (2.21.1)\n",
"Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (4.23.0)\n",
"Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (21.2.0)\n",
"Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach!=5.0.0->bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.5.1)\n",
"Requirement already satisfied: tinycss2<1.5,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (1.4.0)\n",
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (25.3.0)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (2025.4.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.36.2)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (0.25.1)\n",
"Requirement already satisfied: jupyter-server<3,>=1.8 in /usr/local/lib/python3.11/dist-packages (from notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (1.16.0)\n",
"Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (1.17.1)\n",
"Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (2.7)\n",
"Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (2.22)\n",
"Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (4.9.0)\n",
"Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (1.8.0)\n",
"Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (3.10)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ax-platform==0.4.3) (1.3.1)\n",
"Downloading ax_platform-0.4.3-py3-none-any.whl (1.3 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading botorch-0.12.0-py3-none-any.whl (644 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m644.8/644.8 kB\u001b[0m \u001b[31m35.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading gpytorch-1.13-py3-none-any.whl (277 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m277.8/277.8 kB\u001b[0m \u001b[31m17.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading linear_operator-0.5.3-py3-none-any.whl (176 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m176.4/176.4 kB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading jaxtyping-0.2.19-py3-none-any.whl (24 kB)\n",
"Downloading pyre_extensions-0.0.32-py3-none-any.whl (12 kB)\n",
"Downloading pyro_ppl-1.9.1-py3-none-any.whl (755 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m756.0/756.0 kB\u001b[0m \u001b[31m26.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m43.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m20.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m17.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m891.4 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m10.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m56.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading typing_inspect-0.9.0-py3-none-any.whl (8.8 kB)\n",
"Downloading jedi-0.19.2-py2.py3-none-any.whl (1.6 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m48.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading mypy_extensions-1.1.0-py3-none-any.whl (5.0 kB)\n",
"Downloading pyro_api-0.1.2-py3-none-any.whl (11 kB)\n",
"Installing collected packages: pyro-api, nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, mypy-extensions, jedi, typing-inspect, nvidia-cusparse-cu12, nvidia-cudnn-cu12, jaxtyping, pyre-extensions, nvidia-cusolver-cu12, pyro-ppl, linear-operator, gpytorch, botorch, ax-platform\n",
" Attempting uninstall: nvidia-nvjitlink-cu12\n",
" Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n",
" Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n",
" Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n",
" Attempting uninstall: nvidia-curand-cu12\n",
" Found existing installation: nvidia-curand-cu12 10.3.6.82\n",
" Uninstalling nvidia-curand-cu12-10.3.6.82:\n",
" Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n",
" Attempting uninstall: nvidia-cufft-cu12\n",
" Found existing installation: nvidia-cufft-cu12 11.2.3.61\n",
" Uninstalling nvidia-cufft-cu12-11.2.3.61:\n",
" Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n",
" Attempting uninstall: nvidia-cuda-runtime-cu12\n",
" Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n",
" Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n",
" Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n",
" Attempting uninstall: nvidia-cuda-nvrtc-cu12\n",
" Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n",
" Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n",
" Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n",
" Attempting uninstall: nvidia-cuda-cupti-cu12\n",
" Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n",
" Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n",
" Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n",
" Attempting uninstall: nvidia-cublas-cu12\n",
" Found existing installation: nvidia-cublas-cu12 12.5.3.2\n",
" Uninstalling nvidia-cublas-cu12-12.5.3.2:\n",
" Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n",
" Attempting uninstall: nvidia-cusparse-cu12\n",
" Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n",
" Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n",
" Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n",
" Attempting uninstall: nvidia-cudnn-cu12\n",
" Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n",
" Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n",
" Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n"
]
}
],
"source": [
"%pip install ax-platform==0.4.3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zA74rOiWoDro"
},
"outputs": [],
"source": [
"def branin3(x1, x2, x3, c1):\n",
" y = float(\n",
" (x2 - 5.1 / (4 * np.pi**2) * x1**2 + 5.0 / np.pi * x1 - 6.0) ** 2\n",
" + 10 * (1 - 1.0 / (8 * np.pi)) * np.cos(x1)\n",
" + 10\n",
" )\n",
"\n",
" # Contrived way to incorporate x3 into the objective\n",
" y = y * (1 + 0.1 * x1 * x2 * x3)\n",
"\n",
" # add a made-up penalty based on category\n",
" penalty_lookup = {\"A\": 1.0, \"B\": 0.0, \"C\": 2.0}\n",
" y += penalty_lookup[c1]\n",
"\n",
" return y"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "FJtWMajrhG32",
"outputId": "40d861a7-8be2-4746-e533-942811e23cc7"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[WARNING 05-27 21:37:57] ax.service.utils.with_db_settings_base: Ax currently requires a sqlalchemy version below 2.0. This will be addressed in a future release. Disabling SQL storage in Ax for now, if you would like to use SQL storage please install Ax with mysql extras via `pip install ax-platform[mysql]`.\n",
"[INFO 05-27 21:37:57] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.\n",
"[INFO 05-27 21:37:57] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.\n",
"[INFO 05-27 21:37:57] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.\n",
"[INFO 05-27 21:37:57] ax.service.utils.instantiation: Inferred value type of ParameterType.STRING for parameter c1. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.\n",
"/usr/local/lib/python3.11/dist-packages/ax/service/utils/instantiation.py:248: AxParameterWarning: `sort_values` is not specified for `ChoiceParameter` \"c1\". Defaulting to `False` for parameters of `ParameterType` STRING. To override this behavior (or avoid this warning), specify `sort_values` during `ChoiceParameter` construction.\n",
" return ChoiceParameter(\n",
"[INFO 05-27 21:37:57] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[0.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 10.0]), ChoiceParameter(name='c1', parameter_type=STRING, values=['A', 'B', 'C'], is_ordered=False, sort_values=False)], parameter_constraints=[ParameterConstraint(1.0*x1 + 1.0*x2 <= 10.0)]).\n",
"/usr/local/lib/python3.11/dist-packages/ax/modelbridge/cross_validation.py:463: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.\n",
" warn(\"Encountered exception in computing model fit quality: \" + str(e))\n",
"[INFO 05-27 21:37:57] ax.service.ax_client: Generated new trial 0 with parameters {'x1': 3.670143, 'x2': 2.112298, 'c1': 'C'} using model Sobol.\n",
"[INFO 05-27 21:37:57] ax.service.ax_client: Saved JSON-serialized state of optimization to `ax_client_snapshot.json`.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Suggested next experiment (trial #0) -- x1: 3.6701425071805716, x2: 2.1122976299375296, x3: 4.217559862881899, c1: C\n"
]
}
],
"source": [
"# Generated by Honegumi (https://arxiv.org/abs/2502.06815)\n",
"# %pip install ax-platform==0.4.3 matplotlib\n",
"import numpy as np\n",
"import pandas as pd\n",
"from ax.service.ax_client import AxClient, ObjectiveProperties\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"from ax.modelbridge.factory import Models\n",
"from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy\n",
"\n",
"\n",
"obj1_name = \"branin\"\n",
"\n",
"\n",
"# Define total for compositional constraint, where x1 + x2 + x3 == total\n",
"total = 10.0\n",
"\n",
"gs = GenerationStrategy(\n",
" steps=[\n",
" GenerationStep(\n",
" model=Models.SOBOL,\n",
" num_trials=5, # how many sobol trials to perform (rule of thumb: 2 * number of params)\n",
" min_trials_observed=3,\n",
" max_parallelism=5,\n",
" model_kwargs={\"seed\": 999},\n",
" ),\n",
" GenerationStep(\n",
" model=Models.SAASBO,\n",
" num_trials=-1,\n",
" max_parallelism=3,\n",
" model_kwargs={},\n",
" ),\n",
" ]\n",
")\n",
"\n",
"ax_client = AxClient(generation_strategy=gs)\n",
"# if using SAASBO is too slow, remove `gs` entirely and just use `ax_client = AxClient()`, which will give you reasonable defaults\n",
"\n",
"ax_client.create_experiment(\n",
" parameters=[\n",
" {\"name\": \"x1\", \"type\": \"range\", \"bounds\": [0.0, total]},\n",
" {\"name\": \"x2\", \"type\": \"range\", \"bounds\": [0.0, total]},\n",
" {\n",
" \"name\": \"c1\",\n",
" \"type\": \"choice\",\n",
" \"is_ordered\": False,\n",
" \"values\": [\"A\", \"B\", \"C\"],\n",
" },\n",
" ],\n",
" objectives={\n",
" obj1_name: ObjectiveProperties(minimize=True),\n",
" },\n",
" parameter_constraints=[\n",
" f\"x1 + x2 <= {total}\", # reparameterized compositional constraint, which is a type of sum constraint\n",
" ],\n",
")\n",
"\n",
"\n",
"parameterization, trial_index = ax_client.get_next_trial()\n",
"\n",
"# extract parameters\n",
"x1 = parameterization[\"x1\"]\n",
"x2 = parameterization[\"x2\"]\n",
"x3 = total - (x1 + x2) # composition constraint: x1 + x2 + x3 == total\n",
"\n",
"c1 = parameterization[\"c1\"]\n",
"\n",
"print(f\"Suggested next experiment (trial #{trial_index}) -- x1: {x1}, x2: {x2}, x3: {x3}, c1: {c1}\")\n",
"\n",
"snapshot_fpath = \"ax_client_snapshot.json\"\n",
"ax_client.save_to_json_file(snapshot_fpath)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "ZKnLwJmdii6R",
"outputId": "6fdfa44f-d368-4ae4-e28f-95c8459472df"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"9.487827279043518\n"
]
}
],
"source": [
"del ax_client # just to emphasize that it's being loaded in a later cell (i.e., for illustration purposes)\n",
"results = branin3(x1, x2, x3, c1) # this is you \"running the experiment\"\n",
"print(results)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "vzd_8070inpr",
"outputId": "8f16bfe9-17e6-4da8-efe3-337720567a8f"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.11/dist-packages/ax/storage/json_store/decoder.py:303: AxParameterWarning: `sort_values` is not specified for `ChoiceParameter` \"c1\". Defaulting to `False` for parameters of `ParameterType` STRING. To override this behavior (or avoid this warning), specify `sort_values` during `ChoiceParameter` construction.\n",
" return _class(\n",
"[INFO 05-27 21:37:57] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.\n"
]
}
],
"source": [
"ax_client = AxClient.load_from_json_file(snapshot_fpath)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "gytHOHxAiqCv",
"outputId": "dc96b827-7e7a-4898-d18e-d415998e9442"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[INFO 05-27 21:37:57] ax.service.ax_client: Completed trial 0 with data: {'branin': (9.487827, None)}.\n"
]
}
],
"source": [
"# NOTE: we need to pass trial_index to tell AxClient which parameterization these results correspond to)\n",
"ax_client.complete_trial(trial_index=trial_index, raw_data={\"branin\": results})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "iOh-9ZedjhfJ",
"outputId": "a88cac84-c31c-4565-a0ed-6ef70f58aeb6"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.11/dist-packages/ax/modelbridge/cross_validation.py:463: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.\n",
" warn(\"Encountered exception in computing model fit quality: \" + str(e))\n",
"[INFO 05-27 21:37:57] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 1.085235, 'x2': 4.961436, 'c1': 'C'} using model Sobol.\n",
"[INFO 05-27 21:37:57] ax.service.ax_client: Saved JSON-serialized state of optimization to `ax_client_snapshot.json`.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'x1': 1.0852348804473877, 'x2': 4.961435543373227, 'c1': 'C'} 1\n",
"Suggested next experiment -- x1: 1.0852348804473877, x2: 4.961435543373227, x3: 3.953329576179385, c1: C\n"
]
}
],
"source": [
"parameterization, trial_index = ax_client.get_next_trial() # run this cell only once (because doing sequential optimization)\n",
"\n",
"print(parameterization, trial_index)\n",
"\n",
"# extract parameters\n",
"x1 = parameterization[\"x1\"]\n",
"x2 = parameterization[\"x2\"]\n",
"x3 = total - (x1 + x2) # composition constraint: x1 + x2 + x3 == total\n",
"\n",
"c1 = parameterization[\"c1\"]\n",
"\n",
"print(f\"Suggested next experiment -- x1: {x1}, x2: {x2}, x3: {x3}, c1: {c1}\")\n",
"\n",
"ax_client.save_to_json_file(snapshot_fpath)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "c0y-Grs3kHV4",
"outputId": "946ab777-7950-4583-ab96-5175c9244a3f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"48.206864586996545\n"
]
}
],
"source": [
"del ax_client\n",
"results = branin3(x1, x2, x3, c1) # this is you \"running the experiment\"\n",
"print(results)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "ITNW1oyNkbUi",
"outputId": "d3b8bc3a-eca8-4b1a-c2fa-9e410c67cdf8"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.11/dist-packages/ax/storage/json_store/decoder.py:303: AxParameterWarning: `sort_values` is not specified for `ChoiceParameter` \"c1\". Defaulting to `False` for parameters of `ParameterType` STRING. To override this behavior (or avoid this warning), specify `sort_values` during `ChoiceParameter` construction.\n",
" return _class(\n",
"[INFO 05-27 21:37:57] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.\n"
]
}
],
"source": [
"ax_client = AxClient.load_from_json_file(snapshot_fpath)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "MkkjhzOMk7AF",
"outputId": "af6a4d94-c80c-4342-a1f1-0280a1c88bec"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[INFO 05-27 21:37:57] ax.service.ax_client: Completed trial 1 with data: {'branin': (48.206865, None)}.\n"
]
}
],
"source": [
"ax_client.complete_trial(trial_index=trial_index, raw_data={\"branin\": results})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true,
"base_uri": "https://localhost:8080/"
},
"id": "BpQxf7-yl0W7",
"outputId": "f8db80ab-aa8c-4f36-82b3-96c6a498eaf7"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.11/dist-packages/ax/modelbridge/cross_validation.py:463: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.\n",
" warn(\"Encountered exception in computing model fit quality: \" + str(e))\n",
"[INFO 05-27 21:37:57] ax.service.ax_client: Generated new trial 2 with parameters {'x1': 5.255021, 'x2': 1.09099, 'c1': 'A'} using model Sobol.\n",
"[INFO 05-27 21:37:57] ax.service.ax_client: Completed trial 2 with data: {'branin': (47.334254, None)}.\n",
"/usr/local/lib/python3.11/dist-packages/ax/modelbridge/cross_validation.py:463: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.\n",
" warn(\"Encountered exception in computing model fit quality: \" + str(e))\n",
"[INFO 05-27 21:37:57] ax.service.ax_client: Generated new trial 3 with parameters {'x1': 4.206534, 'x2': 2.918926, 'c1': 'A'} using model Sobol.\n",
"[INFO 05-27 21:37:57] ax.service.ax_client: Completed trial 3 with data: {'branin': (33.207714, None)}.\n",
"/usr/local/lib/python3.11/dist-packages/ax/modelbridge/cross_validation.py:463: UserWarning: Encountered exception in computing model fit quality: RandomModelBridge does not support prediction.\n",
" warn(\"Encountered exception in computing model fit quality: \" + str(e))\n",
"[INFO 05-27 21:37:57] ax.service.ax_client: Generated new trial 4 with parameters {'x1': 0.46391, 'x2': 6.076421, 'c1': 'B'} using model Sobol.\n",
"[INFO 05-27 21:37:57] ax.service.ax_client: Completed trial 4 with data: {'branin': (37.937763, None)}.\n",
"[INFO 05-27 21:39:48] ax.service.ax_client: Generated new trial 5 with parameters {'x1': 8.525951, 'x2': 0.0, 'c1': 'C'} using model SAASBO.\n",
"[INFO 05-27 21:39:48] ax.service.ax_client: Completed trial 5 with data: {'branin': (9.339133, None)}.\n",
"[INFO 05-27 21:41:16] ax.service.ax_client: Generated new trial 6 with parameters {'x1': 10.0, 'x2': 0.0, 'c1': 'B'} using model SAASBO.\n",
"[INFO 05-27 21:41:16] ax.service.ax_client: Completed trial 6 with data: {'branin': (10.960889, None)}.\n",
"/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n",
" warnings.warn(\n",
"[INFO 05-27 21:42:53] ax.service.ax_client: Generated new trial 7 with parameters {'x1': 6.282715, 'x2': 0.0, 'c1': 'B'} using model SAASBO.\n",
"[INFO 05-27 21:42:53] ax.service.ax_client: Completed trial 7 with data: {'branin': (20.812079, None)}.\n",
"/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n",
" warnings.warn(\n",
"[INFO 05-27 21:44:38] ax.service.ax_client: Generated new trial 8 with parameters {'x1': 0.0, 'x2': 10.0, 'c1': 'A'} using model SAASBO.\n",
"[INFO 05-27 21:44:38] ax.service.ax_client: Completed trial 8 with data: {'branin': (36.602113, None)}.\n",
"/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n",
" warnings.warn(\n",
"[INFO 05-27 21:46:07] ax.service.ax_client: Generated new trial 9 with parameters {'x1': 10.0, 'x2': 0.0, 'c1': 'C'} using model SAASBO.\n",
"[INFO 05-27 21:46:07] ax.service.ax_client: Completed trial 9 with data: {'branin': (12.960889, None)}.\n",
"/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n",
" warnings.warn(\n",
"[INFO 05-27 21:47:48] ax.service.ax_client: Generated new trial 10 with parameters {'x1': 10.0, 'x2': 0.0, 'c1': 'A'} using model SAASBO.\n",
"[INFO 05-27 21:47:48] ax.service.ax_client: Completed trial 10 with data: {'branin': (11.960889, None)}.\n",
"/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n",
" warnings.warn(\n",
"[INFO 05-27 21:49:39] ax.service.ax_client: Generated new trial 11 with parameters {'x1': 3.232036, 'x2': 0.0, 'c1': 'C'} using model SAASBO.\n",
"[INFO 05-27 21:49:39] ax.service.ax_client: Completed trial 11 with data: {'branin': (7.301467, None)}.\n",
"/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n",
" warnings.warn(\n",
"[INFO 05-27 21:51:30] ax.service.ax_client: Generated new trial 12 with parameters {'x1': 7.784494, 'x2': 0.0, 'c1': 'C'} using model SAASBO.\n",
"[INFO 05-27 21:51:30] ax.service.ax_client: Completed trial 12 with data: {'branin': (14.737298, None)}.\n",
"/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n",
" warnings.warn(\n",
"[INFO 05-27 21:53:28] ax.service.ax_client: Generated new trial 13 with parameters {'x1': 4.273055, 'x2': 0.0, 'c1': 'C'} using model SAASBO.\n",
"[INFO 05-27 21:53:28] ax.service.ax_client: Completed trial 13 with data: {'branin': (10.343237, None)}.\n",
"/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n",
" warnings.warn(\n",
"[INFO 05-27 21:55:23] ax.service.ax_client: Generated new trial 14 with parameters {'x1': 0.0, 'x2': 0.0, 'c1': 'C'} using model SAASBO.\n",
"[INFO 05-27 21:55:23] ax.service.ax_client: Completed trial 14 with data: {'branin': (57.602113, None)}.\n",
"[INFO 05-27 21:57:07] ax.service.ax_client: Generated new trial 15 with parameters {'x1': 3.387063, 'x2': 0.0, 'c1': 'B'} using model SAASBO.\n",
"[INFO 05-27 21:57:07] ax.service.ax_client: Completed trial 15 with data: {'branin': (5.059481, None)}.\n",
"/usr/local/lib/python3.11/dist-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n",
" warnings.warn(\n",
"[INFO 05-27 21:58:36] ax.service.ax_client: Generated new trial 16 with parameters {'x1': 9.123101, 'x2': 0.0, 'c1': 'B'} using model SAASBO.\n",
"[INFO 05-27 21:58:36] ax.service.ax_client: Completed trial 16 with data: {'branin': (5.814624, None)}.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 20min 12s, sys: 5.97 s, total: 20min 18s\n",
"Wall time: 20min 38s\n"
]
}
],
"source": [
"%%time\n",
"# run a bunch of experiments in a loop without saving/loading to give it some data so the end of campaign analysis is meaningful\n",
"\n",
"for i in range(15): # i.e., 10 more trials\n",
" parameterization, trial_index = ax_client.get_next_trial()\n",
"\n",
" # extract parameters\n",
" x1 = parameterization[\"x1\"]\n",
" x2 = parameterization[\"x2\"]\n",
" x3 = total - (x1 + x2) # composition constraint: x1 + x2 + x3 == total\n",
"\n",
" c1 = parameterization[\"c1\"]\n",
"\n",
" results = branin3(x1, x2, x3, c1)\n",
" ax_client.complete_trial(trial_index=trial_index, raw_data={\"branin\": results})"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QBafjdSsjafe"
},
"source": [
"## End of Campaign / Analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true
},
"id": "OfhEK-ARhfMy",
"outputId": "fd56de59-a94b-4367-96c4-1a1ae96e4ad3"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[WARNING 05-27 21:59:52] ax.service.utils.report_utils: Column reason missing for all trials. Not appending column.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best parameters: {'x1': 3.387062556889936, 'x2': 0.0, 'c1': 'B'}\n",
"Best metrics: ({'branin': np.float64(5.78071107054917)}, {'branin': {'branin': np.float64(11.477393275837601)}})\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 900x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_parameters, metrics = ax_client.get_best_parameters()\n",
"\n",
"print(f\"Best parameters: {best_parameters}\")\n",
"print(f\"Best metrics: {metrics}\")\n",
"\n",
"\n",
"# Plot results\n",
"objectives = ax_client.objective_names\n",
"df = ax_client.get_trials_data_frame()\n",
"\n",
"fig, ax = plt.subplots(figsize=(6, 4), dpi=150)\n",
"# ax.scatter(df.index, df[objectives], ec=\"k\", fc=\"none\", label=\"Observed\")\n",
"ax.plot(\n",
" df.index,\n",
" np.minimum.accumulate(df[objectives]),\n",
" color=\"#0033FF\",\n",
" lw=2,\n",
" label=\"Best to Trial\",\n",
")\n",
"ax.set_xlabel(\"Trial Number\")\n",
"ax.set_ylabel(objectives[0])\n",
"\n",
"ax.legend()\n",
"plt.show()"
]
}
],
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyOTZohWpK3IF+V/aV5qVukH",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment