diff --git a/src/modelhub/callbacks/dump_validation_structures.py b/src/modelhub/callbacks/dump_validation_structures.py index 2989a0e..b00299e 100644 --- a/src/modelhub/callbacks/dump_validation_structures.py +++ b/src/modelhub/callbacks/dump_validation_structures.py @@ -1,9 +1,9 @@ from os import PathLike from pathlib import Path +from atomworks.common import parse_example_id from beartype.typing import Any -from atomworks.common import parse_example_id from modelhub.callbacks.base import BaseCallback from modelhub.utils.io import ( build_stack_from_atom_array_and_batched_coords, diff --git a/src/modelhub/callbacks/metrics_logging.py b/src/modelhub/callbacks/metrics_logging.py index 7e608bc..ac9a380 100755 --- a/src/modelhub/callbacks/metrics_logging.py +++ b/src/modelhub/callbacks/metrics_logging.py @@ -3,10 +3,10 @@ from copy import deepcopy from pathlib import Path import pandas as pd +from atomworks.ml.utils import nested_dict from beartype.typing import Any, Literal from omegaconf import ListConfig -from atomworks.ml.utils import nested_dict from modelhub.callbacks.base import BaseCallback from modelhub.utils.ddp import RankedLogger from modelhub.utils.logging import ( diff --git a/src/modelhub/callbacks/train_logging.py b/src/modelhub/callbacks/train_logging.py index 6abbaf8..b480a0a 100755 --- a/src/modelhub/callbacks/train_logging.py +++ b/src/modelhub/callbacks/train_logging.py @@ -2,6 +2,7 @@ import time from collections import defaultdict import pandas as pd +from atomworks.common import parse_example_id from beartype.typing import Any from lightning.fabric.wrappers import ( _FabricOptimizer, @@ -12,7 +13,6 @@ from rich.table import Table from torch import nn from torchmetrics.aggregation import MeanMetric -from atomworks.common import parse_example_id from modelhub.callbacks.base import BaseCallback from modelhub.utils.ddp import RankedLogger from modelhub.utils.logging import ( diff --git a/src/modelhub/data/extra_xforms.py b/src/modelhub/data/extra_xforms.py index ee056a3..7261364 100644 --- a/src/modelhub/data/extra_xforms.py +++ b/src/modelhub/data/extra_xforms.py @@ -1,5 +1,4 @@ import torch - from atomworks.ml.transforms._checks import ( check_contains_keys, ) diff --git a/src/modelhub/data/ground_truth_conformer.py b/src/modelhub/data/ground_truth_conformer.py index 8f028bd..82691dc 100644 --- a/src/modelhub/data/ground_truth_conformer.py +++ b/src/modelhub/data/ground_truth_conformer.py @@ -1,14 +1,13 @@ import numpy as np -from beartype.typing import Any -from biotite.structure import AtomArray -from jaxtyping import Bool, Float - from atomworks.enums import GroundTruthConformerPolicy from atomworks.ml.transforms._checks import ( check_atom_array_annotation, check_contains_keys, ) from atomworks.ml.transforms.base import Transform +from beartype.typing import Any +from biotite.structure import AtomArray +from jaxtyping import Bool, Float def add_ground_truth_reference_conformer( diff --git a/src/modelhub/data/ground_truth_template.py b/src/modelhub/data/ground_truth_template.py index d20ec2e..6f775d1 100644 --- a/src/modelhub/data/ground_truth_template.py +++ b/src/modelhub/data/ground_truth_template.py @@ -3,11 +3,6 @@ from dataclasses import dataclass import numpy as np import torch -from beartype.typing import Any, Callable, Final, Sequence -from biotite.structure import AtomArray -from jaxtyping import Bool, Float, Shaped -from torch import Tensor - from atomworks.enums import ChainType from atomworks.ml.transforms._checks import ( check_atom_array_annotation, @@ -20,6 +15,11 @@ from atomworks.ml.utils.token import ( get_af3_token_center_masks, get_token_starts, ) +from beartype.typing import Any, Callable, Final, Sequence +from biotite.structure import AtomArray +from jaxtyping import Bool, Float, Shaped +from torch import Tensor + from modelhub.utils.torch_utils import assert_no_nans logger = logging.getLogger(__name__) diff --git a/src/modelhub/data/pipeline_utils.py b/src/modelhub/data/pipeline_utils.py index 6100a15..06268cd 100644 --- a/src/modelhub/data/pipeline_utils.py +++ b/src/modelhub/data/pipeline_utils.py @@ -1,11 +1,11 @@ from functools import partial import torch -from omegaconf import DictConfig - from atomworks.enums import ChainType from atomworks.ml.transforms._checks import check_atom_array_annotation from atomworks.ml.transforms.crop import compute_local_hash +from omegaconf import DictConfig + from modelhub.data.ground_truth_template import ( FeaturizeNoisedGroundTruthAsTemplateDistogram, TokenGroupNoiseScaleSampler, diff --git a/src/modelhub/metrics/base.py b/src/modelhub/metrics/base.py index 42ce5b0..2fb8444 100644 --- a/src/modelhub/metrics/base.py +++ b/src/modelhub/metrics/base.py @@ -3,11 +3,11 @@ from abc import ABC, abstractmethod from functools import cached_property import hydra +from atomworks.ml.utils import error, nested_dict from beartype.typing import Any from omegaconf import DictConfig from toolz import keymap -from atomworks.ml.utils import error, nested_dict from modelhub.utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/src/modelhub/metrics/chiral.py b/src/modelhub/metrics/chiral.py index 868b886..7d500dc 100644 --- a/src/modelhub/metrics/chiral.py +++ b/src/modelhub/metrics/chiral.py @@ -1,14 +1,14 @@ import torch -from beartype.typing import Any -from biotite.structure import AtomArray, AtomArrayStack -from jaxtyping import Bool, Float - from atomworks.ml.transforms.af3_reference_molecule import ( get_af3_reference_molecule_features, ) from atomworks.ml.transforms.atom_array import ensure_atom_array_stack from atomworks.ml.transforms.chirals import add_af3_chiral_features from atomworks.ml.transforms.rdkit_utils import get_rdkit_chiral_centers +from beartype.typing import Any +from biotite.structure import AtomArray, AtomArrayStack +from jaxtyping import Bool, Float + from modelhub.kinematics import get_dih from modelhub.metrics.base import Metric diff --git a/src/modelhub/metrics/distogram.py b/src/modelhub/metrics/distogram.py index 51e9e1c..515609d 100644 --- a/src/modelhub/metrics/distogram.py +++ b/src/modelhub/metrics/distogram.py @@ -4,12 +4,12 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from atomworks.ml.utils.token import get_af3_token_representative_idxs from beartype.typing import Any, Literal from biotite.structure import AtomArrayStack from einops import rearrange, repeat from jaxtyping import Bool, Float -from atomworks.ml.utils.token import get_af3_token_representative_idxs from modelhub.loss.af3_losses import distogram_loss from modelhub.metrics.base import Metric from modelhub.utils.torch_utils import assert_no_nans diff --git a/src/modelhub/metrics/lddt.py b/src/modelhub/metrics/lddt.py index 9193c1b..85cb1ef 100644 --- a/src/modelhub/metrics/lddt.py +++ b/src/modelhub/metrics/lddt.py @@ -1,9 +1,5 @@ import numpy as np import torch -from beartype.typing import Any -from biotite.structure import AtomArray, AtomArrayStack, stack -from jaxtyping import Bool, Float, Int - from atomworks.ml.transforms.atom_array import ( AddGlobalTokenIdAnnotation, ensure_atom_array_stack, @@ -11,6 +7,10 @@ from atomworks.ml.transforms.atom_array import ( from atomworks.ml.transforms.atomize import AtomizeByCCDName from atomworks.ml.transforms.base import Compose from atomworks.ml.utils.token import get_token_starts +from beartype.typing import Any +from biotite.structure import AtomArray, AtomArrayStack, stack +from jaxtyping import Bool, Float, Int + from modelhub.metrics.base import Metric from modelhub.utils.ddp import RankedLogger diff --git a/src/modelhub/metrics/metric_utils.py b/src/modelhub/metrics/metric_utils.py index a8b3bf9..80f628f 100644 --- a/src/modelhub/metrics/metric_utils.py +++ b/src/modelhub/metrics/metric_utils.py @@ -1,16 +1,22 @@ from itertools import combinations +from typing import Union import numpy as np import torch +from jaxtyping import Bool, Float +from numpy.typing import NDArray -def find_bin_midpoints(max_distance, num_bins, device="cpu"): +def find_bin_midpoints( + max_distance: float, num_bins: int, device: Union[str, torch.device] = "cpu" +) -> Float[torch.Tensor, "num_bins"]: """ Find the bin midpoints for a given binning scheme. Used to find expectation of values when converting binned predictions to unbinned predictions. Assumes the minimum of the schema is 0. Args: max_distance: float, maximum distance num_bins: int, number of bins + device: device to run on Returns: pae_midpoints: [num_bins], bin midpoints """ @@ -26,7 +32,9 @@ def find_bin_midpoints(max_distance, num_bins, device="cpu"): return midpoints -def unbin_logits(logits, max_distance, num_bins): +def unbin_logits( + logits: Float[torch.Tensor, "B num_bins L X"], max_distance: float, num_bins: int +) -> Float[torch.Tensor, "B L L"]: """ Unbin the logits to get the matrix Args: @@ -42,7 +50,9 @@ def unbin_logits(logits, max_distance, num_bins): return unbinned -def create_chainwise_masks_1d(ch_label, device="cpu"): +def create_chainwise_masks_1d( + ch_label: NDArray[np.str_], device: Union[str, torch.device] = "cpu" +) -> dict[str, Bool[torch.Tensor, "L"]]: """ Create 1D chainwise masks for a set of chain labels Args: @@ -61,7 +71,9 @@ def create_chainwise_masks_1d(ch_label, device="cpu"): return ch_masks -def create_chainwise_masks_2d(ch_label, device="cpu"): +def create_chainwise_masks_2d( + ch_label: NDArray[np.str_], device: Union[str, torch.device] = "cpu" +) -> dict[str, Bool[torch.Tensor, "L L"]]: """ Create 2D chainwise masks for a set of chain labels Args: @@ -79,9 +91,16 @@ def create_chainwise_masks_2d(ch_label, device="cpu"): return ch_masks -def create_interface_masks_2d(ch_label, device="cpu"): +def create_interface_masks_2d( + ch_label: NDArray[np.str_], device: Union[str, torch.device] = "cpu" +) -> dict[tuple[str, str], Bool[torch.Tensor, "L L"]]: """ Create interface masks for a set of chain labels + Args: + ch_label: np.ndarray [L], chain labels + device: torch.device, device to run on + Returns: + pairs_to_score: dict mapping chain pairs to boolean masks """ unique_chains = np.unique(ch_label) pairs_to_score = {} @@ -97,12 +116,17 @@ def create_interface_masks_2d(ch_label, device="cpu"): return pairs_to_score -def compute_mean_over_subsampled_pairs(matrix_to_mean, pairs_to_score, eps=1e-6): +def compute_mean_over_subsampled_pairs( + matrix_to_mean: Float[torch.Tensor, "B L M"], + pairs_to_score: Bool[torch.Tensor, "L M"], + eps: float = 1e-6, +) -> Float[torch.Tensor, "B"]: """ Compute the mean over a subsample of pairs in a 2d matrix. Returns a tensor with an element for each batch Args: matrix_to_mean: tensor of shape (batch, L, L) pairs_to_score: 2d tensor of shape (L, L) with 1s where pairs should be scored and 0s elsewhere + eps: small epsilon value to avoid division by zero Returns: 1d tensor of shape (batch,) with the mean over the subsampled pairs for each batch """ @@ -120,9 +144,49 @@ def compute_mean_over_subsampled_pairs(matrix_to_mean, pairs_to_score, eps=1e-6) return batch -def spread_batch_into_dictionary(batch): +def compute_min_over_subsampled_pairs( + matrix_to_min: Float[torch.Tensor, "B L M"], + pairs_to_score: Bool[torch.Tensor, "L M"], +) -> Float[torch.Tensor, "B"]: + """ + Compute the min over a subsample of pairs in a 2d matrix. Returns a tensor with an element for each batch + Args: + matrix_to_min: tensor of shape (batch, L, L) + pairs_to_score: 2d tensor of shape (L, L) with 1s where pairs should be scored and 0s elsewhere + Returns: + 1d tensor of shape (batch,) with the min over the subsampled pairs for each batch + """ + B, L, M = matrix_to_min.shape + assert matrix_to_min.shape == ( + B, + L, + M, + ), "Matrix to min should be of shape (batch, L, M)" + assert pairs_to_score.shape == (L, M), "Pairs to score should be of shape (L, M)" + # Use torch.where to efficiently mask without cloning the entire matrix + # This broadcasts pairs_to_score across the batch dimension + masked_matrix = torch.where( + pairs_to_score.bool(), # condition (L, M) -> broadcasts to (B, L, M) + matrix_to_min, # if True: use original values (B, L, M) + torch.tensor( + float("inf"), device=matrix_to_min.device, dtype=matrix_to_min.dtype + ), # if False: use inf + ) + + # Flatten the last two dimensions and compute min across them + batch = masked_matrix.view(B, -1).min(dim=-1)[0] + + assert batch.shape == (B,), "Batch should be of shape (batch,)" + return batch + + +def spread_batch_into_dictionary(batch: Float[torch.Tensor, "B"]) -> dict[int, float]: """ Given a batch of data, create a dictionary with keys as the batch index and value as the corresponding data + Args: + batch: 1D tensor of shape (B,) + Returns: + Dictionary mapping batch indices to float values """ assert len(batch.shape) == 1, f"Batch should be a 1d tensor, {batch}" return {i: data.item() for i, data in enumerate(batch)} diff --git a/src/modelhub/metrics/predicted_error.py b/src/modelhub/metrics/predicted_error.py index 41fff95..840b431 100644 --- a/src/modelhub/metrics/predicted_error.py +++ b/src/modelhub/metrics/predicted_error.py @@ -107,8 +107,9 @@ class ComputeIPTM(Metric): protein_mask[None, :] & protein_mask[:, None] * to_calculate ) protein_ligand_mask = ( - protein_mask[None, :] & ligand_mask[:, None] * to_calculate - ) + (protein_mask[None, :] & ligand_mask[:, None]) + | (ligand_mask[None, :] & protein_mask[:, None]) + ) * to_calculate ligand_ligand_mask = ligand_mask[None, :] & ligand_mask[:, None] * to_calculate # calculate iptm for each interface type iptm_protein_protein = compute_ptm(pae, protein_protein_mask) diff --git a/src/modelhub/metrics/rasa.py b/src/modelhub/metrics/rasa.py index 5b1d79c..deb11f7 100644 --- a/src/modelhub/metrics/rasa.py +++ b/src/modelhub/metrics/rasa.py @@ -1,8 +1,8 @@ import numpy as np +from atomworks.ml.transforms.sasa import calculate_atomwise_rasa from beartype.typing import Any from biotite.structure import AtomArrayStack -from atomworks.ml.transforms.sasa import calculate_atomwise_rasa from modelhub.metrics.base import Metric diff --git a/src/modelhub/metrics/selected_distances.py b/src/modelhub/metrics/selected_distances.py index 1f98dc0..6e54906 100644 --- a/src/modelhub/metrics/selected_distances.py +++ b/src/modelhub/metrics/selected_distances.py @@ -1,12 +1,12 @@ import numpy as np -from beartype.typing import Any -from biotite.structure import AtomArrayStack - from atomworks.ml.utils import nested_dict from atomworks.ml.utils.selection import ( get_mask_from_atom_selection, parse_selection_string, ) +from beartype.typing import Any +from biotite.structure import AtomArrayStack + from modelhub.metrics.base import Metric diff --git a/src/modelhub/resolvers.py b/src/modelhub/resolvers.py index 394dba7..0b65ae2 100644 --- a/src/modelhub/resolvers.py +++ b/src/modelhub/resolvers.py @@ -6,11 +6,10 @@ Documentation on custom resolvers: import importlib +from atomworks.enums import ChainType, ChainTypeInfo from beartype.typing import Any from omegaconf import OmegaConf -from atomworks.enums import ChainType, ChainTypeInfo - from .common import run_once diff --git a/src/modelhub/symmetry/resolve.py b/src/modelhub/symmetry/resolve.py index 2c34adb..07089d0 100644 --- a/src/modelhub/symmetry/resolve.py +++ b/src/modelhub/symmetry/resolve.py @@ -5,9 +5,6 @@ from typing import Any, Dict import numpy as np import torch -from biotite.structure import AtomArray, AtomArrayStack -from jaxtyping import Bool, Float, Int - from atomworks.ml.transforms.atom_array import ( AddGlobalTokenIdAnnotation, ensure_atom_array_stack, @@ -15,6 +12,9 @@ from atomworks.ml.transforms.atom_array import ( from atomworks.ml.transforms.atomize import AtomizeByCCDName from atomworks.ml.transforms.base import Compose, convert_to_torch from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX +from biotite.structure import AtomArray, AtomArrayStack +from jaxtyping import Bool, Float, Int + from modelhub.loss.af3_losses import ( ResidueSymmetryResolution, SubunitSymmetryResolution, diff --git a/src/modelhub/utils/datasets.py b/src/modelhub/utils/datasets.py index 5f3c916..3cf5e46 100755 --- a/src/modelhub/utils/datasets.py +++ b/src/modelhub/utils/datasets.py @@ -1,5 +1,12 @@ import hydra import torch +from atomworks.ml.samplers import ( + DistributedMixedSampler, + FallbackSamplerWrapper, + LazyWeightedRandomSampler, + LoadBalancedDistributedSampler, + MixedSampler, +) from beartype.typing import Any from omegaconf import DictConfig, ListConfig from torch.utils.data import ( @@ -13,13 +20,6 @@ from torch.utils.data import ( ) from torch.utils.data.distributed import DistributedSampler -from atomworks.ml.samplers import ( - DistributedMixedSampler, - FallbackSamplerWrapper, - LazyWeightedRandomSampler, - LoadBalancedDistributedSampler, - MixedSampler, -) from modelhub.resolvers import register_resolvers from modelhub.utils.ddp import RankedLogger diff --git a/src/modelhub/utils/io.py b/src/modelhub/utils/io.py index 434eb5c..d0e219b 100644 --- a/src/modelhub/utils/io.py +++ b/src/modelhub/utils/io.py @@ -4,10 +4,10 @@ from pathlib import Path import numpy as np import torch +from atomworks.io.utils.io_utils import to_cif_file from beartype.typing import Literal from biotite.structure import AtomArray, AtomArrayStack, stack -from atomworks.io.utils.io_utils import to_cif_file from modelhub.alignment import weighted_rigid_align from modelhub.utils.ddp import RankedLogger diff --git a/src/modelhub/utils/predicted_error.py b/src/modelhub/utils/predicted_error.py index c4c5f45..f244d21 100644 --- a/src/modelhub/utils/predicted_error.py +++ b/src/modelhub/utils/predicted_error.py @@ -13,6 +13,7 @@ from omegaconf import DictConfig from modelhub.chemical import NHEAVY from modelhub.metrics.metric_utils import ( compute_mean_over_subsampled_pairs, + compute_min_over_subsampled_pairs, create_chainwise_masks_1d, create_chainwise_masks_2d, create_interface_masks_2d, @@ -131,6 +132,15 @@ def compile_af3_confidence_outputs( for k, v in interface_masks.items() } + pae_interface_min = { + k: spread_batch_into_dictionary(compute_min_over_subsampled_pairs(pae, v)) + for k, v in interface_masks.items() + } + + pde_interface_min = { + k: spread_batch_into_dictionary(compute_min_over_subsampled_pairs(pde, v)) + for k, v in interface_masks.items() + } # Calculate chainwise metrics chain_masks_2d = create_chainwise_masks_2d(chain_iid_token_lvl, device=pae.device) pae_chainwise = { @@ -167,6 +177,8 @@ def compile_af3_confidence_outputs( "chain_wise_mean_pde": pde_chainwise, "interface_wise_mean_pae": pae_interface, "interface_wise_mean_pde": pde_interface, + "interface_wise_min_pae": pae_interface_min, + "interface_wise_min_pde": pde_interface_min, } # Generate DataFrame rows @@ -204,6 +216,12 @@ def compile_af3_confidence_outputs( "pde_interface": confidence_data["interface_wise_mean_pde"][ (chain_i, chain_j) ][batch_idx], + "min_pae_interface": confidence_data["interface_wise_min_pae"][ + (chain_i, chain_j) + ][batch_idx], + "min_pde_interface": confidence_data["interface_wise_min_pde"][ + (chain_i, chain_j) + ][batch_idx], "overall_plddt": confidence_data["mean_plddt"][batch_idx], "overall_pde": confidence_data["mean_pde"][batch_idx], "overall_pae": confidence_data["mean_pae"][batch_idx], diff --git a/src/modelhub/utils/recycling.py b/src/modelhub/utils/recycling.py index dbce108..74af4dc 100755 --- a/src/modelhub/utils/recycling.py +++ b/src/modelhub/utils/recycling.py @@ -1,7 +1,6 @@ import math import torch - from atomworks.ml.utils.rng import create_rng_state_from_seeds, rng_state diff --git a/tests/test_inference_pipelines.py b/tests/test_inference_pipelines.py index 6a69e5f..e79130b 100644 --- a/tests/test_inference_pipelines.py +++ b/tests/test_inference_pipelines.py @@ -5,9 +5,9 @@ from pathlib import Path import hydra import numpy as np import pytest +from atomworks.io import parse from hydra import compose, initialize -from atomworks.io import parse from modelhub.utils.inference import ( apply_conformer_and_template_selections, build_file_paths_for_prediction, diff --git a/tests/test_inference_regression.py b/tests/test_inference_regression.py index d073e83..b70ed0a 100755 --- a/tests/test_inference_regression.py +++ b/tests/test_inference_regression.py @@ -7,14 +7,13 @@ import numpy as np import pandas as pd import pytest import torch -from conftest import TEST_DATA_DIR -from hydra import compose, initialize -from hydra.utils import instantiate - from atomworks.ml.utils.rng import ( create_rng_state_from_seeds, rng_state, ) +from conftest import TEST_DATA_DIR +from hydra import compose, initialize +from hydra.utils import instantiate def compare_csv_files(