import itertools from typing import List import einops import numpy as np import pandas as pd import torch import tree from beartype.typing import Any from biotite.structure import AtomArray, AtomArrayStack from omegaconf import DictConfig from modelhub.chemical import NHEAVY from modelhub.metrics.metric_utils import ( compute_mean_over_subsampled_pairs, compute_min_over_subsampled_pairs, create_chainwise_masks_1d, create_chainwise_masks_2d, create_interface_masks_2d, spread_batch_into_dictionary, unbin_logits, ) def get_mean_atomwise_plddt( plddt_logits: torch.Tensor, is_real_atom: torch.Tensor, max_value: float, ) -> torch.Tensor: """Aggregate plddts. Args: plddt_logits: Tensor of shape [B, n_token, max_atoms_in_a_token * n_bin] with logits is_real_atom: Boolean mask of shape [B, n_token, max_atoms_in_a_token] indicating which atoms are real (i.e., not padding) max_value: Maximum value for pLDDT (assigned to the last bin) Returns: plddt: Tensor of shape [B,] with the mean atom-wise pLDDT for each batch """ assert ( plddt_logits.ndim == 3 ), "plddt_logits must be a 3D tensor (B, n_token, max_atoms_in_a_token * n_bins)" # TODO: Replace with the last dimension of is_real_atom; right now that number is too large (36) because it includes hydrogens max_atoms_in_a_token = NHEAVY # Since the pLDDT logits have the last dimension (max_atoms_in_a_token * n_bins), we can calculate n_bins directly assert ( plddt_logits.shape[-1] % max_atoms_in_a_token == 0 ), "The last dimension of plddt_logits must be divisible by max_atoms_in_a_token!" n_bins = plddt_logits.shape[-1] // max_atoms_in_a_token # ... reshape to match what unbin_logits expects reshaped_plddt_logits = einops.rearrange( plddt_logits, "... n_token (max_atoms_in_a_token n_bins) -> ... n_bins n_token max_atoms_in_a_token", max_atoms_in_a_token=max_atoms_in_a_token, n_bins=n_bins, ).float() # [..., n_token, n_bins * max_atoms_in_a_token] -> [ ..., n_bins, n_token, max_atoms_in_a_token] plddt = unbin_logits( reshaped_plddt_logits, max_value, n_bins, ) is_real_atom = is_real_atom.to(device=plddt.device) # ... create mask indicating which atoms are "real" (i.e., not padding) and calculate the mean mask = is_real_atom[:, :max_atoms_in_a_token].unsqueeze(0) atomwise_plddt_mean = (plddt * mask).sum(dim=(1, 2)) / mask.sum(dim=(1, 2)) return atomwise_plddt_mean def compile_af3_confidence_outputs( plddt_logits: torch.Tensor, pae_logits: torch.Tensor, pde_logits: torch.Tensor, chain_iid_token_lvl: torch.Tensor, is_real_atom: torch.Tensor, example_id: str, confidence_loss_cfg: DictConfig | dict, ) -> dict[str, Any]: # TODO: Refactor to accept an AtomArray # TODO: Taking the confidence_loss_cfg does not align with functional programming best-practices; we should instead take the max_value and n_bins as arguments """Given the confidence logits, computes the confidence metrics for the model's predictions. Returns: dict[str, Any]: A dictionary containing the following: - confidence_df: A DataFrame containing the aggregate confidence metrics at the chain- and interface-level - plddt: The pLDDT logits - pae: The pAE logits - pde: The pDE logits """ # Reorder the input tensors to be in (B, n_bins, ...) format for unbinning plddt = unbin_logits( plddt_logits.reshape( -1, plddt_logits.shape[1], NHEAVY, confidence_loss_cfg.plddt.n_bins, ) .permute(0, 3, 1, 2) .float(), confidence_loss_cfg.plddt.max_value, confidence_loss_cfg.plddt.n_bins, ) # Unbin the pae and pde logits pae = unbin_logits( pae_logits.permute(0, 3, 1, 2).float(), confidence_loss_cfg.pae.max_value, confidence_loss_cfg.pae.n_bins, ) pde = unbin_logits( pde_logits.permute(0, 3, 1, 2).float(), confidence_loss_cfg.pde.max_value, confidence_loss_cfg.pde.n_bins, ) # Calculate interface metrics interface_masks = create_interface_masks_2d(chain_iid_token_lvl, device=pae.device) pae_interface = { k: spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(pae, v)) for k, v in interface_masks.items() } pde_interface = { k: spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(pde, v)) for k, v in interface_masks.items() } pae_interface_min = { k: spread_batch_into_dictionary(compute_min_over_subsampled_pairs(pae, v)) for k, v in interface_masks.items() } pde_interface_min = { k: spread_batch_into_dictionary(compute_min_over_subsampled_pairs(pde, v)) for k, v in interface_masks.items() } # Calculate chainwise metrics chain_masks_2d = create_chainwise_masks_2d(chain_iid_token_lvl, device=pae.device) pae_chainwise = { k: spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(pae, v)) for k, v in chain_masks_2d.items() } pde_chainwise = { k: spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(pde, v)) for k, v in chain_masks_2d.items() } chain_masks_1d = create_chainwise_masks_1d( chain_iid_token_lvl, device=is_real_atom.device ) plddt_chainwise = { k: spread_batch_into_dictionary( compute_mean_over_subsampled_pairs( plddt, is_real_atom[..., :NHEAVY] * v[:, None] ) ) for k, v in chain_masks_1d.items() } # Aggregate confidence data confidence_data = { "example_id": example_id, "mean_plddt": spread_batch_into_dictionary( compute_mean_over_subsampled_pairs(plddt, is_real_atom[..., :NHEAVY]) ), "mean_pae": spread_batch_into_dictionary(pae.mean(dim=(-1, -2))), "mean_pde": spread_batch_into_dictionary(pde.mean(dim=(-1, -2))), "chain_wise_mean_plddt": plddt_chainwise, "chain_wise_mean_pae": pae_chainwise, "chain_wise_mean_pde": pde_chainwise, "interface_wise_mean_pae": pae_interface, "interface_wise_mean_pde": pde_interface, "interface_wise_min_pae": pae_interface_min, "interface_wise_min_pde": pde_interface_min, } # Generate DataFrame rows num_batches = plddt.shape[0] chains = np.unique(chain_iid_token_lvl) chain_pairs = list(itertools.combinations(chains, 2)) # For every batch, chain, and interface (chain pair), generate a dataframe row chain_rows = [ { "example_id": example_id, "chain_chainwise": chain, "chainwise_plddt": confidence_data["chain_wise_mean_plddt"][chain][ batch_idx ], "chainwise_pde": confidence_data["chain_wise_mean_pde"][chain][batch_idx], "chainwise_pae": confidence_data["chain_wise_mean_pae"][chain][batch_idx], "overall_plddt": confidence_data["mean_plddt"][batch_idx], "overall_pde": confidence_data["mean_pde"][batch_idx], "overall_pae": confidence_data["mean_pae"][batch_idx], "batch_idx": batch_idx, } for batch_idx in range(num_batches) for chain in chains ] interface_rows = [ { "example_id": example_id, "chain_i_interface": chain_i, "chain_j_interface": chain_j, "pae_interface": confidence_data["interface_wise_mean_pae"][ (chain_i, chain_j) ][batch_idx], "pde_interface": confidence_data["interface_wise_mean_pde"][ (chain_i, chain_j) ][batch_idx], "min_pae_interface": confidence_data["interface_wise_min_pae"][ (chain_i, chain_j) ][batch_idx], "min_pde_interface": confidence_data["interface_wise_min_pde"][ (chain_i, chain_j) ][batch_idx], "overall_plddt": confidence_data["mean_plddt"][batch_idx], "overall_pde": confidence_data["mean_pde"][batch_idx], "overall_pae": confidence_data["mean_pae"][batch_idx], "batch_idx": batch_idx, } for batch_idx in range(num_batches) for (chain_i, chain_j) in chain_pairs ] return { "confidence_df": pd.DataFrame(itertools.chain([*chain_rows, *interface_rows])), "plddt": plddt, "pae": pae, "pde": pde, } def compute_batch_indices_with_lowest_predicted_error( plddt: torch.Tensor, is_real_atom: torch.Tensor, pae: torch.Tensor, confidence_loss_cfg: dict | DictConfig, chain_iid_token_lvl: torch.Tensor, is_ligand: torch.Tensor, interfaces_to_score: list[tuple], pn_units_to_score: list[tuple], ) -> dict[str, Any]: """Given the confidence logits, computes the index within the diffusion batch of the best predicted structure. Metrics include pAE, pLDDT, and pDE, among others. Returns: dict[str, Any]: A dictionary containing the following keys: - pae_idx: The index within the diffusion batch of the structure with the best overall pAE (Predicted Aligned Error) - pde_idx: The index within the diffusion batch of the structure with the best overall pDE (Predicted Distance Error) - plddt_idx: The index within the diffusion batch of the structure with the best overall pLDDT (Predicted Local Distance Difference Test) - best_chain_to_all_idx: The index within the diffusion batch of the structure with the best pAE subsampled over any pair (i,j) where i == chain or j == chain - best_chain_to_self_idx: The index within the diffusion batch of the structure with the best pAE subsampled over any pair (i,j) where i == chain and j == chain - best_interface_idx: For each interface between two scored PN Units, the index within the diffusion batch of the structure with the best mean pAE for all (i,j) where i == interface_chain or j == interface_chain and i != j - best_lig_ipae_idx: The index within the diffusion batch for the best pAE subsambled over any pair (i,j) where i == chain or j == chain and i != j and i or j is a ligand """ # TODO: Have this function take an `AtomArray` as input so we quickly build masks with much less code # TODO: Explore how we can write this function more concisely return_dict = {} # AF3's ranking metrics work like this, but using ptm instead of ipae: scored_chains, interfaces, interface_chains = _select_scored_units( interfaces_to_score, pn_units_to_score ) chain_to_all_masks = _create_chain_to_all_masks(chain_iid_token_lvl, scored_chains) chain_to_self_masks = _create_chain_to_self_masks( chain_iid_token_lvl, scored_chains ) interface_masks, lig_chains = _create_interface_masks( chain_iid_token_lvl, interfaces, is_ligand ) # map everything to gpu gpu = plddt.device chain_to_all_masks = tree.map_structure( lambda x: x.to(gpu) if hasattr(x, "cpu") else x, chain_to_all_masks ) chain_to_self_masks = tree.map_structure( lambda x: x.to(gpu) if hasattr(x, "cpu") else x, chain_to_self_masks ) interface_masks = tree.map_structure( lambda x: x.to(gpu) if hasattr(x, "cpu") else x, interface_masks ) # Reshape logits to B, K, L, NHEAVY plddt = ( plddt.reshape( -1, plddt.shape[1], NHEAVY, confidence_loss_cfg.plddt.n_bins, ) .permute(0, 3, 1, 2) .float() ) # Reshape the pae and pde logits to B, K, L, L pae_logits = pae.permute(0, 3, 1, 2).float() pde_logits = pae.permute(0, 3, 1, 2).float() pae_logits_unbinned = unbin_logits( pae_logits, confidence_loss_cfg.pae.max_value, confidence_loss_cfg.pae.n_bins ) plddt_logits_unbinned = unbin_logits( plddt, confidence_loss_cfg.plddt.max_value, confidence_loss_cfg.plddt.n_bins ) pde_logits_unbinned = unbin_logits( pde_logits, confidence_loss_cfg.pde.max_value, confidence_loss_cfg.pde.n_bins ) complex_pae = pae_logits_unbinned.mean(dim=(1, 2)) complex_pde = pde_logits_unbinned.mean(dim=(1, 2)) complex_plddt = (plddt_logits_unbinned * is_real_atom[..., :NHEAVY]).sum( dim=(1, 2) ) / is_real_atom[..., :NHEAVY].sum() return_dict["pae_idx"] = torch.argmin(complex_pae) return_dict["pde_idx"] = torch.argmin(complex_pde) return_dict["plddt_idx"] = torch.argmax(complex_plddt) chain_to_self_paes = _get_masked_error_per_chain( scored_chains, chain_to_self_masks, pae_logits_unbinned ) chain_to_all_paes = _get_masked_error_per_chain( scored_chains, chain_to_all_masks, pae_logits_unbinned ) interface_chain_paes = _get_masked_error_per_chain( interface_chains, interface_masks, pae_logits_unbinned ) # average over both interfaces average_interface_paes = _get_average_error_per_interface( interfaces, lig_chains, interface_chain_paes ) return_dict["best_chain_to_all_idx"] = _get_lowest_error_indices(chain_to_all_paes) return_dict["best_chain_to_self_idx"] = _get_lowest_error_indices( chain_to_self_paes ) return_dict["best_interface_idx"] = _get_lowest_error_indices( average_interface_paes ) # for ligands, we don't average the error return_dict["best_lig_ipae_idx"] = _get_lowest_error_ligand_indices( interface_chain_paes, interfaces, lig_chains ) return return_dict def annotate_atom_array_b_factor_with_plddt( atom_array: AtomArray | AtomArrayStack, plddt: torch.Tensor, is_real_atom: torch.Tensor, ) -> List[AtomArray]: """Annotates the b_factor of an AtomArray with the pLDDT values in the occupancy field. Args: atom_array: The AtomArray or AtomArrayStack to annotate plddt: The pLDDT tensor of shape (B, I, NHEAVY) is_real_atom: A mask indicating which atoms are in the structure of shape (I, NHEAVY) Returns: list[AtomArray]: The annotated list of AtomArrays. We must return a list of AtomArrays because the AtomArray class does not support setting different values as annotations other than the coordinate feature. """ atom_wise_plddt = plddt[:, is_real_atom[..., :NHEAVY]] assert atom_wise_plddt.shape[1] == atom_array.array_length() atom_array_list = [] # bitotite's AtomArray does not support setting different values as annotations other than # the coordinate feature, so we convert atom_array to a list of AtomArrays if isinstance(atom_array, AtomArrayStack): for i, aa in enumerate(atom_array): aa.set_annotation("b_factor", atom_wise_plddt[i].cpu().numpy()) atom_array_list.append(aa) else: assert atom_wise_plddt.shape[0] == 1 atom_array.set_annotation("b_factor", atom_wise_plddt[0].cpu().numpy()) atom_array_list.append(atom_array) for aa in atom_array_list: assert np.isnan(aa.b_factor).sum() == 0 return atom_array_list def _select_scored_units( interfaces_to_score: list[tuple], pn_units_to_score: list[tuple] ): scored_chains = [] interfaces = [] interface_chains = [] for k in interfaces_to_score: interfaces.append(f"{k[0]}-{k[1]}") interface_chains.append(k[0]) interface_chains.append(k[1]) for k in pn_units_to_score: scored_chains.append(k[0]) return scored_chains, interfaces, interface_chains def _create_chain_to_all_masks(ch_label, chains_to_score): unique_chains = np.unique(ch_label) I = len(ch_label) chain_to_all_masks = {} for chain in unique_chains: if chain in chains_to_score: indices = torch.from_numpy((ch_label == chain)) mask = indices.unsqueeze(0) | indices.unsqueeze(1) # set the diagonal to false mask = mask & ~torch.eye(I, device=mask.device, dtype=torch.bool) chain_to_all_masks[chain] = mask return chain_to_all_masks def _create_chain_to_self_masks(ch_label, chains_to_score): unique_chains = np.unique(ch_label) I = len(ch_label) chain_to_self_masks = {} for chain in unique_chains: if chain in chains_to_score: indices = torch.from_numpy((ch_label == chain)) mask = indices.unsqueeze(0) & indices.unsqueeze(1) # set the diagonal to false mask = mask & ~torch.eye(I, device=mask.device, dtype=torch.bool) chain_to_self_masks[chain] = mask return chain_to_self_masks def _create_interface_masks(ch_label, interfaces, is_ligand): interface_masks = {} interface_chains = [] ligand_chains = [] for interface in interfaces: interface_chains.append(interface.split("-")[0]) interface_chains.append(interface.split("-")[1]) interface_chains = set(interface_chains) for chain in interface_chains: chain_indices = torch.from_numpy((ch_label == chain)) to_self = chain_indices.unsqueeze(0) & chain_indices.unsqueeze(1) to_all = chain_indices.unsqueeze(0) | chain_indices.unsqueeze(1) interface_mask = to_all & ~to_self interface_masks[chain] = interface_mask if torch.all(is_ligand[chain_indices]): ligand_chains.append(chain) return interface_masks, ligand_chains def _get_masked_error_per_chain(chains, masks, unbinned_logits): error = {} for chain in chains: mask = masks[chain] chain_error = compute_mean_over_subsampled_pairs(unbinned_logits, mask) error[chain] = chain_error return error def _get_average_error_per_interface(interfaces, lig_chains, interface_errors): average_error = {} for interface in interfaces: chain_a = interface.split("-")[0] chain_b = interface.split("-")[1] average_error[interface] = ( interface_errors[chain_a] + interface_errors[chain_b] ) / 2 return average_error def _get_lowest_error_indices(errors): lowest_error_indices = {} for k, v in errors.items(): lowest_error_indices[k] = torch.argmin(v) return lowest_error_indices def _get_lowest_error_ligand_indices(errors, interfaces, lig_chains): # ligands are a special case in AF3, where they only consider the ligand chain's error and not the average for the interface lowest_error_indices = {} for interface in interfaces: chain_a = interface.split("-")[0] chain_b = interface.split("-")[1] if chain_a in lig_chains or chain_b in lig_chains: if chain_a in lig_chains: lig_chain = chain_a elif chain_b in lig_chains: lig_chain = chain_b lowest_error_indices[interface] = torch.argmin(errors[lig_chain]) else: # assign a random value to avoid key errors downstream; sorting ligand interfaces # from other types is handles in analysis lowest_error_indices[interface] = 0 return lowest_error_indices