Parallelize rc-foundry install and checkpoint downloads for colab notebook (#157)

load ckpts and foundry in parallel to save time
This commit is contained in:
Magnus Bauer
2026-01-07 16:31:06 -08:00
committed by GitHub
parent 8bfde2381a
commit 94eb95a004

View File

@@ -20,40 +20,73 @@
"metadata": {},
"outputs": [],
"source": [
"# Install dependencies (skip if already done)\n",
"import os\n",
"import os, sys, subprocess\n",
"from pathlib import Path\n",
"\n",
"# Set environment variables\n",
"os.environ['CCD_MIRROR_PATH'] = ''\n",
"os.environ['PDB_MIRROR_PATH'] = ''\n",
"os.environ[\"CCD_MIRROR_PATH\"] = \"\"\n",
"os.environ[\"PDB_MIRROR_PATH\"] = \"\"\n",
"\n",
"if not os.path.isfile(\"FOUNDRY_READY\"):\n",
" print(\"Installing rc-foundry...\")\n",
" \n",
" # Uninstall torchvision first to avoid operator conflicts\n",
" os.system(\"pip uninstall -y torchvision\")\n",
" \n",
" # Install rc-foundry\n",
" os.system(\"pip install -q 'rc-foundry[all]'\")\n",
" \n",
" # Mark as ready\n",
" os.system(\"touch FOUNDRY_READY\")\n",
" \n",
" print(\"Done!\")\n",
"CKPT_DIR = Path(\"/root/.foundry/checkpoints/\")\n",
"CKPT_DIR.mkdir(parents=True, exist_ok=True)\n",
"READY_FLAG = Path(\"FOUNDRY_READY\")\n",
"LOG = Path(\"foundry_setup.log\")\n",
"\n",
"CHECKPOINTS = {\n",
" \"rfd3\": {\n",
" \"url\": \"https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt\",\n",
" \"filename\": \"rfd3_latest.ckpt\",\n",
" },\n",
" \"ligandmpnn\": {\n",
" \"url\": \"https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt\",\n",
" \"filename\": \"ligandmpnn_v_32_010_25.pt\",\n",
" },\n",
" \"rf3\": {\n",
" \"url\": \"https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt\",\n",
" \"filename\": \"rf3_foundry_01_24_latest_remapped.ckpt\",\n",
" },\n",
"}\n",
"\n",
"# Always remove torchvision first\n",
"subprocess.check_call([sys.executable, \"-m\", \"pip\", \"uninstall\", \"-y\", \"torchvision\"])\n",
"\n",
"# Start rc-foundry install in the background (if not done)\n",
"pip_proc = None\n",
"if not READY_FLAG.exists():\n",
" print(\"Installing rc-foundry (background)...\")\n",
" pip_proc = subprocess.Popen(\n",
" [sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"rc-foundry[all]\"],\n",
" stdout=LOG.open(\"ab\"),\n",
" stderr=subprocess.STDOUT,\n",
" )\n",
"else:\n",
" print(\"rc-foundry already installed.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "colab-download-weights",
"metadata": {},
"outputs": [],
"source": [
"# Download model weights (skips already-downloaded models automatically)\n",
"# In total, ~6GB (3GB for RFD3, 3GB for RF3, <100MB for MPNN); may take a few minutes depending on your connection speed\n",
"os.system(\"foundry install rfd3 ligandmpnn rf3\")"
" print(\"rc-foundry already installed.\")\n",
"\n",
"# Start checkpoint downloads in parallel with correct filenames\n",
"dl_procs = []\n",
"for name, info in CHECKPOINTS.items():\n",
" dest = CKPT_DIR / info[\"filename\"]\n",
" if dest.exists():\n",
" continue\n",
" print(f\"Downloading {name} -> {dest} (background)...\")\n",
" dl_procs.append(\n",
" subprocess.Popen(\n",
" [\"curl\", \"-L\", \"-o\", str(dest), info[\"url\"]],\n",
" stdout=LOG.open(\"ab\"),\n",
" stderr=subprocess.STDOUT,\n",
" )\n",
" )\n",
"\n",
"# Wait when you actually need everything ready\n",
"if pip_proc:\n",
" rc = pip_proc.wait()\n",
" if rc == 0:\n",
" READY_FLAG.touch()\n",
" else:\n",
" print(f\"pip install failed with code {rc}\")\n",
"for p in dl_procs:\n",
" p.wait()\n",
"\n",
"print(\"Setup steps finished (see foundry_setup.log).\")"
]
},
{
@@ -298,7 +331,7 @@
"# Create input from the MPNN-designed structure (first design)\n",
"# This re-folds the sequence to validate it adopts the intended structure\n",
"input_structure = InferenceInput.from_atom_array(atom_array, example_id=\"example_protein\")\n",
"rf3_outputs = inference_engine.run(inputs=input_structure)\n",
"rf3_outputs = inference_engine.run(inputs=input_structure, annotate_b_factor_with_plddt=True)\n",
"\n",
"# Outputs: dict mapping example_id -> list[RF3Output] (multiple models per input)\n",
"print(f\"Output keys: {rf3_outputs.keys()}\")\n",