Fix glycan O1 leaving-atom cleanup for standalone glycans

Inspired by https://github.com/google-deepmind/alphafold3/issues/650 and
https://github.com/google-deepmind/alphafold3/pull/633.

PiperOrigin-RevId: 911350518
Change-Id: Ia68e36e3580fb6092fdaeb62352df7e06991b58c
This commit is contained in:
Augustin Zidek
2026-05-06 08:23:42 -07:00
committed by Copybara-Service
parent 97639fff6f
commit 62a93afa5d
5 changed files with 83 additions and 17 deletions

View File

@@ -53,7 +53,6 @@ from jax import numpy as jnp
import numpy as np
import tokamax
_HOME_DIR = pathlib.Path.home()
_DEFAULT_MODEL_DIR = _HOME_DIR / 'models'
_DEFAULT_DB_DIR = _HOME_DIR / 'public_databases'
@@ -127,7 +126,6 @@ DB_DIR = flags.DEFINE_multi_string(
'Path to the directory containing the databases. Can be specified multiple'
' times to search multiple directories in order.',
)
_SMALL_BFD_DATABASE_PATH = flags.DEFINE_string(
'small_bfd_database_path',
'${DB_DIR}/bfd-first_non_consensus_sequences.fasta',
@@ -285,6 +283,14 @@ _CONFORMER_MAX_ITERATIONS = flags.DEFINE_integer(
'conformer search.',
lower_bound=0,
)
_FIX_STANDALONE_GLYCANS = flags.DEFINE_bool(
'fix_standalone_glycans',
False,
'AlphaFold 3 model training and evaluation filtered out leaving atoms from'
' glycan ligands even if they were not bonded to anything ("standalone"'
' glycans). Setting this flag to True fixes this undesirable behavior, but'
' moves away from the regime where AlphaFold 3 was trained and evaluated.',
)
# JAX inference performance tuning.
_JAX_COMPILATION_CACHE_DIR = flags.DEFINE_string(
@@ -513,10 +519,12 @@ class ResultsForSeed:
def predict_structure(
fold_input: folding_input.Input,
model_runner: ModelRunner,
*,
buckets: Sequence[int] | None = None,
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
fix_standalone_glycans: bool = False,
) -> Sequence[ResultsForSeed]:
"""Runs the full inference pipeline to predict structures for each seed."""
@@ -531,6 +539,7 @@ def predict_structure(
ref_max_modified_date=ref_max_modified_date,
conformer_max_iterations=conformer_max_iterations,
resolve_msa_overlaps=resolve_msa_overlaps,
fix_standalone_glycans=fix_standalone_glycans,
)
print(
f'Featurising data with {len(fold_input.rng_seeds)} seed(s) took'
@@ -731,6 +740,7 @@ def process_fold_input(
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
fix_standalone_glycans: bool = False,
force_output_dir: bool = False,
compress_large_output_files: bool = False,
) -> folding_input.Input | Sequence[ResultsForSeed]:
@@ -759,6 +769,11 @@ def process_fold_input(
paper. Set this to false if providing custom paired MSA using the unpaired
MSA field to keep it exactly as is as deduplication against the paired MSA
could break the manually crafted pairing between MSA sequences.
fix_standalone_glycans: If True, standalone glycans are preserved when
filter_leaving_atoms is True. This is False by default to match the
AlphaFold 3 paper. Note that the model has been trained with the default
setting, so setting this to True may cause non-standard behaviour of the
model.
force_output_dir: If True, do not create a new output directory even if the
existing one is non-empty. Instead use the existing output directory and
potentially overwrite existing files. If False, create a new timestamped
@@ -815,6 +830,7 @@ def process_fold_input(
ref_max_modified_date=ref_max_modified_date,
conformer_max_iterations=conformer_max_iterations,
resolve_msa_overlaps=resolve_msa_overlaps,
fix_standalone_glycans=fix_standalone_glycans,
)
print(f'Writing outputs with {len(fold_input.rng_seeds)} seed(s)...')
write_outputs(
@@ -985,6 +1001,7 @@ def main(_):
ref_max_modified_date=max_template_date,
conformer_max_iterations=_CONFORMER_MAX_ITERATIONS.value,
resolve_msa_overlaps=_RESOLVE_MSA_OVERLAPS.value,
fix_standalone_glycans=_FIX_STANDALONE_GLYCANS.value,
force_output_dir=_FORCE_OUTPUT_DIR.value,
compress_large_output_files=_COMPRESS_LARGE_OUTPUT_FILES.value,
)

View File

@@ -42,6 +42,7 @@ def featurise_input(
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
resolve_msa_overlaps: bool = True,
fix_standalone_glycans: bool = False,
verbose: bool = False,
) -> Sequence[features.BatchDict]:
"""Featurise the folding input.
@@ -66,6 +67,12 @@ def featurise_input(
paper. Set this to false if providing custom paired MSA using the unpaired
MSA field to keep it exactly as is as deduplication against the paired MSA
could break the manually crafted pairing between MSA sequences.
fix_standalone_glycans: AlphaFold 3 model training and evaluation filtered
out leaving atoms from glycan ligands even if they were not bonded to
anything ("standalone" glycans). Setting this flag to True fixes this
undesirable behavior, but moves away from the regime where AlphaFold 3 was
trained and evaluated. This has only an effect if filter_leaving_atoms is
True in the WholePdbPipeline.Config.
verbose: Whether to print progress messages.
Returns:
@@ -80,6 +87,7 @@ def featurise_input(
ref_max_modified_date=ref_max_modified_date,
conformer_max_iterations=conformer_max_iterations,
resolve_msa_overlaps=resolve_msa_overlaps,
fix_standalone_glycans=fix_standalone_glycans,
),
)

View File

@@ -28,7 +28,6 @@ import jax.numpy as jnp
import numpy as np
from rdkit import Chem
xnp_ndarray: TypeAlias = np.ndarray | jnp.ndarray # pylint: disable=invalid-name
NumpyIndex: TypeAlias = Any
@@ -641,6 +640,7 @@ def get_link_drop_atoms(
is_end_terminus: bool,
bonded_atoms: set[str],
drop_ligand_leaving_atoms: bool = False,
fix_standalone_glycans: bool = False,
) -> set[str]:
"""Returns set of atoms that are dropped when this res_name gets linked.
@@ -651,6 +651,12 @@ def get_link_drop_atoms(
is_end_terminus: whether the residue is the c-terminus
bonded_atoms: Names of atoms coming off this residue.
drop_ligand_leaving_atoms: Flag to switch on/off leaving atoms for ligands.
fix_standalone_glycans: AlphaFold 3 model training and evaluation filtered
out leaving atoms from glycan ligands even if they were not bonded to
anything ("standalone" glycans). Setting this flag to True fixes this
undesirable behavior, but moves away from the regime where AlphaFold 3 was
trained and evaluated. This has only an effect if
drop_ligand_leaving_atoms is True.
Returns:
Set of atoms that are dropped when this amino acid gets linked.
@@ -677,8 +683,12 @@ def get_link_drop_atoms(
*chemical_component_sets.GLYCAN_OTHER_LIGANDS,
*chemical_component_sets.GLYCAN_LINKING_LIGANDS,
}:
if 'O1' not in bonded_atoms:
drop_atoms.update({'O1'})
if fix_standalone_glycans:
if bonded_atoms and 'O1' not in bonded_atoms:
drop_atoms.update({'O1'})
else:
if 'O1' not in bonded_atoms:
drop_atoms.update({'O1'})
return drop_atoms
@@ -743,6 +753,7 @@ def make_flat_atom_layout(
with_hydrogens: bool = False,
skip_unk_residues: bool = True,
drop_ligand_leaving_atoms: bool = False,
fix_standalone_glycans: bool = False,
) -> AtomLayout:
"""Make a flat atom layout for given residues.
@@ -761,6 +772,12 @@ def make_flat_atom_layout(
compatible with the rest of AlphaFold that does not predict atoms for
unknown residues
drop_ligand_leaving_atoms: Flag to switch on/ off leaving atoms for ligands.
fix_standalone_glycans: AlphaFold 3 model training and evaluation filtered
out leaving atoms from glycan ligands even if they were not bonded to
anything ("standalone" glycans). Setting this flag to True fixes this
undesirable behavior, but moves away from the regime where AlphaFold 3 was
trained and evaluated. This has only an effect if
drop_ligand_leaving_atoms is True.
Returns:
an `AtomLayout` object
@@ -834,6 +851,7 @@ def make_flat_atom_layout(
is_end_terminus=residues.is_end_terminus[idx],
bonded_atoms=bonded_atoms,
drop_ligand_leaving_atoms=drop_ligand_leaving_atoms,
fix_standalone_glycans=fix_standalone_glycans,
)
# If deprotonation info is available, remove the specific atoms.

View File

@@ -121,6 +121,12 @@ class WholePdbPipeline:
unpaired MSA field to keep it exactly as is as deduplication against
the paired MSA could break the manually crafted pairing between MSA
sequences.
fix_standalone_glycans: AlphaFold 3 model training and evaluation filtered
out leaving atoms from glycan ligands even if they were not bonded to
anything ("standalone" glycans). Setting this flag to True fixes this
undesirable behavior, but moves away from the regime where AlphaFold 3
was trained and evaluated. This has only an effect if
drop_ligand_leaving_atoms is True.
"""
max_atoms_per_token: int = 24
@@ -141,6 +147,7 @@ class WholePdbPipeline:
deterministic_frames: bool = True
conformer_max_iterations: int | None = None
resolve_msa_overlaps: bool = True
fix_standalone_glycans: bool = False
def __init__(self, *, config: Config):
"""Initializes WholePdb data pipeline.
@@ -181,6 +188,7 @@ class WholePdbPipeline:
covalent_bonds_only=True,
remove_polymer_polymer_bonds=True,
remove_bad_bonds=True,
fix_standalone_glycans=self._config.fix_standalone_glycans,
)
# No chains after cleaning.
@@ -209,6 +217,7 @@ class WholePdbPipeline:
polymer_ligand_bonds=polymer_ligand_bonds,
ligand_ligand_bonds=ligand_ligand_bonds,
drop_ligand_leaving_atoms=self._config.drop_ligand_leaving_atoms,
fix_standalone_glycans=self._config.fix_standalone_glycans,
)
)

View File

@@ -28,6 +28,7 @@ def _get_leaving_atom_mask(
chain_type: str,
res_id: int,
res_name: str,
fix_standalone_glycans: bool,
) -> np.ndarray:
"""Updates a drop_leaving_atoms mask with new leaving atom locations."""
bonded_atoms = atom_layout.get_bonded_atoms(
@@ -44,6 +45,7 @@ def _get_leaving_atom_mask(
is_end_terminus=False,
bonded_atoms=bonded_atoms,
drop_ligand_leaving_atoms=True,
fix_standalone_glycans=fix_standalone_glycans,
)
# Default mask where everything is false, which equates to being kept.
drop_atom_filter_atoms = struc.chain_id != struc.chain_id
@@ -74,6 +76,7 @@ def clean_structure(
covalent_bonds_only: bool,
remove_polymer_polymer_bonds: bool,
remove_bad_bonds: bool,
fix_standalone_glycans: bool,
) -> structure.Structure:
"""Returns a cleaned version of the input structure.
@@ -90,6 +93,12 @@ def clean_structure(
covalent_bonds_only: Only include covalent bonds.
remove_polymer_polymer_bonds: Remove polymer-polymer bonds.
remove_bad_bonds: Whether to remove badly bonded ligands.
fix_standalone_glycans: AlphaFold 3 model training and evaluation filtered
out leaving atoms from glycan ligands even if they were not bonded to
anything ("standalone" glycans). Setting this flag to True fixes this
undesirable behavior, but moves away from the regime where AlphaFold 3 was
trained and evaluated. This has only an effect if filter_leaving_atoms is
True.
"""
# Drop chains without specified sequences.
@@ -128,8 +137,10 @@ def clean_structure(
*chemical_component_sets.GLYCAN_LINKING_LIGANDS,
}
# If only glycan ligands and no O1 atoms, we can do parallel drop.
# If we need to keep O1 for standalone glycans, iterate over each residue.
if (
only_glycan_ligands_for_leaving_atoms
not fix_standalone_glycans
and only_glycan_ligands_for_leaving_atoms
and (not (ligand_ligand_bonds.atom_name == 'O1').any())
and (not (polymer_ligand_bonds.atom_name == 'O1').any())
):
@@ -153,23 +164,22 @@ def clean_structure(
chain_type=res['chain_type'],
res_id=res['res_id'],
res_name=res_name,
fix_standalone_glycans=fix_standalone_glycans,
)
drop_leaving_atoms_all = np.logical_or(
drop_leaving_atoms_all, drop_atom_filter
)
num_atoms_before = struc.num_atoms
struc = struc.filter_out(drop_leaving_atoms_all)
num_atoms_after = struc.num_atoms
if num_atoms_before > num_atoms_after:
logging.error(
'Dropped %s atoms from GT struc: chain_id %s res_id %s res_name %s',
num_atoms_before - num_atoms_after,
struc.chain_id,
struc.res_id,
struc.res_name,
if np.any(drop_leaving_atoms_all):
logging.info(
'Dropped %d atoms: chain_id %s, res_id %s, res_name %s, atom_name %s',
struc.num_atoms - np.sum(~drop_leaving_atoms_all),
struc.chain_id[drop_leaving_atoms_all],
struc.res_id[drop_leaving_atoms_all],
struc.res_name[drop_leaving_atoms_all],
struc.atom_name[drop_leaving_atoms_all],
)
struc = struc.filter_out(drop_leaving_atoms_all)
# Can filter by bond type without having to iterate over bonds.
if struc.bonds and covalent_bonds_only:
@@ -241,6 +251,7 @@ def create_empty_output_struc_and_layout(
polymer_ligand_bonds: atom_layout.AtomLayout | None = None,
ligand_ligand_bonds: atom_layout.AtomLayout | None = None,
drop_ligand_leaving_atoms: bool = False,
fix_standalone_glycans: bool = False,
) -> tuple[structure.Structure, atom_layout.AtomLayout]:
"""Make zero-coordinate structure from all physical residues.
@@ -252,6 +263,8 @@ def create_empty_output_struc_and_layout(
polymer_ligand_bonds: Bond information for polymer-ligand pairs.
ligand_ligand_bonds: Bond information for ligand-ligand pairs.
drop_ligand_leaving_atoms: Flag for handling leaving atoms for ligands.
fix_standalone_glycans: If True, standalone glycans are preserved when
drop_ligand_leaving_atoms is True.
Returns:
Tuple of structure with all bonds, physical residues and coordinates set to
@@ -292,6 +305,7 @@ def create_empty_output_struc_and_layout(
polymer_ligand_bonds=polymer_ligand_bonds,
ligand_ligand_bonds=ligand_ligand_bonds,
drop_ligand_leaving_atoms=drop_ligand_leaving_atoms,
fix_standalone_glycans=fix_standalone_glycans,
)
empty_output_struc = atom_layout.make_structure(