mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2026-06-02 11:54:36 +08:00
Fix glycan O1 leaving-atom cleanup for standalone glycans
Inspired by https://github.com/google-deepmind/alphafold3/issues/650 and https://github.com/google-deepmind/alphafold3/pull/633. PiperOrigin-RevId: 911350518 Change-Id: Ia68e36e3580fb6092fdaeb62352df7e06991b58c
This commit is contained in:
committed by
Copybara-Service
parent
97639fff6f
commit
62a93afa5d
@@ -53,7 +53,6 @@ from jax import numpy as jnp
|
||||
import numpy as np
|
||||
import tokamax
|
||||
|
||||
|
||||
_HOME_DIR = pathlib.Path.home()
|
||||
_DEFAULT_MODEL_DIR = _HOME_DIR / 'models'
|
||||
_DEFAULT_DB_DIR = _HOME_DIR / 'public_databases'
|
||||
@@ -127,7 +126,6 @@ DB_DIR = flags.DEFINE_multi_string(
|
||||
'Path to the directory containing the databases. Can be specified multiple'
|
||||
' times to search multiple directories in order.',
|
||||
)
|
||||
|
||||
_SMALL_BFD_DATABASE_PATH = flags.DEFINE_string(
|
||||
'small_bfd_database_path',
|
||||
'${DB_DIR}/bfd-first_non_consensus_sequences.fasta',
|
||||
@@ -285,6 +283,14 @@ _CONFORMER_MAX_ITERATIONS = flags.DEFINE_integer(
|
||||
'conformer search.',
|
||||
lower_bound=0,
|
||||
)
|
||||
_FIX_STANDALONE_GLYCANS = flags.DEFINE_bool(
|
||||
'fix_standalone_glycans',
|
||||
False,
|
||||
'AlphaFold 3 model training and evaluation filtered out leaving atoms from'
|
||||
' glycan ligands even if they were not bonded to anything ("standalone"'
|
||||
' glycans). Setting this flag to True fixes this undesirable behavior, but'
|
||||
' moves away from the regime where AlphaFold 3 was trained and evaluated.',
|
||||
)
|
||||
|
||||
# JAX inference performance tuning.
|
||||
_JAX_COMPILATION_CACHE_DIR = flags.DEFINE_string(
|
||||
@@ -513,10 +519,12 @@ class ResultsForSeed:
|
||||
def predict_structure(
|
||||
fold_input: folding_input.Input,
|
||||
model_runner: ModelRunner,
|
||||
*,
|
||||
buckets: Sequence[int] | None = None,
|
||||
ref_max_modified_date: datetime.date | None = None,
|
||||
conformer_max_iterations: int | None = None,
|
||||
resolve_msa_overlaps: bool = True,
|
||||
fix_standalone_glycans: bool = False,
|
||||
) -> Sequence[ResultsForSeed]:
|
||||
"""Runs the full inference pipeline to predict structures for each seed."""
|
||||
|
||||
@@ -531,6 +539,7 @@ def predict_structure(
|
||||
ref_max_modified_date=ref_max_modified_date,
|
||||
conformer_max_iterations=conformer_max_iterations,
|
||||
resolve_msa_overlaps=resolve_msa_overlaps,
|
||||
fix_standalone_glycans=fix_standalone_glycans,
|
||||
)
|
||||
print(
|
||||
f'Featurising data with {len(fold_input.rng_seeds)} seed(s) took'
|
||||
@@ -731,6 +740,7 @@ def process_fold_input(
|
||||
ref_max_modified_date: datetime.date | None = None,
|
||||
conformer_max_iterations: int | None = None,
|
||||
resolve_msa_overlaps: bool = True,
|
||||
fix_standalone_glycans: bool = False,
|
||||
force_output_dir: bool = False,
|
||||
compress_large_output_files: bool = False,
|
||||
) -> folding_input.Input | Sequence[ResultsForSeed]:
|
||||
@@ -759,6 +769,11 @@ def process_fold_input(
|
||||
paper. Set this to false if providing custom paired MSA using the unpaired
|
||||
MSA field to keep it exactly as is as deduplication against the paired MSA
|
||||
could break the manually crafted pairing between MSA sequences.
|
||||
fix_standalone_glycans: If True, standalone glycans are preserved when
|
||||
filter_leaving_atoms is True. This is False by default to match the
|
||||
AlphaFold 3 paper. Note that the model has been trained with the default
|
||||
setting, so setting this to True may cause non-standard behaviour of the
|
||||
model.
|
||||
force_output_dir: If True, do not create a new output directory even if the
|
||||
existing one is non-empty. Instead use the existing output directory and
|
||||
potentially overwrite existing files. If False, create a new timestamped
|
||||
@@ -815,6 +830,7 @@ def process_fold_input(
|
||||
ref_max_modified_date=ref_max_modified_date,
|
||||
conformer_max_iterations=conformer_max_iterations,
|
||||
resolve_msa_overlaps=resolve_msa_overlaps,
|
||||
fix_standalone_glycans=fix_standalone_glycans,
|
||||
)
|
||||
print(f'Writing outputs with {len(fold_input.rng_seeds)} seed(s)...')
|
||||
write_outputs(
|
||||
@@ -985,6 +1001,7 @@ def main(_):
|
||||
ref_max_modified_date=max_template_date,
|
||||
conformer_max_iterations=_CONFORMER_MAX_ITERATIONS.value,
|
||||
resolve_msa_overlaps=_RESOLVE_MSA_OVERLAPS.value,
|
||||
fix_standalone_glycans=_FIX_STANDALONE_GLYCANS.value,
|
||||
force_output_dir=_FORCE_OUTPUT_DIR.value,
|
||||
compress_large_output_files=_COMPRESS_LARGE_OUTPUT_FILES.value,
|
||||
)
|
||||
|
||||
@@ -42,6 +42,7 @@ def featurise_input(
|
||||
ref_max_modified_date: datetime.date | None = None,
|
||||
conformer_max_iterations: int | None = None,
|
||||
resolve_msa_overlaps: bool = True,
|
||||
fix_standalone_glycans: bool = False,
|
||||
verbose: bool = False,
|
||||
) -> Sequence[features.BatchDict]:
|
||||
"""Featurise the folding input.
|
||||
@@ -66,6 +67,12 @@ def featurise_input(
|
||||
paper. Set this to false if providing custom paired MSA using the unpaired
|
||||
MSA field to keep it exactly as is as deduplication against the paired MSA
|
||||
could break the manually crafted pairing between MSA sequences.
|
||||
fix_standalone_glycans: AlphaFold 3 model training and evaluation filtered
|
||||
out leaving atoms from glycan ligands even if they were not bonded to
|
||||
anything ("standalone" glycans). Setting this flag to True fixes this
|
||||
undesirable behavior, but moves away from the regime where AlphaFold 3 was
|
||||
trained and evaluated. This has only an effect if filter_leaving_atoms is
|
||||
True in the WholePdbPipeline.Config.
|
||||
verbose: Whether to print progress messages.
|
||||
|
||||
Returns:
|
||||
@@ -80,6 +87,7 @@ def featurise_input(
|
||||
ref_max_modified_date=ref_max_modified_date,
|
||||
conformer_max_iterations=conformer_max_iterations,
|
||||
resolve_msa_overlaps=resolve_msa_overlaps,
|
||||
fix_standalone_glycans=fix_standalone_glycans,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from rdkit import Chem
|
||||
|
||||
|
||||
xnp_ndarray: TypeAlias = np.ndarray | jnp.ndarray # pylint: disable=invalid-name
|
||||
NumpyIndex: TypeAlias = Any
|
||||
|
||||
@@ -641,6 +640,7 @@ def get_link_drop_atoms(
|
||||
is_end_terminus: bool,
|
||||
bonded_atoms: set[str],
|
||||
drop_ligand_leaving_atoms: bool = False,
|
||||
fix_standalone_glycans: bool = False,
|
||||
) -> set[str]:
|
||||
"""Returns set of atoms that are dropped when this res_name gets linked.
|
||||
|
||||
@@ -651,6 +651,12 @@ def get_link_drop_atoms(
|
||||
is_end_terminus: whether the residue is the c-terminus
|
||||
bonded_atoms: Names of atoms coming off this residue.
|
||||
drop_ligand_leaving_atoms: Flag to switch on/off leaving atoms for ligands.
|
||||
fix_standalone_glycans: AlphaFold 3 model training and evaluation filtered
|
||||
out leaving atoms from glycan ligands even if they were not bonded to
|
||||
anything ("standalone" glycans). Setting this flag to True fixes this
|
||||
undesirable behavior, but moves away from the regime where AlphaFold 3 was
|
||||
trained and evaluated. This has only an effect if
|
||||
drop_ligand_leaving_atoms is True.
|
||||
|
||||
Returns:
|
||||
Set of atoms that are dropped when this amino acid gets linked.
|
||||
@@ -677,8 +683,12 @@ def get_link_drop_atoms(
|
||||
*chemical_component_sets.GLYCAN_OTHER_LIGANDS,
|
||||
*chemical_component_sets.GLYCAN_LINKING_LIGANDS,
|
||||
}:
|
||||
if 'O1' not in bonded_atoms:
|
||||
drop_atoms.update({'O1'})
|
||||
if fix_standalone_glycans:
|
||||
if bonded_atoms and 'O1' not in bonded_atoms:
|
||||
drop_atoms.update({'O1'})
|
||||
else:
|
||||
if 'O1' not in bonded_atoms:
|
||||
drop_atoms.update({'O1'})
|
||||
return drop_atoms
|
||||
|
||||
|
||||
@@ -743,6 +753,7 @@ def make_flat_atom_layout(
|
||||
with_hydrogens: bool = False,
|
||||
skip_unk_residues: bool = True,
|
||||
drop_ligand_leaving_atoms: bool = False,
|
||||
fix_standalone_glycans: bool = False,
|
||||
) -> AtomLayout:
|
||||
"""Make a flat atom layout for given residues.
|
||||
|
||||
@@ -761,6 +772,12 @@ def make_flat_atom_layout(
|
||||
compatible with the rest of AlphaFold that does not predict atoms for
|
||||
unknown residues
|
||||
drop_ligand_leaving_atoms: Flag to switch on/ off leaving atoms for ligands.
|
||||
fix_standalone_glycans: AlphaFold 3 model training and evaluation filtered
|
||||
out leaving atoms from glycan ligands even if they were not bonded to
|
||||
anything ("standalone" glycans). Setting this flag to True fixes this
|
||||
undesirable behavior, but moves away from the regime where AlphaFold 3 was
|
||||
trained and evaluated. This has only an effect if
|
||||
drop_ligand_leaving_atoms is True.
|
||||
|
||||
Returns:
|
||||
an `AtomLayout` object
|
||||
@@ -834,6 +851,7 @@ def make_flat_atom_layout(
|
||||
is_end_terminus=residues.is_end_terminus[idx],
|
||||
bonded_atoms=bonded_atoms,
|
||||
drop_ligand_leaving_atoms=drop_ligand_leaving_atoms,
|
||||
fix_standalone_glycans=fix_standalone_glycans,
|
||||
)
|
||||
|
||||
# If deprotonation info is available, remove the specific atoms.
|
||||
|
||||
@@ -121,6 +121,12 @@ class WholePdbPipeline:
|
||||
unpaired MSA field to keep it exactly as is as deduplication against
|
||||
the paired MSA could break the manually crafted pairing between MSA
|
||||
sequences.
|
||||
fix_standalone_glycans: AlphaFold 3 model training and evaluation filtered
|
||||
out leaving atoms from glycan ligands even if they were not bonded to
|
||||
anything ("standalone" glycans). Setting this flag to True fixes this
|
||||
undesirable behavior, but moves away from the regime where AlphaFold 3
|
||||
was trained and evaluated. This has only an effect if
|
||||
drop_ligand_leaving_atoms is True.
|
||||
"""
|
||||
|
||||
max_atoms_per_token: int = 24
|
||||
@@ -141,6 +147,7 @@ class WholePdbPipeline:
|
||||
deterministic_frames: bool = True
|
||||
conformer_max_iterations: int | None = None
|
||||
resolve_msa_overlaps: bool = True
|
||||
fix_standalone_glycans: bool = False
|
||||
|
||||
def __init__(self, *, config: Config):
|
||||
"""Initializes WholePdb data pipeline.
|
||||
@@ -181,6 +188,7 @@ class WholePdbPipeline:
|
||||
covalent_bonds_only=True,
|
||||
remove_polymer_polymer_bonds=True,
|
||||
remove_bad_bonds=True,
|
||||
fix_standalone_glycans=self._config.fix_standalone_glycans,
|
||||
)
|
||||
|
||||
# No chains after cleaning.
|
||||
@@ -209,6 +217,7 @@ class WholePdbPipeline:
|
||||
polymer_ligand_bonds=polymer_ligand_bonds,
|
||||
ligand_ligand_bonds=ligand_ligand_bonds,
|
||||
drop_ligand_leaving_atoms=self._config.drop_ligand_leaving_atoms,
|
||||
fix_standalone_glycans=self._config.fix_standalone_glycans,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ def _get_leaving_atom_mask(
|
||||
chain_type: str,
|
||||
res_id: int,
|
||||
res_name: str,
|
||||
fix_standalone_glycans: bool,
|
||||
) -> np.ndarray:
|
||||
"""Updates a drop_leaving_atoms mask with new leaving atom locations."""
|
||||
bonded_atoms = atom_layout.get_bonded_atoms(
|
||||
@@ -44,6 +45,7 @@ def _get_leaving_atom_mask(
|
||||
is_end_terminus=False,
|
||||
bonded_atoms=bonded_atoms,
|
||||
drop_ligand_leaving_atoms=True,
|
||||
fix_standalone_glycans=fix_standalone_glycans,
|
||||
)
|
||||
# Default mask where everything is false, which equates to being kept.
|
||||
drop_atom_filter_atoms = struc.chain_id != struc.chain_id
|
||||
@@ -74,6 +76,7 @@ def clean_structure(
|
||||
covalent_bonds_only: bool,
|
||||
remove_polymer_polymer_bonds: bool,
|
||||
remove_bad_bonds: bool,
|
||||
fix_standalone_glycans: bool,
|
||||
) -> structure.Structure:
|
||||
"""Returns a cleaned version of the input structure.
|
||||
|
||||
@@ -90,6 +93,12 @@ def clean_structure(
|
||||
covalent_bonds_only: Only include covalent bonds.
|
||||
remove_polymer_polymer_bonds: Remove polymer-polymer bonds.
|
||||
remove_bad_bonds: Whether to remove badly bonded ligands.
|
||||
fix_standalone_glycans: AlphaFold 3 model training and evaluation filtered
|
||||
out leaving atoms from glycan ligands even if they were not bonded to
|
||||
anything ("standalone" glycans). Setting this flag to True fixes this
|
||||
undesirable behavior, but moves away from the regime where AlphaFold 3 was
|
||||
trained and evaluated. This has only an effect if filter_leaving_atoms is
|
||||
True.
|
||||
"""
|
||||
|
||||
# Drop chains without specified sequences.
|
||||
@@ -128,8 +137,10 @@ def clean_structure(
|
||||
*chemical_component_sets.GLYCAN_LINKING_LIGANDS,
|
||||
}
|
||||
# If only glycan ligands and no O1 atoms, we can do parallel drop.
|
||||
# If we need to keep O1 for standalone glycans, iterate over each residue.
|
||||
if (
|
||||
only_glycan_ligands_for_leaving_atoms
|
||||
not fix_standalone_glycans
|
||||
and only_glycan_ligands_for_leaving_atoms
|
||||
and (not (ligand_ligand_bonds.atom_name == 'O1').any())
|
||||
and (not (polymer_ligand_bonds.atom_name == 'O1').any())
|
||||
):
|
||||
@@ -153,23 +164,22 @@ def clean_structure(
|
||||
chain_type=res['chain_type'],
|
||||
res_id=res['res_id'],
|
||||
res_name=res_name,
|
||||
fix_standalone_glycans=fix_standalone_glycans,
|
||||
)
|
||||
drop_leaving_atoms_all = np.logical_or(
|
||||
drop_leaving_atoms_all, drop_atom_filter
|
||||
)
|
||||
|
||||
num_atoms_before = struc.num_atoms
|
||||
struc = struc.filter_out(drop_leaving_atoms_all)
|
||||
num_atoms_after = struc.num_atoms
|
||||
|
||||
if num_atoms_before > num_atoms_after:
|
||||
logging.error(
|
||||
'Dropped %s atoms from GT struc: chain_id %s res_id %s res_name %s',
|
||||
num_atoms_before - num_atoms_after,
|
||||
struc.chain_id,
|
||||
struc.res_id,
|
||||
struc.res_name,
|
||||
if np.any(drop_leaving_atoms_all):
|
||||
logging.info(
|
||||
'Dropped %d atoms: chain_id %s, res_id %s, res_name %s, atom_name %s',
|
||||
struc.num_atoms - np.sum(~drop_leaving_atoms_all),
|
||||
struc.chain_id[drop_leaving_atoms_all],
|
||||
struc.res_id[drop_leaving_atoms_all],
|
||||
struc.res_name[drop_leaving_atoms_all],
|
||||
struc.atom_name[drop_leaving_atoms_all],
|
||||
)
|
||||
struc = struc.filter_out(drop_leaving_atoms_all)
|
||||
|
||||
# Can filter by bond type without having to iterate over bonds.
|
||||
if struc.bonds and covalent_bonds_only:
|
||||
@@ -241,6 +251,7 @@ def create_empty_output_struc_and_layout(
|
||||
polymer_ligand_bonds: atom_layout.AtomLayout | None = None,
|
||||
ligand_ligand_bonds: atom_layout.AtomLayout | None = None,
|
||||
drop_ligand_leaving_atoms: bool = False,
|
||||
fix_standalone_glycans: bool = False,
|
||||
) -> tuple[structure.Structure, atom_layout.AtomLayout]:
|
||||
"""Make zero-coordinate structure from all physical residues.
|
||||
|
||||
@@ -252,6 +263,8 @@ def create_empty_output_struc_and_layout(
|
||||
polymer_ligand_bonds: Bond information for polymer-ligand pairs.
|
||||
ligand_ligand_bonds: Bond information for ligand-ligand pairs.
|
||||
drop_ligand_leaving_atoms: Flag for handling leaving atoms for ligands.
|
||||
fix_standalone_glycans: If True, standalone glycans are preserved when
|
||||
drop_ligand_leaving_atoms is True.
|
||||
|
||||
Returns:
|
||||
Tuple of structure with all bonds, physical residues and coordinates set to
|
||||
@@ -292,6 +305,7 @@ def create_empty_output_struc_and_layout(
|
||||
polymer_ligand_bonds=polymer_ligand_bonds,
|
||||
ligand_ligand_bonds=ligand_ligand_bonds,
|
||||
drop_ligand_leaving_atoms=drop_ligand_leaving_atoms,
|
||||
fix_standalone_glycans=fix_standalone_glycans,
|
||||
)
|
||||
|
||||
empty_output_struc = atom_layout.make_structure(
|
||||
|
||||
Reference in New Issue
Block a user