mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
515 lines
19 KiB
Python
515 lines
19 KiB
Python
import itertools
|
|
from typing import List
|
|
|
|
import einops
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
import tree
|
|
from beartype.typing import Any
|
|
from biotite.structure import AtomArray, AtomArrayStack
|
|
from omegaconf import DictConfig
|
|
|
|
from modelhub.chemical import NHEAVY
|
|
from modelhub.metrics.metric_utils import (
|
|
compute_mean_over_subsampled_pairs,
|
|
compute_min_over_subsampled_pairs,
|
|
create_chainwise_masks_1d,
|
|
create_chainwise_masks_2d,
|
|
create_interface_masks_2d,
|
|
spread_batch_into_dictionary,
|
|
unbin_logits,
|
|
)
|
|
|
|
|
|
def get_mean_atomwise_plddt(
|
|
plddt_logits: torch.Tensor,
|
|
is_real_atom: torch.Tensor,
|
|
max_value: float,
|
|
) -> torch.Tensor:
|
|
"""Aggregate plddts.
|
|
|
|
Args:
|
|
plddt_logits: Tensor of shape [B, n_token, max_atoms_in_a_token * n_bin] with logits
|
|
is_real_atom: Boolean mask of shape [B, n_token, max_atoms_in_a_token] indicating which atoms are real (i.e., not padding)
|
|
max_value: Maximum value for pLDDT (assigned to the last bin)
|
|
|
|
Returns:
|
|
plddt: Tensor of shape [B,] with the mean atom-wise pLDDT for each batch
|
|
"""
|
|
assert (
|
|
plddt_logits.ndim == 3
|
|
), "plddt_logits must be a 3D tensor (B, n_token, max_atoms_in_a_token * n_bins)"
|
|
|
|
# TODO: Replace with the last dimension of is_real_atom; right now that number is too large (36) because it includes hydrogens
|
|
max_atoms_in_a_token = NHEAVY
|
|
|
|
# Since the pLDDT logits have the last dimension (max_atoms_in_a_token * n_bins), we can calculate n_bins directly
|
|
assert (
|
|
plddt_logits.shape[-1] % max_atoms_in_a_token == 0
|
|
), "The last dimension of plddt_logits must be divisible by max_atoms_in_a_token!"
|
|
n_bins = plddt_logits.shape[-1] // max_atoms_in_a_token
|
|
|
|
# ... reshape to match what unbin_logits expects
|
|
reshaped_plddt_logits = einops.rearrange(
|
|
plddt_logits,
|
|
"... n_token (max_atoms_in_a_token n_bins) -> ... n_bins n_token max_atoms_in_a_token",
|
|
max_atoms_in_a_token=max_atoms_in_a_token,
|
|
n_bins=n_bins,
|
|
).float() # [..., n_token, n_bins * max_atoms_in_a_token] -> [ ..., n_bins, n_token, max_atoms_in_a_token]
|
|
|
|
plddt = unbin_logits(
|
|
reshaped_plddt_logits,
|
|
max_value,
|
|
n_bins,
|
|
)
|
|
|
|
is_real_atom = is_real_atom.to(device=plddt.device)
|
|
|
|
# ... create mask indicating which atoms are "real" (i.e., not padding) and calculate the mean
|
|
mask = is_real_atom[:, :max_atoms_in_a_token].unsqueeze(0)
|
|
atomwise_plddt_mean = (plddt * mask).sum(dim=(1, 2)) / mask.sum(dim=(1, 2))
|
|
|
|
return atomwise_plddt_mean
|
|
|
|
|
|
def compile_af3_confidence_outputs(
|
|
plddt_logits: torch.Tensor,
|
|
pae_logits: torch.Tensor,
|
|
pde_logits: torch.Tensor,
|
|
chain_iid_token_lvl: torch.Tensor,
|
|
is_real_atom: torch.Tensor,
|
|
example_id: str,
|
|
confidence_loss_cfg: DictConfig | dict,
|
|
) -> dict[str, Any]:
|
|
# TODO: Refactor to accept an AtomArray
|
|
# TODO: Taking the confidence_loss_cfg does not align with functional programming best-practices; we should instead take the max_value and n_bins as arguments
|
|
|
|
"""Given the confidence logits, computes the confidence metrics for the model's predictions.
|
|
|
|
Returns:
|
|
dict[str, Any]: A dictionary containing the following:
|
|
- confidence_df: A DataFrame containing the aggregate confidence metrics at the chain- and interface-level
|
|
- plddt: The pLDDT logits
|
|
- pae: The pAE logits
|
|
- pde: The pDE logits
|
|
"""
|
|
|
|
# Reorder the input tensors to be in (B, n_bins, ...) format for unbinning
|
|
plddt = unbin_logits(
|
|
plddt_logits.reshape(
|
|
-1,
|
|
plddt_logits.shape[1],
|
|
NHEAVY,
|
|
confidence_loss_cfg.plddt.n_bins,
|
|
)
|
|
.permute(0, 3, 1, 2)
|
|
.float(),
|
|
confidence_loss_cfg.plddt.max_value,
|
|
confidence_loss_cfg.plddt.n_bins,
|
|
)
|
|
|
|
# Unbin the pae and pde logits
|
|
pae = unbin_logits(
|
|
pae_logits.permute(0, 3, 1, 2).float(),
|
|
confidence_loss_cfg.pae.max_value,
|
|
confidence_loss_cfg.pae.n_bins,
|
|
)
|
|
pde = unbin_logits(
|
|
pde_logits.permute(0, 3, 1, 2).float(),
|
|
confidence_loss_cfg.pde.max_value,
|
|
confidence_loss_cfg.pde.n_bins,
|
|
)
|
|
|
|
# Calculate interface metrics
|
|
interface_masks = create_interface_masks_2d(chain_iid_token_lvl, device=pae.device)
|
|
pae_interface = {
|
|
k: spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(pae, v))
|
|
for k, v in interface_masks.items()
|
|
}
|
|
pde_interface = {
|
|
k: spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(pde, v))
|
|
for k, v in interface_masks.items()
|
|
}
|
|
|
|
pae_interface_min = {
|
|
k: spread_batch_into_dictionary(compute_min_over_subsampled_pairs(pae, v))
|
|
for k, v in interface_masks.items()
|
|
}
|
|
|
|
pde_interface_min = {
|
|
k: spread_batch_into_dictionary(compute_min_over_subsampled_pairs(pde, v))
|
|
for k, v in interface_masks.items()
|
|
}
|
|
# Calculate chainwise metrics
|
|
chain_masks_2d = create_chainwise_masks_2d(chain_iid_token_lvl, device=pae.device)
|
|
pae_chainwise = {
|
|
k: spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(pae, v))
|
|
for k, v in chain_masks_2d.items()
|
|
}
|
|
pde_chainwise = {
|
|
k: spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(pde, v))
|
|
for k, v in chain_masks_2d.items()
|
|
}
|
|
|
|
chain_masks_1d = create_chainwise_masks_1d(
|
|
chain_iid_token_lvl, device=is_real_atom.device
|
|
)
|
|
plddt_chainwise = {
|
|
k: spread_batch_into_dictionary(
|
|
compute_mean_over_subsampled_pairs(
|
|
plddt, is_real_atom[..., :NHEAVY] * v[:, None]
|
|
)
|
|
)
|
|
for k, v in chain_masks_1d.items()
|
|
}
|
|
|
|
# Aggregate confidence data
|
|
confidence_data = {
|
|
"example_id": example_id,
|
|
"mean_plddt": spread_batch_into_dictionary(
|
|
compute_mean_over_subsampled_pairs(plddt, is_real_atom[..., :NHEAVY])
|
|
),
|
|
"mean_pae": spread_batch_into_dictionary(pae.mean(dim=(-1, -2))),
|
|
"mean_pde": spread_batch_into_dictionary(pde.mean(dim=(-1, -2))),
|
|
"chain_wise_mean_plddt": plddt_chainwise,
|
|
"chain_wise_mean_pae": pae_chainwise,
|
|
"chain_wise_mean_pde": pde_chainwise,
|
|
"interface_wise_mean_pae": pae_interface,
|
|
"interface_wise_mean_pde": pde_interface,
|
|
"interface_wise_min_pae": pae_interface_min,
|
|
"interface_wise_min_pde": pde_interface_min,
|
|
}
|
|
|
|
# Generate DataFrame rows
|
|
num_batches = plddt.shape[0]
|
|
chains = np.unique(chain_iid_token_lvl)
|
|
chain_pairs = list(itertools.combinations(chains, 2))
|
|
|
|
# For every batch, chain, and interface (chain pair), generate a dataframe row
|
|
chain_rows = [
|
|
{
|
|
"example_id": example_id,
|
|
"chain_chainwise": chain,
|
|
"chainwise_plddt": confidence_data["chain_wise_mean_plddt"][chain][
|
|
batch_idx
|
|
],
|
|
"chainwise_pde": confidence_data["chain_wise_mean_pde"][chain][batch_idx],
|
|
"chainwise_pae": confidence_data["chain_wise_mean_pae"][chain][batch_idx],
|
|
"overall_plddt": confidence_data["mean_plddt"][batch_idx],
|
|
"overall_pde": confidence_data["mean_pde"][batch_idx],
|
|
"overall_pae": confidence_data["mean_pae"][batch_idx],
|
|
"batch_idx": batch_idx,
|
|
}
|
|
for batch_idx in range(num_batches)
|
|
for chain in chains
|
|
]
|
|
|
|
interface_rows = [
|
|
{
|
|
"example_id": example_id,
|
|
"chain_i_interface": chain_i,
|
|
"chain_j_interface": chain_j,
|
|
"pae_interface": confidence_data["interface_wise_mean_pae"][
|
|
(chain_i, chain_j)
|
|
][batch_idx],
|
|
"pde_interface": confidence_data["interface_wise_mean_pde"][
|
|
(chain_i, chain_j)
|
|
][batch_idx],
|
|
"min_pae_interface": confidence_data["interface_wise_min_pae"][
|
|
(chain_i, chain_j)
|
|
][batch_idx],
|
|
"min_pde_interface": confidence_data["interface_wise_min_pde"][
|
|
(chain_i, chain_j)
|
|
][batch_idx],
|
|
"overall_plddt": confidence_data["mean_plddt"][batch_idx],
|
|
"overall_pde": confidence_data["mean_pde"][batch_idx],
|
|
"overall_pae": confidence_data["mean_pae"][batch_idx],
|
|
"batch_idx": batch_idx,
|
|
}
|
|
for batch_idx in range(num_batches)
|
|
for (chain_i, chain_j) in chain_pairs
|
|
]
|
|
|
|
return {
|
|
"confidence_df": pd.DataFrame(itertools.chain([*chain_rows, *interface_rows])),
|
|
"plddt": plddt,
|
|
"pae": pae,
|
|
"pde": pde,
|
|
}
|
|
|
|
|
|
def compute_batch_indices_with_lowest_predicted_error(
|
|
plddt: torch.Tensor,
|
|
is_real_atom: torch.Tensor,
|
|
pae: torch.Tensor,
|
|
confidence_loss_cfg: dict | DictConfig,
|
|
chain_iid_token_lvl: torch.Tensor,
|
|
is_ligand: torch.Tensor,
|
|
interfaces_to_score: list[tuple],
|
|
pn_units_to_score: list[tuple],
|
|
) -> dict[str, Any]:
|
|
"""Given the confidence logits, computes the index within the diffusion batch of the best predicted structure.
|
|
|
|
Metrics include pAE, pLDDT, and pDE, among others.
|
|
|
|
Returns:
|
|
dict[str, Any]: A dictionary containing the following keys:
|
|
- pae_idx: The index within the diffusion batch of the structure with the best overall pAE (Predicted Aligned Error)
|
|
- pde_idx: The index within the diffusion batch of the structure with the best overall pDE (Predicted Distance Error)
|
|
- plddt_idx: The index within the diffusion batch of the structure with the best overall pLDDT (Predicted Local Distance
|
|
Difference Test)
|
|
- best_chain_to_all_idx: The index within the diffusion batch of the structure with the best pAE subsampled over any
|
|
pair (i,j) where i == chain or j == chain
|
|
- best_chain_to_self_idx: The index within the diffusion batch of the structure with the best pAE subsampled over any
|
|
pair (i,j) where i == chain and j == chain
|
|
- best_interface_idx: For each interface between two scored PN Units, the index within the diffusion batch of the
|
|
structure with the best mean pAE for all (i,j) where i == interface_chain or j == interface_chain and i != j
|
|
- best_lig_ipae_idx: The index within the diffusion batch for the best pAE subsambled over any pair (i,j)
|
|
where i == chain or j == chain and i != j and i or j is a ligand
|
|
"""
|
|
# TODO: Have this function take an `AtomArray` as input so we quickly build masks with much less code
|
|
# TODO: Explore how we can write this function more concisely
|
|
return_dict = {}
|
|
|
|
# AF3's ranking metrics work like this, but using ptm instead of ipae:
|
|
scored_chains, interfaces, interface_chains = _select_scored_units(
|
|
interfaces_to_score, pn_units_to_score
|
|
)
|
|
|
|
chain_to_all_masks = _create_chain_to_all_masks(chain_iid_token_lvl, scored_chains)
|
|
chain_to_self_masks = _create_chain_to_self_masks(
|
|
chain_iid_token_lvl, scored_chains
|
|
)
|
|
interface_masks, lig_chains = _create_interface_masks(
|
|
chain_iid_token_lvl, interfaces, is_ligand
|
|
)
|
|
|
|
# map everything to gpu
|
|
gpu = plddt.device
|
|
chain_to_all_masks = tree.map_structure(
|
|
lambda x: x.to(gpu) if hasattr(x, "cpu") else x, chain_to_all_masks
|
|
)
|
|
chain_to_self_masks = tree.map_structure(
|
|
lambda x: x.to(gpu) if hasattr(x, "cpu") else x, chain_to_self_masks
|
|
)
|
|
interface_masks = tree.map_structure(
|
|
lambda x: x.to(gpu) if hasattr(x, "cpu") else x, interface_masks
|
|
)
|
|
|
|
# Reshape logits to B, K, L, NHEAVY
|
|
plddt = (
|
|
plddt.reshape(
|
|
-1,
|
|
plddt.shape[1],
|
|
NHEAVY,
|
|
confidence_loss_cfg.plddt.n_bins,
|
|
)
|
|
.permute(0, 3, 1, 2)
|
|
.float()
|
|
)
|
|
# Reshape the pae and pde logits to B, K, L, L
|
|
pae_logits = pae.permute(0, 3, 1, 2).float()
|
|
pde_logits = pae.permute(0, 3, 1, 2).float()
|
|
|
|
pae_logits_unbinned = unbin_logits(
|
|
pae_logits, confidence_loss_cfg.pae.max_value, confidence_loss_cfg.pae.n_bins
|
|
)
|
|
plddt_logits_unbinned = unbin_logits(
|
|
plddt, confidence_loss_cfg.plddt.max_value, confidence_loss_cfg.plddt.n_bins
|
|
)
|
|
pde_logits_unbinned = unbin_logits(
|
|
pde_logits, confidence_loss_cfg.pde.max_value, confidence_loss_cfg.pde.n_bins
|
|
)
|
|
|
|
complex_pae = pae_logits_unbinned.mean(dim=(1, 2))
|
|
complex_pde = pde_logits_unbinned.mean(dim=(1, 2))
|
|
complex_plddt = (plddt_logits_unbinned * is_real_atom[..., :NHEAVY]).sum(
|
|
dim=(1, 2)
|
|
) / is_real_atom[..., :NHEAVY].sum()
|
|
|
|
return_dict["pae_idx"] = torch.argmin(complex_pae)
|
|
return_dict["pde_idx"] = torch.argmin(complex_pde)
|
|
return_dict["plddt_idx"] = torch.argmax(complex_plddt)
|
|
|
|
chain_to_self_paes = _get_masked_error_per_chain(
|
|
scored_chains, chain_to_self_masks, pae_logits_unbinned
|
|
)
|
|
chain_to_all_paes = _get_masked_error_per_chain(
|
|
scored_chains, chain_to_all_masks, pae_logits_unbinned
|
|
)
|
|
interface_chain_paes = _get_masked_error_per_chain(
|
|
interface_chains, interface_masks, pae_logits_unbinned
|
|
)
|
|
# average over both interfaces
|
|
average_interface_paes = _get_average_error_per_interface(
|
|
interfaces, lig_chains, interface_chain_paes
|
|
)
|
|
|
|
return_dict["best_chain_to_all_idx"] = _get_lowest_error_indices(chain_to_all_paes)
|
|
return_dict["best_chain_to_self_idx"] = _get_lowest_error_indices(
|
|
chain_to_self_paes
|
|
)
|
|
return_dict["best_interface_idx"] = _get_lowest_error_indices(
|
|
average_interface_paes
|
|
)
|
|
# for ligands, we don't average the error
|
|
return_dict["best_lig_ipae_idx"] = _get_lowest_error_ligand_indices(
|
|
interface_chain_paes, interfaces, lig_chains
|
|
)
|
|
return return_dict
|
|
|
|
|
|
def annotate_atom_array_b_factor_with_plddt(
|
|
atom_array: AtomArray | AtomArrayStack,
|
|
plddt: torch.Tensor,
|
|
is_real_atom: torch.Tensor,
|
|
) -> List[AtomArray]:
|
|
"""Annotates the b_factor of an AtomArray with the pLDDT values in the occupancy field.
|
|
|
|
Args:
|
|
atom_array: The AtomArray or AtomArrayStack to annotate
|
|
plddt: The pLDDT tensor of shape (B, I, NHEAVY)
|
|
is_real_atom: A mask indicating which atoms are in the structure of shape (I, NHEAVY)
|
|
|
|
Returns:
|
|
list[AtomArray]: The annotated list of AtomArrays. We must return a list of AtomArrays
|
|
because the AtomArray class does not support setting different values as annotations
|
|
other than the coordinate feature.
|
|
"""
|
|
atom_wise_plddt = plddt[:, is_real_atom[..., :NHEAVY]]
|
|
assert atom_wise_plddt.shape[1] == atom_array.array_length()
|
|
atom_array_list = []
|
|
# bitotite's AtomArray does not support setting different values as annotations other than
|
|
# the coordinate feature, so we convert atom_array to a list of AtomArrays
|
|
if isinstance(atom_array, AtomArrayStack):
|
|
for i, aa in enumerate(atom_array):
|
|
aa.set_annotation("b_factor", atom_wise_plddt[i].cpu().numpy())
|
|
atom_array_list.append(aa)
|
|
else:
|
|
assert atom_wise_plddt.shape[0] == 1
|
|
atom_array.set_annotation("b_factor", atom_wise_plddt[0].cpu().numpy())
|
|
atom_array_list.append(atom_array)
|
|
|
|
for aa in atom_array_list:
|
|
assert np.isnan(aa.b_factor).sum() == 0
|
|
|
|
return atom_array_list
|
|
|
|
|
|
def _select_scored_units(
|
|
interfaces_to_score: list[tuple], pn_units_to_score: list[tuple]
|
|
):
|
|
scored_chains = []
|
|
interfaces = []
|
|
interface_chains = []
|
|
for k in interfaces_to_score:
|
|
interfaces.append(f"{k[0]}-{k[1]}")
|
|
interface_chains.append(k[0])
|
|
interface_chains.append(k[1])
|
|
for k in pn_units_to_score:
|
|
scored_chains.append(k[0])
|
|
|
|
return scored_chains, interfaces, interface_chains
|
|
|
|
|
|
def _create_chain_to_all_masks(ch_label, chains_to_score):
|
|
unique_chains = np.unique(ch_label)
|
|
I = len(ch_label)
|
|
chain_to_all_masks = {}
|
|
for chain in unique_chains:
|
|
if chain in chains_to_score:
|
|
indices = torch.from_numpy((ch_label == chain))
|
|
mask = indices.unsqueeze(0) | indices.unsqueeze(1)
|
|
# set the diagonal to false
|
|
mask = mask & ~torch.eye(I, device=mask.device, dtype=torch.bool)
|
|
chain_to_all_masks[chain] = mask
|
|
return chain_to_all_masks
|
|
|
|
|
|
def _create_chain_to_self_masks(ch_label, chains_to_score):
|
|
unique_chains = np.unique(ch_label)
|
|
I = len(ch_label)
|
|
chain_to_self_masks = {}
|
|
for chain in unique_chains:
|
|
if chain in chains_to_score:
|
|
indices = torch.from_numpy((ch_label == chain))
|
|
mask = indices.unsqueeze(0) & indices.unsqueeze(1)
|
|
# set the diagonal to false
|
|
mask = mask & ~torch.eye(I, device=mask.device, dtype=torch.bool)
|
|
chain_to_self_masks[chain] = mask
|
|
return chain_to_self_masks
|
|
|
|
|
|
def _create_interface_masks(ch_label, interfaces, is_ligand):
|
|
interface_masks = {}
|
|
interface_chains = []
|
|
ligand_chains = []
|
|
for interface in interfaces:
|
|
interface_chains.append(interface.split("-")[0])
|
|
interface_chains.append(interface.split("-")[1])
|
|
interface_chains = set(interface_chains)
|
|
for chain in interface_chains:
|
|
chain_indices = torch.from_numpy((ch_label == chain))
|
|
|
|
to_self = chain_indices.unsqueeze(0) & chain_indices.unsqueeze(1)
|
|
to_all = chain_indices.unsqueeze(0) | chain_indices.unsqueeze(1)
|
|
interface_mask = to_all & ~to_self
|
|
interface_masks[chain] = interface_mask
|
|
|
|
if torch.all(is_ligand[chain_indices]):
|
|
ligand_chains.append(chain)
|
|
|
|
return interface_masks, ligand_chains
|
|
|
|
|
|
def _get_masked_error_per_chain(chains, masks, unbinned_logits):
|
|
error = {}
|
|
for chain in chains:
|
|
mask = masks[chain]
|
|
chain_error = compute_mean_over_subsampled_pairs(unbinned_logits, mask)
|
|
error[chain] = chain_error
|
|
|
|
return error
|
|
|
|
|
|
def _get_average_error_per_interface(interfaces, lig_chains, interface_errors):
|
|
average_error = {}
|
|
for interface in interfaces:
|
|
chain_a = interface.split("-")[0]
|
|
chain_b = interface.split("-")[1]
|
|
average_error[interface] = (
|
|
interface_errors[chain_a] + interface_errors[chain_b]
|
|
) / 2
|
|
|
|
return average_error
|
|
|
|
|
|
def _get_lowest_error_indices(errors):
|
|
lowest_error_indices = {}
|
|
for k, v in errors.items():
|
|
lowest_error_indices[k] = torch.argmin(v)
|
|
|
|
return lowest_error_indices
|
|
|
|
|
|
def _get_lowest_error_ligand_indices(errors, interfaces, lig_chains):
|
|
# ligands are a special case in AF3, where they only consider the ligand chain's error and not the average for the interface
|
|
lowest_error_indices = {}
|
|
for interface in interfaces:
|
|
chain_a = interface.split("-")[0]
|
|
chain_b = interface.split("-")[1]
|
|
if chain_a in lig_chains or chain_b in lig_chains:
|
|
if chain_a in lig_chains:
|
|
lig_chain = chain_a
|
|
elif chain_b in lig_chains:
|
|
lig_chain = chain_b
|
|
|
|
lowest_error_indices[interface] = torch.argmin(errors[lig_chain])
|
|
else:
|
|
# assign a random value to avoid key errors downstream; sorting ligand interfaces
|
|
# from other types is handles in analysis
|
|
lowest_error_indices[interface] = 0
|
|
|
|
return lowest_error_indices
|