mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2026-06-02 11:54:36 +08:00
Clean up structure cleaning - remove unused code paths
PiperOrigin-RevId: 903711715 Change-Id: I953f1b27c6a0b30cc79a6cbd3ee3503d1f24a7ec
This commit is contained in:
committed by
Copybara-Service
parent
cfeeedd24d
commit
e9adcd4dbf
@@ -103,8 +103,6 @@ class WholePdbPipeline:
|
||||
date the model coordinates can be used as a fallback.
|
||||
max_templates: The maximum number of templates to send through the network
|
||||
set to 0 to switch off templates.
|
||||
filter_clashes: If true then will remove clashing chains.
|
||||
filter_crystal_aids: If true ligands in the cryal aid list are removed.
|
||||
max_paired_sequence_per_species: The maximum number of sequences per
|
||||
species that will be used for MSA pairing.
|
||||
drop_ligand_leaving_atoms: Flag for handling leaving atoms for ligands.
|
||||
@@ -115,8 +113,6 @@ class WholePdbPipeline:
|
||||
atom_cross_att_keys_subset_size: keys subset size in atom cross attention
|
||||
flatten_non_standard_residues: Whether to expand non-standard polymer
|
||||
residues into flat-atom format.
|
||||
remove_nonsymmetric_bonds: Whether to remove nonsymmetric bonds from
|
||||
symmetric polymer chains.
|
||||
deterministic_frames: Whether to use fixed-seed reference positions to
|
||||
construct deterministic frames.
|
||||
resolve_msa_overlaps: Whether to deduplicate unpaired MSA against paired
|
||||
@@ -136,15 +132,12 @@ class WholePdbPipeline:
|
||||
max_template_date: datetime.date | None = None
|
||||
ref_max_modified_date: datetime.date | None = None
|
||||
max_templates: int = 4
|
||||
filter_clashes: bool = False
|
||||
filter_crystal_aids: bool = False
|
||||
max_paired_sequence_per_species: int = 600
|
||||
drop_ligand_leaving_atoms: bool = True
|
||||
average_num_atoms_per_token: int = 24
|
||||
atom_cross_att_queries_subset_size: int = 32
|
||||
atom_cross_att_keys_subset_size: int = 128
|
||||
flatten_non_standard_residues: bool = True
|
||||
remove_nonsymmetric_bonds: bool = False
|
||||
deterministic_frames: bool = True
|
||||
conformer_max_iterations: int | None = None
|
||||
resolve_msa_overlaps: bool = True
|
||||
@@ -176,13 +169,11 @@ class WholePdbPipeline:
|
||||
logging.info('Processing %s', logging_name)
|
||||
|
||||
# Clean structure.
|
||||
cleaned_struc, cleaning_metadata = structure_cleaning.clean_structure(
|
||||
cleaned_struc = structure_cleaning.clean_structure(
|
||||
struct,
|
||||
ccd=ccd,
|
||||
drop_non_standard_atoms=True,
|
||||
drop_missing_sequence=True,
|
||||
filter_clashes=self._config.filter_clashes,
|
||||
filter_crystal_aids=self._config.filter_crystal_aids,
|
||||
filter_waters=True,
|
||||
filter_hydrogens=True,
|
||||
filter_leaving_atoms=self._config.drop_ligand_leaving_atoms,
|
||||
@@ -190,21 +181,9 @@ class WholePdbPipeline:
|
||||
covalent_bonds_only=True,
|
||||
remove_polymer_polymer_bonds=True,
|
||||
remove_bad_bonds=True,
|
||||
remove_nonsymmetric_bonds=self._config.remove_nonsymmetric_bonds,
|
||||
)
|
||||
|
||||
num_clashing_chains_removed = cleaning_metadata[
|
||||
'num_clashing_chains_removed'
|
||||
]
|
||||
|
||||
if num_clashing_chains_removed:
|
||||
logging.info(
|
||||
'Removed %d clashing chains from %s',
|
||||
num_clashing_chains_removed,
|
||||
logging_name,
|
||||
)
|
||||
|
||||
# No chains after fixes
|
||||
# No chains after cleaning.
|
||||
if cleaned_struc.num_chains == 0:
|
||||
raise MmcifNumChainsError(f'{logging_name}: No chains in structure!')
|
||||
|
||||
|
||||
@@ -10,8 +10,6 @@
|
||||
|
||||
"""Prepare PDB structure for training or inference."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from absl import logging
|
||||
from alphafold3 import structure
|
||||
from alphafold3.constants import chemical_component_sets
|
||||
@@ -19,8 +17,6 @@ from alphafold3.constants import chemical_components
|
||||
from alphafold3.constants import mmcif_names
|
||||
from alphafold3.model.atom_layout import atom_layout
|
||||
from alphafold3.model.pipeline import inter_chain_bonds
|
||||
from alphafold3.model.scoring import covalent_bond_cleaning
|
||||
from alphafold3.structure import sterics
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -70,9 +66,7 @@ def clean_structure(
|
||||
ccd: chemical_components.Ccd,
|
||||
*,
|
||||
drop_missing_sequence: bool,
|
||||
filter_clashes: bool,
|
||||
drop_non_standard_atoms: bool,
|
||||
filter_crystal_aids: bool,
|
||||
filter_waters: bool,
|
||||
filter_hydrogens: bool,
|
||||
filter_leaving_atoms: bool,
|
||||
@@ -80,17 +74,14 @@ def clean_structure(
|
||||
covalent_bonds_only: bool,
|
||||
remove_polymer_polymer_bonds: bool,
|
||||
remove_bad_bonds: bool,
|
||||
remove_nonsymmetric_bonds: bool,
|
||||
) -> tuple[structure.Structure, dict[str, Any]]:
|
||||
"""Cleans structure.
|
||||
) -> structure.Structure:
|
||||
"""Returns a cleaned version of the input structure.
|
||||
|
||||
Args:
|
||||
struc: Structure to clean.
|
||||
ccd: The chemical components dictionary.
|
||||
drop_missing_sequence: Whether to drop chains without specified sequences.
|
||||
filter_clashes: Whether to drop clashing chains.
|
||||
drop_non_standard_atoms: Whether to drop non CCD standard atoms.
|
||||
filter_crystal_aids: Whether to drop ligands in the crystal aid set.
|
||||
filter_waters: Whether to drop water chains.
|
||||
filter_hydrogens: Whether to drop hyrdogen atoms.
|
||||
filter_leaving_atoms: Whether to drop leaving atoms based on heuristics.
|
||||
@@ -99,43 +90,13 @@ def clean_structure(
|
||||
covalent_bonds_only: Only include covalent bonds.
|
||||
remove_polymer_polymer_bonds: Remove polymer-polymer bonds.
|
||||
remove_bad_bonds: Whether to remove badly bonded ligands.
|
||||
remove_nonsymmetric_bonds: Whether to remove nonsymmetric polymer-ligand
|
||||
bonds from symmetric polymer chains.
|
||||
|
||||
Returns:
|
||||
Tuple of structure and metadata dict. The metadata dict has
|
||||
information about what was cleaned from the original.
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
# Crop crystallization aids.
|
||||
if (
|
||||
filter_crystal_aids
|
||||
and struc.structure_method in mmcif_names.CRYSTALLIZATION_METHODS
|
||||
):
|
||||
struc = struc.filter_out(
|
||||
res_name=chemical_component_sets.COMMON_CRYSTALLIZATION_AIDS
|
||||
)
|
||||
|
||||
# Drop chains without specified sequences.
|
||||
if drop_missing_sequence:
|
||||
chains_with_unk_sequence = struc.find_chains_with_unknown_sequence()
|
||||
num_with_unk_sequence = len(chains_with_unk_sequence)
|
||||
if chains_with_unk_sequence:
|
||||
struc = struc.filter_out(chain_id=chains_with_unk_sequence)
|
||||
else:
|
||||
num_with_unk_sequence = 0
|
||||
metadata['num_with_unk_sequence'] = num_with_unk_sequence
|
||||
|
||||
# Remove intersecting chains.
|
||||
if filter_clashes and struc.num_chains > 1:
|
||||
clashing_chains = sterics.find_clashing_chains(struc)
|
||||
if clashing_chains:
|
||||
struc = struc.filter_out(chain_id=clashing_chains)
|
||||
else:
|
||||
clashing_chains = []
|
||||
metadata['num_clashing_chains_removed'] = len(clashing_chains)
|
||||
metadata['chains_removed'] = clashing_chains
|
||||
|
||||
# Drop non-standard atoms
|
||||
if drop_non_standard_atoms:
|
||||
@@ -268,27 +229,7 @@ def clean_structure(
|
||||
new_bonds = structure.Bonds.make_empty()
|
||||
struc = struc.copy_and_update(bonds=new_bonds)
|
||||
|
||||
if struc.bonds and remove_nonsymmetric_bonds:
|
||||
# Check for asymmetric polymer-ligand bonds and remove if these exist.
|
||||
polymer_ligand_bonds = inter_chain_bonds.get_polymer_ligand_bonds(
|
||||
struc,
|
||||
only_glycan_ligands=False,
|
||||
)
|
||||
if polymer_ligand_bonds:
|
||||
if covalent_bond_cleaning.has_nonsymmetric_bonds_on_symmetric_polymer_chains(
|
||||
struc, polymer_ligand_bonds
|
||||
):
|
||||
from_atom_idxs, dest_atom_idxs = struc.bonds.get_atom_indices(
|
||||
struc.atom_key
|
||||
)
|
||||
poly_chain_types = list(mmcif_names.POLYMER_CHAIN_TYPES)
|
||||
is_polymer_bond = np.logical_or(
|
||||
np.isin(struc.chain_type[from_atom_idxs], poly_chain_types),
|
||||
np.isin(struc.chain_type[dest_atom_idxs], poly_chain_types),
|
||||
)
|
||||
struc = struc.copy_and_update(bonds=struc.bonds[~is_polymer_bond])
|
||||
|
||||
return struc, metadata
|
||||
return struc
|
||||
|
||||
|
||||
def create_empty_output_struc_and_layout(
|
||||
|
||||
@@ -1,265 +0,0 @@
|
||||
# Copyright 2024 DeepMind Technologies Limited
|
||||
#
|
||||
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
|
||||
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
#
|
||||
# To request access to the AlphaFold 3 model parameters, follow the process set
|
||||
# out at https://github.com/google-deepmind/alphafold3. You may only use these
|
||||
# if received directly from Google. Use is subject to terms of use available at
|
||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||
|
||||
"""Some methods to compute metrics for PTMs."""
|
||||
|
||||
import collections
|
||||
from collections.abc import Mapping
|
||||
import dataclasses
|
||||
|
||||
from alphafold3 import structure
|
||||
from alphafold3.constants import mmcif_names
|
||||
from alphafold3.model.atom_layout import atom_layout
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ResIdMapping:
|
||||
old_res_ids: np.ndarray
|
||||
new_res_ids: np.ndarray
|
||||
|
||||
|
||||
def _count_symmetric_chains(struc: structure.Structure) -> Mapping[str, int]:
|
||||
"""Returns a dict with each chain ID and count."""
|
||||
chain_res_name_sequence_from_chain_id = struc.chain_res_name_sequence(
|
||||
include_missing_residues=True, fix_non_standard_polymer_res=False
|
||||
)
|
||||
counts_for_chain_res_name_sequence = collections.Counter(
|
||||
chain_res_name_sequence_from_chain_id.values()
|
||||
)
|
||||
chain_symmetric_count = {}
|
||||
for chain_id, chain_res_name in chain_res_name_sequence_from_chain_id.items():
|
||||
chain_symmetric_count[chain_id] = counts_for_chain_res_name_sequence[
|
||||
chain_res_name
|
||||
]
|
||||
return chain_symmetric_count
|
||||
|
||||
|
||||
def has_nonsymmetric_bonds_on_symmetric_polymer_chains(
|
||||
struc: structure.Structure, polymer_ligand_bonds: atom_layout.AtomLayout
|
||||
) -> bool:
|
||||
"""Returns true if nonsymmetric bonds found on polymer chains."""
|
||||
try:
|
||||
_get_polymer_dim(polymer_ligand_bonds)
|
||||
except ValueError:
|
||||
return True
|
||||
if _has_non_polymer_ligand_ptm_bonds(polymer_ligand_bonds):
|
||||
return True
|
||||
if _has_multiple_polymers_bonded_to_one_ligand(polymer_ligand_bonds):
|
||||
return True
|
||||
combined_struc, _ = _combine_polymer_ligand_ptm_chains(
|
||||
struc, polymer_ligand_bonds
|
||||
)
|
||||
struc = struc.filter(chain_type=mmcif_names.POLYMER_CHAIN_TYPES)
|
||||
combined_struc = combined_struc.filter(
|
||||
chain_type=mmcif_names.POLYMER_CHAIN_TYPES
|
||||
)
|
||||
return _count_symmetric_chains(struc) != _count_symmetric_chains(
|
||||
combined_struc
|
||||
)
|
||||
|
||||
|
||||
def _has_non_polymer_ligand_ptm_bonds(
|
||||
polymer_ligand_bonds: atom_layout.AtomLayout,
|
||||
):
|
||||
"""Checks if all bonds are between a polymer chain and a ligand chain type."""
|
||||
for start_chain_type, end_chain_type in polymer_ligand_bonds.chain_type:
|
||||
if (
|
||||
start_chain_type in mmcif_names.POLYMER_CHAIN_TYPES
|
||||
and end_chain_type in mmcif_names.LIGAND_CHAIN_TYPES
|
||||
):
|
||||
continue
|
||||
elif (
|
||||
start_chain_type in mmcif_names.LIGAND_CHAIN_TYPES
|
||||
and end_chain_type in mmcif_names.POLYMER_CHAIN_TYPES
|
||||
):
|
||||
continue
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _combine_polymer_ligand_ptm_chains(
|
||||
struc: structure.Structure,
|
||||
polymer_ligand_bonds: atom_layout.AtomLayout,
|
||||
) -> tuple[structure.Structure, dict[tuple[str, str], ResIdMapping]]:
|
||||
"""Combines the ptm polymer-ligand chains together.
|
||||
|
||||
This will prevent them from being permuted away from each other when chains
|
||||
are matched to the ground truth. This function also returns the res_id mapping
|
||||
from the separate ligand res_ids to their res_ids in the combined
|
||||
polymer-ligand chain; this information is needed to later separate the
|
||||
combined polymer-ligand chain.
|
||||
|
||||
Args:
|
||||
struc: Structure to be modified.
|
||||
polymer_ligand_bonds: AtomLayout with polymer-ligand bond info.
|
||||
|
||||
Returns:
|
||||
A tuple of a Structure with each ptm polymer-ligand chain relabelled as one
|
||||
chain and a dict from bond chain pair to the res_id mapping.
|
||||
"""
|
||||
if not _has_only_single_bond_from_each_chain(polymer_ligand_bonds):
|
||||
if _has_multiple_ligands_bonded_to_one_polymer(polymer_ligand_bonds):
|
||||
# For structures where a polymer chain is connected to multiple ligands,
|
||||
# we need to sort the multiple bonds from the same chain by res_id to
|
||||
# ensure that the combined polymer-ligand chain will always be the same
|
||||
# when you have repeated symmetric polymer-ligand chains.
|
||||
polymer_ligand_bonds = (
|
||||
_sort_polymer_ligand_bonds_by_polymer_chain_and_res_id(
|
||||
polymer_ligand_bonds
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Code cannot handle multiple bonds from one chain unless'
|
||||
' its several ligands bonded to a polymer.'
|
||||
)
|
||||
res_id_mappings_for_bond_chain_pair = dict()
|
||||
for (start_chain_id, end_chain_id), (start_chain_type, end_chain_type) in zip(
|
||||
polymer_ligand_bonds.chain_id, polymer_ligand_bonds.chain_type
|
||||
):
|
||||
poly_info, ligand_info = _get_polymer_and_ligand_chain_ids_and_types(
|
||||
start_chain_id, end_chain_id, start_chain_type, end_chain_type
|
||||
)
|
||||
polymer_chain_id, polymer_chain_type = poly_info
|
||||
ligand_chain_id, _ = ligand_info
|
||||
|
||||
# Join the ligand chain to the polymer chain.
|
||||
ligand_res_ids = struc.filter(chain_id=ligand_chain_id).res_id
|
||||
new_res_ids = ligand_res_ids + len(struc.all_residues[polymer_chain_id])
|
||||
res_id_mappings_for_bond_chain_pair[(polymer_chain_id, ligand_chain_id)] = (
|
||||
ResIdMapping(old_res_ids=ligand_res_ids, new_res_ids=new_res_ids)
|
||||
)
|
||||
chain_groups = []
|
||||
chain_group_ids = []
|
||||
chain_group_types = []
|
||||
for chain_id, chain_type in zip(
|
||||
struc.chains_table.id, struc.chains_table.type
|
||||
):
|
||||
if chain_id == ligand_chain_id:
|
||||
continue
|
||||
elif chain_id == polymer_chain_id:
|
||||
chain_groups.append([polymer_chain_id, ligand_chain_id])
|
||||
chain_group_ids.append(polymer_chain_id)
|
||||
chain_group_types.append(polymer_chain_type)
|
||||
else:
|
||||
chain_groups.append([chain_id])
|
||||
chain_group_ids.append(chain_id)
|
||||
chain_group_types.append(chain_type)
|
||||
|
||||
struc = struc.merge_chains(
|
||||
chain_groups=chain_groups,
|
||||
chain_group_ids=chain_group_ids,
|
||||
chain_group_types=chain_group_types,
|
||||
)
|
||||
|
||||
return struc, res_id_mappings_for_bond_chain_pair
|
||||
|
||||
|
||||
def _has_only_single_bond_from_each_chain(
|
||||
polymer_ligand_bonds: atom_layout.AtomLayout,
|
||||
) -> bool:
|
||||
"""Checks that there is at most one bond from each chain."""
|
||||
chain_ids = []
|
||||
for chains in polymer_ligand_bonds.chain_id:
|
||||
chain_ids.extend(chains)
|
||||
if len(chain_ids) != len(set(chain_ids)):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_polymer_and_ligand_chain_ids_and_types(
|
||||
start_chain_id: str,
|
||||
end_chain_id: str,
|
||||
start_chain_type: str,
|
||||
end_chain_type: str,
|
||||
) -> tuple[tuple[str, str], tuple[str, str]]:
|
||||
"""Finds polymer and ligand chain ids from chain types."""
|
||||
if (
|
||||
start_chain_type in mmcif_names.POLYMER_CHAIN_TYPES
|
||||
and end_chain_type in mmcif_names.LIGAND_CHAIN_TYPES
|
||||
):
|
||||
return (start_chain_id, start_chain_type), (end_chain_id, end_chain_type)
|
||||
elif (
|
||||
start_chain_type in mmcif_names.LIGAND_CHAIN_TYPES
|
||||
and end_chain_type in mmcif_names.POLYMER_CHAIN_TYPES
|
||||
):
|
||||
return (end_chain_id, end_chain_type), (start_chain_id, start_chain_type)
|
||||
else:
|
||||
raise ValueError(
|
||||
'This code only handles PTM-bonds from polymer chain to ligands.'
|
||||
)
|
||||
|
||||
|
||||
def _get_polymer_dim(polymer_ligand_bonds: atom_layout.AtomLayout) -> int:
|
||||
"""Gets polymer dimension from the polymer-ligand bond layout."""
|
||||
start_chain_types = []
|
||||
end_chain_types = []
|
||||
for start_chain_type, end_chain_type in polymer_ligand_bonds.chain_type:
|
||||
start_chain_types.append(start_chain_type)
|
||||
end_chain_types.append(end_chain_type)
|
||||
if set(start_chain_types).issubset(
|
||||
set(mmcif_names.POLYMER_CHAIN_TYPES)
|
||||
) and set(end_chain_types).issubset(set(mmcif_names.LIGAND_CHAIN_TYPES)):
|
||||
return 0
|
||||
elif set(start_chain_types).issubset(mmcif_names.LIGAND_CHAIN_TYPES) and set(
|
||||
end_chain_types
|
||||
).issubset(set(mmcif_names.POLYMER_CHAIN_TYPES)):
|
||||
return 1
|
||||
else:
|
||||
raise ValueError(
|
||||
'Polymer and ligand dimensions are not consistent within the structure.'
|
||||
)
|
||||
|
||||
|
||||
def _has_multiple_ligands_bonded_to_one_polymer(polymer_ligand_bonds):
|
||||
"""Checks if there are multiple ligands bonded to one polymer."""
|
||||
polymer_dim = _get_polymer_dim(polymer_ligand_bonds)
|
||||
polymer_chain_ids = [
|
||||
chains[polymer_dim] for chains in polymer_ligand_bonds.chain_id
|
||||
]
|
||||
if len(polymer_chain_ids) != len(set(polymer_chain_ids)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _has_multiple_polymers_bonded_to_one_ligand(polymer_ligand_bonds):
|
||||
"""Checks if there are multiple polymer chains bonded to one ligand."""
|
||||
polymer_dim = _get_polymer_dim(polymer_ligand_bonds)
|
||||
ligand_dim = 1 - polymer_dim
|
||||
ligand_chain_ids = [
|
||||
chains[ligand_dim] for chains in polymer_ligand_bonds.chain_id
|
||||
]
|
||||
if len(ligand_chain_ids) != len(set(ligand_chain_ids)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _sort_polymer_ligand_bonds_by_polymer_chain_and_res_id(
|
||||
polymer_ligand_bonds,
|
||||
):
|
||||
"""Sorts bonds by res_id (for when a polymer chain has multiple bonded ligands)."""
|
||||
|
||||
polymer_dim = _get_polymer_dim(polymer_ligand_bonds)
|
||||
|
||||
polymer_chain_ids = [
|
||||
chains[polymer_dim] for chains in polymer_ligand_bonds.chain_id
|
||||
]
|
||||
polymer_res_ids = [res[polymer_dim] for res in polymer_ligand_bonds.res_id]
|
||||
|
||||
polymer_chain_and_res_id = zip(polymer_chain_ids, polymer_res_ids)
|
||||
sorted_indices = [
|
||||
idx
|
||||
for idx, _ in sorted(
|
||||
enumerate(polymer_chain_and_res_id), key=lambda x: x[1]
|
||||
)
|
||||
]
|
||||
return polymer_ligand_bonds[sorted_indices]
|
||||
@@ -1,140 +0,0 @@
|
||||
# Copyright 2024 DeepMind Technologies Limited
|
||||
#
|
||||
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
|
||||
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
#
|
||||
# To request access to the AlphaFold 3 model parameters, follow the process set
|
||||
# out at https://github.com/google-deepmind/alphafold3. You may only use these
|
||||
# if received directly from Google. Use is subject to terms of use available at
|
||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||
|
||||
"""Functions relating to spatial locations of atoms within a structure."""
|
||||
|
||||
from collections.abc import Collection, Sequence
|
||||
|
||||
from alphafold3 import structure
|
||||
from alphafold3.structure import mmcif
|
||||
import numpy as np
|
||||
import scipy
|
||||
|
||||
|
||||
def _make_atom_has_clash_mask(
|
||||
kd_query_result: np.ndarray,
|
||||
struc: structure.Structure,
|
||||
ignore_chains: Collection[str],
|
||||
) -> np.ndarray:
|
||||
"""Returns a boolean NumPy array representing whether each atom has a clash.
|
||||
|
||||
Args:
|
||||
kd_query_result: NumPy array containing N-atoms arrays, each array
|
||||
containing indices to atoms that clash with the N'th atom.
|
||||
struc: Structure over which clashes were detected.
|
||||
ignore_chains: Collection of chains that should not be considered clashing.
|
||||
A boolean NumPy array of length N atoms.
|
||||
"""
|
||||
atom_is_clashing = np.zeros((struc.num_atoms,), dtype=bool)
|
||||
for atom_index, clashes in enumerate(kd_query_result):
|
||||
chain_i = struc.chain_id[atom_index]
|
||||
if chain_i in ignore_chains:
|
||||
continue
|
||||
islig_i = struc.is_ligand_mask[atom_index]
|
||||
for clashing_atom_index in clashes:
|
||||
chain_c = struc.chain_id[clashing_atom_index]
|
||||
if chain_c in ignore_chains:
|
||||
continue
|
||||
islig_c = struc.is_ligand_mask[clashing_atom_index]
|
||||
if (
|
||||
clashing_atom_index == atom_index
|
||||
or chain_i == chain_c
|
||||
or islig_i != islig_c
|
||||
):
|
||||
# Ignore clashes within chain or between ligand and polymer.
|
||||
continue
|
||||
atom_is_clashing[atom_index] = True
|
||||
return atom_is_clashing
|
||||
|
||||
|
||||
def find_clashing_chains(
|
||||
struc: structure.Structure,
|
||||
clash_thresh_angstrom: float = 1.7,
|
||||
clash_thresh_fraction: float = 0.3,
|
||||
) -> Sequence[str]:
|
||||
"""Finds chains that clash with others.
|
||||
|
||||
Clashes are defined by polymer backbone atoms and all ligand atoms.
|
||||
Ligand-polymer clashes are not dropped.
|
||||
|
||||
Will not find clashes if all coordinates are 0. Coordinates are all 0s if
|
||||
the structure is generated from sequences only, as done for inference in
|
||||
dendro for example.
|
||||
|
||||
Args:
|
||||
struc: The structure defining the chains and atom positions.
|
||||
clash_thresh_angstrom: Below this distance, atoms are considered clashing.
|
||||
clash_thresh_fraction: Chains with more than this fraction of their atoms
|
||||
considered clashing will be dropped. This value should be in the range (0,
|
||||
1].
|
||||
|
||||
Returns:
|
||||
A sequence of chain ids for chains that clash.
|
||||
|
||||
Raises:
|
||||
ValueError: If `clash_thresh_fraction` is not in range (0,1].
|
||||
"""
|
||||
if not 0 < clash_thresh_fraction <= 1:
|
||||
raise ValueError('clash_thresh_fraction must be in range (0,1]')
|
||||
|
||||
struc_backbone = struc.filter_polymers_to_single_atom_per_res()
|
||||
if struc_backbone.num_chains == 0:
|
||||
return []
|
||||
|
||||
# If the coordinates are all 0, do not search for clashes.
|
||||
if not np.any(struc_backbone.coords):
|
||||
return []
|
||||
|
||||
coord_kdtree = scipy.spatial.cKDTree(struc_backbone.coords)
|
||||
|
||||
# For each atom coordinate, find all atoms within the clash thresh radius.
|
||||
clashing_per_atom = coord_kdtree.query_ball_point(
|
||||
struc_backbone.coords, r=clash_thresh_angstrom
|
||||
)
|
||||
chain_ids = struc_backbone.chains
|
||||
if struc_backbone.atom_occupancy is not None:
|
||||
chain_occupancy = np.array([
|
||||
np.mean(struc_backbone.atom_occupancy[start:end])
|
||||
for start, end in struc_backbone.iter_chain_ranges()
|
||||
])
|
||||
else:
|
||||
chain_occupancy = None
|
||||
|
||||
# Remove chains until no more significant clashing.
|
||||
chains_to_remove = set()
|
||||
for _ in range(len(chain_ids)):
|
||||
# Calculate maximally clashing.
|
||||
atom_has_clash = _make_atom_has_clash_mask(
|
||||
clashing_per_atom, struc_backbone, chains_to_remove
|
||||
)
|
||||
clashes_per_chain = np.array([
|
||||
atom_has_clash[start:end].mean()
|
||||
for start, end in struc_backbone.iter_chain_ranges()
|
||||
])
|
||||
max_clash = np.max(clashes_per_chain)
|
||||
if max_clash <= clash_thresh_fraction:
|
||||
# None of the remaining chains exceed the clash fraction threshold, so
|
||||
# we can exit.
|
||||
break
|
||||
|
||||
# Greedily remove worst with the lowest occupancy.
|
||||
most_clashes = np.nonzero(clashes_per_chain == max_clash)[0]
|
||||
if chain_occupancy is not None:
|
||||
occupancy_clashing = chain_occupancy[most_clashes]
|
||||
last_lowest_occupancy = (
|
||||
len(occupancy_clashing) - np.argmin(occupancy_clashing[::-1]) - 1
|
||||
)
|
||||
worst_and_last = most_clashes[last_lowest_occupancy]
|
||||
else:
|
||||
worst_and_last = most_clashes[-1]
|
||||
|
||||
chains_to_remove.add(chain_ids[worst_and_last])
|
||||
|
||||
return sorted(chains_to_remove, key=mmcif.str_id_to_int_id)
|
||||
Reference in New Issue
Block a user