chore: add tests, add sampler inference arguments

This commit is contained in:
ncorley
2025-02-03 15:57:31 -08:00
parent 2628c0efdc
commit bd5ac3ef22
11 changed files with 138 additions and 5765 deletions

View File

@@ -1,7 +1,7 @@
# Inference with AF3 Repro
# Inference with `modelhub-AF3` repository
We have reproduced AF3 and are sharing the weights with the lab to use for various tasks.
This guide provides instructions on preparing inputs and running inference for our AF3 reproduction (name TBD).
This guide provides instructions on preparing inputs and running inference for our AF3 reproduction.
Additional variations (e.g., with chirality inputs, ligand geometry conditioning, protein backbone coordinate conditioning) are in-the-works; however, the core inference API will not change.
@@ -13,7 +13,7 @@ We enumerate two options for preparing inputs: one with a JSON API, one by creat
### Option 1: Prepare inputs using a combination of one-letter polymer sequences, SMILES strings, CCD codes, and SDF files
Create a JSON file with each component similar to `inputs/example)from_json.json`; e.g.,
Create a JSON file with each component similar to `rf2aa/tests/data/example_from_json.json`; e.g.,
```json
[
@@ -82,8 +82,8 @@ ligand_from_ccd = {
atom_array_from_ccd = components_to_atom_array([monomer, ligand_from_ccd])
atom_array_from_smiles = components_to_atom_array([monomer, ligand_from_smiles])
to_cif_file(atom_array_from_ccd, "./inputs/example_from_ccd.cif")
to_cif_file(atom_array_from_smiles, "./inputs/example_from_smiles.cif")
to_cif_file(atom_array_from_ccd, "example_from_ccd.cif")
to_cif_file(atom_array_from_smiles, "example_from_smiles.cif")
```
## Step 2: Run `inference.py`
@@ -108,19 +108,19 @@ Example commands (to be run from the `inference` working directory):
### Using a CIF
```bash
apptainer -s run --nv /net/software/containers/users/rohith/modelhub_lab_20250124.sif /net/software/lab/modelhub/rf2aa/inference/inference.py /net/software/lab/modelhub/rf2aa/inference/examples/inputs/example_from_ccd.cif --checkpoint_path /projects/ml/RF2_allatom/weights/af3_repro_with_confidence_20250124.pt --cif_out_dir ./examples/predictions
apptainer -s run --nv /net/software/containers/users/rohith/modelhub_lab_20250124.sif /net/software/lab/modelhub/rf2aa/inference/inference.py /net/software/lab/modelhub/rf2aa/tests/data/example_from_ccd.cif --checkpoint_path /projects/ml/RF2_allatom/weights/af3_repro_with_confidence_20250124.pt --cif_out_dir ./predictions
```
### Using a PDB from MPNN, renaming clashing ligands (example from Indrek)
Note that in this PDB file, the ligand "HGS" is a custom ligand, whose three-letter code overlaps with a real CCD ligand. Thus, we must rename.
```bash
apptainer -s run --nv /net/software/containers/users/rohith/modelhub_lab_20250124.sif /net/software/lab/modelhub/rf2aa/inference/inference.py /net/software/lab/modelhub/rf2aa/inference/examples/inputs/example_pdb_from_indrek.cif --checkpoint_path /projects/ml/RF2_allatom/weights/af3_repro_with_confidence_20250124.pt --cif_out_dir ./examples/predictions --rename_residues '{"HGS": "L:1"}'
apptainer -s run --nv /net/software/containers/users/rohith/modelhub_lab_20250124.sif /net/software/lab/modelhub/rf2aa/inference/inference.py /net/software/lab/modelhub/rf2aa/tests/data/example_pdb_from_indrek.cif --checkpoint_path /projects/ml/RF2_allatom/weights/af3_repro_with_confidence_20250124.pt --cif_out_dir ./predictions --rename_residues '{"HGS": "L:1"}'
```
### Using the JSON
```bash
apptainer -s run --nv /net/software/containers/users/rohith/modelhub_lab_20250124.sif /net/software/lab/modelhub/rf2aa/inference/inference.py /net/software/lab/modelhub/rf2aa/inference/examples/inputs/example_from_json.json --checkpoint_path /projects/ml/RF2_allatom/weights/af3_repro_with_confidence_20250124.pt --cif_out_dir ./examples/predictions
apptainer -s run --nv /net/software/containers/users/rohith/modelhub_lab_20250124.sif /net/software/lab/modelhub/rf2aa/inference/inference.py /net/software/lab/modelhub/rf2aa/tests/data/example_from_json.json --checkpoint_path /projects/ml/RF2_allatom/weights/af3_repro_with_confidence_20250124.pt --cif_out_dir ./predictions
```
## Step 3: View the Predicted Structure(s)

View File

@@ -1,12 +0,0 @@
{
"name": "example_from_smiles(json)",
"components": [
{
"seq": "MNAKEIVVHALRLLENGDARGWCDLFHPEGVLEYPYPPPGYKTRFEGRETIWAHMRLFPEYMTIRFTDVQFYETADPDLAIGEFHGDGVHTVSGGKLAADYISVLRTRDGQILLYRLFFNPLRVLEPLGLEHHHHHH",
"chain_id": "A"
},
{
"smiles": "O=C1OCC(=C1)C5C4(C(O)CC3C(CCC2CC(O)CCC23C)C4(O)CC5)C"
}
]
}

View File

@@ -1,10 +0,0 @@
[
{
"seq": "GSGVSLGQALLILSVAALLGTTVEEAVKRALWLKTKLGVSLEQAARTLSVAAYLGTTVEEAVKRALKLKTKLGVSLEQALLILFAAAALGTTVEEAVKRALKLKTKLGVSLEQALLILWTAVELGTTVEEAVKRALKLKTKLGVSLGQAQAILVVAAELGTTVEEAVYRALKLKTKLGVSLGQALLILEVAAKLGTTVEEAVKRALKLTTKLG",
"chain_id": "A"
},
{
"path": "./examples/inputs/test_sdf.sdf"
}
]

File diff suppressed because it is too large Load Diff

View File

@@ -1,26 +0,0 @@
[
{
"name": "multiple_examples_from_json(1)",
"components": [
{
"seq": "MNAKEIVVHALRLLENGDARGWCDLFHPEGVLEYPYPPPGYKTRFEGRETIWAHMRLFPEYMTIRFTDVQFYETADPDLAIGEFHGDGVHTVSGGKLAADYISVLRTRDGQILLYRLFFNPLRVLEPLGLEHHHHHH",
"chain_id": "A"
},
{
"smiles": "O=C1OCC(=C1)C5C4(C(O)CC3C(CCC2CC(O)CCC23C)C4(O)CC5)C"
}
]
},
{
"name": "multiple_examples_from_json(2)",
"components": [
{
"seq": "GSGVSLGQALLILSVAALLGTTVEEAVKRALWLKTKLGVSLEQAARTLSVAAYLGTTVEEAVKRALKLKTKLGVSLEQALLILFAAAALGTTVEEAVKRALKLKTKLGVSLEQALLILWTAVELGTTVEEAVKRALKLKTKLGVSLGQAQAILVVAAELGTTVEEAVYRALKLKTKLGVSLGQALLILEVAAKLGTTVEEAVKRALKLTTKLG",
"chain_id": "A"
},
{
"path": "./examples/inputs/test_sdf.sdf"
}
]
}
]

View File

@@ -1,5 +1,4 @@
import hydra
import os
from os import PathLike
from pathlib import Path
@@ -28,7 +27,7 @@ logger = logging.getLogger(__name__)
encoding = AF3SequenceEncoding()
def build_stack_from_atom_array_and_batched_coords(
coords: np.ndarray, atom_array: AtomArray, annotations_to_keep: list[str] = ["chain_id", "transformation_id", "res_id", "res_name", "element"]
coords: np.ndarray, atom_array: AtomArray, annotations_to_keep: list[str] = ["chain_id", "transformation_id", "res_id", "res_name", "element", "atom_name"]
) -> AtomArrayStack:
"""Builds an AtomArrayStack from an AtomArray and a set of coordinates with a batch dimension.
@@ -41,11 +40,17 @@ def build_stack_from_atom_array_and_batched_coords(
# (Diffusion batch size will become the number of models)
n_batch = coords.shape[0]
# Remove unwanted annotations
for annotation in atom_array.get_annotation_categories():
if annotation not in annotations_to_keep:
atom_array.del_annotation(annotation)
# Build the stack and assign the coordinates
atom_array_stack = stack([atom_array for _ in range(n_batch)])
atom_array_stack.coord = coords
# Adjust chain_id if there are multiple transformations
# (Otherwise, we will have ambiguous bond annotations)
# (Otherwise, we will have ambiguous bond annotations, since only `chain_id` is used for the bond annotations)
if "transformation_id" in atom_array.get_annotation_categories() and len(np.unique(atom_array_stack.transformation_id)) > 1:
atom_array_stack.chain_id = (
atom_array_stack.chain_id + atom_array_stack.transformation_id
@@ -53,7 +58,7 @@ def build_stack_from_atom_array_and_batched_coords(
return atom_array_stack
def _spoof_cif_from_dictionary(item: dict, temp_dir: os.PathLike) -> Path:
def _spoof_cif_from_dictionary(item: dict, temp_dir: PathLike) -> Path:
"""Unpacks a dictionary to create a CIF file from its components.
Args:
@@ -85,6 +90,81 @@ def _spoof_cif_from_dictionary(item: dict, temp_dir: os.PathLike) -> Path:
return Path(save_path)
def _build_file_paths_for_prediction(inputs: list, temp_dir: PathLike) -> list[Path]:
"""Prepare files for prediction based on the input paths.
Input paths may be dictionary-like format (e.g., JSON, YAML, Pickle), CIF/PDB files, or directories containing these files.
Processes directories to find supported file types and converts dictionary-like formats to CIF files.
Args:
inputs (list): List of input paths (JSON, YAML, Pickle, or CIF/PDB).
temp_dir (Path): Path to the temporary directory for storing CIF files.
Returns:
list[Path]: List of file paths for prediction.
"""
DICTIONARY_LIKE_EXTENSIONS = {".json", ".yaml", ".yml", ".pkl"}
CIF_LIKE_EXTENSIONS = {".cif", ".pdb", ".bcif", ".cif.gz", ".pdb.gz", ".bcif.gz"}
# Collect all files from inputs, handling directories and individual files
paths_to_raw_input_files = []
for input_path in inputs:
if Path(input_path).is_dir():
paths_to_raw_input_files.extend(_find_files(input_path, DICTIONARY_LIKE_EXTENSIONS | CIF_LIKE_EXTENSIONS))
else:
paths_to_raw_input_files.append(Path(input_path))
paths_to_cif_like_files = []
for path in paths_to_raw_input_files:
concatenated_suffix = ''.join(path.suffixes)
if concatenated_suffix in DICTIONARY_LIKE_EXTENSIONS:
# Spoof CIF files from dictionary-like formats
with open(path, 'rb' if path.suffix == ".pkl" else 'r') as file:
# Load data based on file extension
if path.suffix == ".json":
data = json.load(file)
elif path.suffix in {".yaml", ".yml"}:
raise NotImplementedError("YAML files are not yet supported.")
elif path.suffix == ".pkl":
data = pickle.load(file)
if isinstance(data, dict):
data = [data] # Convert single dictionary to list for uniform processing
for item in data:
paths_to_cif_like_files.append(_spoof_cif_from_dictionary(item, temp_dir))
elif concatenated_suffix in CIF_LIKE_EXTENSIONS:
# Directly use CIF-like files
paths_to_cif_like_files.append(path)
else:
raise ValueError(f"Unsupported file extension: {path.suffix} (path: {path}; paths: {paths_to_raw_input_files}).")
return paths_to_cif_like_files
def _find_files(path: PathLike, supported_file_types: list) -> list[Path]:
"""Recursively find all files with the given extensions in the specified path.
Args:
path (PathLike): Path to the directory containing the files.
supported_file_types (list): List of supported file extensions.
Returns:
list[Path]: List of files with the given extensions.
"""
files_with_supported_types = []
path = Path(path)
# Check if the path is a directory
if path.is_dir():
# Search for files with each supported extension
for file_type in supported_file_types:
files_with_supported_types.extend(path.glob(f'*{file_type}'))
elif path.is_file() and path.suffix in supported_file_types:
# If it's a file and has a supported extension, add to the list
files_with_supported_types.append(path)
return files_with_supported_types
class EvaluateAF3:
"""Class for inference with AF3. Evaluates a trained AF3 model on a set of spoofed CIFs."""
@@ -94,7 +174,9 @@ class EvaluateAF3:
n_recycles: int,
diffusion_batch_size: int,
residue_renaming_dict: dict | None = None,
temp_dir: PathLike | None = None
temp_dir: PathLike | None = None,
num_steps: int = 200,
solver: str = "af3",
):
"""Initialize the evaluator.
@@ -106,6 +188,8 @@ class EvaluateAF3:
diffusion_batch_size (int): Diffusion batch size for AF3. Each predicted structure will be saved as a separate model within the same CIF file.
residue_renaming_dict (dict): Dictionary of residue names to rename to avoid CCD clashes, e.g., {'ALA': 'L:1'}.
temp_dir (PathLike): Temporary directory to store intermediate files. The default is None.
num_steps (int): Number of steps for sampling of the diffusion model. The default is 200; we see reasonable results with 50 steps.
solver (str): Solver to use for inference. Options are 'af3', 'simple', 'euler', and 'heun'. The default is 'af3'.
"""
# Load the checkpoint
@@ -117,6 +201,8 @@ class EvaluateAF3:
# Sampler sets diffusion batch size based on the following, not strictly on batch size in vaildation transform
self.config.dataset_params["diffusion_batch_size_valid"] = diffusion_batch_size
self.config.af3_inference["num_steps"] = num_steps
self.config.af3_inference["solver"] = solver
# Load the AF-3 trainer
self.trainer = trainer_factory[self.config.experiment.trainer](config=self.config)
@@ -148,35 +234,12 @@ class EvaluateAF3:
pipeline = hydra.utils.instantiate(self.config.dataset_params.val.interface.transform)
return pipeline
def find_files(self, spoofed_cif_path):
"""Find all files with the given extensions in the spoofed CIF directory.
Args:
spoofed_cif_path (Path): Path to the directory containing spoofed CIF files.
Returns:
List[Path]: List of files with the given extensions.
"""
matched_files = []
valid_extensions = [".cif", ".pdb", ".bcif", ".cif.gz", ".pdb.gz", ".bcif.gz"]
if spoofed_cif_path.is_file():
# Check if the file has one of the expected extensions
if any(spoofed_cif_path.name.endswith(ext) for ext in valid_extensions):
matched_files.append(spoofed_cif_path)
else:
# If it's a directory, search for files with the given extensions
logger.info(f"Searching for files with extensions {valid_extensions} in {spoofed_cif_path}...")
for ext in valid_extensions:
matched_files.extend(spoofed_cif_path.glob(f"*{ext}"))
return matched_files
def eval(self, files: list[os.PathLike]):
def eval(self, files: list[PathLike]):
"""Evaluate the model on a set of spoofed CIF files.
Args:
files (list[os.PathLike]): List of paths to spoofed CIF files or directories containing spoofed CIF files.
files (list[PathLike]): List of paths to spoofed CIF files or directories containing spoofed CIF files.
Coordinates must be present but may contain NaN values. If a directory is provided,
all files with the extensions .cif, .pdb, .bcif, .cif.gz, .pdb.gz, .bcif.gz will be processed.
"""
@@ -193,15 +256,9 @@ class EvaluateAF3:
# Construct the AF3 inference pipeline
pipeline = self.construct_pipeline()
# Accumulate all structures to predict
structures_to_predict = []
for file in files:
assert Path(file).exists(), f"Path {file} does not exist."
structures_to_predict.extend(self.find_files(file))
logger.info(f"Found {len(files)} structures to predict: {files}.")
logger.info(f"Found {len(structures_to_predict)} structures to predict: {structures_to_predict}.")
for structure in structures_to_predict:
for structure in files:
# ... parse into an AtomArray (`parse` handles all valid formats)
logger.info(f"Parsing from path: {structure}")
example_id = structure.name.split('.')[0]
@@ -240,7 +297,6 @@ class EvaluateAF3:
with torch.no_grad():
outputs = self.trainer.sampler.sample([pipeline_output], n_cycle=self.n_recycles, use_amp=self.config.training_params.use_amp)
# Override the AtomArray with the predited coordinates
atom_array_stack = build_stack_from_atom_array_and_batched_coords(
outputs["X_L"].cpu().numpy(), pipeline_output["atom_array"]
@@ -263,12 +319,14 @@ class EvaluateAF3:
def main():
parser = argparse.ArgumentParser(description="Evaluate AF3 using specified paths.")
parser.add_argument("inputs", nargs="+", help="List of paths to files (JSON or CIF/PDB) or directories of CIF/PDB files.")
parser.add_argument("inputs", nargs="+", help="List of paths to supported file types or directories of of supported files.")
parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the checkpoint file")
parser.add_argument("--cif_out_dir", type=str, required=True, help="Directory for output CIF files")
parser.add_argument("--n_recycles", type=int, default=10, help="Number of recycles for AF3")
parser.add_argument("--diffusion_batch_size", type=int, default=5, help="Diffusion batch size for AF3")
parser.add_argument("--rename_residues", type=str, default="", help="Dictionary of residue names to rename to avoid CCD clashes, e.g., {'ALA': 'L:1'}")
parser.add_argument("--num_steps", type=int, default=200, help="Number of steps for sampling of the diffusion model")
parser.add_argument("--solver", type=str, default="af3", help="Solver to use for inference. Options are 'af3', 'simple', 'euler', and 'heun'.")
args = parser.parse_args()
with tempfile.TemporaryDirectory() as temp_dir:
@@ -276,29 +334,9 @@ def main():
temp_dir.mkdir(parents=True, exist_ok=True)
# Prepare inputs based on the file types
file_paths_for_prediction = []
for path in args.inputs:
path = Path(path)
if path.suffix in {".json", ".yaml", ".yml", ".pkl"}:
# (Dictionary-like inputs, which will be converted to "spoofed" CIF files)
with open(path, 'r') as file:
# Load data based on file extension
if path.suffix == ".json":
data = json.load(file)
elif path.suffix in {".yaml", ".yml"}:
raise NotImplementedError("YAML files are not yet supported.")
elif path.suffix == ".pkl":
data = pickle.load(file)
if isinstance(data, dict):
data = [data] # Convert single dictionary to list for uniform processing
for item in data:
file_paths_for_prediction.append(_spoof_cif_from_dictionary(item, temp_dir))
else:
# (CIF/PDB files or directories of CIF/PDB files)
file_paths_for_prediction.append(path)
file_paths_for_prediction = _build_file_paths_for_prediction(args.inputs, temp_dir)
# Rename residues if necessary (e.g., for MPNN outputs that have ligand names that clash with the CCD)
residue_renaming_dict = json.loads(args.rename_residues) if args.rename_residues else {}
# Construct the evaluator
@@ -308,7 +346,9 @@ def main():
n_recycles=args.n_recycles,
diffusion_batch_size=args.diffusion_batch_size,
residue_renaming_dict=residue_renaming_dict,
temp_dir=temp_dir
temp_dir=temp_dir,
num_steps=args.num_steps,
solver=args.solver,
)
# Launch the evaluation

View File

@@ -0,0 +1,27 @@
from rf2aa.inference.inference import _build_file_paths_for_prediction
import pytest
from os import PathLike
from cifutils import parse
from pathlib import Path
current_file_directory = Path(__file__).parent
@pytest.mark.parametrize("file_path", [
"data/example_from_ccd.cif",
"data/nested_examples",
"data/example_from_sdf.json",
"data/example_from_smiles.cif",
"data/multiple_examples_from_json.json"
])
def test_build_file_paths_for_prediction(file_path: PathLike, tmp_path: Path):
"""Use the inference pipeline to build and parse inputs for prediction."""
file_path = [current_file_directory / Path(file_path)]
# Call the function with the file path and temporary directory
paths = _build_file_paths_for_prediction(file_path, tmp_path)
# Iterate over the returned paths and parse them
for path in paths:
output = parse(path)
assert output is not None
assert len(output["assemblies"]["1"][0]) > 0