Created
August 20, 2025 19:53
-
-
Save zhensongren/db8775e8793dd2ca73c3d527d001f8be to your computer and use it in GitHub Desktop.
BO api
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # app/main.py | |
| from io import StringIO, BytesIO | |
| from typing import Dict, List, Tuple, Optional | |
| from uuid import uuid4 | |
| import hashlib | |
| import pandas as pd | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| # ---- BayBE (verified API) | |
| from baybe import Campaign | |
| from baybe.parameters import CategoricalParameter, NumericalContinuousParameter | |
| from baybe.searchspace import SearchSpace | |
| from baybe.targets import NumericalTarget | |
| from baybe.objectives import ParetoObjective | |
| from baybe.recommenders import TwoPhaseMetaRecommender, FPSRecommender, BotorchRecommender | |
| # BayBE docs show: Campaign(...), add_measurements(...), recommend(batch_size=K). :contentReference[oaicite:2]{index=2} | |
| app = FastAPI(title="BayBE CSV Sync API") | |
| # --- Fixed template headers | |
| TEMPLATE_HEADERS = [ | |
| "solvent","temp_C","pressure_bar","resin_pct","batch_id","replicate","y_yield_pct","y_defect_rate" | |
| ] | |
| PARAM_COLS = ["solvent","temp_C","pressure_bar","resin_pct"] | |
| TARGET_COLS = ["y_yield_pct","y_defect_rate"] | |
| # --- Allowed values / bounds (match your Excel template) | |
| SOLVENT_LEVELS = ["A","B","C"] | |
| BOUNDS = { | |
| "temp_C": (0.0, 200.0), | |
| "pressure_bar": (0.0, 5.0), | |
| "resin_pct": (0.0, 100.0), | |
| "y_yield_pct": (0.0, 100.0), | |
| "y_defect_rate": (0.0, 1.0), | |
| } | |
| def _csv_hash(raw_bytes: bytes) -> str: | |
| return hashlib.sha256(raw_bytes).hexdigest() | |
| def _validate_headers(df: pd.DataFrame): | |
| if list(df.columns) != TEMPLATE_HEADERS: | |
| raise HTTPException(400, f"CSV headers must exactly match: {', '.join(TEMPLATE_HEADERS)}") | |
| def _validate_and_cast(df: pd.DataFrame) -> pd.DataFrame: | |
| out = df.copy() | |
| # solvent | |
| bad = ~out["solvent"].isin(SOLVENT_LEVELS) & out["solvent"].notna() | |
| if bad.any(): | |
| raise HTTPException(400, f"Invalid solvent values: {sorted(out.loc[bad,'solvent'].astype(str).unique())}") | |
| # numeric bounds | |
| for col in ["temp_C","pressure_bar","resin_pct","y_yield_pct","y_defect_rate"]: | |
| if out[col].notna().any(): | |
| out[col] = pd.to_numeric(out[col], errors="raise") | |
| lo, hi = BOUNDS[col] | |
| bad_range = ~out[col].between(lo, hi) & out[col].notna() | |
| if bad_range.any(): | |
| raise HTTPException(400, f"{col} has values outside [{lo}, {hi}]") | |
| # replicate integer >=1 (when present) | |
| if out["replicate"].notna().any(): | |
| rep = pd.to_numeric(out["replicate"], errors="raise").astype(int) | |
| if (rep < 1).any(): | |
| raise HTTPException(400, "replicate must be >= 1") | |
| out["replicate"] = rep | |
| # batch_id passthrough | |
| out["batch_id"] = out["batch_id"].astype(str) | |
| return out | |
| def _param_key(row: pd.Series) -> Tuple: | |
| return tuple(row[c] for c in PARAM_COLS) | |
| def _default_campaign() -> Campaign: | |
| # Search space & objective | |
| ss = SearchSpace.from_product([ | |
| CategoricalParameter(name="solvent", values=SOLVENT_LEVELS), | |
| NumericalContinuousParameter(name="temp_C", bounds=BOUNDS["temp_C"]), | |
| NumericalContinuousParameter(name="pressure_bar", bounds=BOUNDS["pressure_bar"]), | |
| NumericalContinuousParameter(name="resin_pct", bounds=BOUNDS["resin_pct"]), | |
| ]) | |
| objective = ParetoObjective([ | |
| NumericalTarget(name="y_yield_pct", mode="MAX"), | |
| NumericalTarget(name="y_defect_rate", mode="MIN"), | |
| ]) | |
| rec = TwoPhaseMetaRecommender(FPSRecommender(), BotorchRecommender()) | |
| return Campaign(searchspace=ss, objective=objective, recommender=rec) | |
| # --- Server state per campaign | |
| class State: | |
| def __init__(self, campaign: Campaign, csv_df: pd.DataFrame, csv_hash: str): | |
| self.campaign = campaign | |
| self.csv_df = csv_df # authoritative combined CSV (measured + recommendations) | |
| self.csv_hash = csv_hash | |
| # Track seen parameter combinations (measured or recommended) to dedupe | |
| self.seen_params = set(_param_key(r) for _, r in csv_df[PARAM_COLS].dropna().iterrows()) | |
| # Seed database with rows that have both targets present | |
| have_y = csv_df[TARGET_COLS].notna().all(axis=1) | |
| seed_df = csv_df.loc[have_y, PARAM_COLS + TARGET_COLS] | |
| if not seed_df.empty: | |
| self.campaign.add_measurements(seed_df) # BayBE call (DataFrame in). :contentReference[oaicite:3]{index=3} | |
| STATES: Dict[str, State] = {} | |
| def _append_recommendations(state: State, batch_size: int) -> pd.DataFrame: | |
| # Ask BayBE for new points (requires explicit batch_size). :contentReference[oaicite:4]{index=4} | |
| rec_df = state.campaign.recommend(batch_size=batch_size) # DataFrame with PARAM_COLS | |
| # Drop any already-seen param combos (idempotency) | |
| rec_df = rec_df.loc[~rec_df.apply(_param_key, axis=1).isin(state.seen_params)] | |
| if rec_df.empty: | |
| return state.csv_df | |
| # Fill missing non-optimized columns and empty targets | |
| rec_df = rec_df.copy() | |
| rec_df["batch_id"] = [f"REC-{uuid4().hex[:6]}"] * len(rec_df) | |
| rec_df["replicate"] = 1 | |
| for col in TARGET_COLS: | |
| rec_df[col] = pd.NA | |
| # Reorder to template | |
| rec_df = rec_df[TEMPLATE_HEADERS] | |
| # Append | |
| state.csv_df = pd.concat([state.csv_df, rec_df], ignore_index=True) | |
| # Update seen keys | |
| for _, r in rec_df.iterrows(): | |
| state.seen_params.add(_param_key(r)) | |
| return state.csv_df | |
| def _csv_response(df: pd.DataFrame) -> StreamingResponse: | |
| buf = BytesIO() | |
| df.to_csv(buf, index=False) | |
| buf.seek(0) | |
| resp = StreamingResponse(buf, media_type="text/csv") | |
| resp.headers["X-CSV-Hash"] = _csv_hash(buf.getvalue()) | |
| return resp | |
| @app.post("/campaigns/import_csv") | |
| async def import_csv(file: UploadFile = File(...), batch_size: int = Form(3)): | |
| raw = await file.read() | |
| try: | |
| df = pd.read_csv(StringIO(raw.decode("utf-8"))) | |
| except Exception as e: | |
| raise HTTPException(400, f"Failed to read CSV: {e}") | |
| _validate_headers(df) | |
| df = _validate_and_cast(df) | |
| cid = str(uuid4()) | |
| state = State(_default_campaign(), df.copy(), _csv_hash(raw)) | |
| STATES[cid] = state | |
| # Append first recommendations and return CSV directly | |
| _append_recommendations(state, batch_size=batch_size) | |
| return _csv_response(state.csv_df) | |
| @app.post("/campaigns/{cid}/sync") | |
| async def sync_csv(cid: str, file: UploadFile = File(...), batch_size: int = Form(3), force: int = Form(0)): | |
| if cid not in STATES: | |
| raise HTTPException(404, "campaign not found") | |
| state = STATES[cid] | |
| raw = await file.read() | |
| new_hash = _csv_hash(raw) | |
| if new_hash == state.csv_hash and not force: | |
| # Nothing changed upstream: just return current server CSV | |
| return _csv_response(state.csv_df) | |
| # Parse & validate the uploaded CSV | |
| try: | |
| df_up = pd.read_csv(StringIO(raw.decode("utf-8"))) | |
| except Exception as e: | |
| raise HTTPException(400, f"Failed to read CSV: {e}") | |
| _validate_headers(df_up) | |
| df_up = _validate_and_cast(df_up) | |
| # Find rows with complete targets (new measurements) | |
| have_y = df_up[TARGET_COLS].notna().all(axis=1) | |
| measured_up = df_up.loc[have_y, PARAM_COLS + TARGET_COLS] | |
| # Deduplicate: only add parameter combos not yet seen as measured in server CSV | |
| already_measured = state.csv_df[TARGET_COLS].notna().all(axis=1) | |
| measured_server_keys = set(state.csv_df.loc[already_measured, PARAM_COLS].apply(_param_key, axis=1)) | |
| new_meas = measured_up.loc[~measured_up.apply(_param_key, axis=1).isin(measured_server_keys)] | |
| if not new_meas.empty: | |
| state.campaign.add_measurements(new_meas) # BayBE call. :contentReference[oaicite:5]{index=5} | |
| # Merge upstream measured rows into server CSV (keep latest values) | |
| # Align to template by adding aux cols | |
| add_df = new_meas.copy() | |
| add_df["batch_id"] = "MEAS" | |
| add_df["replicate"] = 1 | |
| add_df = add_df[TEMPLATE_HEADERS] | |
| state.csv_df = pd.concat([state.csv_df, add_df], ignore_index=True) | |
| # Now append new recommendations and return CSV | |
| _append_recommendations(state, batch_size=batch_size) | |
| # Update the stored upstream hash to suppress repeats until the next change | |
| state.csv_hash = new_hash | |
| return _csv_response(state.csv_df) | |
| @app.get("/campaigns/{cid}/export_csv") | |
| def export_csv(cid: str): | |
| if cid not in STATES: | |
| raise HTTPException(404, "campaign not found") | |
| return _csv_response(STATES[cid].csv_df) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment