mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
535 lines
21 KiB
Python
535 lines
21 KiB
Python
import os
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import pickle
|
|
import tempfile
|
|
from collections.abc import Mapping
|
|
from os import PathLike
|
|
from pathlib import Path
|
|
|
|
import hydra
|
|
import numpy as np
|
|
import torch
|
|
import yaml
|
|
from biotite.structure import AtomArray, AtomArrayStack, stack
|
|
from cifutils import parse
|
|
from cifutils.tools.inference import (
|
|
build_msa_paths_by_chain_id_from_component_list,
|
|
components_to_atom_array,
|
|
)
|
|
from cifutils.utils.io_utils import to_cif_file
|
|
from datahub.encoding_definitions import AF3SequenceEncoding
|
|
import omegaconf
|
|
from omegaconf import OmegaConf
|
|
|
|
from rf2aa.metrics.predicted_error import WriteAF3Confidence
|
|
from rf2aa.trainer_base import trainer_factory
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Define the sequence encoding; needed to decode the restypes when saving to CIF
|
|
encoding = AF3SequenceEncoding()
|
|
|
|
|
|
def build_stack_from_atom_array_and_batched_coords(
|
|
coords: np.ndarray,
|
|
atom_array: AtomArray,
|
|
annotations_to_keep: list[str] = [
|
|
"chain_id",
|
|
"transformation_id",
|
|
"res_id",
|
|
"res_name",
|
|
"element",
|
|
"atom_name",
|
|
],
|
|
) -> AtomArrayStack:
|
|
"""Builds an AtomArrayStack from an AtomArray and a set of coordinates with a batch dimension.
|
|
|
|
Additionally, handles the case where the AtomArray contains multiple transformations and we must adjust the chain_id.
|
|
|
|
Args:
|
|
coords (np.array): The coordinates to be assigned to the AtomArrayStack. Must have shape (nbatch, n_atoms, 3).
|
|
atom_array (AtomArray): The AtomArray to be stacked. Must have shape (n_atoms,)
|
|
"""
|
|
# (Diffusion batch size will become the number of models)
|
|
n_batch = coords.shape[0]
|
|
|
|
# Remove unwanted annotations
|
|
for annotation in atom_array.get_annotation_categories():
|
|
if annotation not in annotations_to_keep:
|
|
atom_array.del_annotation(annotation)
|
|
|
|
# Build the stack and assign the coordinates
|
|
atom_array_stack = stack([atom_array for _ in range(n_batch)])
|
|
atom_array_stack.coord = coords
|
|
|
|
# Adjust chain_id if there are multiple transformations
|
|
# (Otherwise, we will have ambiguous bond annotations, since only `chain_id` is used for the bond annotations)
|
|
if (
|
|
"transformation_id" in atom_array.get_annotation_categories()
|
|
and len(np.unique(atom_array_stack.transformation_id)) > 1
|
|
):
|
|
atom_array_stack.chain_id = (
|
|
atom_array_stack.chain_id + atom_array_stack.transformation_id
|
|
)
|
|
|
|
return atom_array_stack
|
|
|
|
|
|
def _spoof_cif_from_dictionary(item: dict, temp_dir: PathLike) -> Path:
|
|
"""Unpacks a dictionary to create a CIF file from its components.
|
|
|
|
Args:
|
|
item (dict): A dictionary containing 'name' and 'components', optionally 'bonds'.
|
|
temp_dir (Path): Path to the temporary directory for storing CIF files.
|
|
|
|
Returns:
|
|
Path: The path to the created CIF file, saved in the temporary directory.
|
|
|
|
Raises:
|
|
NotImplementedError: If 'bonds' is present in the dictionary.
|
|
ValueError: If 'name' or 'components' are missing from the dictionary.
|
|
"""
|
|
# Validate the dictionary structure ("name" and "components" are required, "bonds" is optional)
|
|
assert (
|
|
"name" in item and "components" in item
|
|
), "The input dictionary must contain 'name' and 'components' keys."
|
|
|
|
# Build components
|
|
atom_array, component_list = components_to_atom_array(
|
|
item["components"], return_components=True, bonds=item.get("bonds", None)
|
|
)
|
|
msa_paths_by_chain_id = build_msa_paths_by_chain_id_from_component_list(
|
|
component_list
|
|
)
|
|
|
|
# Create a temporary CIF file from the JSON data
|
|
cif_path = Path(temp_dir) / f"{item['name']}.cif"
|
|
save_path = to_cif_file(
|
|
atom_array,
|
|
cif_path,
|
|
extra_categories={"msa_paths_by_chain_id": msa_paths_by_chain_id}
|
|
if msa_paths_by_chain_id
|
|
else None,
|
|
file_type="cif", # Not zipped for efficiency (as it's a temporary directory anyways)
|
|
)
|
|
|
|
return Path(save_path)
|
|
|
|
|
|
def _build_file_paths_for_prediction(inputs: list, temp_dir: PathLike) -> list[Path]:
|
|
"""Prepare files for prediction based on the input paths.
|
|
|
|
Input paths may be dictionary-like format (e.g., JSON, YAML, Pickle), CIF/PDB files, or directories containing these files.
|
|
Processes directories to find supported file types and converts dictionary-like formats to CIF files.
|
|
|
|
Args:
|
|
inputs (list): List of input paths (JSON, YAML, Pickle, or CIF/PDB).
|
|
temp_dir (Path): Path to the temporary directory for storing CIF files.
|
|
|
|
Returns:
|
|
list[Path]: List of file paths for prediction.
|
|
"""
|
|
DICTIONARY_LIKE_EXTENSIONS = {".json", ".yaml", ".yml", ".pkl"}
|
|
CIF_LIKE_EXTENSIONS = {".cif", ".pdb", ".bcif", ".cif.gz", ".pdb.gz", ".bcif.gz"}
|
|
|
|
# Collect all files from inputs, handling directories and individual files
|
|
paths_to_raw_input_files = []
|
|
for input_path in inputs:
|
|
if Path(input_path).is_dir():
|
|
paths_to_raw_input_files.extend(
|
|
_find_files(
|
|
input_path, DICTIONARY_LIKE_EXTENSIONS | CIF_LIKE_EXTENSIONS
|
|
)
|
|
)
|
|
else:
|
|
paths_to_raw_input_files.append(Path(input_path))
|
|
|
|
paths_to_cif_like_files = []
|
|
for path in paths_to_raw_input_files:
|
|
#concatenated_suffix = "".join(path.suffixes)
|
|
concatenated_suffix = path.suffixes[-1]
|
|
if concatenated_suffix in DICTIONARY_LIKE_EXTENSIONS:
|
|
# Spoof CIF files from dictionary-like formats
|
|
with open(path, "rb" if path.suffix == ".pkl" else "r") as file:
|
|
# Load data based on file extension
|
|
if path.suffix == ".json":
|
|
data = json.load(file)
|
|
elif path.suffix in {".yaml", ".yml"}:
|
|
raise NotImplementedError("YAML files are not yet supported.")
|
|
elif path.suffix == ".pkl":
|
|
data = pickle.load(file)
|
|
|
|
if isinstance(data, dict):
|
|
data = [
|
|
data
|
|
] # Convert single dictionary to list for uniform processing
|
|
|
|
for item in data:
|
|
paths_to_cif_like_files.append(
|
|
_spoof_cif_from_dictionary(item, temp_dir)
|
|
)
|
|
elif concatenated_suffix in CIF_LIKE_EXTENSIONS:
|
|
# Directly use CIF-like files
|
|
paths_to_cif_like_files.append(path)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported file extension: {path.suffix} (path: {path}; paths: {paths_to_raw_input_files})."
|
|
)
|
|
|
|
return paths_to_cif_like_files
|
|
|
|
|
|
def _find_files(path: PathLike, supported_file_types: list) -> list[Path]:
|
|
"""Recursively find all files with the given extensions in the specified path.
|
|
|
|
Args:
|
|
path (PathLike): Path to the directory containing the files.
|
|
supported_file_types (list): List of supported file extensions.
|
|
|
|
Returns:
|
|
list[Path]: List of files with the given extensions.
|
|
"""
|
|
files_with_supported_types = []
|
|
path = Path(path)
|
|
|
|
# Check if the path is a directory
|
|
if path.is_dir():
|
|
# Search for files with each supported extension
|
|
for file_type in supported_file_types:
|
|
files_with_supported_types.extend(path.glob(f"*{file_type}"))
|
|
elif path.is_file() and path.suffix in supported_file_types:
|
|
# If it's a file and has a supported extension, add to the list
|
|
files_with_supported_types.append(path)
|
|
|
|
return files_with_supported_types
|
|
|
|
|
|
def _update_nested_dictconfig(d: Mapping, u: Mapping, depth: int = 0) -> Mapping:
|
|
"""Recursive function to overwrite contents of one nested omegaconf dictconfig with another.
|
|
|
|
Args:
|
|
d: dictionary of dictconfigs whose contents will be overwritten
|
|
u: dictionary of items which will overwrite or add to values in d
|
|
depth: depth of recursion: a positive integer:
|
|
-used to keep the outermost layer of the config as a dict instead of DictConfig.
|
|
-set to 1 or higher to return only DictConfig.
|
|
Returns:
|
|
d updated to contain values in u
|
|
"""
|
|
d = dict(d)
|
|
u = dict(u)
|
|
for k, v in u.items():
|
|
if isinstance(v, Mapping):
|
|
d[k] = _update_nested_dictconfig(d.get(k, {}), v, depth=depth + 1)
|
|
else:
|
|
d[k] = v
|
|
if depth == 0:
|
|
return d
|
|
else:
|
|
return omegaconf.dictconfig.DictConfig(d)
|
|
|
|
|
|
class EvaluateAF3:
|
|
"""Class for inference with AF3. Evaluates a trained AF3 model on a set of spoofed CIFs."""
|
|
|
|
def __init__(
|
|
self,
|
|
checkpoint_path: PathLike,
|
|
cif_out_dir: PathLike,
|
|
n_recycles: int,
|
|
diffusion_batch_size: int,
|
|
config_override_path: PathLike | None = None,
|
|
residue_renaming_dict: dict | None = None,
|
|
temp_dir: PathLike | None = None,
|
|
num_steps: int = 200,
|
|
solver: str = "af3",
|
|
overwrite: bool = False
|
|
):
|
|
"""Initialize the evaluator.
|
|
|
|
Args:
|
|
checkpoint_path (PathLike): Path to the checkpoint file, e.g., /path/to/checkpoint.pt.
|
|
cif_out_dir (PathLike): Directory to save the output (predicted) CIF files.
|
|
config_override_path (PathLike): Path to a yaml file with config options to override those in the checkpoint file.
|
|
world_size (int): Number of GPUs to use for evaluation.
|
|
n_recycles (int): Number of recycles for AF3. The default is 10.
|
|
diffusion_batch_size (int): Diffusion batch size for AF3. Each predicted structure will be saved as a separate model within the same CIF file.
|
|
residue_renaming_dict (dict): Dictionary of residue names to rename to avoid CCD clashes, e.g., {'ALA': 'L:1'}.
|
|
temp_dir (PathLike): Temporary directory to store intermediate files. The default is None.
|
|
num_steps (int): Number of steps for sampling of the diffusion model. The default is 200; we see reasonable results with 50 steps.
|
|
solver (str): Solver to use for inference. Options are 'af3', 'simple', 'euler', and 'heun'. The default is 'af3'.
|
|
"""
|
|
|
|
# Load the checkpoint
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
|
|
|
|
# Load the config
|
|
self.config = OmegaConf.create(checkpoint["training_config"])
|
|
|
|
if config_override_path is not None:
|
|
with open(config_override_path, 'r') as fs:
|
|
config_override_dict = yaml.load(fs, yaml.FullLoader)
|
|
self.config = _update_nested_dictconfig(self.config, config_override_dict)
|
|
self.config = OmegaConf.create(self.config)
|
|
|
|
# Make sure we aren't using the version with a bug in plddt
|
|
if (
|
|
self.config.experiment.name
|
|
== "rf2aa-af3-repro-rollout_nmw_from_scratch_af3_style_no_cb_normal_crop_cont_3"
|
|
):
|
|
raise ValueError(
|
|
"These weights are outdated and the plddt metric may be inaccurate. Please update to the latest available weights."
|
|
)
|
|
|
|
# Sampler sets diffusion batch size based on the following, not strictly on batch size in vaildation transform
|
|
self.config.dataset_params["diffusion_batch_size_valid"] = diffusion_batch_size
|
|
self.config.af3_inference["num_steps"] = num_steps
|
|
self.config.af3_inference["solver"] = solver
|
|
|
|
# Load the AF-3 trainer
|
|
self.trainer = trainer_factory[self.config.experiment.trainer](
|
|
config=self.config
|
|
)
|
|
self.trainer.checkpoint = checkpoint
|
|
|
|
# Set the output directory for the CIF files (e.g., predicted structures)
|
|
self.cif_out_dir = Path(cif_out_dir) if cif_out_dir else Path("./")
|
|
|
|
# Model parameters
|
|
self.n_recycles = n_recycles
|
|
self.diffusion_batch_size = diffusion_batch_size
|
|
if "confidence_loss" in self.config.loss:
|
|
self.confidence_writer = WriteAF3Confidence(
|
|
**self.config.loss.confidence_loss
|
|
)
|
|
else:
|
|
self.confidence_writer = None
|
|
|
|
# Rename residues
|
|
self.residue_renaming_dict = residue_renaming_dict
|
|
self.temp_dir = Path(temp_dir)
|
|
|
|
self.overwrite = overwrite
|
|
|
|
def construct_pipeline(self):
|
|
"""Construct the AF3 inference pipeline."""
|
|
self.config.dataset_params.val.interface.transform.n_recycles = self.n_recycles
|
|
self.config.dataset_params.val.interface.transform.diffusion_batch_size = (
|
|
self.diffusion_batch_size
|
|
)
|
|
self.config.dataset_params.val.interface.transform.return_atom_array = (
|
|
True # Required for `to_cif`
|
|
)
|
|
|
|
assert (
|
|
self.config.dataset_params.val.interface.transform.n_recycles
|
|
== self.n_recycles
|
|
), "Number of recycles not set correctly."
|
|
assert (
|
|
self.config.dataset_params.val.interface.transform.diffusion_batch_size
|
|
== self.diffusion_batch_size
|
|
), "Diffusion batch size not set correctly."
|
|
pipeline = hydra.utils.instantiate(
|
|
self.config.dataset_params.val.interface.transform
|
|
)
|
|
return pipeline
|
|
|
|
def eval(self, files: list[PathLike]):
|
|
"""Evaluate the model on a set of spoofed CIF files.
|
|
|
|
Args:
|
|
files (list[PathLike]): List of paths to spoofed CIF files or directories containing spoofed CIF files.
|
|
Coordinates must be present but may contain NaN values. If a directory is provided,
|
|
all files with the extensions .cif, .pdb, .bcif, .cif.gz, .pdb.gz, .bcif.gz will be processed.
|
|
"""
|
|
# Construct the model and load the checkpoint
|
|
gpu = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
self.trainer.construct_model(device=gpu, inference=True)
|
|
self.trainer.load_model()
|
|
|
|
# Set the model to evaluation mode
|
|
self.trainer.model.eval()
|
|
|
|
logger.info("Building Transform pipeline...")
|
|
|
|
# Construct the AF3 inference pipeline
|
|
pipeline = self.construct_pipeline()
|
|
|
|
logger.info(f"Found {len(files)} structures to predict: {files}.")
|
|
|
|
for structure in files:
|
|
# ... parse into an AtomArray (`parse` handles all valid formats)
|
|
logger.info(f"Parsing from path: {structure}")
|
|
#example_id = structure.name.split(".")[0]
|
|
example_id = ".".join(structure.name.split(".")[:-1])
|
|
|
|
# optionally, skip if output already exists
|
|
cif_output_path = example_id + '.cif'
|
|
cif_output_path = self.cif_out_dir / cif_output_path
|
|
if os.path.exists(cif_output_path) and not self.overwrite:
|
|
logger.info(f"Existing output for {example_id} found at {cif_output_path}. Skipping this example. Set --overwrite to not skip examples with existing output")
|
|
continue
|
|
|
|
# If we're renaming residues, we do a brute-force replacement in the CIF file
|
|
if self.residue_renaming_dict:
|
|
logger.info(
|
|
f"Renaming residues in {structure} with brute-force find and replace: {self.residue_renaming_dict}"
|
|
)
|
|
with open(structure, "r") as f:
|
|
content = f.read()
|
|
for old_res, new_res in self.residue_renaming_dict.items():
|
|
content = content.replace(old_res, new_res)
|
|
structure = Path(self.temp_dir / structure.name)
|
|
with open(structure, "w") as f:
|
|
f.write(content)
|
|
|
|
out = parse(structure, remove_hydrogens=True)
|
|
|
|
# ... get the atom array and set NaN coordinates to random
|
|
atom_array = (
|
|
out["assemblies"]["1"][0]
|
|
if "assemblies" in out
|
|
else out["asym_unit"][0]
|
|
)
|
|
|
|
# HACK: Set NaN coordinates to random values to avoid unexpected behavior in the pipeline
|
|
atom_array.coord[np.isnan(atom_array.coord)] = np.random.rand(
|
|
*atom_array.coord[np.isnan(atom_array.coord)].shape
|
|
)
|
|
|
|
# ... assemble the pipeline input in a format compatible with the DataHub pipeline
|
|
pipeline_input = {
|
|
"example_id": example_id,
|
|
"atom_array": atom_array,
|
|
"chain_info": out["chain_info"],
|
|
}
|
|
|
|
# ... run dataloading and featurization
|
|
pipeline_output = pipeline(pipeline_input)
|
|
|
|
# Model inference
|
|
with torch.no_grad():
|
|
outputs = self.trainer.sampler.sample(
|
|
[pipeline_output],
|
|
n_cycle=self.n_recycles,
|
|
use_amp=self.config.training_params.use_amp,
|
|
)
|
|
|
|
# Override the AtomArray with the predited coordinates
|
|
atom_array_stack = build_stack_from_atom_array_and_batched_coords(
|
|
outputs["X_L"].cpu().numpy(), pipeline_output["atom_array"]
|
|
)
|
|
|
|
# Write the atom array to a CIF file
|
|
# NOTE: To make the secondary structure appear, run `dss` in PyMol (see: https://biology.stackexchange.com/questions/70143/can-pymol-show-cartoon-secondary-structure-for-a-pdb-of-multiple-frames)
|
|
out_path = to_cif_file(
|
|
atom_array_stack, self.cif_out_dir / example_id, file_type="cif"
|
|
)
|
|
logger.info(f"Prediction for {example_id} written to {out_path}.")
|
|
|
|
if "confidence" in outputs:
|
|
loss_input = {
|
|
"example_id": example_id,
|
|
"is_real_atom": pipeline_output["confidence_feats"]["is_real_atom"],
|
|
}
|
|
logger.info(f"Writing {example_id}.score to {self.cif_out_dir}")
|
|
df = self.confidence_writer(None, outputs, loss_input)
|
|
df.to_csv(self.cif_out_dir / f"{example_id}.score", index=False)
|
|
logger.info(
|
|
f"Confidence metrics for {example_id}.cif written to {self.cif_out_dir / example_id}.score."
|
|
)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Evaluate AF3 using specified paths.")
|
|
parser.add_argument(
|
|
"inputs",
|
|
nargs="+",
|
|
help="List of paths to supported file types or directories of of supported files.",
|
|
)
|
|
parser.add_argument(
|
|
"--checkpoint_path", type=str, required=True, help="Path to the checkpoint file"
|
|
)
|
|
parser.add_argument(
|
|
"--cif_out_dir", type=str, required=True, help="Directory for output CIF files"
|
|
)
|
|
parser.add_argument(
|
|
"--config_override_path",
|
|
type=str,
|
|
required=False,
|
|
help="Path to a yaml file with configs to override those in the checkpoint file",
|
|
)
|
|
parser.add_argument(
|
|
"--n_recycles", type=int, default=10, help="Number of recycles for AF3"
|
|
)
|
|
parser.add_argument(
|
|
"--diffusion_batch_size",
|
|
type=int,
|
|
default=5,
|
|
help="Diffusion batch size for AF3",
|
|
)
|
|
parser.add_argument(
|
|
"--rename_residues",
|
|
type=str,
|
|
default="",
|
|
help="Dictionary of residue names to rename to avoid CCD clashes, e.g., {'ALA': 'L:1'}",
|
|
)
|
|
parser.add_argument(
|
|
"--num_steps",
|
|
type=int,
|
|
default=200,
|
|
help="Number of steps for sampling of the diffusion model",
|
|
)
|
|
parser.add_argument(
|
|
"--solver",
|
|
type=str,
|
|
default="af3",
|
|
help="Solver to use for inference. Options are 'af3', 'simple', 'euler', and 'heun'.",
|
|
)
|
|
parser.add_argument(
|
|
"--overwrite",
|
|
default=False,
|
|
action="store_true",
|
|
help="Overwrite existing .cif outputs with new runs.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
temp_dir = Path(temp_dir)
|
|
temp_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Prepare inputs based on the file types
|
|
file_paths_for_prediction = _build_file_paths_for_prediction(
|
|
args.inputs, temp_dir
|
|
)
|
|
|
|
# Rename residues if necessary (e.g., for MPNN outputs that have ligand names that clash with the CCD)
|
|
residue_renaming_dict = (
|
|
json.loads(args.rename_residues) if args.rename_residues else {}
|
|
)
|
|
|
|
# Construct the evaluator
|
|
evaluator = EvaluateAF3(
|
|
checkpoint_path=args.checkpoint_path,
|
|
cif_out_dir=args.cif_out_dir,
|
|
config_override_path=args.config_override_path,
|
|
n_recycles=args.n_recycles,
|
|
diffusion_batch_size=args.diffusion_batch_size,
|
|
residue_renaming_dict=residue_renaming_dict,
|
|
temp_dir=temp_dir,
|
|
num_steps=args.num_steps,
|
|
solver=args.solver,
|
|
overwrite=args.overwrite
|
|
)
|
|
|
|
# Launch the evaluation
|
|
evaluator.eval(files=file_paths_for_prediction)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|