feat: collab notebook

This commit is contained in:
ncorley
2025-12-03 09:31:41 -08:00
parent 395737750a
commit 04cb7f2ca6
3 changed files with 452 additions and 0 deletions

View File

@@ -23,6 +23,9 @@ foundry install rfd3 ligandmpnn rf3 --checkpoint_dir <path/to/ckpt/dir>
>*See `examples/all.ipynb` for how to run each model in a notebook.*
### Google Colab
For an interactive Google Colab notebook walking through a basic design pipeline with RFD3, MPNN, and RF3, please see the [IPD Design Pipeline Tutorial](https://colab.research.google.com/drive/1ZwIMV3n9h0ZOnIXX0GyKUuoiahgifBxh?usp=sharing).
### RFdiffusion3 (RFD3)
[RFdiffusion3](https://www.biorxiv.org/content/10.1101/2025.09.18.676967v2) is an all-atom generative model capable of designing protein structures under complex constraints.

Binary file not shown.

After

Width:  |  Height:  |  Size: 204 KiB

View File

@@ -0,0 +1,449 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "colab-setup",
"metadata": {},
"source": [
"## Google Colab Setup\n",
"\n",
"**GPU Required:** Before running, enable GPU runtime:\n",
"1. Go to **Runtime → Change runtime type**\n",
"2. Select **T4 GPU** (or better)\n",
"3. Click **Save**"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "colab-install-package",
"metadata": {},
"outputs": [],
"source": [
"# Install dependencies (skip if already done)\n",
"import os\n",
"\n",
"# Set environment variables\n",
"os.environ['CCD_MIRROR_PATH'] = ''\n",
"os.environ['PDB_MIRROR_PATH'] = ''\n",
"\n",
"if not os.path.isfile(\"FOUNDRY_READY\"):\n",
" print(\"Installing rc-foundry...\")\n",
" \n",
" # Uninstall torchvision first to avoid operator conflicts\n",
" os.system(\"pip uninstall -y torchvision\")\n",
" \n",
" # Install rc-foundry\n",
" os.system(\"pip install -q 'rc-foundry[all]'\")\n",
" \n",
" # Mark as ready\n",
" os.system(\"touch FOUNDRY_READY\")\n",
" \n",
" print(\"Done!\")\n",
"else:\n",
" print(\"rc-foundry already installed.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "colab-download-weights",
"metadata": {},
"outputs": [],
"source": [
"# Download model weights (skips already-downloaded models automatically)\n",
"# In total, ~6GB (3GB for RFD3, 3GB for RF3, <100MB for MPNN); may take a few minutes depending on your connection speed\n",
"os.system(\"foundry install rfd3 ligandmpnn rf3\")"
]
},
{
"cell_type": "markdown",
"id": "f3df148e",
"metadata": {},
"source": [
"# Example: End-To-End *De Novo* Protein Design Pipeline\n",
"\n",
"## Overview\n",
"\n",
"This notebook demonstrates an end-to-end protein design workflow using three deep learning networks from the Institute for Protein Design:\n",
"\n",
"| Step | Model | Purpose |\n",
"|------|-------|---------|\n",
"| 1. **Generation** | RFD3 | Generate novel proteins via diffusion |\n",
"| 2. **Sequence Design** | MPNN | Design amino acid sequences for the generated backbone |\n",
"| 3. **Structure Validation via Refolding** | RF3 | Predict the structure from designed sequence to validate designability |\n",
"\n",
"All models are unified through [AtomWorks](https://github.com/RosettaCommons/atomworks) (for both inference and training), relying on Biotite `AtomArray` objects.\n",
"\n",
"### Pipeline Flow\n",
"```\n",
"RFD3 (backbone) → MPNN (sequence) → RF3 (validation) → RMSD comparison\n",
"```\n",
"\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "819e8193",
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings('ignore', module='atomworks')\n",
"\n",
"# Shared utilities for visualization (from AtomWorks)\n",
"from atomworks.io.utils.visualize import view"
]
},
{
"cell_type": "markdown",
"id": "a7tw5gds8p",
"metadata": {},
"source": [
"## Section 1: All-Atom Generation with RFD3\n",
"\n",
"RFdiffusion3 (RFD3) generates *de novo* all-atom proteins that meet specific conditioning requirements.\n",
"\n",
"**Parameters Used** *(many more are available for more complex protein design tasks)*:\n",
"- `length`: Target protein length in residues\n",
"- `diffusion_batch_size`: Number of structures to generate per batch\n",
"- `n_batches`: Number of batches to run\n",
"\n",
"**Outputs:** Dictionary of `RFD3Output` objects."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d16cb95b-3f4c-4167-952b-278bdb561bf7",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"from lightning.fabric import seed_everything\n",
"from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine\n",
"\n",
"# Set seed for reproducibility\n",
"seed_everything(0)\n",
"\n",
"# Configure RFD3 inference\n",
"config = RFD3InferenceConfig(\n",
" specification={\n",
" 'length': 80, # Generate 80-residue proteins\n",
" },\n",
" diffusion_batch_size=2, # Generate 2 structures per batch\n",
")\n",
"\n",
"# Initialize engine and run generation\n",
"model = RFD3InferenceEngine(**config)\n",
"outputs = model.run(\n",
" inputs=None, # None for unconditional generation\n",
" out_dir=None, # None to return in memory (no file output)\n",
" n_batches=1, # Generate 1 batch\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "86adad90",
"metadata": {},
"outputs": [],
"source": [
"# View generated example IDs (one key per generated structure)\n",
"outputs.keys()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "35253ef3-2ce1-4dca-958a-63b359b70d73",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Inspect RFD3 outputs and extract the generated structures\n",
"for idx, data in outputs.items():\n",
" print(f\"Batch {idx}: {len(data)} structure(s)\")\n",
" print(f\" Output type: {type(data[0]).__name__}\")\n",
" print(f\" AtomArray: {data[0].atom_array}\")\n",
"\n",
"# Extract the first generated structure for downstream use\n",
"first_key = next(iter(outputs.keys()))\n",
"atom_array = outputs[first_key][0].atom_array\n",
"\n",
"# Visualize the generated structure\n",
"view(atom_array)"
]
},
{
"cell_type": "markdown",
"id": "1cziwe2nb26h",
"metadata": {},
"source": [
"---\n",
"\n",
"## Section 2: Sequence Design with MPNN\n",
"\n",
"Protein and Ligand MPNN (Message Passing Neural Network) designs amino acid sequences that will fold into a target backbone structure.\n",
"\n",
"**Model Options:**\n",
"- `protein_mpnn`: Original ProteinMPNN for protein-only design\n",
"- `ligand_mpnn`: Extended model supporting ligand-aware design\n",
"\n",
"**Key Parameters:**\n",
"- `batch_size`: Number of sequences to generate per structure\n",
"- `remove_waters`: Whether to exclude water molecules from context"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d07ae413",
"metadata": {},
"outputs": [],
"source": [
"from mpnn.inference_engines.mpnn import MPNNInferenceEngine\n",
"\n",
"# Configure MPNN inference engine\n",
"# See mpnn.utils.inference.MPNN_GLOBAL_INFERENCE_DEFAULTS for all options\n",
"engine_config = {\n",
" \"model_type\": \"ligand_mpnn\", # or \"protein_mpnn\" for vanilla ProteinMPNN\n",
" \"is_legacy_weights\": True, # Required for now for ligand_mpnn and protein_mpnn\n",
" \"out_directory\": None, # Return results in memory\n",
" \"write_structures\": False,\n",
" \"write_fasta\": False,\n",
"}\n",
"\n",
"# Configure per-input inference options\n",
"# See mpnn.utils.inference.MPNN_PER_INPUT_INFERENCE_DEFAULTS for all options\n",
"input_configs = [\n",
" {\n",
" \"batch_size\": 10, # Generate 10 sequences per structure\n",
" \"remove_waters\": True,\n",
" }\n",
"]\n",
"\n",
"# Run sequence design on the RFD3-generated backbone\n",
"model = MPNNInferenceEngine(**engine_config)\n",
"mpnn_outputs = model.run(input_dicts=input_configs, atom_arrays=[atom_array])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75f5e558-3323-4094-9cce-c5a19b54f4cf",
"metadata": {},
"outputs": [],
"source": [
"from biotite.structure import get_residue_starts\n",
"from biotite.sequence import ProteinSequence\n",
"\n",
"# Extract and display the designed sequences\n",
"print(f\"Generated {len(mpnn_outputs)} designed sequences:\\n\")\n",
"\n",
"for i, item in enumerate(mpnn_outputs):\n",
" res_starts = get_residue_starts(item.atom_array)\n",
" # Convert 3-letter codes to 1-letter using Biotite\n",
" seq_1letter = ''.join(\n",
" ProteinSequence.convert_letter_3to1(res_name)\n",
" for res_name in item.atom_array.res_name[res_starts]\n",
" )\n",
" print(f\"Sequence {i+1}: {seq_1letter}\")"
]
},
{
"cell_type": "markdown",
"id": "te84j9ce9rn",
"metadata": {},
"source": [
"---\n",
"\n",
"## Section 3: Structure Prediction with RF3\n",
"\n",
"RF3 (RoseTTAFold 3) predicts protein structures from sequences. By re-folding the MPNN-designed sequence, we can validate whether the design is likely to adopt the intended backbone structure.\n",
"\n",
"**Outputs:** `RF3Output` objects containing:\n",
"- `atom_array`: Predicted structure as Biotite AtomArray\n",
"- `summary_confidences`: Overall confidence metrics (pLDDT, PAE, pTM, etc.)\n",
"- `confidences`: Per-atom/residue confidence scores\n",
"\n",
"**Confidence Metrics:**\n",
"| Metric | Description |\n",
"|--------|-------------|\n",
"| pLDDT | Per-residue confidence (0-1, higher is better) |\n",
"| PAE | Predicted Aligned Error (lower is better) |\n",
"| pTM | Predicted TM-score |\n",
"| ranking_score | Overall model quality score |"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d85eda46-4dca-4e92-8946-73b69a5536e3",
"metadata": {},
"outputs": [],
"source": [
"from rf3.inference_engines.rf3 import RF3InferenceEngine\n",
"from rf3.utils.inference import InferenceInput\n",
"\n",
"\n",
"# Initialize RF3 inference engine\n",
"inference_engine = RF3InferenceEngine(ckpt_path='rf3', verbose=False)\n",
"\n",
"# Create input from the MPNN-designed structure (first design)\n",
"# This re-folds the sequence to validate it adopts the intended structure\n",
"input_structure = InferenceInput.from_atom_array(atom_array, example_id=\"example_protein\")\n",
"rf3_outputs = inference_engine.run(inputs=input_structure)\n",
"\n",
"# Outputs: dict mapping example_id -> list[RF3Output] (multiple models per input)\n",
"print(f\"Output keys: {rf3_outputs.keys()}\")\n",
"print(f\"Number of models for 'example_protein': {len(rf3_outputs['example_protein'])}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f7113c03",
"metadata": {},
"outputs": [],
"source": [
"# Extract the top-ranked prediction\n",
"rf3_output = rf3_outputs[\"example_protein\"][0]\n",
"\n",
"# Inspect RF3Output structure\n",
"print(f\"RF3Output contains:\")\n",
"print(f\" - atom_array: {len(rf3_output.atom_array)} atoms\")\n",
"print(f\" - summary_confidences: {list(rf3_output.summary_confidences.keys())}\")\n",
"print(f\" - confidences: {list(rf3_output.confidences.keys()) if rf3_output.confidences else None}\")\n",
"\n",
"# Visualize the predicted structure\n",
"view(rf3_output.atom_array)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6mqnovlrygo",
"metadata": {},
"outputs": [],
"source": [
"# Summary confidences: overall model quality metrics\n",
"summary = rf3_output.summary_confidences\n",
"\n",
"print(\"=== Summary Confidences ===\")\n",
"print(f\" Overall pLDDT: {summary['overall_plddt']:.3f}\")\n",
"print(f\" Overall PAE: {summary['overall_pae']:.2f} A\")\n",
"print(f\" Overall PDE: {summary['overall_pde']:.3f}\")\n",
"print(f\" pTM: {summary['ptm']:.3f}\")\n",
"print(f\" ipTM: {summary.get('iptm', 'N/A (single chain)')}\")\n",
"print(f\" Ranking score: {summary['ranking_score']:.3f}\")\n",
"print(f\" Has clash: {summary['has_clash']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "g5gupirgy6",
"metadata": {},
"outputs": [],
"source": [
"# Detailed per-atom/residue confidences\n",
"conf = rf3_output.confidences\n",
"\n",
"print(\"=== Per-Atom/Residue Confidences ===\")\n",
"print(f\" atom_plddts: {len(conf['atom_plddts'])} values (one per atom)\")\n",
"print(f\" atom_chain_ids: {len(conf['atom_chain_ids'])} values\")\n",
"print(f\" token_chain_ids: {len(conf['token_chain_ids'])} values (one per residue)\")\n",
"print(f\" token_res_ids: {len(conf['token_res_ids'])} values\")\n",
"print(f\" PAE matrix: {len(conf['pae'])}x{len(conf['pae'][0])}\")\n",
"\n",
"# Preview first 10 atom pLDDT scores\n",
"import numpy as np\n",
"print(f\"\\nFirst 10 atom pLDDTs: {np.round(conf['atom_plddts'][:10], 2).tolist()}\")"
]
},
{
"cell_type": "markdown",
"id": "0epbqi91bv3n",
"metadata": {},
"source": [
"---\n",
"\n",
"## Section 4: Validation and Export\n",
"\n",
"The final step compares the RF3-predicted structure against the original RFD3-generated backbone. A low backbone RMSD indicates the designed sequence is likely to fold into the intended structure (high designability)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "238a0241",
"metadata": {},
"outputs": [],
"source": [
"from biotite.structure import rmsd, superimpose\n",
"from atomworks.constants import PROTEIN_BACKBONE_ATOM_NAMES\n",
"import numpy as np\n",
"\n",
"# Get structures for comparison\n",
"aa_generated = atom_array # Original RFD3 backbone (from Section 1)\n",
"aa_refolded = rf3_output.atom_array # RF3-predicted structure\n",
"\n",
"# Filter to backbone atoms (N, CA, C, O)\n",
"bb_generated = aa_generated[np.isin(aa_generated.atom_name, PROTEIN_BACKBONE_ATOM_NAMES)]\n",
"bb_refolded = aa_refolded[np.isin(aa_refolded.atom_name, PROTEIN_BACKBONE_ATOM_NAMES)]\n",
"\n",
"# Superimpose structures and calculate RMSD\n",
"bb_refolded_fitted, _ = superimpose(bb_generated, bb_refolded)\n",
"rmsd_value = rmsd(bb_generated, bb_refolded_fitted)\n",
"\n",
"print(f\"Backbone RMSD: {rmsd_value:.2f} A\")\n",
"print(f\"\\nInterpretation: {'Excellent' if rmsd_value < 1.0 else 'Good' if rmsd_value < 2.0 else 'Moderate'} designability\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9070ead8",
"metadata": {},
"outputs": [],
"source": [
"from atomworks.io.utils.io_utils import to_cif_file\n",
"\n",
"# Export structures to CIF format for visualization in PyMOL/ChimeraX\n",
"to_cif_file(aa_generated, \"generated.cif\")\n",
"to_cif_file(aa_refolded, \"refolded.cif\")\n",
"\n",
"print(\"Exported structures:\")\n",
"print(\" - generated.cif: Original RFD3 backbone\")\n",
"print(\" - refolded.cif: RF3-predicted structure\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "modelworks",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}