Skip to content

Instantly share code, notes, and snippets.

@zhensongren
Created August 20, 2025 19:53
Show Gist options
  • Select an option

  • Save zhensongren/db8775e8793dd2ca73c3d527d001f8be to your computer and use it in GitHub Desktop.

Select an option

Save zhensongren/db8775e8793dd2ca73c3d527d001f8be to your computer and use it in GitHub Desktop.
BO api
# app/main.py
from io import StringIO, BytesIO
from typing import Dict, List, Tuple, Optional
from uuid import uuid4
import hashlib
import pandas as pd
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
# ---- BayBE (verified API)
from baybe import Campaign
from baybe.parameters import CategoricalParameter, NumericalContinuousParameter
from baybe.searchspace import SearchSpace
from baybe.targets import NumericalTarget
from baybe.objectives import ParetoObjective
from baybe.recommenders import TwoPhaseMetaRecommender, FPSRecommender, BotorchRecommender
# BayBE docs show: Campaign(...), add_measurements(...), recommend(batch_size=K). :contentReference[oaicite:2]{index=2}
app = FastAPI(title="BayBE CSV Sync API")
# --- Fixed template headers
TEMPLATE_HEADERS = [
"solvent","temp_C","pressure_bar","resin_pct","batch_id","replicate","y_yield_pct","y_defect_rate"
]
PARAM_COLS = ["solvent","temp_C","pressure_bar","resin_pct"]
TARGET_COLS = ["y_yield_pct","y_defect_rate"]
# --- Allowed values / bounds (match your Excel template)
SOLVENT_LEVELS = ["A","B","C"]
BOUNDS = {
"temp_C": (0.0, 200.0),
"pressure_bar": (0.0, 5.0),
"resin_pct": (0.0, 100.0),
"y_yield_pct": (0.0, 100.0),
"y_defect_rate": (0.0, 1.0),
}
def _csv_hash(raw_bytes: bytes) -> str:
return hashlib.sha256(raw_bytes).hexdigest()
def _validate_headers(df: pd.DataFrame):
if list(df.columns) != TEMPLATE_HEADERS:
raise HTTPException(400, f"CSV headers must exactly match: {', '.join(TEMPLATE_HEADERS)}")
def _validate_and_cast(df: pd.DataFrame) -> pd.DataFrame:
out = df.copy()
# solvent
bad = ~out["solvent"].isin(SOLVENT_LEVELS) & out["solvent"].notna()
if bad.any():
raise HTTPException(400, f"Invalid solvent values: {sorted(out.loc[bad,'solvent'].astype(str).unique())}")
# numeric bounds
for col in ["temp_C","pressure_bar","resin_pct","y_yield_pct","y_defect_rate"]:
if out[col].notna().any():
out[col] = pd.to_numeric(out[col], errors="raise")
lo, hi = BOUNDS[col]
bad_range = ~out[col].between(lo, hi) & out[col].notna()
if bad_range.any():
raise HTTPException(400, f"{col} has values outside [{lo}, {hi}]")
# replicate integer >=1 (when present)
if out["replicate"].notna().any():
rep = pd.to_numeric(out["replicate"], errors="raise").astype(int)
if (rep < 1).any():
raise HTTPException(400, "replicate must be >= 1")
out["replicate"] = rep
# batch_id passthrough
out["batch_id"] = out["batch_id"].astype(str)
return out
def _param_key(row: pd.Series) -> Tuple:
return tuple(row[c] for c in PARAM_COLS)
def _default_campaign() -> Campaign:
# Search space & objective
ss = SearchSpace.from_product([
CategoricalParameter(name="solvent", values=SOLVENT_LEVELS),
NumericalContinuousParameter(name="temp_C", bounds=BOUNDS["temp_C"]),
NumericalContinuousParameter(name="pressure_bar", bounds=BOUNDS["pressure_bar"]),
NumericalContinuousParameter(name="resin_pct", bounds=BOUNDS["resin_pct"]),
])
objective = ParetoObjective([
NumericalTarget(name="y_yield_pct", mode="MAX"),
NumericalTarget(name="y_defect_rate", mode="MIN"),
])
rec = TwoPhaseMetaRecommender(FPSRecommender(), BotorchRecommender())
return Campaign(searchspace=ss, objective=objective, recommender=rec)
# --- Server state per campaign
class State:
def __init__(self, campaign: Campaign, csv_df: pd.DataFrame, csv_hash: str):
self.campaign = campaign
self.csv_df = csv_df # authoritative combined CSV (measured + recommendations)
self.csv_hash = csv_hash
# Track seen parameter combinations (measured or recommended) to dedupe
self.seen_params = set(_param_key(r) for _, r in csv_df[PARAM_COLS].dropna().iterrows())
# Seed database with rows that have both targets present
have_y = csv_df[TARGET_COLS].notna().all(axis=1)
seed_df = csv_df.loc[have_y, PARAM_COLS + TARGET_COLS]
if not seed_df.empty:
self.campaign.add_measurements(seed_df) # BayBE call (DataFrame in). :contentReference[oaicite:3]{index=3}
STATES: Dict[str, State] = {}
def _append_recommendations(state: State, batch_size: int) -> pd.DataFrame:
# Ask BayBE for new points (requires explicit batch_size). :contentReference[oaicite:4]{index=4}
rec_df = state.campaign.recommend(batch_size=batch_size) # DataFrame with PARAM_COLS
# Drop any already-seen param combos (idempotency)
rec_df = rec_df.loc[~rec_df.apply(_param_key, axis=1).isin(state.seen_params)]
if rec_df.empty:
return state.csv_df
# Fill missing non-optimized columns and empty targets
rec_df = rec_df.copy()
rec_df["batch_id"] = [f"REC-{uuid4().hex[:6]}"] * len(rec_df)
rec_df["replicate"] = 1
for col in TARGET_COLS:
rec_df[col] = pd.NA
# Reorder to template
rec_df = rec_df[TEMPLATE_HEADERS]
# Append
state.csv_df = pd.concat([state.csv_df, rec_df], ignore_index=True)
# Update seen keys
for _, r in rec_df.iterrows():
state.seen_params.add(_param_key(r))
return state.csv_df
def _csv_response(df: pd.DataFrame) -> StreamingResponse:
buf = BytesIO()
df.to_csv(buf, index=False)
buf.seek(0)
resp = StreamingResponse(buf, media_type="text/csv")
resp.headers["X-CSV-Hash"] = _csv_hash(buf.getvalue())
return resp
@app.post("/campaigns/import_csv")
async def import_csv(file: UploadFile = File(...), batch_size: int = Form(3)):
raw = await file.read()
try:
df = pd.read_csv(StringIO(raw.decode("utf-8")))
except Exception as e:
raise HTTPException(400, f"Failed to read CSV: {e}")
_validate_headers(df)
df = _validate_and_cast(df)
cid = str(uuid4())
state = State(_default_campaign(), df.copy(), _csv_hash(raw))
STATES[cid] = state
# Append first recommendations and return CSV directly
_append_recommendations(state, batch_size=batch_size)
return _csv_response(state.csv_df)
@app.post("/campaigns/{cid}/sync")
async def sync_csv(cid: str, file: UploadFile = File(...), batch_size: int = Form(3), force: int = Form(0)):
if cid not in STATES:
raise HTTPException(404, "campaign not found")
state = STATES[cid]
raw = await file.read()
new_hash = _csv_hash(raw)
if new_hash == state.csv_hash and not force:
# Nothing changed upstream: just return current server CSV
return _csv_response(state.csv_df)
# Parse & validate the uploaded CSV
try:
df_up = pd.read_csv(StringIO(raw.decode("utf-8")))
except Exception as e:
raise HTTPException(400, f"Failed to read CSV: {e}")
_validate_headers(df_up)
df_up = _validate_and_cast(df_up)
# Find rows with complete targets (new measurements)
have_y = df_up[TARGET_COLS].notna().all(axis=1)
measured_up = df_up.loc[have_y, PARAM_COLS + TARGET_COLS]
# Deduplicate: only add parameter combos not yet seen as measured in server CSV
already_measured = state.csv_df[TARGET_COLS].notna().all(axis=1)
measured_server_keys = set(state.csv_df.loc[already_measured, PARAM_COLS].apply(_param_key, axis=1))
new_meas = measured_up.loc[~measured_up.apply(_param_key, axis=1).isin(measured_server_keys)]
if not new_meas.empty:
state.campaign.add_measurements(new_meas) # BayBE call. :contentReference[oaicite:5]{index=5}
# Merge upstream measured rows into server CSV (keep latest values)
# Align to template by adding aux cols
add_df = new_meas.copy()
add_df["batch_id"] = "MEAS"
add_df["replicate"] = 1
add_df = add_df[TEMPLATE_HEADERS]
state.csv_df = pd.concat([state.csv_df, add_df], ignore_index=True)
# Now append new recommendations and return CSV
_append_recommendations(state, batch_size=batch_size)
# Update the stored upstream hash to suppress repeats until the next change
state.csv_hash = new_hash
return _csv_response(state.csv_df)
@app.get("/campaigns/{cid}/export_csv")
def export_csv(cid: str):
if cid not in STATES:
raise HTTPException(404, "campaign not found")
return _csv_response(STATES[cid].csv_df)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment