Created
October 27, 2025 14:59
-
-
Save nboyd/2185b69c15ca87000b31f2111d423167 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ######### | |
| # | |
| # Low-hanging fruit: | |
| # 1. Epitope selection! E.g. using Pesto-style predictor or something simple (SASA + charge?) | |
| # 2. Cropping! Could be sped up by a factor of 5-10 by cropping the target properly | |
| # 3. Early stopping. Kill runs that are going poorly in early stages | |
| # 4. Filtering + ranking with more models | |
| # 5. Different design models, hyperparameter tuning etc. E.g. Boltz2 + templates | |
| # | |
| import modal | |
| MAX_TIME = int(9.0 * 60 * 60) | |
| def download_protenix(): | |
| from mosaic.models.protenix import ProtenixMini | |
| _ = ProtenixMini() | |
| image = ( | |
| modal.Image.debian_slim() | |
| .apt_install("git") | |
| .run_commands("git clone https://github.com/escalante-bio/mosaic.git") | |
| .workdir("mosaic") | |
| .run_commands("uv pip install --system -r pyproject.toml") | |
| .run_commands("uv pip install --system jax[cuda]") | |
| .run_commands("uv pip install --system .") | |
| .run_function(download_protenix) | |
| ) | |
| app = modal.App("vhh", image=image) | |
| @app.function(gpu="h100", timeout=int(20 * 60 * 60)) | |
| def f(pdb_string: str): | |
| import time | |
| import equinox as eqx | |
| import gemmi | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import mosaic.losses.structure_prediction as sp | |
| from mosaic.common import TOKENS | |
| from mosaic.losses.ablang import AbLangPseudoLikelihood, load_ablang | |
| from mosaic.losses.esmc import ESMCPseudoLikelihood, load_esmc | |
| from mosaic.losses.protein_mpnn import ( | |
| InverseFoldingSequenceRecovery, | |
| ) | |
| from mosaic.losses.transformations import SetPositions | |
| from mosaic.models.protenix import ProtenixMini | |
| from mosaic.optimizers import gradient_MCMC, simplex_APGM | |
| from mosaic.proteinmpnn.mpnn import load_abmpnn | |
| from mosaic.structure_prediction import TargetChain | |
| from mosaic.losses.trigram import TrigramLL | |
| APP_START_TIME = time.time() | |
| d = jax.devices()[0] | |
| print("Running on device:", d.device_kind) | |
| if "H100" not in d.device_kind: | |
| print( | |
| "Warning: not running on H100! " | |
| ) | |
| model = ProtenixMini() | |
| # load target structure and clean up | |
| target_structure = gemmi.read_pdb_string(pdb_string) | |
| for i, r in enumerate(target_structure[0][0]): | |
| r.seqid.num = i | |
| target_structure.assign_label_seq_id(True) | |
| target_structure.remove_ligands_and_waters() | |
| target_structure.remove_alternative_conformations() | |
| masked_framework_sequence = "QVQLVESGGGLVQPGGSLRLSCAASXXXXXXXXXXXLGWFRQAPGQGLEAVAAXXXXXXXXYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCXXXXXXXXXXXXXXXXXXWGQGTLVTVS" | |
| N = len(masked_framework_sequence) | |
| mpnn = load_abmpnn(backbone_noise=0.05) | |
| # first we "precycle" the target: run 5 recycling iterations on the target alone to ensure the model has properly folded it | |
| # this is basically to get around the fact that Protenix-mini doesn't support templates | |
| target_features, _ = model.target_only_features( | |
| chains=[ | |
| TargetChain( | |
| sequence=gemmi.one_letter_code([r.name for r in target_structure[0][0]]), | |
| use_msa=True, | |
| ), | |
| ] | |
| ) | |
| model_output = model.model_output( | |
| features=target_features, | |
| recycling_steps=5, | |
| key=jax.random.key(np.random.randint(10000)), | |
| ) | |
| _target_embedding = model_output.trunk_state | |
| # add zeros for binder to target_embedding | |
| M = len(target_structure[0][0]) | |
| padded_embedding = eqx.tree_at( | |
| lambda e: (e.s, e.z), | |
| _target_embedding, | |
| ( | |
| jnp.zeros((N + M, 384)).at[N:].set(_target_embedding.s), | |
| jnp.zeros((N + M, N + M, 128)).at[N:, N:].set(_target_embedding.z), | |
| ), | |
| ) | |
| # next we load features for actual design | |
| features, writer = model.target_only_features( | |
| chains=[ | |
| TargetChain( | |
| sequence=masked_framework_sequence, | |
| use_msa=True, | |
| ), | |
| TargetChain( | |
| sequence=gemmi.one_letter_code([r.name for r in target_structure[0][0]]), | |
| use_msa=True, | |
| ), | |
| ] | |
| ) | |
| # we'll also add AbLang LL | |
| ablang, ablang_tokenizer = load_ablang("heavy") | |
| ablang_pll = AbLangPseudoLikelihood( | |
| model=ablang, | |
| tokenizer=ablang_tokenizer, | |
| stop_grad=True, | |
| ) | |
| # and ESMC PLL | |
| ESMCPLL = ESMCPseudoLikelihood(load_esmc("esmc_300m"), stop_grad=True) | |
| structure_loss = ( | |
| 0.1 * sp.PLDDTLoss() | |
| + 1 | |
| * sp.BinderTargetContact( | |
| paratope_idx=np.array( | |
| [i for (i, c) in enumerate(masked_framework_sequence) if c == "X"] | |
| ), | |
| ) | |
| + 0.50 * sp.TargetBinderPAE() | |
| + 0.05 * sp.BinderTargetPAE() | |
| + 0.25 * sp.IPTMLoss() | |
| + 0.2 * sp.WithinBinderPAE() | |
| + 0.5 * sp.WithinBinderContact() | |
| + 0.0 | |
| * InverseFoldingSequenceRecovery( | |
| mpnn, | |
| temp=jax.device_put(0.0001), | |
| bias=50.0 | |
| * jax.nn.one_hot( | |
| SetPositions.from_sequence( | |
| wildtype=masked_framework_sequence, loss=None | |
| ).wildtype, | |
| 20, | |
| ), | |
| ) | |
| # + 0.1 * ProteinMPNNLoss(mpnn, num_samples=4) | |
| ) | |
| model_loss = model.build_loss( | |
| loss=structure_loss, | |
| features=features, | |
| recycling_steps=2, | |
| return_coords=False, | |
| return_state=False, | |
| ) | |
| # set initial recycling state from precycling | |
| model_loss = eqx.tree_at( | |
| lambda l: l.initial_recycling_state, | |
| model_loss, | |
| padded_embedding, | |
| is_leaf=lambda x: x is None, | |
| ) | |
| # add a small trigram LL term (mostly to avoid homopolymer runs) | |
| trigram_ll = TrigramLL.from_pkl() | |
| loss = SetPositions.from_sequence( | |
| wildtype=masked_framework_sequence, | |
| loss=0.1 * ESMCPLL + 2 * model_loss + 0.1 * ablang_pll + trigram_ll, | |
| ) | |
| results = [] | |
| while (time.time() - APP_START_TIME) < MAX_TIME: | |
| start_time = time.time() | |
| _pssm = 0.5 * jax.random.gumbel( | |
| key=jax.random.key(np.random.randint(1000000)), shape=(37, 20) | |
| ) | |
| _, partial_pssm = simplex_APGM( | |
| loss_function=loss, | |
| x=_pssm, | |
| n_steps=20, | |
| stepsize=1.5 * np.sqrt(_pssm.shape[0]), | |
| momentum=0.0, | |
| scale=1.00, | |
| serial_evaluation=False, | |
| logspace=True, | |
| max_gradient_norm=1.0, | |
| ) | |
| _, partial_pssm = simplex_APGM( | |
| loss_function=loss, | |
| x=partial_pssm, | |
| n_steps=30, | |
| stepsize=0.5 * np.sqrt(_pssm.shape[0]), | |
| momentum=0.0, | |
| scale=1.1, | |
| serial_evaluation=False, | |
| logspace=False, | |
| max_gradient_norm=1.0, | |
| ) | |
| print("".join(TOKENS[i] for i in partial_pssm.argmax(-1))) | |
| _, partial_pssm = simplex_APGM( | |
| loss_function=loss, | |
| x=jnp.log(partial_pssm + 1e-5), | |
| n_steps=30, | |
| stepsize=0.25 * np.sqrt(_pssm.shape[0]), | |
| momentum=0.0, | |
| scale=1.1, | |
| serial_evaluation=False, | |
| logspace=True, | |
| max_gradient_norm=1.0, | |
| ) | |
| print("".join(TOKENS[i] for i in partial_pssm.argmax(-1))) | |
| s_mcmc = gradient_MCMC( | |
| loss=loss, | |
| sequence=jax.device_put(partial_pssm.argmax(-1)), | |
| steps=30, | |
| fix_loss_key=False, | |
| proposal_temp=1e-5, | |
| max_path_length=1, | |
| ) | |
| finish_time = time.time() | |
| print(f"Design took {finish_time - start_time:.2f} seconds") | |
| prediction_inpaint = model.predict( | |
| PSSM=loss.sequence(jax.nn.one_hot(s_mcmc, 20)), | |
| writer=writer, | |
| features=features, | |
| recycling_steps=3, | |
| key=jax.random.key(np.random.randint(10000)), | |
| ) | |
| design_str = "".join( | |
| TOKENS[i] for i in loss.sequence(jax.nn.one_hot(s_mcmc, 20)).argmax(-1) | |
| ) | |
| results.append((float(prediction_inpaint.iptm), design_str)) | |
| print(f"IPTM: {prediction_inpaint.iptm:.2f} \t" + design_str) | |
| return sorted(results, reverse=True) | |
| @app.local_entrypoint() | |
| def main(input_pdb: str): | |
| from pathlib import Path | |
| iptms_and_designs = f.remote(pdb_string=Path(input_pdb).read_text()) | |
| # dump to mosaic.faa | |
| Path("mosaic.faa").write_text( | |
| "\n".join( | |
| f">design_{i}_iptm{iptm:.4f}\n{design}" | |
| for i, (iptm, design) in enumerate(iptms_and_designs) | |
| ) | |
| ) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment