Merge pull request #12 from RosettaCommons/fix/inference

fix: inference
This commit is contained in:
Nathaniel Corley
2025-09-19 16:19:47 -07:00
committed by GitHub
23 changed files with 135 additions and 57 deletions

View File

@@ -1,9 +1,9 @@
from os import PathLike
from pathlib import Path
from atomworks.common import parse_example_id
from beartype.typing import Any
from atomworks.common import parse_example_id
from modelhub.callbacks.base import BaseCallback
from modelhub.utils.io import (
build_stack_from_atom_array_and_batched_coords,

View File

@@ -3,10 +3,10 @@ from copy import deepcopy
from pathlib import Path
import pandas as pd
from atomworks.ml.utils import nested_dict
from beartype.typing import Any, Literal
from omegaconf import ListConfig
from atomworks.ml.utils import nested_dict
from modelhub.callbacks.base import BaseCallback
from modelhub.utils.ddp import RankedLogger
from modelhub.utils.logging import (

View File

@@ -2,6 +2,7 @@ import time
from collections import defaultdict
import pandas as pd
from atomworks.common import parse_example_id
from beartype.typing import Any
from lightning.fabric.wrappers import (
_FabricOptimizer,
@@ -12,7 +13,6 @@ from rich.table import Table
from torch import nn
from torchmetrics.aggregation import MeanMetric
from atomworks.common import parse_example_id
from modelhub.callbacks.base import BaseCallback
from modelhub.utils.ddp import RankedLogger
from modelhub.utils.logging import (

View File

@@ -1,5 +1,4 @@
import torch
from atomworks.ml.transforms._checks import (
check_contains_keys,
)

View File

@@ -1,14 +1,13 @@
import numpy as np
from beartype.typing import Any
from biotite.structure import AtomArray
from jaxtyping import Bool, Float
from atomworks.enums import GroundTruthConformerPolicy
from atomworks.ml.transforms._checks import (
check_atom_array_annotation,
check_contains_keys,
)
from atomworks.ml.transforms.base import Transform
from beartype.typing import Any
from biotite.structure import AtomArray
from jaxtyping import Bool, Float
def add_ground_truth_reference_conformer(

View File

@@ -3,11 +3,6 @@ from dataclasses import dataclass
import numpy as np
import torch
from beartype.typing import Any, Callable, Final, Sequence
from biotite.structure import AtomArray
from jaxtyping import Bool, Float, Shaped
from torch import Tensor
from atomworks.enums import ChainType
from atomworks.ml.transforms._checks import (
check_atom_array_annotation,
@@ -20,6 +15,11 @@ from atomworks.ml.utils.token import (
get_af3_token_center_masks,
get_token_starts,
)
from beartype.typing import Any, Callable, Final, Sequence
from biotite.structure import AtomArray
from jaxtyping import Bool, Float, Shaped
from torch import Tensor
from modelhub.utils.torch_utils import assert_no_nans
logger = logging.getLogger(__name__)

View File

@@ -1,11 +1,11 @@
from functools import partial
import torch
from omegaconf import DictConfig
from atomworks.enums import ChainType
from atomworks.ml.transforms._checks import check_atom_array_annotation
from atomworks.ml.transforms.crop import compute_local_hash
from omegaconf import DictConfig
from modelhub.data.ground_truth_template import (
FeaturizeNoisedGroundTruthAsTemplateDistogram,
TokenGroupNoiseScaleSampler,

View File

@@ -3,11 +3,11 @@ from abc import ABC, abstractmethod
from functools import cached_property
import hydra
from atomworks.ml.utils import error, nested_dict
from beartype.typing import Any
from omegaconf import DictConfig
from toolz import keymap
from atomworks.ml.utils import error, nested_dict
from modelhub.utils.ddp import RankedLogger
ranked_logger = RankedLogger(__name__, rank_zero_only=True)

View File

@@ -1,14 +1,14 @@
import torch
from beartype.typing import Any
from biotite.structure import AtomArray, AtomArrayStack
from jaxtyping import Bool, Float
from atomworks.ml.transforms.af3_reference_molecule import (
get_af3_reference_molecule_features,
)
from atomworks.ml.transforms.atom_array import ensure_atom_array_stack
from atomworks.ml.transforms.chirals import add_af3_chiral_features
from atomworks.ml.transforms.rdkit_utils import get_rdkit_chiral_centers
from beartype.typing import Any
from biotite.structure import AtomArray, AtomArrayStack
from jaxtyping import Bool, Float
from modelhub.kinematics import get_dih
from modelhub.metrics.base import Metric

View File

@@ -4,12 +4,12 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from atomworks.ml.utils.token import get_af3_token_representative_idxs
from beartype.typing import Any, Literal
from biotite.structure import AtomArrayStack
from einops import rearrange, repeat
from jaxtyping import Bool, Float
from atomworks.ml.utils.token import get_af3_token_representative_idxs
from modelhub.loss.af3_losses import distogram_loss
from modelhub.metrics.base import Metric
from modelhub.utils.torch_utils import assert_no_nans

View File

@@ -1,9 +1,5 @@
import numpy as np
import torch
from beartype.typing import Any
from biotite.structure import AtomArray, AtomArrayStack, stack
from jaxtyping import Bool, Float, Int
from atomworks.ml.transforms.atom_array import (
AddGlobalTokenIdAnnotation,
ensure_atom_array_stack,
@@ -11,6 +7,10 @@ from atomworks.ml.transforms.atom_array import (
from atomworks.ml.transforms.atomize import AtomizeByCCDName
from atomworks.ml.transforms.base import Compose
from atomworks.ml.utils.token import get_token_starts
from beartype.typing import Any
from biotite.structure import AtomArray, AtomArrayStack, stack
from jaxtyping import Bool, Float, Int
from modelhub.metrics.base import Metric
from modelhub.utils.ddp import RankedLogger

View File

@@ -1,16 +1,22 @@
from itertools import combinations
from typing import Union
import numpy as np
import torch
from jaxtyping import Bool, Float
from numpy.typing import NDArray
def find_bin_midpoints(max_distance, num_bins, device="cpu"):
def find_bin_midpoints(
max_distance: float, num_bins: int, device: Union[str, torch.device] = "cpu"
) -> Float[torch.Tensor, "num_bins"]:
"""
Find the bin midpoints for a given binning scheme. Used to find expectation of values when converting binned
predictions to unbinned predictions. Assumes the minimum of the schema is 0.
Args:
max_distance: float, maximum distance
num_bins: int, number of bins
device: device to run on
Returns:
pae_midpoints: [num_bins], bin midpoints
"""
@@ -26,7 +32,9 @@ def find_bin_midpoints(max_distance, num_bins, device="cpu"):
return midpoints
def unbin_logits(logits, max_distance, num_bins):
def unbin_logits(
logits: Float[torch.Tensor, "B num_bins L X"], max_distance: float, num_bins: int
) -> Float[torch.Tensor, "B L L"]:
"""
Unbin the logits to get the matrix
Args:
@@ -42,7 +50,9 @@ def unbin_logits(logits, max_distance, num_bins):
return unbinned
def create_chainwise_masks_1d(ch_label, device="cpu"):
def create_chainwise_masks_1d(
ch_label: NDArray[np.str_], device: Union[str, torch.device] = "cpu"
) -> dict[str, Bool[torch.Tensor, "L"]]:
"""
Create 1D chainwise masks for a set of chain labels
Args:
@@ -61,7 +71,9 @@ def create_chainwise_masks_1d(ch_label, device="cpu"):
return ch_masks
def create_chainwise_masks_2d(ch_label, device="cpu"):
def create_chainwise_masks_2d(
ch_label: NDArray[np.str_], device: Union[str, torch.device] = "cpu"
) -> dict[str, Bool[torch.Tensor, "L L"]]:
"""
Create 2D chainwise masks for a set of chain labels
Args:
@@ -79,9 +91,16 @@ def create_chainwise_masks_2d(ch_label, device="cpu"):
return ch_masks
def create_interface_masks_2d(ch_label, device="cpu"):
def create_interface_masks_2d(
ch_label: NDArray[np.str_], device: Union[str, torch.device] = "cpu"
) -> dict[tuple[str, str], Bool[torch.Tensor, "L L"]]:
"""
Create interface masks for a set of chain labels
Args:
ch_label: np.ndarray [L], chain labels
device: torch.device, device to run on
Returns:
pairs_to_score: dict mapping chain pairs to boolean masks
"""
unique_chains = np.unique(ch_label)
pairs_to_score = {}
@@ -97,12 +116,17 @@ def create_interface_masks_2d(ch_label, device="cpu"):
return pairs_to_score
def compute_mean_over_subsampled_pairs(matrix_to_mean, pairs_to_score, eps=1e-6):
def compute_mean_over_subsampled_pairs(
matrix_to_mean: Float[torch.Tensor, "B L M"],
pairs_to_score: Bool[torch.Tensor, "L M"],
eps: float = 1e-6,
) -> Float[torch.Tensor, "B"]:
"""
Compute the mean over a subsample of pairs in a 2d matrix. Returns a tensor with an element for each batch
Args:
matrix_to_mean: tensor of shape (batch, L, L)
pairs_to_score: 2d tensor of shape (L, L) with 1s where pairs should be scored and 0s elsewhere
eps: small epsilon value to avoid division by zero
Returns:
1d tensor of shape (batch,) with the mean over the subsampled pairs for each batch
"""
@@ -120,9 +144,49 @@ def compute_mean_over_subsampled_pairs(matrix_to_mean, pairs_to_score, eps=1e-6)
return batch
def spread_batch_into_dictionary(batch):
def compute_min_over_subsampled_pairs(
matrix_to_min: Float[torch.Tensor, "B L M"],
pairs_to_score: Bool[torch.Tensor, "L M"],
) -> Float[torch.Tensor, "B"]:
"""
Compute the min over a subsample of pairs in a 2d matrix. Returns a tensor with an element for each batch
Args:
matrix_to_min: tensor of shape (batch, L, L)
pairs_to_score: 2d tensor of shape (L, L) with 1s where pairs should be scored and 0s elsewhere
Returns:
1d tensor of shape (batch,) with the min over the subsampled pairs for each batch
"""
B, L, M = matrix_to_min.shape
assert matrix_to_min.shape == (
B,
L,
M,
), "Matrix to min should be of shape (batch, L, M)"
assert pairs_to_score.shape == (L, M), "Pairs to score should be of shape (L, M)"
# Use torch.where to efficiently mask without cloning the entire matrix
# This broadcasts pairs_to_score across the batch dimension
masked_matrix = torch.where(
pairs_to_score.bool(), # condition (L, M) -> broadcasts to (B, L, M)
matrix_to_min, # if True: use original values (B, L, M)
torch.tensor(
float("inf"), device=matrix_to_min.device, dtype=matrix_to_min.dtype
), # if False: use inf
)
# Flatten the last two dimensions and compute min across them
batch = masked_matrix.view(B, -1).min(dim=-1)[0]
assert batch.shape == (B,), "Batch should be of shape (batch,)"
return batch
def spread_batch_into_dictionary(batch: Float[torch.Tensor, "B"]) -> dict[int, float]:
"""
Given a batch of data, create a dictionary with keys as the batch index and value as the corresponding data
Args:
batch: 1D tensor of shape (B,)
Returns:
Dictionary mapping batch indices to float values
"""
assert len(batch.shape) == 1, f"Batch should be a 1d tensor, {batch}"
return {i: data.item() for i, data in enumerate(batch)}

View File

@@ -107,8 +107,9 @@ class ComputeIPTM(Metric):
protein_mask[None, :] & protein_mask[:, None] * to_calculate
)
protein_ligand_mask = (
protein_mask[None, :] & ligand_mask[:, None] * to_calculate
)
(protein_mask[None, :] & ligand_mask[:, None])
| (ligand_mask[None, :] & protein_mask[:, None])
) * to_calculate
ligand_ligand_mask = ligand_mask[None, :] & ligand_mask[:, None] * to_calculate
# calculate iptm for each interface type
iptm_protein_protein = compute_ptm(pae, protein_protein_mask)

View File

@@ -1,8 +1,8 @@
import numpy as np
from atomworks.ml.transforms.sasa import calculate_atomwise_rasa
from beartype.typing import Any
from biotite.structure import AtomArrayStack
from atomworks.ml.transforms.sasa import calculate_atomwise_rasa
from modelhub.metrics.base import Metric

View File

@@ -1,12 +1,12 @@
import numpy as np
from beartype.typing import Any
from biotite.structure import AtomArrayStack
from atomworks.ml.utils import nested_dict
from atomworks.ml.utils.selection import (
get_mask_from_atom_selection,
parse_selection_string,
)
from beartype.typing import Any
from biotite.structure import AtomArrayStack
from modelhub.metrics.base import Metric

View File

@@ -6,11 +6,10 @@ Documentation on custom resolvers:
import importlib
from atomworks.enums import ChainType, ChainTypeInfo
from beartype.typing import Any
from omegaconf import OmegaConf
from atomworks.enums import ChainType, ChainTypeInfo
from .common import run_once

View File

@@ -5,9 +5,6 @@ from typing import Any, Dict
import numpy as np
import torch
from biotite.structure import AtomArray, AtomArrayStack
from jaxtyping import Bool, Float, Int
from atomworks.ml.transforms.atom_array import (
AddGlobalTokenIdAnnotation,
ensure_atom_array_stack,
@@ -15,6 +12,9 @@ from atomworks.ml.transforms.atom_array import (
from atomworks.ml.transforms.atomize import AtomizeByCCDName
from atomworks.ml.transforms.base import Compose, convert_to_torch
from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX
from biotite.structure import AtomArray, AtomArrayStack
from jaxtyping import Bool, Float, Int
from modelhub.loss.af3_losses import (
ResidueSymmetryResolution,
SubunitSymmetryResolution,

View File

@@ -1,5 +1,12 @@
import hydra
import torch
from atomworks.ml.samplers import (
DistributedMixedSampler,
FallbackSamplerWrapper,
LazyWeightedRandomSampler,
LoadBalancedDistributedSampler,
MixedSampler,
)
from beartype.typing import Any
from omegaconf import DictConfig, ListConfig
from torch.utils.data import (
@@ -13,13 +20,6 @@ from torch.utils.data import (
)
from torch.utils.data.distributed import DistributedSampler
from atomworks.ml.samplers import (
DistributedMixedSampler,
FallbackSamplerWrapper,
LazyWeightedRandomSampler,
LoadBalancedDistributedSampler,
MixedSampler,
)
from modelhub.resolvers import register_resolvers
from modelhub.utils.ddp import RankedLogger

View File

@@ -4,10 +4,10 @@ from pathlib import Path
import numpy as np
import torch
from atomworks.io.utils.io_utils import to_cif_file
from beartype.typing import Literal
from biotite.structure import AtomArray, AtomArrayStack, stack
from atomworks.io.utils.io_utils import to_cif_file
from modelhub.alignment import weighted_rigid_align
from modelhub.utils.ddp import RankedLogger

View File

@@ -13,6 +13,7 @@ from omegaconf import DictConfig
from modelhub.chemical import NHEAVY
from modelhub.metrics.metric_utils import (
compute_mean_over_subsampled_pairs,
compute_min_over_subsampled_pairs,
create_chainwise_masks_1d,
create_chainwise_masks_2d,
create_interface_masks_2d,
@@ -131,6 +132,15 @@ def compile_af3_confidence_outputs(
for k, v in interface_masks.items()
}
pae_interface_min = {
k: spread_batch_into_dictionary(compute_min_over_subsampled_pairs(pae, v))
for k, v in interface_masks.items()
}
pde_interface_min = {
k: spread_batch_into_dictionary(compute_min_over_subsampled_pairs(pde, v))
for k, v in interface_masks.items()
}
# Calculate chainwise metrics
chain_masks_2d = create_chainwise_masks_2d(chain_iid_token_lvl, device=pae.device)
pae_chainwise = {
@@ -167,6 +177,8 @@ def compile_af3_confidence_outputs(
"chain_wise_mean_pde": pde_chainwise,
"interface_wise_mean_pae": pae_interface,
"interface_wise_mean_pde": pde_interface,
"interface_wise_min_pae": pae_interface_min,
"interface_wise_min_pde": pde_interface_min,
}
# Generate DataFrame rows
@@ -204,6 +216,12 @@ def compile_af3_confidence_outputs(
"pde_interface": confidence_data["interface_wise_mean_pde"][
(chain_i, chain_j)
][batch_idx],
"min_pae_interface": confidence_data["interface_wise_min_pae"][
(chain_i, chain_j)
][batch_idx],
"min_pde_interface": confidence_data["interface_wise_min_pde"][
(chain_i, chain_j)
][batch_idx],
"overall_plddt": confidence_data["mean_plddt"][batch_idx],
"overall_pde": confidence_data["mean_pde"][batch_idx],
"overall_pae": confidence_data["mean_pae"][batch_idx],

View File

@@ -1,7 +1,6 @@
import math
import torch
from atomworks.ml.utils.rng import create_rng_state_from_seeds, rng_state

View File

@@ -5,9 +5,9 @@ from pathlib import Path
import hydra
import numpy as np
import pytest
from atomworks.io import parse
from hydra import compose, initialize
from atomworks.io import parse
from modelhub.utils.inference import (
apply_conformer_and_template_selections,
build_file_paths_for_prediction,

View File

@@ -7,14 +7,13 @@ import numpy as np
import pandas as pd
import pytest
import torch
from conftest import TEST_DATA_DIR
from hydra import compose, initialize
from hydra.utils import instantiate
from atomworks.ml.utils.rng import (
create_rng_state_from_seeds,
rng_state,
)
from conftest import TEST_DATA_DIR
from hydra import compose, initialize
from hydra.utils import instantiate
def compare_csv_files(