mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
Merge pull request #12 from RosettaCommons/fix/inference
fix: inference
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
|
||||
from atomworks.common import parse_example_id
|
||||
from beartype.typing import Any
|
||||
|
||||
from atomworks.common import parse_example_id
|
||||
from modelhub.callbacks.base import BaseCallback
|
||||
from modelhub.utils.io import (
|
||||
build_stack_from_atom_array_and_batched_coords,
|
||||
|
||||
@@ -3,10 +3,10 @@ from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from atomworks.ml.utils import nested_dict
|
||||
from beartype.typing import Any, Literal
|
||||
from omegaconf import ListConfig
|
||||
|
||||
from atomworks.ml.utils import nested_dict
|
||||
from modelhub.callbacks.base import BaseCallback
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from modelhub.utils.logging import (
|
||||
|
||||
@@ -2,6 +2,7 @@ import time
|
||||
from collections import defaultdict
|
||||
|
||||
import pandas as pd
|
||||
from atomworks.common import parse_example_id
|
||||
from beartype.typing import Any
|
||||
from lightning.fabric.wrappers import (
|
||||
_FabricOptimizer,
|
||||
@@ -12,7 +13,6 @@ from rich.table import Table
|
||||
from torch import nn
|
||||
from torchmetrics.aggregation import MeanMetric
|
||||
|
||||
from atomworks.common import parse_example_id
|
||||
from modelhub.callbacks.base import BaseCallback
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from modelhub.utils.logging import (
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
|
||||
from atomworks.ml.transforms._checks import (
|
||||
check_contains_keys,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import numpy as np
|
||||
from beartype.typing import Any
|
||||
from biotite.structure import AtomArray
|
||||
from jaxtyping import Bool, Float
|
||||
|
||||
from atomworks.enums import GroundTruthConformerPolicy
|
||||
from atomworks.ml.transforms._checks import (
|
||||
check_atom_array_annotation,
|
||||
check_contains_keys,
|
||||
)
|
||||
from atomworks.ml.transforms.base import Transform
|
||||
from beartype.typing import Any
|
||||
from biotite.structure import AtomArray
|
||||
from jaxtyping import Bool, Float
|
||||
|
||||
|
||||
def add_ground_truth_reference_conformer(
|
||||
|
||||
@@ -3,11 +3,6 @@ from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from beartype.typing import Any, Callable, Final, Sequence
|
||||
from biotite.structure import AtomArray
|
||||
from jaxtyping import Bool, Float, Shaped
|
||||
from torch import Tensor
|
||||
|
||||
from atomworks.enums import ChainType
|
||||
from atomworks.ml.transforms._checks import (
|
||||
check_atom_array_annotation,
|
||||
@@ -20,6 +15,11 @@ from atomworks.ml.utils.token import (
|
||||
get_af3_token_center_masks,
|
||||
get_token_starts,
|
||||
)
|
||||
from beartype.typing import Any, Callable, Final, Sequence
|
||||
from biotite.structure import AtomArray
|
||||
from jaxtyping import Bool, Float, Shaped
|
||||
from torch import Tensor
|
||||
|
||||
from modelhub.utils.torch_utils import assert_no_nans
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from atomworks.enums import ChainType
|
||||
from atomworks.ml.transforms._checks import check_atom_array_annotation
|
||||
from atomworks.ml.transforms.crop import compute_local_hash
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from modelhub.data.ground_truth_template import (
|
||||
FeaturizeNoisedGroundTruthAsTemplateDistogram,
|
||||
TokenGroupNoiseScaleSampler,
|
||||
|
||||
@@ -3,11 +3,11 @@ from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
|
||||
import hydra
|
||||
from atomworks.ml.utils import error, nested_dict
|
||||
from beartype.typing import Any
|
||||
from omegaconf import DictConfig
|
||||
from toolz import keymap
|
||||
|
||||
from atomworks.ml.utils import error, nested_dict
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import torch
|
||||
from beartype.typing import Any
|
||||
from biotite.structure import AtomArray, AtomArrayStack
|
||||
from jaxtyping import Bool, Float
|
||||
|
||||
from atomworks.ml.transforms.af3_reference_molecule import (
|
||||
get_af3_reference_molecule_features,
|
||||
)
|
||||
from atomworks.ml.transforms.atom_array import ensure_atom_array_stack
|
||||
from atomworks.ml.transforms.chirals import add_af3_chiral_features
|
||||
from atomworks.ml.transforms.rdkit_utils import get_rdkit_chiral_centers
|
||||
from beartype.typing import Any
|
||||
from biotite.structure import AtomArray, AtomArrayStack
|
||||
from jaxtyping import Bool, Float
|
||||
|
||||
from modelhub.kinematics import get_dih
|
||||
from modelhub.metrics.base import Metric
|
||||
|
||||
|
||||
@@ -4,12 +4,12 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from atomworks.ml.utils.token import get_af3_token_representative_idxs
|
||||
from beartype.typing import Any, Literal
|
||||
from biotite.structure import AtomArrayStack
|
||||
from einops import rearrange, repeat
|
||||
from jaxtyping import Bool, Float
|
||||
|
||||
from atomworks.ml.utils.token import get_af3_token_representative_idxs
|
||||
from modelhub.loss.af3_losses import distogram_loss
|
||||
from modelhub.metrics.base import Metric
|
||||
from modelhub.utils.torch_utils import assert_no_nans
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from beartype.typing import Any
|
||||
from biotite.structure import AtomArray, AtomArrayStack, stack
|
||||
from jaxtyping import Bool, Float, Int
|
||||
|
||||
from atomworks.ml.transforms.atom_array import (
|
||||
AddGlobalTokenIdAnnotation,
|
||||
ensure_atom_array_stack,
|
||||
@@ -11,6 +7,10 @@ from atomworks.ml.transforms.atom_array import (
|
||||
from atomworks.ml.transforms.atomize import AtomizeByCCDName
|
||||
from atomworks.ml.transforms.base import Compose
|
||||
from atomworks.ml.utils.token import get_token_starts
|
||||
from beartype.typing import Any
|
||||
from biotite.structure import AtomArray, AtomArrayStack, stack
|
||||
from jaxtyping import Bool, Float, Int
|
||||
|
||||
from modelhub.metrics.base import Metric
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
from itertools import combinations
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from jaxtyping import Bool, Float
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
def find_bin_midpoints(max_distance, num_bins, device="cpu"):
|
||||
def find_bin_midpoints(
|
||||
max_distance: float, num_bins: int, device: Union[str, torch.device] = "cpu"
|
||||
) -> Float[torch.Tensor, "num_bins"]:
|
||||
"""
|
||||
Find the bin midpoints for a given binning scheme. Used to find expectation of values when converting binned
|
||||
predictions to unbinned predictions. Assumes the minimum of the schema is 0.
|
||||
Args:
|
||||
max_distance: float, maximum distance
|
||||
num_bins: int, number of bins
|
||||
device: device to run on
|
||||
Returns:
|
||||
pae_midpoints: [num_bins], bin midpoints
|
||||
"""
|
||||
@@ -26,7 +32,9 @@ def find_bin_midpoints(max_distance, num_bins, device="cpu"):
|
||||
return midpoints
|
||||
|
||||
|
||||
def unbin_logits(logits, max_distance, num_bins):
|
||||
def unbin_logits(
|
||||
logits: Float[torch.Tensor, "B num_bins L X"], max_distance: float, num_bins: int
|
||||
) -> Float[torch.Tensor, "B L L"]:
|
||||
"""
|
||||
Unbin the logits to get the matrix
|
||||
Args:
|
||||
@@ -42,7 +50,9 @@ def unbin_logits(logits, max_distance, num_bins):
|
||||
return unbinned
|
||||
|
||||
|
||||
def create_chainwise_masks_1d(ch_label, device="cpu"):
|
||||
def create_chainwise_masks_1d(
|
||||
ch_label: NDArray[np.str_], device: Union[str, torch.device] = "cpu"
|
||||
) -> dict[str, Bool[torch.Tensor, "L"]]:
|
||||
"""
|
||||
Create 1D chainwise masks for a set of chain labels
|
||||
Args:
|
||||
@@ -61,7 +71,9 @@ def create_chainwise_masks_1d(ch_label, device="cpu"):
|
||||
return ch_masks
|
||||
|
||||
|
||||
def create_chainwise_masks_2d(ch_label, device="cpu"):
|
||||
def create_chainwise_masks_2d(
|
||||
ch_label: NDArray[np.str_], device: Union[str, torch.device] = "cpu"
|
||||
) -> dict[str, Bool[torch.Tensor, "L L"]]:
|
||||
"""
|
||||
Create 2D chainwise masks for a set of chain labels
|
||||
Args:
|
||||
@@ -79,9 +91,16 @@ def create_chainwise_masks_2d(ch_label, device="cpu"):
|
||||
return ch_masks
|
||||
|
||||
|
||||
def create_interface_masks_2d(ch_label, device="cpu"):
|
||||
def create_interface_masks_2d(
|
||||
ch_label: NDArray[np.str_], device: Union[str, torch.device] = "cpu"
|
||||
) -> dict[tuple[str, str], Bool[torch.Tensor, "L L"]]:
|
||||
"""
|
||||
Create interface masks for a set of chain labels
|
||||
Args:
|
||||
ch_label: np.ndarray [L], chain labels
|
||||
device: torch.device, device to run on
|
||||
Returns:
|
||||
pairs_to_score: dict mapping chain pairs to boolean masks
|
||||
"""
|
||||
unique_chains = np.unique(ch_label)
|
||||
pairs_to_score = {}
|
||||
@@ -97,12 +116,17 @@ def create_interface_masks_2d(ch_label, device="cpu"):
|
||||
return pairs_to_score
|
||||
|
||||
|
||||
def compute_mean_over_subsampled_pairs(matrix_to_mean, pairs_to_score, eps=1e-6):
|
||||
def compute_mean_over_subsampled_pairs(
|
||||
matrix_to_mean: Float[torch.Tensor, "B L M"],
|
||||
pairs_to_score: Bool[torch.Tensor, "L M"],
|
||||
eps: float = 1e-6,
|
||||
) -> Float[torch.Tensor, "B"]:
|
||||
"""
|
||||
Compute the mean over a subsample of pairs in a 2d matrix. Returns a tensor with an element for each batch
|
||||
Args:
|
||||
matrix_to_mean: tensor of shape (batch, L, L)
|
||||
pairs_to_score: 2d tensor of shape (L, L) with 1s where pairs should be scored and 0s elsewhere
|
||||
eps: small epsilon value to avoid division by zero
|
||||
Returns:
|
||||
1d tensor of shape (batch,) with the mean over the subsampled pairs for each batch
|
||||
"""
|
||||
@@ -120,9 +144,49 @@ def compute_mean_over_subsampled_pairs(matrix_to_mean, pairs_to_score, eps=1e-6)
|
||||
return batch
|
||||
|
||||
|
||||
def spread_batch_into_dictionary(batch):
|
||||
def compute_min_over_subsampled_pairs(
|
||||
matrix_to_min: Float[torch.Tensor, "B L M"],
|
||||
pairs_to_score: Bool[torch.Tensor, "L M"],
|
||||
) -> Float[torch.Tensor, "B"]:
|
||||
"""
|
||||
Compute the min over a subsample of pairs in a 2d matrix. Returns a tensor with an element for each batch
|
||||
Args:
|
||||
matrix_to_min: tensor of shape (batch, L, L)
|
||||
pairs_to_score: 2d tensor of shape (L, L) with 1s where pairs should be scored and 0s elsewhere
|
||||
Returns:
|
||||
1d tensor of shape (batch,) with the min over the subsampled pairs for each batch
|
||||
"""
|
||||
B, L, M = matrix_to_min.shape
|
||||
assert matrix_to_min.shape == (
|
||||
B,
|
||||
L,
|
||||
M,
|
||||
), "Matrix to min should be of shape (batch, L, M)"
|
||||
assert pairs_to_score.shape == (L, M), "Pairs to score should be of shape (L, M)"
|
||||
# Use torch.where to efficiently mask without cloning the entire matrix
|
||||
# This broadcasts pairs_to_score across the batch dimension
|
||||
masked_matrix = torch.where(
|
||||
pairs_to_score.bool(), # condition (L, M) -> broadcasts to (B, L, M)
|
||||
matrix_to_min, # if True: use original values (B, L, M)
|
||||
torch.tensor(
|
||||
float("inf"), device=matrix_to_min.device, dtype=matrix_to_min.dtype
|
||||
), # if False: use inf
|
||||
)
|
||||
|
||||
# Flatten the last two dimensions and compute min across them
|
||||
batch = masked_matrix.view(B, -1).min(dim=-1)[0]
|
||||
|
||||
assert batch.shape == (B,), "Batch should be of shape (batch,)"
|
||||
return batch
|
||||
|
||||
|
||||
def spread_batch_into_dictionary(batch: Float[torch.Tensor, "B"]) -> dict[int, float]:
|
||||
"""
|
||||
Given a batch of data, create a dictionary with keys as the batch index and value as the corresponding data
|
||||
Args:
|
||||
batch: 1D tensor of shape (B,)
|
||||
Returns:
|
||||
Dictionary mapping batch indices to float values
|
||||
"""
|
||||
assert len(batch.shape) == 1, f"Batch should be a 1d tensor, {batch}"
|
||||
return {i: data.item() for i, data in enumerate(batch)}
|
||||
|
||||
@@ -107,8 +107,9 @@ class ComputeIPTM(Metric):
|
||||
protein_mask[None, :] & protein_mask[:, None] * to_calculate
|
||||
)
|
||||
protein_ligand_mask = (
|
||||
protein_mask[None, :] & ligand_mask[:, None] * to_calculate
|
||||
)
|
||||
(protein_mask[None, :] & ligand_mask[:, None])
|
||||
| (ligand_mask[None, :] & protein_mask[:, None])
|
||||
) * to_calculate
|
||||
ligand_ligand_mask = ligand_mask[None, :] & ligand_mask[:, None] * to_calculate
|
||||
# calculate iptm for each interface type
|
||||
iptm_protein_protein = compute_ptm(pae, protein_protein_mask)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import numpy as np
|
||||
from atomworks.ml.transforms.sasa import calculate_atomwise_rasa
|
||||
from beartype.typing import Any
|
||||
from biotite.structure import AtomArrayStack
|
||||
|
||||
from atomworks.ml.transforms.sasa import calculate_atomwise_rasa
|
||||
from modelhub.metrics.base import Metric
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import numpy as np
|
||||
from beartype.typing import Any
|
||||
from biotite.structure import AtomArrayStack
|
||||
|
||||
from atomworks.ml.utils import nested_dict
|
||||
from atomworks.ml.utils.selection import (
|
||||
get_mask_from_atom_selection,
|
||||
parse_selection_string,
|
||||
)
|
||||
from beartype.typing import Any
|
||||
from biotite.structure import AtomArrayStack
|
||||
|
||||
from modelhub.metrics.base import Metric
|
||||
|
||||
|
||||
|
||||
@@ -6,11 +6,10 @@ Documentation on custom resolvers:
|
||||
|
||||
import importlib
|
||||
|
||||
from atomworks.enums import ChainType, ChainTypeInfo
|
||||
from beartype.typing import Any
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from atomworks.enums import ChainType, ChainTypeInfo
|
||||
|
||||
from .common import run_once
|
||||
|
||||
|
||||
|
||||
@@ -5,9 +5,6 @@ from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from biotite.structure import AtomArray, AtomArrayStack
|
||||
from jaxtyping import Bool, Float, Int
|
||||
|
||||
from atomworks.ml.transforms.atom_array import (
|
||||
AddGlobalTokenIdAnnotation,
|
||||
ensure_atom_array_stack,
|
||||
@@ -15,6 +12,9 @@ from atomworks.ml.transforms.atom_array import (
|
||||
from atomworks.ml.transforms.atomize import AtomizeByCCDName
|
||||
from atomworks.ml.transforms.base import Compose, convert_to_torch
|
||||
from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX
|
||||
from biotite.structure import AtomArray, AtomArrayStack
|
||||
from jaxtyping import Bool, Float, Int
|
||||
|
||||
from modelhub.loss.af3_losses import (
|
||||
ResidueSymmetryResolution,
|
||||
SubunitSymmetryResolution,
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
import hydra
|
||||
import torch
|
||||
from atomworks.ml.samplers import (
|
||||
DistributedMixedSampler,
|
||||
FallbackSamplerWrapper,
|
||||
LazyWeightedRandomSampler,
|
||||
LoadBalancedDistributedSampler,
|
||||
MixedSampler,
|
||||
)
|
||||
from beartype.typing import Any
|
||||
from omegaconf import DictConfig, ListConfig
|
||||
from torch.utils.data import (
|
||||
@@ -13,13 +20,6 @@ from torch.utils.data import (
|
||||
)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from atomworks.ml.samplers import (
|
||||
DistributedMixedSampler,
|
||||
FallbackSamplerWrapper,
|
||||
LazyWeightedRandomSampler,
|
||||
LoadBalancedDistributedSampler,
|
||||
MixedSampler,
|
||||
)
|
||||
from modelhub.resolvers import register_resolvers
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from atomworks.io.utils.io_utils import to_cif_file
|
||||
from beartype.typing import Literal
|
||||
from biotite.structure import AtomArray, AtomArrayStack, stack
|
||||
|
||||
from atomworks.io.utils.io_utils import to_cif_file
|
||||
from modelhub.alignment import weighted_rigid_align
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from omegaconf import DictConfig
|
||||
from modelhub.chemical import NHEAVY
|
||||
from modelhub.metrics.metric_utils import (
|
||||
compute_mean_over_subsampled_pairs,
|
||||
compute_min_over_subsampled_pairs,
|
||||
create_chainwise_masks_1d,
|
||||
create_chainwise_masks_2d,
|
||||
create_interface_masks_2d,
|
||||
@@ -131,6 +132,15 @@ def compile_af3_confidence_outputs(
|
||||
for k, v in interface_masks.items()
|
||||
}
|
||||
|
||||
pae_interface_min = {
|
||||
k: spread_batch_into_dictionary(compute_min_over_subsampled_pairs(pae, v))
|
||||
for k, v in interface_masks.items()
|
||||
}
|
||||
|
||||
pde_interface_min = {
|
||||
k: spread_batch_into_dictionary(compute_min_over_subsampled_pairs(pde, v))
|
||||
for k, v in interface_masks.items()
|
||||
}
|
||||
# Calculate chainwise metrics
|
||||
chain_masks_2d = create_chainwise_masks_2d(chain_iid_token_lvl, device=pae.device)
|
||||
pae_chainwise = {
|
||||
@@ -167,6 +177,8 @@ def compile_af3_confidence_outputs(
|
||||
"chain_wise_mean_pde": pde_chainwise,
|
||||
"interface_wise_mean_pae": pae_interface,
|
||||
"interface_wise_mean_pde": pde_interface,
|
||||
"interface_wise_min_pae": pae_interface_min,
|
||||
"interface_wise_min_pde": pde_interface_min,
|
||||
}
|
||||
|
||||
# Generate DataFrame rows
|
||||
@@ -204,6 +216,12 @@ def compile_af3_confidence_outputs(
|
||||
"pde_interface": confidence_data["interface_wise_mean_pde"][
|
||||
(chain_i, chain_j)
|
||||
][batch_idx],
|
||||
"min_pae_interface": confidence_data["interface_wise_min_pae"][
|
||||
(chain_i, chain_j)
|
||||
][batch_idx],
|
||||
"min_pde_interface": confidence_data["interface_wise_min_pde"][
|
||||
(chain_i, chain_j)
|
||||
][batch_idx],
|
||||
"overall_plddt": confidence_data["mean_plddt"][batch_idx],
|
||||
"overall_pde": confidence_data["mean_pde"][batch_idx],
|
||||
"overall_pae": confidence_data["mean_pae"][batch_idx],
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from atomworks.ml.utils.rng import create_rng_state_from_seeds, rng_state
|
||||
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ from pathlib import Path
|
||||
import hydra
|
||||
import numpy as np
|
||||
import pytest
|
||||
from atomworks.io import parse
|
||||
from hydra import compose, initialize
|
||||
|
||||
from atomworks.io import parse
|
||||
from modelhub.utils.inference import (
|
||||
apply_conformer_and_template_selections,
|
||||
build_file_paths_for_prediction,
|
||||
|
||||
@@ -7,14 +7,13 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import torch
|
||||
from conftest import TEST_DATA_DIR
|
||||
from hydra import compose, initialize
|
||||
from hydra.utils import instantiate
|
||||
|
||||
from atomworks.ml.utils.rng import (
|
||||
create_rng_state_from_seeds,
|
||||
rng_state,
|
||||
)
|
||||
from conftest import TEST_DATA_DIR
|
||||
from hydra import compose, initialize
|
||||
from hydra.utils import instantiate
|
||||
|
||||
|
||||
def compare_csv_files(
|
||||
|
||||
Reference in New Issue
Block a user