Skip to content

Instantly share code, notes, and snippets.

@nboyd
Created October 27, 2025 14:59
Show Gist options
  • Select an option

  • Save nboyd/2185b69c15ca87000b31f2111d423167 to your computer and use it in GitHub Desktop.

Select an option

Save nboyd/2185b69c15ca87000b31f2111d423167 to your computer and use it in GitHub Desktop.
#########
#
# Low-hanging fruit:
# 1. Epitope selection! E.g. using Pesto-style predictor or something simple (SASA + charge?)
# 2. Cropping! Could be sped up by a factor of 5-10 by cropping the target properly
# 3. Early stopping. Kill runs that are going poorly in early stages
# 4. Filtering + ranking with more models
# 5. Different design models, hyperparameter tuning etc. E.g. Boltz2 + templates
#
import modal
MAX_TIME = int(9.0 * 60 * 60)
def download_protenix():
from mosaic.models.protenix import ProtenixMini
_ = ProtenixMini()
image = (
modal.Image.debian_slim()
.apt_install("git")
.run_commands("git clone https://github.com/escalante-bio/mosaic.git")
.workdir("mosaic")
.run_commands("uv pip install --system -r pyproject.toml")
.run_commands("uv pip install --system jax[cuda]")
.run_commands("uv pip install --system .")
.run_function(download_protenix)
)
app = modal.App("vhh", image=image)
@app.function(gpu="h100", timeout=int(20 * 60 * 60))
def f(pdb_string: str):
import time
import equinox as eqx
import gemmi
import jax
import jax.numpy as jnp
import numpy as np
import mosaic.losses.structure_prediction as sp
from mosaic.common import TOKENS
from mosaic.losses.ablang import AbLangPseudoLikelihood, load_ablang
from mosaic.losses.esmc import ESMCPseudoLikelihood, load_esmc
from mosaic.losses.protein_mpnn import (
InverseFoldingSequenceRecovery,
)
from mosaic.losses.transformations import SetPositions
from mosaic.models.protenix import ProtenixMini
from mosaic.optimizers import gradient_MCMC, simplex_APGM
from mosaic.proteinmpnn.mpnn import load_abmpnn
from mosaic.structure_prediction import TargetChain
from mosaic.losses.trigram import TrigramLL
APP_START_TIME = time.time()
d = jax.devices()[0]
print("Running on device:", d.device_kind)
if "H100" not in d.device_kind:
print(
"Warning: not running on H100! "
)
model = ProtenixMini()
# load target structure and clean up
target_structure = gemmi.read_pdb_string(pdb_string)
for i, r in enumerate(target_structure[0][0]):
r.seqid.num = i
target_structure.assign_label_seq_id(True)
target_structure.remove_ligands_and_waters()
target_structure.remove_alternative_conformations()
masked_framework_sequence = "QVQLVESGGGLVQPGGSLRLSCAASXXXXXXXXXXXLGWFRQAPGQGLEAVAAXXXXXXXXYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCXXXXXXXXXXXXXXXXXXWGQGTLVTVS"
N = len(masked_framework_sequence)
mpnn = load_abmpnn(backbone_noise=0.05)
# first we "precycle" the target: run 5 recycling iterations on the target alone to ensure the model has properly folded it
# this is basically to get around the fact that Protenix-mini doesn't support templates
target_features, _ = model.target_only_features(
chains=[
TargetChain(
sequence=gemmi.one_letter_code([r.name for r in target_structure[0][0]]),
use_msa=True,
),
]
)
model_output = model.model_output(
features=target_features,
recycling_steps=5,
key=jax.random.key(np.random.randint(10000)),
)
_target_embedding = model_output.trunk_state
# add zeros for binder to target_embedding
M = len(target_structure[0][0])
padded_embedding = eqx.tree_at(
lambda e: (e.s, e.z),
_target_embedding,
(
jnp.zeros((N + M, 384)).at[N:].set(_target_embedding.s),
jnp.zeros((N + M, N + M, 128)).at[N:, N:].set(_target_embedding.z),
),
)
# next we load features for actual design
features, writer = model.target_only_features(
chains=[
TargetChain(
sequence=masked_framework_sequence,
use_msa=True,
),
TargetChain(
sequence=gemmi.one_letter_code([r.name for r in target_structure[0][0]]),
use_msa=True,
),
]
)
# we'll also add AbLang LL
ablang, ablang_tokenizer = load_ablang("heavy")
ablang_pll = AbLangPseudoLikelihood(
model=ablang,
tokenizer=ablang_tokenizer,
stop_grad=True,
)
# and ESMC PLL
ESMCPLL = ESMCPseudoLikelihood(load_esmc("esmc_300m"), stop_grad=True)
structure_loss = (
0.1 * sp.PLDDTLoss()
+ 1
* sp.BinderTargetContact(
paratope_idx=np.array(
[i for (i, c) in enumerate(masked_framework_sequence) if c == "X"]
),
)
+ 0.50 * sp.TargetBinderPAE()
+ 0.05 * sp.BinderTargetPAE()
+ 0.25 * sp.IPTMLoss()
+ 0.2 * sp.WithinBinderPAE()
+ 0.5 * sp.WithinBinderContact()
+ 0.0
* InverseFoldingSequenceRecovery(
mpnn,
temp=jax.device_put(0.0001),
bias=50.0
* jax.nn.one_hot(
SetPositions.from_sequence(
wildtype=masked_framework_sequence, loss=None
).wildtype,
20,
),
)
# + 0.1 * ProteinMPNNLoss(mpnn, num_samples=4)
)
model_loss = model.build_loss(
loss=structure_loss,
features=features,
recycling_steps=2,
return_coords=False,
return_state=False,
)
# set initial recycling state from precycling
model_loss = eqx.tree_at(
lambda l: l.initial_recycling_state,
model_loss,
padded_embedding,
is_leaf=lambda x: x is None,
)
# add a small trigram LL term (mostly to avoid homopolymer runs)
trigram_ll = TrigramLL.from_pkl()
loss = SetPositions.from_sequence(
wildtype=masked_framework_sequence,
loss=0.1 * ESMCPLL + 2 * model_loss + 0.1 * ablang_pll + trigram_ll,
)
results = []
while (time.time() - APP_START_TIME) < MAX_TIME:
start_time = time.time()
_pssm = 0.5 * jax.random.gumbel(
key=jax.random.key(np.random.randint(1000000)), shape=(37, 20)
)
_, partial_pssm = simplex_APGM(
loss_function=loss,
x=_pssm,
n_steps=20,
stepsize=1.5 * np.sqrt(_pssm.shape[0]),
momentum=0.0,
scale=1.00,
serial_evaluation=False,
logspace=True,
max_gradient_norm=1.0,
)
_, partial_pssm = simplex_APGM(
loss_function=loss,
x=partial_pssm,
n_steps=30,
stepsize=0.5 * np.sqrt(_pssm.shape[0]),
momentum=0.0,
scale=1.1,
serial_evaluation=False,
logspace=False,
max_gradient_norm=1.0,
)
print("".join(TOKENS[i] for i in partial_pssm.argmax(-1)))
_, partial_pssm = simplex_APGM(
loss_function=loss,
x=jnp.log(partial_pssm + 1e-5),
n_steps=30,
stepsize=0.25 * np.sqrt(_pssm.shape[0]),
momentum=0.0,
scale=1.1,
serial_evaluation=False,
logspace=True,
max_gradient_norm=1.0,
)
print("".join(TOKENS[i] for i in partial_pssm.argmax(-1)))
s_mcmc = gradient_MCMC(
loss=loss,
sequence=jax.device_put(partial_pssm.argmax(-1)),
steps=30,
fix_loss_key=False,
proposal_temp=1e-5,
max_path_length=1,
)
finish_time = time.time()
print(f"Design took {finish_time - start_time:.2f} seconds")
prediction_inpaint = model.predict(
PSSM=loss.sequence(jax.nn.one_hot(s_mcmc, 20)),
writer=writer,
features=features,
recycling_steps=3,
key=jax.random.key(np.random.randint(10000)),
)
design_str = "".join(
TOKENS[i] for i in loss.sequence(jax.nn.one_hot(s_mcmc, 20)).argmax(-1)
)
results.append((float(prediction_inpaint.iptm), design_str))
print(f"IPTM: {prediction_inpaint.iptm:.2f} \t" + design_str)
return sorted(results, reverse=True)
@app.local_entrypoint()
def main(input_pdb: str):
from pathlib import Path
iptms_and_designs = f.remote(pdb_string=Path(input_pdb).read_text())
# dump to mosaic.faa
Path("mosaic.faa").write_text(
"\n".join(
f">design_{i}_iptm{iptm:.4f}\n{design}"
for i, (iptm, design) in enumerate(iptms_and_designs)
)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment