diff --git a/src/alphafold3/model/pipeline/pipeline.py b/src/alphafold3/model/pipeline/pipeline.py index f8c4367..39e5beb 100644 --- a/src/alphafold3/model/pipeline/pipeline.py +++ b/src/alphafold3/model/pipeline/pipeline.py @@ -103,8 +103,6 @@ class WholePdbPipeline: date the model coordinates can be used as a fallback. max_templates: The maximum number of templates to send through the network set to 0 to switch off templates. - filter_clashes: If true then will remove clashing chains. - filter_crystal_aids: If true ligands in the cryal aid list are removed. max_paired_sequence_per_species: The maximum number of sequences per species that will be used for MSA pairing. drop_ligand_leaving_atoms: Flag for handling leaving atoms for ligands. @@ -115,8 +113,6 @@ class WholePdbPipeline: atom_cross_att_keys_subset_size: keys subset size in atom cross attention flatten_non_standard_residues: Whether to expand non-standard polymer residues into flat-atom format. - remove_nonsymmetric_bonds: Whether to remove nonsymmetric bonds from - symmetric polymer chains. deterministic_frames: Whether to use fixed-seed reference positions to construct deterministic frames. resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired @@ -136,15 +132,12 @@ class WholePdbPipeline: max_template_date: datetime.date | None = None ref_max_modified_date: datetime.date | None = None max_templates: int = 4 - filter_clashes: bool = False - filter_crystal_aids: bool = False max_paired_sequence_per_species: int = 600 drop_ligand_leaving_atoms: bool = True average_num_atoms_per_token: int = 24 atom_cross_att_queries_subset_size: int = 32 atom_cross_att_keys_subset_size: int = 128 flatten_non_standard_residues: bool = True - remove_nonsymmetric_bonds: bool = False deterministic_frames: bool = True conformer_max_iterations: int | None = None resolve_msa_overlaps: bool = True @@ -176,13 +169,11 @@ class WholePdbPipeline: logging.info('Processing %s', logging_name) # Clean structure. - cleaned_struc, cleaning_metadata = structure_cleaning.clean_structure( + cleaned_struc = structure_cleaning.clean_structure( struct, ccd=ccd, drop_non_standard_atoms=True, drop_missing_sequence=True, - filter_clashes=self._config.filter_clashes, - filter_crystal_aids=self._config.filter_crystal_aids, filter_waters=True, filter_hydrogens=True, filter_leaving_atoms=self._config.drop_ligand_leaving_atoms, @@ -190,21 +181,9 @@ class WholePdbPipeline: covalent_bonds_only=True, remove_polymer_polymer_bonds=True, remove_bad_bonds=True, - remove_nonsymmetric_bonds=self._config.remove_nonsymmetric_bonds, ) - num_clashing_chains_removed = cleaning_metadata[ - 'num_clashing_chains_removed' - ] - - if num_clashing_chains_removed: - logging.info( - 'Removed %d clashing chains from %s', - num_clashing_chains_removed, - logging_name, - ) - - # No chains after fixes + # No chains after cleaning. if cleaned_struc.num_chains == 0: raise MmcifNumChainsError(f'{logging_name}: No chains in structure!') diff --git a/src/alphafold3/model/pipeline/structure_cleaning.py b/src/alphafold3/model/pipeline/structure_cleaning.py index 4043a3f..531f635 100644 --- a/src/alphafold3/model/pipeline/structure_cleaning.py +++ b/src/alphafold3/model/pipeline/structure_cleaning.py @@ -10,8 +10,6 @@ """Prepare PDB structure for training or inference.""" -from typing import Any - from absl import logging from alphafold3 import structure from alphafold3.constants import chemical_component_sets @@ -19,8 +17,6 @@ from alphafold3.constants import chemical_components from alphafold3.constants import mmcif_names from alphafold3.model.atom_layout import atom_layout from alphafold3.model.pipeline import inter_chain_bonds -from alphafold3.model.scoring import covalent_bond_cleaning -from alphafold3.structure import sterics import numpy as np @@ -70,9 +66,7 @@ def clean_structure( ccd: chemical_components.Ccd, *, drop_missing_sequence: bool, - filter_clashes: bool, drop_non_standard_atoms: bool, - filter_crystal_aids: bool, filter_waters: bool, filter_hydrogens: bool, filter_leaving_atoms: bool, @@ -80,17 +74,14 @@ def clean_structure( covalent_bonds_only: bool, remove_polymer_polymer_bonds: bool, remove_bad_bonds: bool, - remove_nonsymmetric_bonds: bool, -) -> tuple[structure.Structure, dict[str, Any]]: - """Cleans structure. +) -> structure.Structure: + """Returns a cleaned version of the input structure. Args: struc: Structure to clean. ccd: The chemical components dictionary. drop_missing_sequence: Whether to drop chains without specified sequences. - filter_clashes: Whether to drop clashing chains. drop_non_standard_atoms: Whether to drop non CCD standard atoms. - filter_crystal_aids: Whether to drop ligands in the crystal aid set. filter_waters: Whether to drop water chains. filter_hydrogens: Whether to drop hyrdogen atoms. filter_leaving_atoms: Whether to drop leaving atoms based on heuristics. @@ -99,43 +90,13 @@ def clean_structure( covalent_bonds_only: Only include covalent bonds. remove_polymer_polymer_bonds: Remove polymer-polymer bonds. remove_bad_bonds: Whether to remove badly bonded ligands. - remove_nonsymmetric_bonds: Whether to remove nonsymmetric polymer-ligand - bonds from symmetric polymer chains. - - Returns: - Tuple of structure and metadata dict. The metadata dict has - information about what was cleaned from the original. """ - metadata = {} - # Crop crystallization aids. - if ( - filter_crystal_aids - and struc.structure_method in mmcif_names.CRYSTALLIZATION_METHODS - ): - struc = struc.filter_out( - res_name=chemical_component_sets.COMMON_CRYSTALLIZATION_AIDS - ) - # Drop chains without specified sequences. if drop_missing_sequence: chains_with_unk_sequence = struc.find_chains_with_unknown_sequence() - num_with_unk_sequence = len(chains_with_unk_sequence) if chains_with_unk_sequence: struc = struc.filter_out(chain_id=chains_with_unk_sequence) - else: - num_with_unk_sequence = 0 - metadata['num_with_unk_sequence'] = num_with_unk_sequence - - # Remove intersecting chains. - if filter_clashes and struc.num_chains > 1: - clashing_chains = sterics.find_clashing_chains(struc) - if clashing_chains: - struc = struc.filter_out(chain_id=clashing_chains) - else: - clashing_chains = [] - metadata['num_clashing_chains_removed'] = len(clashing_chains) - metadata['chains_removed'] = clashing_chains # Drop non-standard atoms if drop_non_standard_atoms: @@ -268,27 +229,7 @@ def clean_structure( new_bonds = structure.Bonds.make_empty() struc = struc.copy_and_update(bonds=new_bonds) - if struc.bonds and remove_nonsymmetric_bonds: - # Check for asymmetric polymer-ligand bonds and remove if these exist. - polymer_ligand_bonds = inter_chain_bonds.get_polymer_ligand_bonds( - struc, - only_glycan_ligands=False, - ) - if polymer_ligand_bonds: - if covalent_bond_cleaning.has_nonsymmetric_bonds_on_symmetric_polymer_chains( - struc, polymer_ligand_bonds - ): - from_atom_idxs, dest_atom_idxs = struc.bonds.get_atom_indices( - struc.atom_key - ) - poly_chain_types = list(mmcif_names.POLYMER_CHAIN_TYPES) - is_polymer_bond = np.logical_or( - np.isin(struc.chain_type[from_atom_idxs], poly_chain_types), - np.isin(struc.chain_type[dest_atom_idxs], poly_chain_types), - ) - struc = struc.copy_and_update(bonds=struc.bonds[~is_polymer_bond]) - - return struc, metadata + return struc def create_empty_output_struc_and_layout( diff --git a/src/alphafold3/model/scoring/covalent_bond_cleaning.py b/src/alphafold3/model/scoring/covalent_bond_cleaning.py deleted file mode 100644 index 3763401..0000000 --- a/src/alphafold3/model/scoring/covalent_bond_cleaning.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright 2024 DeepMind Technologies Limited -# -# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of -# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ -# -# To request access to the AlphaFold 3 model parameters, follow the process set -# out at https://github.com/google-deepmind/alphafold3. You may only use these -# if received directly from Google. Use is subject to terms of use available at -# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md - -"""Some methods to compute metrics for PTMs.""" - -import collections -from collections.abc import Mapping -import dataclasses - -from alphafold3 import structure -from alphafold3.constants import mmcif_names -from alphafold3.model.atom_layout import atom_layout -import numpy as np - - -@dataclasses.dataclass(frozen=True) -class ResIdMapping: - old_res_ids: np.ndarray - new_res_ids: np.ndarray - - -def _count_symmetric_chains(struc: structure.Structure) -> Mapping[str, int]: - """Returns a dict with each chain ID and count.""" - chain_res_name_sequence_from_chain_id = struc.chain_res_name_sequence( - include_missing_residues=True, fix_non_standard_polymer_res=False - ) - counts_for_chain_res_name_sequence = collections.Counter( - chain_res_name_sequence_from_chain_id.values() - ) - chain_symmetric_count = {} - for chain_id, chain_res_name in chain_res_name_sequence_from_chain_id.items(): - chain_symmetric_count[chain_id] = counts_for_chain_res_name_sequence[ - chain_res_name - ] - return chain_symmetric_count - - -def has_nonsymmetric_bonds_on_symmetric_polymer_chains( - struc: structure.Structure, polymer_ligand_bonds: atom_layout.AtomLayout -) -> bool: - """Returns true if nonsymmetric bonds found on polymer chains.""" - try: - _get_polymer_dim(polymer_ligand_bonds) - except ValueError: - return True - if _has_non_polymer_ligand_ptm_bonds(polymer_ligand_bonds): - return True - if _has_multiple_polymers_bonded_to_one_ligand(polymer_ligand_bonds): - return True - combined_struc, _ = _combine_polymer_ligand_ptm_chains( - struc, polymer_ligand_bonds - ) - struc = struc.filter(chain_type=mmcif_names.POLYMER_CHAIN_TYPES) - combined_struc = combined_struc.filter( - chain_type=mmcif_names.POLYMER_CHAIN_TYPES - ) - return _count_symmetric_chains(struc) != _count_symmetric_chains( - combined_struc - ) - - -def _has_non_polymer_ligand_ptm_bonds( - polymer_ligand_bonds: atom_layout.AtomLayout, -): - """Checks if all bonds are between a polymer chain and a ligand chain type.""" - for start_chain_type, end_chain_type in polymer_ligand_bonds.chain_type: - if ( - start_chain_type in mmcif_names.POLYMER_CHAIN_TYPES - and end_chain_type in mmcif_names.LIGAND_CHAIN_TYPES - ): - continue - elif ( - start_chain_type in mmcif_names.LIGAND_CHAIN_TYPES - and end_chain_type in mmcif_names.POLYMER_CHAIN_TYPES - ): - continue - else: - return True - return False - - -def _combine_polymer_ligand_ptm_chains( - struc: structure.Structure, - polymer_ligand_bonds: atom_layout.AtomLayout, -) -> tuple[structure.Structure, dict[tuple[str, str], ResIdMapping]]: - """Combines the ptm polymer-ligand chains together. - - This will prevent them from being permuted away from each other when chains - are matched to the ground truth. This function also returns the res_id mapping - from the separate ligand res_ids to their res_ids in the combined - polymer-ligand chain; this information is needed to later separate the - combined polymer-ligand chain. - - Args: - struc: Structure to be modified. - polymer_ligand_bonds: AtomLayout with polymer-ligand bond info. - - Returns: - A tuple of a Structure with each ptm polymer-ligand chain relabelled as one - chain and a dict from bond chain pair to the res_id mapping. - """ - if not _has_only_single_bond_from_each_chain(polymer_ligand_bonds): - if _has_multiple_ligands_bonded_to_one_polymer(polymer_ligand_bonds): - # For structures where a polymer chain is connected to multiple ligands, - # we need to sort the multiple bonds from the same chain by res_id to - # ensure that the combined polymer-ligand chain will always be the same - # when you have repeated symmetric polymer-ligand chains. - polymer_ligand_bonds = ( - _sort_polymer_ligand_bonds_by_polymer_chain_and_res_id( - polymer_ligand_bonds - ) - ) - else: - raise ValueError( - 'Code cannot handle multiple bonds from one chain unless' - ' its several ligands bonded to a polymer.' - ) - res_id_mappings_for_bond_chain_pair = dict() - for (start_chain_id, end_chain_id), (start_chain_type, end_chain_type) in zip( - polymer_ligand_bonds.chain_id, polymer_ligand_bonds.chain_type - ): - poly_info, ligand_info = _get_polymer_and_ligand_chain_ids_and_types( - start_chain_id, end_chain_id, start_chain_type, end_chain_type - ) - polymer_chain_id, polymer_chain_type = poly_info - ligand_chain_id, _ = ligand_info - - # Join the ligand chain to the polymer chain. - ligand_res_ids = struc.filter(chain_id=ligand_chain_id).res_id - new_res_ids = ligand_res_ids + len(struc.all_residues[polymer_chain_id]) - res_id_mappings_for_bond_chain_pair[(polymer_chain_id, ligand_chain_id)] = ( - ResIdMapping(old_res_ids=ligand_res_ids, new_res_ids=new_res_ids) - ) - chain_groups = [] - chain_group_ids = [] - chain_group_types = [] - for chain_id, chain_type in zip( - struc.chains_table.id, struc.chains_table.type - ): - if chain_id == ligand_chain_id: - continue - elif chain_id == polymer_chain_id: - chain_groups.append([polymer_chain_id, ligand_chain_id]) - chain_group_ids.append(polymer_chain_id) - chain_group_types.append(polymer_chain_type) - else: - chain_groups.append([chain_id]) - chain_group_ids.append(chain_id) - chain_group_types.append(chain_type) - - struc = struc.merge_chains( - chain_groups=chain_groups, - chain_group_ids=chain_group_ids, - chain_group_types=chain_group_types, - ) - - return struc, res_id_mappings_for_bond_chain_pair - - -def _has_only_single_bond_from_each_chain( - polymer_ligand_bonds: atom_layout.AtomLayout, -) -> bool: - """Checks that there is at most one bond from each chain.""" - chain_ids = [] - for chains in polymer_ligand_bonds.chain_id: - chain_ids.extend(chains) - if len(chain_ids) != len(set(chain_ids)): - return False - return True - - -def _get_polymer_and_ligand_chain_ids_and_types( - start_chain_id: str, - end_chain_id: str, - start_chain_type: str, - end_chain_type: str, -) -> tuple[tuple[str, str], tuple[str, str]]: - """Finds polymer and ligand chain ids from chain types.""" - if ( - start_chain_type in mmcif_names.POLYMER_CHAIN_TYPES - and end_chain_type in mmcif_names.LIGAND_CHAIN_TYPES - ): - return (start_chain_id, start_chain_type), (end_chain_id, end_chain_type) - elif ( - start_chain_type in mmcif_names.LIGAND_CHAIN_TYPES - and end_chain_type in mmcif_names.POLYMER_CHAIN_TYPES - ): - return (end_chain_id, end_chain_type), (start_chain_id, start_chain_type) - else: - raise ValueError( - 'This code only handles PTM-bonds from polymer chain to ligands.' - ) - - -def _get_polymer_dim(polymer_ligand_bonds: atom_layout.AtomLayout) -> int: - """Gets polymer dimension from the polymer-ligand bond layout.""" - start_chain_types = [] - end_chain_types = [] - for start_chain_type, end_chain_type in polymer_ligand_bonds.chain_type: - start_chain_types.append(start_chain_type) - end_chain_types.append(end_chain_type) - if set(start_chain_types).issubset( - set(mmcif_names.POLYMER_CHAIN_TYPES) - ) and set(end_chain_types).issubset(set(mmcif_names.LIGAND_CHAIN_TYPES)): - return 0 - elif set(start_chain_types).issubset(mmcif_names.LIGAND_CHAIN_TYPES) and set( - end_chain_types - ).issubset(set(mmcif_names.POLYMER_CHAIN_TYPES)): - return 1 - else: - raise ValueError( - 'Polymer and ligand dimensions are not consistent within the structure.' - ) - - -def _has_multiple_ligands_bonded_to_one_polymer(polymer_ligand_bonds): - """Checks if there are multiple ligands bonded to one polymer.""" - polymer_dim = _get_polymer_dim(polymer_ligand_bonds) - polymer_chain_ids = [ - chains[polymer_dim] for chains in polymer_ligand_bonds.chain_id - ] - if len(polymer_chain_ids) != len(set(polymer_chain_ids)): - return True - return False - - -def _has_multiple_polymers_bonded_to_one_ligand(polymer_ligand_bonds): - """Checks if there are multiple polymer chains bonded to one ligand.""" - polymer_dim = _get_polymer_dim(polymer_ligand_bonds) - ligand_dim = 1 - polymer_dim - ligand_chain_ids = [ - chains[ligand_dim] for chains in polymer_ligand_bonds.chain_id - ] - if len(ligand_chain_ids) != len(set(ligand_chain_ids)): - return True - return False - - -def _sort_polymer_ligand_bonds_by_polymer_chain_and_res_id( - polymer_ligand_bonds, -): - """Sorts bonds by res_id (for when a polymer chain has multiple bonded ligands).""" - - polymer_dim = _get_polymer_dim(polymer_ligand_bonds) - - polymer_chain_ids = [ - chains[polymer_dim] for chains in polymer_ligand_bonds.chain_id - ] - polymer_res_ids = [res[polymer_dim] for res in polymer_ligand_bonds.res_id] - - polymer_chain_and_res_id = zip(polymer_chain_ids, polymer_res_ids) - sorted_indices = [ - idx - for idx, _ in sorted( - enumerate(polymer_chain_and_res_id), key=lambda x: x[1] - ) - ] - return polymer_ligand_bonds[sorted_indices] diff --git a/src/alphafold3/structure/sterics.py b/src/alphafold3/structure/sterics.py deleted file mode 100644 index 34f99a7..0000000 --- a/src/alphafold3/structure/sterics.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2024 DeepMind Technologies Limited -# -# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of -# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ -# -# To request access to the AlphaFold 3 model parameters, follow the process set -# out at https://github.com/google-deepmind/alphafold3. You may only use these -# if received directly from Google. Use is subject to terms of use available at -# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md - -"""Functions relating to spatial locations of atoms within a structure.""" - -from collections.abc import Collection, Sequence - -from alphafold3 import structure -from alphafold3.structure import mmcif -import numpy as np -import scipy - - -def _make_atom_has_clash_mask( - kd_query_result: np.ndarray, - struc: structure.Structure, - ignore_chains: Collection[str], -) -> np.ndarray: - """Returns a boolean NumPy array representing whether each atom has a clash. - - Args: - kd_query_result: NumPy array containing N-atoms arrays, each array - containing indices to atoms that clash with the N'th atom. - struc: Structure over which clashes were detected. - ignore_chains: Collection of chains that should not be considered clashing. - A boolean NumPy array of length N atoms. - """ - atom_is_clashing = np.zeros((struc.num_atoms,), dtype=bool) - for atom_index, clashes in enumerate(kd_query_result): - chain_i = struc.chain_id[atom_index] - if chain_i in ignore_chains: - continue - islig_i = struc.is_ligand_mask[atom_index] - for clashing_atom_index in clashes: - chain_c = struc.chain_id[clashing_atom_index] - if chain_c in ignore_chains: - continue - islig_c = struc.is_ligand_mask[clashing_atom_index] - if ( - clashing_atom_index == atom_index - or chain_i == chain_c - or islig_i != islig_c - ): - # Ignore clashes within chain or between ligand and polymer. - continue - atom_is_clashing[atom_index] = True - return atom_is_clashing - - -def find_clashing_chains( - struc: structure.Structure, - clash_thresh_angstrom: float = 1.7, - clash_thresh_fraction: float = 0.3, -) -> Sequence[str]: - """Finds chains that clash with others. - - Clashes are defined by polymer backbone atoms and all ligand atoms. - Ligand-polymer clashes are not dropped. - - Will not find clashes if all coordinates are 0. Coordinates are all 0s if - the structure is generated from sequences only, as done for inference in - dendro for example. - - Args: - struc: The structure defining the chains and atom positions. - clash_thresh_angstrom: Below this distance, atoms are considered clashing. - clash_thresh_fraction: Chains with more than this fraction of their atoms - considered clashing will be dropped. This value should be in the range (0, - 1]. - - Returns: - A sequence of chain ids for chains that clash. - - Raises: - ValueError: If `clash_thresh_fraction` is not in range (0,1]. - """ - if not 0 < clash_thresh_fraction <= 1: - raise ValueError('clash_thresh_fraction must be in range (0,1]') - - struc_backbone = struc.filter_polymers_to_single_atom_per_res() - if struc_backbone.num_chains == 0: - return [] - - # If the coordinates are all 0, do not search for clashes. - if not np.any(struc_backbone.coords): - return [] - - coord_kdtree = scipy.spatial.cKDTree(struc_backbone.coords) - - # For each atom coordinate, find all atoms within the clash thresh radius. - clashing_per_atom = coord_kdtree.query_ball_point( - struc_backbone.coords, r=clash_thresh_angstrom - ) - chain_ids = struc_backbone.chains - if struc_backbone.atom_occupancy is not None: - chain_occupancy = np.array([ - np.mean(struc_backbone.atom_occupancy[start:end]) - for start, end in struc_backbone.iter_chain_ranges() - ]) - else: - chain_occupancy = None - - # Remove chains until no more significant clashing. - chains_to_remove = set() - for _ in range(len(chain_ids)): - # Calculate maximally clashing. - atom_has_clash = _make_atom_has_clash_mask( - clashing_per_atom, struc_backbone, chains_to_remove - ) - clashes_per_chain = np.array([ - atom_has_clash[start:end].mean() - for start, end in struc_backbone.iter_chain_ranges() - ]) - max_clash = np.max(clashes_per_chain) - if max_clash <= clash_thresh_fraction: - # None of the remaining chains exceed the clash fraction threshold, so - # we can exit. - break - - # Greedily remove worst with the lowest occupancy. - most_clashes = np.nonzero(clashes_per_chain == max_clash)[0] - if chain_occupancy is not None: - occupancy_clashing = chain_occupancy[most_clashes] - last_lowest_occupancy = ( - len(occupancy_clashing) - np.argmin(occupancy_clashing[::-1]) - 1 - ) - worst_and_last = most_clashes[last_lowest_occupancy] - else: - worst_and_last = most_clashes[-1] - - chains_to_remove.add(chain_ids[worst_and_last]) - - return sorted(chains_to_remove, key=mmcif.str_id_to_int_id)