diff --git a/run_alphafold.py b/run_alphafold.py index 6bc53dc..4b8dd70 100644 --- a/run_alphafold.py +++ b/run_alphafold.py @@ -53,7 +53,6 @@ from jax import numpy as jnp import numpy as np import tokamax - _HOME_DIR = pathlib.Path.home() _DEFAULT_MODEL_DIR = _HOME_DIR / 'models' _DEFAULT_DB_DIR = _HOME_DIR / 'public_databases' @@ -127,7 +126,6 @@ DB_DIR = flags.DEFINE_multi_string( 'Path to the directory containing the databases. Can be specified multiple' ' times to search multiple directories in order.', ) - _SMALL_BFD_DATABASE_PATH = flags.DEFINE_string( 'small_bfd_database_path', '${DB_DIR}/bfd-first_non_consensus_sequences.fasta', @@ -285,6 +283,14 @@ _CONFORMER_MAX_ITERATIONS = flags.DEFINE_integer( 'conformer search.', lower_bound=0, ) +_FIX_STANDALONE_GLYCANS = flags.DEFINE_bool( + 'fix_standalone_glycans', + False, + 'AlphaFold 3 model training and evaluation filtered out leaving atoms from' + ' glycan ligands even if they were not bonded to anything ("standalone"' + ' glycans). Setting this flag to True fixes this undesirable behavior, but' + ' moves away from the regime where AlphaFold 3 was trained and evaluated.', +) # JAX inference performance tuning. _JAX_COMPILATION_CACHE_DIR = flags.DEFINE_string( @@ -513,10 +519,12 @@ class ResultsForSeed: def predict_structure( fold_input: folding_input.Input, model_runner: ModelRunner, + *, buckets: Sequence[int] | None = None, ref_max_modified_date: datetime.date | None = None, conformer_max_iterations: int | None = None, resolve_msa_overlaps: bool = True, + fix_standalone_glycans: bool = False, ) -> Sequence[ResultsForSeed]: """Runs the full inference pipeline to predict structures for each seed.""" @@ -531,6 +539,7 @@ def predict_structure( ref_max_modified_date=ref_max_modified_date, conformer_max_iterations=conformer_max_iterations, resolve_msa_overlaps=resolve_msa_overlaps, + fix_standalone_glycans=fix_standalone_glycans, ) print( f'Featurising data with {len(fold_input.rng_seeds)} seed(s) took' @@ -731,6 +740,7 @@ def process_fold_input( ref_max_modified_date: datetime.date | None = None, conformer_max_iterations: int | None = None, resolve_msa_overlaps: bool = True, + fix_standalone_glycans: bool = False, force_output_dir: bool = False, compress_large_output_files: bool = False, ) -> folding_input.Input | Sequence[ResultsForSeed]: @@ -759,6 +769,11 @@ def process_fold_input( paper. Set this to false if providing custom paired MSA using the unpaired MSA field to keep it exactly as is as deduplication against the paired MSA could break the manually crafted pairing between MSA sequences. + fix_standalone_glycans: If True, standalone glycans are preserved when + filter_leaving_atoms is True. This is False by default to match the + AlphaFold 3 paper. Note that the model has been trained with the default + setting, so setting this to True may cause non-standard behaviour of the + model. force_output_dir: If True, do not create a new output directory even if the existing one is non-empty. Instead use the existing output directory and potentially overwrite existing files. If False, create a new timestamped @@ -815,6 +830,7 @@ def process_fold_input( ref_max_modified_date=ref_max_modified_date, conformer_max_iterations=conformer_max_iterations, resolve_msa_overlaps=resolve_msa_overlaps, + fix_standalone_glycans=fix_standalone_glycans, ) print(f'Writing outputs with {len(fold_input.rng_seeds)} seed(s)...') write_outputs( @@ -985,6 +1001,7 @@ def main(_): ref_max_modified_date=max_template_date, conformer_max_iterations=_CONFORMER_MAX_ITERATIONS.value, resolve_msa_overlaps=_RESOLVE_MSA_OVERLAPS.value, + fix_standalone_glycans=_FIX_STANDALONE_GLYCANS.value, force_output_dir=_FORCE_OUTPUT_DIR.value, compress_large_output_files=_COMPRESS_LARGE_OUTPUT_FILES.value, ) diff --git a/src/alphafold3/data/featurisation.py b/src/alphafold3/data/featurisation.py index 98e8591..abbf970 100644 --- a/src/alphafold3/data/featurisation.py +++ b/src/alphafold3/data/featurisation.py @@ -42,6 +42,7 @@ def featurise_input( ref_max_modified_date: datetime.date | None = None, conformer_max_iterations: int | None = None, resolve_msa_overlaps: bool = True, + fix_standalone_glycans: bool = False, verbose: bool = False, ) -> Sequence[features.BatchDict]: """Featurise the folding input. @@ -66,6 +67,12 @@ def featurise_input( paper. Set this to false if providing custom paired MSA using the unpaired MSA field to keep it exactly as is as deduplication against the paired MSA could break the manually crafted pairing between MSA sequences. + fix_standalone_glycans: AlphaFold 3 model training and evaluation filtered + out leaving atoms from glycan ligands even if they were not bonded to + anything ("standalone" glycans). Setting this flag to True fixes this + undesirable behavior, but moves away from the regime where AlphaFold 3 was + trained and evaluated. This has only an effect if filter_leaving_atoms is + True in the WholePdbPipeline.Config. verbose: Whether to print progress messages. Returns: @@ -80,6 +87,7 @@ def featurise_input( ref_max_modified_date=ref_max_modified_date, conformer_max_iterations=conformer_max_iterations, resolve_msa_overlaps=resolve_msa_overlaps, + fix_standalone_glycans=fix_standalone_glycans, ), ) diff --git a/src/alphafold3/model/atom_layout/atom_layout.py b/src/alphafold3/model/atom_layout/atom_layout.py index 541fe7f..2a1d667 100644 --- a/src/alphafold3/model/atom_layout/atom_layout.py +++ b/src/alphafold3/model/atom_layout/atom_layout.py @@ -28,7 +28,6 @@ import jax.numpy as jnp import numpy as np from rdkit import Chem - xnp_ndarray: TypeAlias = np.ndarray | jnp.ndarray # pylint: disable=invalid-name NumpyIndex: TypeAlias = Any @@ -641,6 +640,7 @@ def get_link_drop_atoms( is_end_terminus: bool, bonded_atoms: set[str], drop_ligand_leaving_atoms: bool = False, + fix_standalone_glycans: bool = False, ) -> set[str]: """Returns set of atoms that are dropped when this res_name gets linked. @@ -651,6 +651,12 @@ def get_link_drop_atoms( is_end_terminus: whether the residue is the c-terminus bonded_atoms: Names of atoms coming off this residue. drop_ligand_leaving_atoms: Flag to switch on/off leaving atoms for ligands. + fix_standalone_glycans: AlphaFold 3 model training and evaluation filtered + out leaving atoms from glycan ligands even if they were not bonded to + anything ("standalone" glycans). Setting this flag to True fixes this + undesirable behavior, but moves away from the regime where AlphaFold 3 was + trained and evaluated. This has only an effect if + drop_ligand_leaving_atoms is True. Returns: Set of atoms that are dropped when this amino acid gets linked. @@ -677,8 +683,12 @@ def get_link_drop_atoms( *chemical_component_sets.GLYCAN_OTHER_LIGANDS, *chemical_component_sets.GLYCAN_LINKING_LIGANDS, }: - if 'O1' not in bonded_atoms: - drop_atoms.update({'O1'}) + if fix_standalone_glycans: + if bonded_atoms and 'O1' not in bonded_atoms: + drop_atoms.update({'O1'}) + else: + if 'O1' not in bonded_atoms: + drop_atoms.update({'O1'}) return drop_atoms @@ -743,6 +753,7 @@ def make_flat_atom_layout( with_hydrogens: bool = False, skip_unk_residues: bool = True, drop_ligand_leaving_atoms: bool = False, + fix_standalone_glycans: bool = False, ) -> AtomLayout: """Make a flat atom layout for given residues. @@ -761,6 +772,12 @@ def make_flat_atom_layout( compatible with the rest of AlphaFold that does not predict atoms for unknown residues drop_ligand_leaving_atoms: Flag to switch on/ off leaving atoms for ligands. + fix_standalone_glycans: AlphaFold 3 model training and evaluation filtered + out leaving atoms from glycan ligands even if they were not bonded to + anything ("standalone" glycans). Setting this flag to True fixes this + undesirable behavior, but moves away from the regime where AlphaFold 3 was + trained and evaluated. This has only an effect if + drop_ligand_leaving_atoms is True. Returns: an `AtomLayout` object @@ -834,6 +851,7 @@ def make_flat_atom_layout( is_end_terminus=residues.is_end_terminus[idx], bonded_atoms=bonded_atoms, drop_ligand_leaving_atoms=drop_ligand_leaving_atoms, + fix_standalone_glycans=fix_standalone_glycans, ) # If deprotonation info is available, remove the specific atoms. diff --git a/src/alphafold3/model/pipeline/pipeline.py b/src/alphafold3/model/pipeline/pipeline.py index a938ebb..e6f8a53 100644 --- a/src/alphafold3/model/pipeline/pipeline.py +++ b/src/alphafold3/model/pipeline/pipeline.py @@ -121,6 +121,12 @@ class WholePdbPipeline: unpaired MSA field to keep it exactly as is as deduplication against the paired MSA could break the manually crafted pairing between MSA sequences. + fix_standalone_glycans: AlphaFold 3 model training and evaluation filtered + out leaving atoms from glycan ligands even if they were not bonded to + anything ("standalone" glycans). Setting this flag to True fixes this + undesirable behavior, but moves away from the regime where AlphaFold 3 + was trained and evaluated. This has only an effect if + drop_ligand_leaving_atoms is True. """ max_atoms_per_token: int = 24 @@ -141,6 +147,7 @@ class WholePdbPipeline: deterministic_frames: bool = True conformer_max_iterations: int | None = None resolve_msa_overlaps: bool = True + fix_standalone_glycans: bool = False def __init__(self, *, config: Config): """Initializes WholePdb data pipeline. @@ -181,6 +188,7 @@ class WholePdbPipeline: covalent_bonds_only=True, remove_polymer_polymer_bonds=True, remove_bad_bonds=True, + fix_standalone_glycans=self._config.fix_standalone_glycans, ) # No chains after cleaning. @@ -209,6 +217,7 @@ class WholePdbPipeline: polymer_ligand_bonds=polymer_ligand_bonds, ligand_ligand_bonds=ligand_ligand_bonds, drop_ligand_leaving_atoms=self._config.drop_ligand_leaving_atoms, + fix_standalone_glycans=self._config.fix_standalone_glycans, ) ) diff --git a/src/alphafold3/model/pipeline/structure_cleaning.py b/src/alphafold3/model/pipeline/structure_cleaning.py index 531f635..545ea69 100644 --- a/src/alphafold3/model/pipeline/structure_cleaning.py +++ b/src/alphafold3/model/pipeline/structure_cleaning.py @@ -28,6 +28,7 @@ def _get_leaving_atom_mask( chain_type: str, res_id: int, res_name: str, + fix_standalone_glycans: bool, ) -> np.ndarray: """Updates a drop_leaving_atoms mask with new leaving atom locations.""" bonded_atoms = atom_layout.get_bonded_atoms( @@ -44,6 +45,7 @@ def _get_leaving_atom_mask( is_end_terminus=False, bonded_atoms=bonded_atoms, drop_ligand_leaving_atoms=True, + fix_standalone_glycans=fix_standalone_glycans, ) # Default mask where everything is false, which equates to being kept. drop_atom_filter_atoms = struc.chain_id != struc.chain_id @@ -74,6 +76,7 @@ def clean_structure( covalent_bonds_only: bool, remove_polymer_polymer_bonds: bool, remove_bad_bonds: bool, + fix_standalone_glycans: bool, ) -> structure.Structure: """Returns a cleaned version of the input structure. @@ -90,6 +93,12 @@ def clean_structure( covalent_bonds_only: Only include covalent bonds. remove_polymer_polymer_bonds: Remove polymer-polymer bonds. remove_bad_bonds: Whether to remove badly bonded ligands. + fix_standalone_glycans: AlphaFold 3 model training and evaluation filtered + out leaving atoms from glycan ligands even if they were not bonded to + anything ("standalone" glycans). Setting this flag to True fixes this + undesirable behavior, but moves away from the regime where AlphaFold 3 was + trained and evaluated. This has only an effect if filter_leaving_atoms is + True. """ # Drop chains without specified sequences. @@ -128,8 +137,10 @@ def clean_structure( *chemical_component_sets.GLYCAN_LINKING_LIGANDS, } # If only glycan ligands and no O1 atoms, we can do parallel drop. + # If we need to keep O1 for standalone glycans, iterate over each residue. if ( - only_glycan_ligands_for_leaving_atoms + not fix_standalone_glycans + and only_glycan_ligands_for_leaving_atoms and (not (ligand_ligand_bonds.atom_name == 'O1').any()) and (not (polymer_ligand_bonds.atom_name == 'O1').any()) ): @@ -153,23 +164,22 @@ def clean_structure( chain_type=res['chain_type'], res_id=res['res_id'], res_name=res_name, + fix_standalone_glycans=fix_standalone_glycans, ) drop_leaving_atoms_all = np.logical_or( drop_leaving_atoms_all, drop_atom_filter ) - num_atoms_before = struc.num_atoms - struc = struc.filter_out(drop_leaving_atoms_all) - num_atoms_after = struc.num_atoms - - if num_atoms_before > num_atoms_after: - logging.error( - 'Dropped %s atoms from GT struc: chain_id %s res_id %s res_name %s', - num_atoms_before - num_atoms_after, - struc.chain_id, - struc.res_id, - struc.res_name, + if np.any(drop_leaving_atoms_all): + logging.info( + 'Dropped %d atoms: chain_id %s, res_id %s, res_name %s, atom_name %s', + struc.num_atoms - np.sum(~drop_leaving_atoms_all), + struc.chain_id[drop_leaving_atoms_all], + struc.res_id[drop_leaving_atoms_all], + struc.res_name[drop_leaving_atoms_all], + struc.atom_name[drop_leaving_atoms_all], ) + struc = struc.filter_out(drop_leaving_atoms_all) # Can filter by bond type without having to iterate over bonds. if struc.bonds and covalent_bonds_only: @@ -241,6 +251,7 @@ def create_empty_output_struc_and_layout( polymer_ligand_bonds: atom_layout.AtomLayout | None = None, ligand_ligand_bonds: atom_layout.AtomLayout | None = None, drop_ligand_leaving_atoms: bool = False, + fix_standalone_glycans: bool = False, ) -> tuple[structure.Structure, atom_layout.AtomLayout]: """Make zero-coordinate structure from all physical residues. @@ -252,6 +263,8 @@ def create_empty_output_struc_and_layout( polymer_ligand_bonds: Bond information for polymer-ligand pairs. ligand_ligand_bonds: Bond information for ligand-ligand pairs. drop_ligand_leaving_atoms: Flag for handling leaving atoms for ligands. + fix_standalone_glycans: If True, standalone glycans are preserved when + drop_ligand_leaving_atoms is True. Returns: Tuple of structure with all bonds, physical residues and coordinates set to @@ -292,6 +305,7 @@ def create_empty_output_struc_and_layout( polymer_ligand_bonds=polymer_ligand_bonds, ligand_ligand_bonds=ligand_ligand_bonds, drop_ligand_leaving_atoms=drop_ligand_leaving_atoms, + fix_standalone_glycans=fix_standalone_glycans, ) empty_output_struc = atom_layout.make_structure(