Files
foundry/examples/all.ipynb
2025-12-17 08:34:30 -08:00

427 lines
15 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "f3df148e",
"metadata": {},
"source": [
"# Example: End-To-End *De Novo* Protein Design Pipeline\n",
"\n",
"## Overview\n",
"\n",
"This notebook demonstrates an end-to-end protein design workflow using three deep learning networks from the Institute for Protein Design:\n",
"\n",
"| Step | Model | Purpose |\n",
"|------|-------|---------|\n",
"| 1. **Backbone Generation** | RFD3 | Generate novel protein backbones via diffusion |\n",
"| 2. **Sequence Design** | MPNN | Design amino acid sequences for the generated backbone |\n",
"| 3. **Structure Validation** | RF3 | Predict the structure from designed sequence to validate designability |\n",
"\n",
"All models are unified through [AtomWorks](https://github.com/RosettaCommons/atomworks) (for both inference and training), relying on Biotite `AtomArray` objects.\n",
"\n",
"This notebook assumes you have the base checkpoints downloaded: `foundry install rfd3 ligandmpnn rf3`. You can also specify the paths directly yourself if you wish. You can register your foundry venv to jupyter with: `python -m ipykernel install --user --name=foundry --display-name \"foundry\"`.\n",
"\n",
"### Pipeline Flow\n",
"```\n",
"RFD3 (backbone) → MPNN (sequence) → RF3 (validation) → RMSD comparison\n",
"```\n",
"---\n",
"\n",
"## Section 0: Installation\n",
"\n",
"Install the Foundry package (includes RFD3, MPNN, and RF3):\n",
"\n",
"```bash\n",
"pip install 'rc-foundry[all]'\n",
"```\n",
"\n",
"Download the model weights (~6GB total, takes a couple minutes):\n",
"\n",
"```bash\n",
"foundry install rfd3 ligandmpnn rf3\n",
"```\n",
"\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "819e8193",
"metadata": {},
"outputs": [],
"source": [
"# Shared utilities for visualization (from AtomWorks)\n",
"from atomworks.io.utils.visualize import view"
]
},
{
"cell_type": "markdown",
"id": "a7tw5gds8p",
"metadata": {},
"source": [
"## Section 1: Backbone Generation with RFD3\n",
"\n",
"RFdiffusion3 (RFD3) generates *de novo* all-atom proteins that meet specific conditioning requirements.\n",
"\n",
"**Parameters Used** *(many more are available for more complex protein design tasks)*:\n",
"- `length`: Target protein length in residues\n",
"- `diffusion_batch_size`: Number of structures to generate per batch\n",
"- `n_batches`: Number of batches to run\n",
"\n",
"**Outputs:** Dictionary of `RFD3Output` objects."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d16cb95b-3f4c-4167-952b-278bdb561bf7",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"from lightning.fabric import seed_everything\n",
"from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine\n",
"\n",
"# Set seed for reproducibility\n",
"seed_everything(0)\n",
"\n",
"# Configure RFD3 inference\n",
"config = RFD3InferenceConfig(\n",
" specification={\n",
" 'length': 80, # Generate 80-residue proteins\n",
" 'extra': {}, # We are not using any extra specifications here.\n",
" },\n",
" diffusion_batch_size=2, # Generate 2 structures per batch\n",
")\n",
"\n",
"# Initialize engine and run generation\n",
"model = RFD3InferenceEngine(**config)\n",
"outputs = model.run(\n",
" inputs=None, # None for unconditional generation\n",
" out_dir=None, # None to return in memory (no file output)\n",
" n_batches=1, # Generate 1 batch\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "86adad90",
"metadata": {},
"outputs": [],
"source": [
"# View generated example IDs (one key per generated structure)\n",
"outputs.keys()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "35253ef3-2ce1-4dca-958a-63b359b70d73",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Inspect RFD3 outputs and extract the generated backbone\n",
"for idx, data in outputs.items():\n",
" print(f\"Batch {idx}: {len(data)} structure(s)\")\n",
" print(f\" Output type: {type(data[0]).__name__}\")\n",
" print(f\" AtomArray: {data[0].atom_array}\")\n",
"\n",
"# Extract the first generated backbone for downstream use\n",
"first_key = next(iter(outputs.keys()))\n",
"atom_array = outputs[first_key][0].atom_array\n",
"\n",
"# Visualize the generated backbone\n",
"view(atom_array)"
]
},
{
"cell_type": "markdown",
"id": "1cziwe2nb26h",
"metadata": {},
"source": [
"---\n",
"\n",
"## Section 2: Sequence Design with MPNN\n",
"\n",
"Protein and Ligand MPNN (Message Passing Neural Network) designs amino acid sequences that will fold into a target backbone structure.\n",
"\n",
"**Model Options:**\n",
"- `protein_mpnn`: Original ProteinMPNN for protein-only design\n",
"- `ligand_mpnn`: Extended model supporting ligand-aware design\n",
"\n",
"**Key Parameters:**\n",
"- `batch_size`: Number of sequences to generate per structure\n",
"- `remove_waters`: Whether to exclude water molecules from context"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d07ae413",
"metadata": {},
"outputs": [],
"source": [
"from mpnn.inference_engines.mpnn import MPNNInferenceEngine\n",
"\n",
"# Configure MPNN inference engine\n",
"# See mpnn.utils.inference.MPNN_GLOBAL_INFERENCE_DEFAULTS for all options\n",
"engine_config = {\n",
" \"model_type\": \"ligand_mpnn\", # or \"protein_mpnn\" for vanilla ProteinMPNN\n",
" \"is_legacy_weights\": True, # Required for now for ligand_mpnn and protein_mpnn\n",
" \"out_directory\": None, # Return results in memory\n",
" \"write_structures\": False,\n",
" \"write_fasta\": False,\n",
"}\n",
"\n",
"# Configure per-input inference options\n",
"# See mpnn.utils.inference.MPNN_PER_INPUT_INFERENCE_DEFAULTS for all options\n",
"input_configs = [\n",
" {\n",
" \"batch_size\": 10, # Generate 10 sequences per structure\n",
" \"remove_waters\": True,\n",
" }\n",
"]\n",
"\n",
"# Run sequence design on the RFD3-generated backbone\n",
"model = MPNNInferenceEngine(**engine_config)\n",
"mpnn_outputs = model.run(input_dicts=input_configs, atom_arrays=[atom_array])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75f5e558-3323-4094-9cce-c5a19b54f4cf",
"metadata": {},
"outputs": [],
"source": [
"from biotite.structure import get_residue_starts\n",
"from biotite.sequence import ProteinSequence\n",
"\n",
"# Extract and display the designed sequences\n",
"print(f\"Generated {len(mpnn_outputs)} designed sequences:\\n\")\n",
"\n",
"for i, item in enumerate(mpnn_outputs):\n",
" res_starts = get_residue_starts(item.atom_array)\n",
" # Convert 3-letter codes to 1-letter using Biotite\n",
" seq_1letter = ''.join(\n",
" ProteinSequence.convert_letter_3to1(res_name)\n",
" for res_name in item.atom_array.res_name[res_starts]\n",
" )\n",
" print(f\"Sequence {i+1}: {seq_1letter}\")"
]
},
{
"cell_type": "markdown",
"id": "te84j9ce9rn",
"metadata": {},
"source": [
"---\n",
"\n",
"## Section 3: Structure Prediction with RF3\n",
"\n",
"RF3 (RoseTTAFold 3) predicts protein structures from sequences. By re-folding the MPNN-designed sequence, we can validate whether the design is likely to adopt the intended backbone structure.\n",
"\n",
"**Outputs:** `RF3Output` objects containing:\n",
"- `atom_array`: Predicted structure as Biotite AtomArray\n",
"- `summary_confidences`: Overall confidence metrics (pLDDT, PAE, pTM, etc.)\n",
"- `confidences`: Per-atom/residue confidence scores\n",
"\n",
"**Confidence Metrics:**\n",
"| Metric | Description |\n",
"|--------|-------------|\n",
"| pLDDT | Per-residue confidence (0-1, higher is better) |\n",
"| PAE | Predicted Aligned Error (lower is better) |\n",
"| pTM | Predicted TM-score |\n",
"| ranking_score | Overall model quality score |"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d85eda46-4dca-4e92-8946-73b69a5536e3",
"metadata": {},
"outputs": [],
"source": [
"from rf3.inference_engines.rf3 import RF3InferenceEngine\n",
"from rf3.utils.inference import InferenceInput\n",
"\n",
"\n",
"# Initialize RF3 inference engine\n",
"inference_engine = RF3InferenceEngine(ckpt_path='rf3', verbose=False)\n",
"\n",
"# Create input from the MPNN-designed structure (first design)\n",
"# This re-folds the sequence to validate it adopts the intended structure\n",
"input_structure = InferenceInput.from_atom_array(atom_array, example_id=\"example_protein\")\n",
"rf3_outputs = inference_engine.run(inputs=input_structure)\n",
"\n",
"# Outputs: dict mapping example_id -> list[RF3Output] (multiple models per input)\n",
"print(f\"Output keys: {rf3_outputs.keys()}\")\n",
"print(f\"Number of models for 'example_protein': {len(rf3_outputs['example_protein'])}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f7113c03",
"metadata": {},
"outputs": [],
"source": [
"# Extract the top-ranked prediction\n",
"rf3_output = rf3_outputs[\"example_protein\"][0]\n",
"\n",
"# Inspect RF3Output structure\n",
"print(f\"RF3Output contains:\")\n",
"print(f\" - atom_array: {len(rf3_output.atom_array)} atoms\")\n",
"print(f\" - summary_confidences: {list(rf3_output.summary_confidences.keys())}\")\n",
"print(f\" - confidences: {list(rf3_output.confidences.keys()) if rf3_output.confidences else None}\")\n",
"\n",
"# Visualize the predicted structure\n",
"view(rf3_output.atom_array)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6mqnovlrygo",
"metadata": {},
"outputs": [],
"source": [
"# Summary confidences: overall model quality metrics\n",
"summary = rf3_output.summary_confidences\n",
"\n",
"print(\"=== Summary Confidences ===\")\n",
"print(f\" Overall pLDDT: {summary['overall_plddt']:.3f}\")\n",
"print(f\" Overall PAE: {summary['overall_pae']:.2f} A\")\n",
"print(f\" Overall PDE: {summary['overall_pde']:.3f}\")\n",
"print(f\" pTM: {summary['ptm']:.3f}\")\n",
"print(f\" ipTM: {summary.get('iptm', 'N/A (single chain)')}\")\n",
"print(f\" Ranking score: {summary['ranking_score']:.3f}\")\n",
"print(f\" Has clash: {summary['has_clash']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "g5gupirgy6",
"metadata": {},
"outputs": [],
"source": [
"# Detailed per-atom/residue confidences\n",
"conf = rf3_output.confidences\n",
"\n",
"print(\"=== Per-Atom/Residue Confidences ===\")\n",
"print(f\" atom_plddts: {len(conf['atom_plddts'])} values (one per atom)\")\n",
"print(f\" atom_chain_ids: {len(conf['atom_chain_ids'])} values\")\n",
"print(f\" token_chain_ids: {len(conf['token_chain_ids'])} values (one per residue)\")\n",
"print(f\" token_res_ids: {len(conf['token_res_ids'])} values\")\n",
"print(f\" PAE matrix: {len(conf['pae'])}x{len(conf['pae'][0])}\")\n",
"\n",
"# Preview first 10 atom pLDDT scores\n",
"import numpy as np\n",
"print(f\"\\nFirst 10 atom pLDDTs: {np.round(conf['atom_plddts'][:10], 2).tolist()}\")"
]
},
{
"cell_type": "markdown",
"id": "0epbqi91bv3n",
"metadata": {},
"source": [
"---\n",
"\n",
"## Section 4: Validation and Export\n",
"\n",
"The final step compares the RF3-predicted structure against the original RFD3-generated backbone. A low backbone RMSD indicates the designed sequence is likely to fold into the intended structure (high designability)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "238a0241",
"metadata": {},
"outputs": [],
"source": [
"from biotite.structure import rmsd, superimpose\n",
"from atomworks.constants import PROTEIN_BACKBONE_ATOM_NAMES\n",
"import numpy as np\n",
"\n",
"# Get structures for comparison\n",
"aa_generated = atom_array # Original RFD3 backbone (from Section 1)\n",
"aa_refolded = rf3_output.atom_array # RF3-predicted structure\n",
"\n",
"# Filter to backbone atoms (N, CA, C, O)\n",
"bb_generated = aa_generated[np.isin(aa_generated.atom_name, PROTEIN_BACKBONE_ATOM_NAMES)]\n",
"bb_refolded = aa_refolded[np.isin(aa_refolded.atom_name, PROTEIN_BACKBONE_ATOM_NAMES)]\n",
"\n",
"# Superimpose structures and calculate RMSD\n",
"bb_refolded_fitted, _ = superimpose(bb_generated, bb_refolded)\n",
"rmsd_value = rmsd(bb_generated, bb_refolded_fitted)\n",
"\n",
"print(f\"Backbone RMSD: {rmsd_value:.2f} A\")\n",
"print(f\"\\nInterpretation: {'Excellent' if rmsd_value < 1.0 else 'Good' if rmsd_value < 2.0 else 'Moderate'} designability\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9070ead8",
"metadata": {},
"outputs": [],
"source": [
"from atomworks.io.utils.io_utils import to_cif_file\n",
"\n",
"# Export structures to CIF format for visualization in PyMOL/ChimeraX\n",
"to_cif_file(aa_generated, \"generated.cif\")\n",
"to_cif_file(aa_refolded, \"refolded.cif\")\n",
"\n",
"print(\"Exported structures:\")\n",
"print(\" - generated.cif: Original RFD3 backbone\")\n",
"print(\" - refolded.cif: RF3-predicted structure\")"
]
},
{
"cell_type": "markdown",
"id": "6fc67730",
"metadata": {},
"source": [
"### Superimposed Result\n",
"\n",
"The image below shows the generated backbone (RFD3) superimposed with the re-folded structure (RF3). Close alignment indicates successful design.\n",
"\n",
"![Superimposed Protein](../docs/_static/superimposed_80_residue_protein.png)"
]
},
{
"cell_type": "markdown",
"id": "c439c90d",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}