Clean up structure cleaning - remove unused code paths

PiperOrigin-RevId: 903711715
Change-Id: I953f1b27c6a0b30cc79a6cbd3ee3503d1f24a7ec
This commit is contained in:
Augustin Zidek
2026-04-22 02:32:41 -07:00
committed by Copybara-Service
parent cfeeedd24d
commit e9adcd4dbf
4 changed files with 5 additions and 490 deletions

View File

@@ -103,8 +103,6 @@ class WholePdbPipeline:
date the model coordinates can be used as a fallback.
max_templates: The maximum number of templates to send through the network
set to 0 to switch off templates.
filter_clashes: If true then will remove clashing chains.
filter_crystal_aids: If true ligands in the cryal aid list are removed.
max_paired_sequence_per_species: The maximum number of sequences per
species that will be used for MSA pairing.
drop_ligand_leaving_atoms: Flag for handling leaving atoms for ligands.
@@ -115,8 +113,6 @@ class WholePdbPipeline:
atom_cross_att_keys_subset_size: keys subset size in atom cross attention
flatten_non_standard_residues: Whether to expand non-standard polymer
residues into flat-atom format.
remove_nonsymmetric_bonds: Whether to remove nonsymmetric bonds from
symmetric polymer chains.
deterministic_frames: Whether to use fixed-seed reference positions to
construct deterministic frames.
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
@@ -136,15 +132,12 @@ class WholePdbPipeline:
max_template_date: datetime.date | None = None
ref_max_modified_date: datetime.date | None = None
max_templates: int = 4
filter_clashes: bool = False
filter_crystal_aids: bool = False
max_paired_sequence_per_species: int = 600
drop_ligand_leaving_atoms: bool = True
average_num_atoms_per_token: int = 24
atom_cross_att_queries_subset_size: int = 32
atom_cross_att_keys_subset_size: int = 128
flatten_non_standard_residues: bool = True
remove_nonsymmetric_bonds: bool = False
deterministic_frames: bool = True
conformer_max_iterations: int | None = None
resolve_msa_overlaps: bool = True
@@ -176,13 +169,11 @@ class WholePdbPipeline:
logging.info('Processing %s', logging_name)
# Clean structure.
cleaned_struc, cleaning_metadata = structure_cleaning.clean_structure(
cleaned_struc = structure_cleaning.clean_structure(
struct,
ccd=ccd,
drop_non_standard_atoms=True,
drop_missing_sequence=True,
filter_clashes=self._config.filter_clashes,
filter_crystal_aids=self._config.filter_crystal_aids,
filter_waters=True,
filter_hydrogens=True,
filter_leaving_atoms=self._config.drop_ligand_leaving_atoms,
@@ -190,21 +181,9 @@ class WholePdbPipeline:
covalent_bonds_only=True,
remove_polymer_polymer_bonds=True,
remove_bad_bonds=True,
remove_nonsymmetric_bonds=self._config.remove_nonsymmetric_bonds,
)
num_clashing_chains_removed = cleaning_metadata[
'num_clashing_chains_removed'
]
if num_clashing_chains_removed:
logging.info(
'Removed %d clashing chains from %s',
num_clashing_chains_removed,
logging_name,
)
# No chains after fixes
# No chains after cleaning.
if cleaned_struc.num_chains == 0:
raise MmcifNumChainsError(f'{logging_name}: No chains in structure!')

View File

@@ -10,8 +10,6 @@
"""Prepare PDB structure for training or inference."""
from typing import Any
from absl import logging
from alphafold3 import structure
from alphafold3.constants import chemical_component_sets
@@ -19,8 +17,6 @@ from alphafold3.constants import chemical_components
from alphafold3.constants import mmcif_names
from alphafold3.model.atom_layout import atom_layout
from alphafold3.model.pipeline import inter_chain_bonds
from alphafold3.model.scoring import covalent_bond_cleaning
from alphafold3.structure import sterics
import numpy as np
@@ -70,9 +66,7 @@ def clean_structure(
ccd: chemical_components.Ccd,
*,
drop_missing_sequence: bool,
filter_clashes: bool,
drop_non_standard_atoms: bool,
filter_crystal_aids: bool,
filter_waters: bool,
filter_hydrogens: bool,
filter_leaving_atoms: bool,
@@ -80,17 +74,14 @@ def clean_structure(
covalent_bonds_only: bool,
remove_polymer_polymer_bonds: bool,
remove_bad_bonds: bool,
remove_nonsymmetric_bonds: bool,
) -> tuple[structure.Structure, dict[str, Any]]:
"""Cleans structure.
) -> structure.Structure:
"""Returns a cleaned version of the input structure.
Args:
struc: Structure to clean.
ccd: The chemical components dictionary.
drop_missing_sequence: Whether to drop chains without specified sequences.
filter_clashes: Whether to drop clashing chains.
drop_non_standard_atoms: Whether to drop non CCD standard atoms.
filter_crystal_aids: Whether to drop ligands in the crystal aid set.
filter_waters: Whether to drop water chains.
filter_hydrogens: Whether to drop hyrdogen atoms.
filter_leaving_atoms: Whether to drop leaving atoms based on heuristics.
@@ -99,43 +90,13 @@ def clean_structure(
covalent_bonds_only: Only include covalent bonds.
remove_polymer_polymer_bonds: Remove polymer-polymer bonds.
remove_bad_bonds: Whether to remove badly bonded ligands.
remove_nonsymmetric_bonds: Whether to remove nonsymmetric polymer-ligand
bonds from symmetric polymer chains.
Returns:
Tuple of structure and metadata dict. The metadata dict has
information about what was cleaned from the original.
"""
metadata = {}
# Crop crystallization aids.
if (
filter_crystal_aids
and struc.structure_method in mmcif_names.CRYSTALLIZATION_METHODS
):
struc = struc.filter_out(
res_name=chemical_component_sets.COMMON_CRYSTALLIZATION_AIDS
)
# Drop chains without specified sequences.
if drop_missing_sequence:
chains_with_unk_sequence = struc.find_chains_with_unknown_sequence()
num_with_unk_sequence = len(chains_with_unk_sequence)
if chains_with_unk_sequence:
struc = struc.filter_out(chain_id=chains_with_unk_sequence)
else:
num_with_unk_sequence = 0
metadata['num_with_unk_sequence'] = num_with_unk_sequence
# Remove intersecting chains.
if filter_clashes and struc.num_chains > 1:
clashing_chains = sterics.find_clashing_chains(struc)
if clashing_chains:
struc = struc.filter_out(chain_id=clashing_chains)
else:
clashing_chains = []
metadata['num_clashing_chains_removed'] = len(clashing_chains)
metadata['chains_removed'] = clashing_chains
# Drop non-standard atoms
if drop_non_standard_atoms:
@@ -268,27 +229,7 @@ def clean_structure(
new_bonds = structure.Bonds.make_empty()
struc = struc.copy_and_update(bonds=new_bonds)
if struc.bonds and remove_nonsymmetric_bonds:
# Check for asymmetric polymer-ligand bonds and remove if these exist.
polymer_ligand_bonds = inter_chain_bonds.get_polymer_ligand_bonds(
struc,
only_glycan_ligands=False,
)
if polymer_ligand_bonds:
if covalent_bond_cleaning.has_nonsymmetric_bonds_on_symmetric_polymer_chains(
struc, polymer_ligand_bonds
):
from_atom_idxs, dest_atom_idxs = struc.bonds.get_atom_indices(
struc.atom_key
)
poly_chain_types = list(mmcif_names.POLYMER_CHAIN_TYPES)
is_polymer_bond = np.logical_or(
np.isin(struc.chain_type[from_atom_idxs], poly_chain_types),
np.isin(struc.chain_type[dest_atom_idxs], poly_chain_types),
)
struc = struc.copy_and_update(bonds=struc.bonds[~is_polymer_bond])
return struc, metadata
return struc
def create_empty_output_struc_and_layout(

View File

@@ -1,265 +0,0 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Some methods to compute metrics for PTMs."""
import collections
from collections.abc import Mapping
import dataclasses
from alphafold3 import structure
from alphafold3.constants import mmcif_names
from alphafold3.model.atom_layout import atom_layout
import numpy as np
@dataclasses.dataclass(frozen=True)
class ResIdMapping:
old_res_ids: np.ndarray
new_res_ids: np.ndarray
def _count_symmetric_chains(struc: structure.Structure) -> Mapping[str, int]:
"""Returns a dict with each chain ID and count."""
chain_res_name_sequence_from_chain_id = struc.chain_res_name_sequence(
include_missing_residues=True, fix_non_standard_polymer_res=False
)
counts_for_chain_res_name_sequence = collections.Counter(
chain_res_name_sequence_from_chain_id.values()
)
chain_symmetric_count = {}
for chain_id, chain_res_name in chain_res_name_sequence_from_chain_id.items():
chain_symmetric_count[chain_id] = counts_for_chain_res_name_sequence[
chain_res_name
]
return chain_symmetric_count
def has_nonsymmetric_bonds_on_symmetric_polymer_chains(
struc: structure.Structure, polymer_ligand_bonds: atom_layout.AtomLayout
) -> bool:
"""Returns true if nonsymmetric bonds found on polymer chains."""
try:
_get_polymer_dim(polymer_ligand_bonds)
except ValueError:
return True
if _has_non_polymer_ligand_ptm_bonds(polymer_ligand_bonds):
return True
if _has_multiple_polymers_bonded_to_one_ligand(polymer_ligand_bonds):
return True
combined_struc, _ = _combine_polymer_ligand_ptm_chains(
struc, polymer_ligand_bonds
)
struc = struc.filter(chain_type=mmcif_names.POLYMER_CHAIN_TYPES)
combined_struc = combined_struc.filter(
chain_type=mmcif_names.POLYMER_CHAIN_TYPES
)
return _count_symmetric_chains(struc) != _count_symmetric_chains(
combined_struc
)
def _has_non_polymer_ligand_ptm_bonds(
polymer_ligand_bonds: atom_layout.AtomLayout,
):
"""Checks if all bonds are between a polymer chain and a ligand chain type."""
for start_chain_type, end_chain_type in polymer_ligand_bonds.chain_type:
if (
start_chain_type in mmcif_names.POLYMER_CHAIN_TYPES
and end_chain_type in mmcif_names.LIGAND_CHAIN_TYPES
):
continue
elif (
start_chain_type in mmcif_names.LIGAND_CHAIN_TYPES
and end_chain_type in mmcif_names.POLYMER_CHAIN_TYPES
):
continue
else:
return True
return False
def _combine_polymer_ligand_ptm_chains(
struc: structure.Structure,
polymer_ligand_bonds: atom_layout.AtomLayout,
) -> tuple[structure.Structure, dict[tuple[str, str], ResIdMapping]]:
"""Combines the ptm polymer-ligand chains together.
This will prevent them from being permuted away from each other when chains
are matched to the ground truth. This function also returns the res_id mapping
from the separate ligand res_ids to their res_ids in the combined
polymer-ligand chain; this information is needed to later separate the
combined polymer-ligand chain.
Args:
struc: Structure to be modified.
polymer_ligand_bonds: AtomLayout with polymer-ligand bond info.
Returns:
A tuple of a Structure with each ptm polymer-ligand chain relabelled as one
chain and a dict from bond chain pair to the res_id mapping.
"""
if not _has_only_single_bond_from_each_chain(polymer_ligand_bonds):
if _has_multiple_ligands_bonded_to_one_polymer(polymer_ligand_bonds):
# For structures where a polymer chain is connected to multiple ligands,
# we need to sort the multiple bonds from the same chain by res_id to
# ensure that the combined polymer-ligand chain will always be the same
# when you have repeated symmetric polymer-ligand chains.
polymer_ligand_bonds = (
_sort_polymer_ligand_bonds_by_polymer_chain_and_res_id(
polymer_ligand_bonds
)
)
else:
raise ValueError(
'Code cannot handle multiple bonds from one chain unless'
' its several ligands bonded to a polymer.'
)
res_id_mappings_for_bond_chain_pair = dict()
for (start_chain_id, end_chain_id), (start_chain_type, end_chain_type) in zip(
polymer_ligand_bonds.chain_id, polymer_ligand_bonds.chain_type
):
poly_info, ligand_info = _get_polymer_and_ligand_chain_ids_and_types(
start_chain_id, end_chain_id, start_chain_type, end_chain_type
)
polymer_chain_id, polymer_chain_type = poly_info
ligand_chain_id, _ = ligand_info
# Join the ligand chain to the polymer chain.
ligand_res_ids = struc.filter(chain_id=ligand_chain_id).res_id
new_res_ids = ligand_res_ids + len(struc.all_residues[polymer_chain_id])
res_id_mappings_for_bond_chain_pair[(polymer_chain_id, ligand_chain_id)] = (
ResIdMapping(old_res_ids=ligand_res_ids, new_res_ids=new_res_ids)
)
chain_groups = []
chain_group_ids = []
chain_group_types = []
for chain_id, chain_type in zip(
struc.chains_table.id, struc.chains_table.type
):
if chain_id == ligand_chain_id:
continue
elif chain_id == polymer_chain_id:
chain_groups.append([polymer_chain_id, ligand_chain_id])
chain_group_ids.append(polymer_chain_id)
chain_group_types.append(polymer_chain_type)
else:
chain_groups.append([chain_id])
chain_group_ids.append(chain_id)
chain_group_types.append(chain_type)
struc = struc.merge_chains(
chain_groups=chain_groups,
chain_group_ids=chain_group_ids,
chain_group_types=chain_group_types,
)
return struc, res_id_mappings_for_bond_chain_pair
def _has_only_single_bond_from_each_chain(
polymer_ligand_bonds: atom_layout.AtomLayout,
) -> bool:
"""Checks that there is at most one bond from each chain."""
chain_ids = []
for chains in polymer_ligand_bonds.chain_id:
chain_ids.extend(chains)
if len(chain_ids) != len(set(chain_ids)):
return False
return True
def _get_polymer_and_ligand_chain_ids_and_types(
start_chain_id: str,
end_chain_id: str,
start_chain_type: str,
end_chain_type: str,
) -> tuple[tuple[str, str], tuple[str, str]]:
"""Finds polymer and ligand chain ids from chain types."""
if (
start_chain_type in mmcif_names.POLYMER_CHAIN_TYPES
and end_chain_type in mmcif_names.LIGAND_CHAIN_TYPES
):
return (start_chain_id, start_chain_type), (end_chain_id, end_chain_type)
elif (
start_chain_type in mmcif_names.LIGAND_CHAIN_TYPES
and end_chain_type in mmcif_names.POLYMER_CHAIN_TYPES
):
return (end_chain_id, end_chain_type), (start_chain_id, start_chain_type)
else:
raise ValueError(
'This code only handles PTM-bonds from polymer chain to ligands.'
)
def _get_polymer_dim(polymer_ligand_bonds: atom_layout.AtomLayout) -> int:
"""Gets polymer dimension from the polymer-ligand bond layout."""
start_chain_types = []
end_chain_types = []
for start_chain_type, end_chain_type in polymer_ligand_bonds.chain_type:
start_chain_types.append(start_chain_type)
end_chain_types.append(end_chain_type)
if set(start_chain_types).issubset(
set(mmcif_names.POLYMER_CHAIN_TYPES)
) and set(end_chain_types).issubset(set(mmcif_names.LIGAND_CHAIN_TYPES)):
return 0
elif set(start_chain_types).issubset(mmcif_names.LIGAND_CHAIN_TYPES) and set(
end_chain_types
).issubset(set(mmcif_names.POLYMER_CHAIN_TYPES)):
return 1
else:
raise ValueError(
'Polymer and ligand dimensions are not consistent within the structure.'
)
def _has_multiple_ligands_bonded_to_one_polymer(polymer_ligand_bonds):
"""Checks if there are multiple ligands bonded to one polymer."""
polymer_dim = _get_polymer_dim(polymer_ligand_bonds)
polymer_chain_ids = [
chains[polymer_dim] for chains in polymer_ligand_bonds.chain_id
]
if len(polymer_chain_ids) != len(set(polymer_chain_ids)):
return True
return False
def _has_multiple_polymers_bonded_to_one_ligand(polymer_ligand_bonds):
"""Checks if there are multiple polymer chains bonded to one ligand."""
polymer_dim = _get_polymer_dim(polymer_ligand_bonds)
ligand_dim = 1 - polymer_dim
ligand_chain_ids = [
chains[ligand_dim] for chains in polymer_ligand_bonds.chain_id
]
if len(ligand_chain_ids) != len(set(ligand_chain_ids)):
return True
return False
def _sort_polymer_ligand_bonds_by_polymer_chain_and_res_id(
polymer_ligand_bonds,
):
"""Sorts bonds by res_id (for when a polymer chain has multiple bonded ligands)."""
polymer_dim = _get_polymer_dim(polymer_ligand_bonds)
polymer_chain_ids = [
chains[polymer_dim] for chains in polymer_ligand_bonds.chain_id
]
polymer_res_ids = [res[polymer_dim] for res in polymer_ligand_bonds.res_id]
polymer_chain_and_res_id = zip(polymer_chain_ids, polymer_res_ids)
sorted_indices = [
idx
for idx, _ in sorted(
enumerate(polymer_chain_and_res_id), key=lambda x: x[1]
)
]
return polymer_ligand_bonds[sorted_indices]

View File

@@ -1,140 +0,0 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Functions relating to spatial locations of atoms within a structure."""
from collections.abc import Collection, Sequence
from alphafold3 import structure
from alphafold3.structure import mmcif
import numpy as np
import scipy
def _make_atom_has_clash_mask(
kd_query_result: np.ndarray,
struc: structure.Structure,
ignore_chains: Collection[str],
) -> np.ndarray:
"""Returns a boolean NumPy array representing whether each atom has a clash.
Args:
kd_query_result: NumPy array containing N-atoms arrays, each array
containing indices to atoms that clash with the N'th atom.
struc: Structure over which clashes were detected.
ignore_chains: Collection of chains that should not be considered clashing.
A boolean NumPy array of length N atoms.
"""
atom_is_clashing = np.zeros((struc.num_atoms,), dtype=bool)
for atom_index, clashes in enumerate(kd_query_result):
chain_i = struc.chain_id[atom_index]
if chain_i in ignore_chains:
continue
islig_i = struc.is_ligand_mask[atom_index]
for clashing_atom_index in clashes:
chain_c = struc.chain_id[clashing_atom_index]
if chain_c in ignore_chains:
continue
islig_c = struc.is_ligand_mask[clashing_atom_index]
if (
clashing_atom_index == atom_index
or chain_i == chain_c
or islig_i != islig_c
):
# Ignore clashes within chain or between ligand and polymer.
continue
atom_is_clashing[atom_index] = True
return atom_is_clashing
def find_clashing_chains(
struc: structure.Structure,
clash_thresh_angstrom: float = 1.7,
clash_thresh_fraction: float = 0.3,
) -> Sequence[str]:
"""Finds chains that clash with others.
Clashes are defined by polymer backbone atoms and all ligand atoms.
Ligand-polymer clashes are not dropped.
Will not find clashes if all coordinates are 0. Coordinates are all 0s if
the structure is generated from sequences only, as done for inference in
dendro for example.
Args:
struc: The structure defining the chains and atom positions.
clash_thresh_angstrom: Below this distance, atoms are considered clashing.
clash_thresh_fraction: Chains with more than this fraction of their atoms
considered clashing will be dropped. This value should be in the range (0,
1].
Returns:
A sequence of chain ids for chains that clash.
Raises:
ValueError: If `clash_thresh_fraction` is not in range (0,1].
"""
if not 0 < clash_thresh_fraction <= 1:
raise ValueError('clash_thresh_fraction must be in range (0,1]')
struc_backbone = struc.filter_polymers_to_single_atom_per_res()
if struc_backbone.num_chains == 0:
return []
# If the coordinates are all 0, do not search for clashes.
if not np.any(struc_backbone.coords):
return []
coord_kdtree = scipy.spatial.cKDTree(struc_backbone.coords)
# For each atom coordinate, find all atoms within the clash thresh radius.
clashing_per_atom = coord_kdtree.query_ball_point(
struc_backbone.coords, r=clash_thresh_angstrom
)
chain_ids = struc_backbone.chains
if struc_backbone.atom_occupancy is not None:
chain_occupancy = np.array([
np.mean(struc_backbone.atom_occupancy[start:end])
for start, end in struc_backbone.iter_chain_ranges()
])
else:
chain_occupancy = None
# Remove chains until no more significant clashing.
chains_to_remove = set()
for _ in range(len(chain_ids)):
# Calculate maximally clashing.
atom_has_clash = _make_atom_has_clash_mask(
clashing_per_atom, struc_backbone, chains_to_remove
)
clashes_per_chain = np.array([
atom_has_clash[start:end].mean()
for start, end in struc_backbone.iter_chain_ranges()
])
max_clash = np.max(clashes_per_chain)
if max_clash <= clash_thresh_fraction:
# None of the remaining chains exceed the clash fraction threshold, so
# we can exit.
break
# Greedily remove worst with the lowest occupancy.
most_clashes = np.nonzero(clashes_per_chain == max_clash)[0]
if chain_occupancy is not None:
occupancy_clashing = chain_occupancy[most_clashes]
last_lowest_occupancy = (
len(occupancy_clashing) - np.argmin(occupancy_clashing[::-1]) - 1
)
worst_and_last = most_clashes[last_lowest_occupancy]
else:
worst_and_last = most_clashes[-1]
chains_to_remove.add(chain_ids[worst_and_last])
return sorted(chains_to_remove, key=mmcif.str_id_to_int_id)