mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
* documentation for release draft start * trajectory.png * update readme to rf3-lab paths, annotate TODOs * add input_pdbs, demo.json * Update README.md example pngs * tasks pngs * Update README.md - restructure pngs and application links * Update README.md mc * Update README.md add ipynb kernel export instruction * mpnn all.ipynb * open and edit tutorial.zip * Update run_inf_tutorial.sh * remove outputs * cleanup * rename * soft code hbplus executable * rename modelforge to foundry (rfd3) README * fix: enabled running rfd3, mpnn inline * cleanup * remove todos, one remaining * clear outputs --------- Co-authored-by: Raktim Mitra <raktim@digs> Co-authored-by: Raktim Mitra <raktim@localhost> Co-authored-by: Rohith Krishna <rohith@localhost> Co-authored-by: Raktim Mitra <raktim@digs.ipd.uw.edu>
151 lines
4.1 KiB
Plaintext
151 lines
4.1 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "819e8193",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# For now, we still need to manually add atomworks dependency\n",
|
|
"import sys;\n",
|
|
"modelhub_test = \"/home/rohith/modelhub/release\"\n",
|
|
"sys.path.append(f'{modelhub_test}/models/rfd3/src')\n",
|
|
"sys.path.append(f'{modelhub_test}/models/mpnn/src')\n",
|
|
"sys.path.append(f'{modelhub_test}/src')\n",
|
|
"\n",
|
|
"from atomworks.io.utils.visualize import view"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d16cb95b-3f4c-4167-952b-278bdb561bf7",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Running RFD3\n",
|
|
"from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine\n",
|
|
"conf = RFD3InferenceConfig(\n",
|
|
" ckpt_path='/projects/ml/aa_design/models/rfd3_latest_cleaned.ckpt',\n",
|
|
" specification={\n",
|
|
" 'length': 10\n",
|
|
" },\n",
|
|
" diffusion_batch_size=2,\n",
|
|
")\n",
|
|
"model = RFD3InferenceEngine(**conf)\n",
|
|
"outputs = model.run(\n",
|
|
" inputs=None,\n",
|
|
" out_dir=None,\n",
|
|
" n_batches=1,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "35253ef3-2ce1-4dca-958a-63b359b70d73",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"for idx, data in outputs.items():\n",
|
|
" print(f\"Output type for batch {idx}: {type(data)}[0] = {type(data[0])}\")\n",
|
|
" print(f\"Output atom_array: {data[0].atom_array}\")\n",
|
|
" atom_array = data[0].atom_array\n",
|
|
"view(atom_array)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d07ae413",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Running RFD3\n",
|
|
"from mpnn.inference_engines.mpnn import MPNNInferenceEngine\n",
|
|
"\n",
|
|
"#see defaults in mpnn.utils.inference.py : MPNN_GLOBAL_INFERENCE_DEFAULTS\n",
|
|
"conf = {\n",
|
|
" ## protein_mpnn options\n",
|
|
" #\"model_type\":\"protein_mpnn\",\n",
|
|
" #\"checkpoint_path\":\"/databases/mpnn/vanilla_model_weights/v_48_020.pt\",\n",
|
|
" ## ligand_mpnn options\n",
|
|
" \"model_type\":\"ligand_mpnn\",\n",
|
|
" \"checkpoint_path\":\"/databases/mpnn/ligand_mpnn_model_weights/s25_r010_t300_p.pt\",\n",
|
|
" \"is_legacy_weights\": True,\n",
|
|
" \"out_directory\": None,\n",
|
|
" \"write_structures\": False,\n",
|
|
" \"write_fasta\": False\n",
|
|
"}\n",
|
|
"\n",
|
|
"# defaults in from mpnn.utils.inference import MPNN_PER_INPUT_INFERENCE_DEFAULTS\n",
|
|
"inputs = [\n",
|
|
" {\n",
|
|
" \"batch_size\": 10,\n",
|
|
" \"remove_waters\":True\n",
|
|
" }\n",
|
|
" ]\n",
|
|
"\n",
|
|
"model = MPNNInferenceEngine(**conf)\n",
|
|
"\n",
|
|
"## update to use atom_arrays from previous cell\n",
|
|
"outputs = model.run(input_dicts=inputs, atom_arrays=[atom_array])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "75f5e558-3323-4094-9cce-c5a19b54f4cf",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from biotite.structure import get_residue_starts\n",
|
|
"from atomworks.constants import STANDARD_AA\n",
|
|
"\n",
|
|
"for item in outputs:\n",
|
|
" atom_array = item.atom_array\n",
|
|
" res_starts = get_residue_starts(atom_array)\n",
|
|
" protein_seq = []\n",
|
|
" for res_name in atom_array.res_name[res_starts]:\n",
|
|
" if res_name in STANDARD_AA:\n",
|
|
" protein_seq.append(str(res_name))\n",
|
|
" print(protein_seq)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d85eda46-4dca-4e92-8946-73b69a5536e3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "release",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|