From 7b816f4035a3db22da88d17d24540c5283316ab9 Mon Sep 17 00:00:00 2001 From: Augustin Zidek Date: Tue, 2 Sep 2025 04:11:09 -0700 Subject: [PATCH] Add support for protein/DNA/RNA/ligand descriptions Suggested in https://github.com/google-deepmind/alphafold3/issues/496. PiperOrigin-RevId: 802081348 Change-Id: I666466fd6a770b6f4a891ed33e6a26651d600c4a --- docs/input.md | 34 ++++++-- src/alphafold3/common/folding_input.py | 108 +++++++++++++++++++++---- 2 files changed, 120 insertions(+), 22 deletions(-) diff --git a/docs/input.md b/docs/input.md index 5b647d7..d5fed0d 100644 --- a/docs/input.md +++ b/docs/input.md @@ -117,7 +117,7 @@ The top-level structure of the input JSON is: "userCCD": "...", # Optional, mutually exclusive with userCCDPath. "userCCDPath": "...", # Optional, mutually exclusive with userCCD. "dialect": "alphafold3", # Required. - "version": 3 # Required. + "version": 4 # Required. } ``` @@ -166,6 +166,8 @@ The top-level `version` field (for the `alphafold3` dialect) can be either `1`, added fields `unpairedMsaPath`, `pairedMsaPath`, and `mmcifPath`. * `3`: added the option of specifying external user-provided CCD using newly added field `userCCDPath`. +* `4`: added the option of specifying textual `description` of protein chains, + RNA chains, DNA chains, or ligands. ## Sequences @@ -186,6 +188,7 @@ Specifies a single protein chain. {"ptmType": "HY3", "ptmPosition": 1}, {"ptmType": "P1L", "ptmPosition": 5} ], + "description": ..., # Optional. "unpairedMsa": ..., # Mutually exclusive with unpairedMsaPath. "unpairedMsaPath": ..., # Mutually exclusive with unpairedMsa. "pairedMsa": ..., # Mutually exclusive with pairedMsaPath. @@ -207,6 +210,9 @@ The fields specify the following: post-translational modifications. Each modification is specified using its CCD code and 1-based residue position. In the example above, we see that the first residue won't be a proline (`P`) but instead `HY3`. +* `description: str`: An optional textual description of this chain. This + field will is only used in the JSON format and serves as a comment + describing this chain. * `unpairedMsa: str`: An optional multiple sequence alignment for this chain. This is specified using the A3M format (equivalent to the FASTA format, but also allows gaps denoted by the hyphen `-` character). See more details @@ -239,6 +245,7 @@ Specifies a single RNA chain. {"modificationType": "2MG", "basePosition": 1}, {"modificationType": "5MC", "basePosition": 4} ], + "description": ..., # Optional. "unpairedMsa": ..., # Mutually exclusive with unpairedMsaPath. "unpairedMsaPath": ... # Mutually exclusive with unpairedMsa. } @@ -255,6 +262,9 @@ The fields specify the following: letters `A`, `C`, `G`, `U`. * `modifications: list[RnaModification]`: An optional list of modifications. Each modification is specified using its CCD code and 1-based base position. +* `description: str`: An optional textual description of this chain. This + field will is only used in the JSON format and serves as a comment + describing this chain. * `unpairedMsa: str`: An optional multiple sequence alignment for this chain. This is specified using the A3M format. See more details below. * `unpairedMsaPath: str`: An optional path to a file that contains the @@ -275,7 +285,8 @@ Specifies a single DNA chain. "modifications": [ {"modificationType": "6OG", "basePosition": 1}, {"modificationType": "6MA", "basePosition": 2} - ] + ], + "description": ... # Optional. } } ``` @@ -290,6 +301,9 @@ The fields specify the following: letters `A`, `C`, `G`, `T`. * `modifications: list[DnaModification]`: An optional list of modifications. Each modification is specified using its CCD code and 1-based base position. +* `description: str`: An optional textual description of this chain. This + field will is only used in the JSON format and serves as a comment + describing this chain. ### Ligands @@ -314,19 +328,22 @@ Specifies a single ligand. Ligands can be specified using 3 different formats: { "ligand": { "id": ["G", "H", "I"], - "ccdCodes": ["ATP"] + "ccdCodes": ["ATP"], + "description": ... # Optional. } }, { "ligand": { "id": "J", - "ccdCodes": ["LIG-1337"] + "ccdCodes": ["LIG-1337"], + "description": ... # Optional. } }, { "ligand": { "id": "K", - "smiles": "CC(=O)OC1C[NH+]2CCC1CC2" + "smiles": "CC(=O)OC1C[NH+]2CCC1CC2", + "description": ... # Optional. } } ``` @@ -342,6 +359,9 @@ The fields specify the following: [user-provided CCD](#user-provided-ccd). * `smiles: str`: An optional string defining the ligand using a SMILES string. The SMILES string must be correctly JSON-escaped. +* `description: str`: An optional textual description of this chain. This + field will is only used in the JSON format and serves as a comment + describing this ligand. Each ligand may be specified using CCD codes or SMILES but not both, i.e. for a given ligand, the `ccdCodes` and `smiles` fields are mutually exclusive. @@ -919,6 +939,7 @@ certain fields and the sequences are not biologically meaningful. {"ptmType": "HY3", "ptmPosition": 1}, {"ptmType": "P1L", "ptmPosition": 5} ], + "description": "10-residue protein with 2 modifications", "unpairedMsa": ..., "pairedMsa": "" } @@ -982,7 +1003,6 @@ certain fields and the sequences are not biologically meaningful. ], "userCCD": ..., "dialect": "alphafold3", - "version": 3 + "version": 4 } - ``` diff --git a/src/alphafold3/common/folding_input.py b/src/alphafold3/common/folding_input.py index ab8c06f..d44c48d 100644 --- a/src/alphafold3/common/folding_input.py +++ b/src/alphafold3/common/folding_input.py @@ -36,7 +36,7 @@ import zstandard as zstd BondAtomId: TypeAlias = tuple[str, int, str] JSON_DIALECT: Final[str] = 'alphafold3' -JSON_VERSIONS: Final[tuple[int, ...]] = (1, 2, 3) +JSON_VERSIONS: Final[tuple[int, ...]] = (1, 2, 3, 4) JSON_VERSION: Final[int] = JSON_VERSIONS[-1] ALPHAFOLDSERVER_JSON_DIALECT: Final[str] = 'alphafoldserver' @@ -127,6 +127,7 @@ class ProteinChain: '_id', '_sequence', '_ptms', + '_description', '_paired_msa', '_unpaired_msa', '_templates', @@ -138,6 +139,7 @@ class ProteinChain: id: str, # pylint: disable=redefined-builtin sequence: str, ptms: Sequence[tuple[str, int]], + description: str | None = None, paired_msa: str | None = None, unpaired_msa: str | None = None, templates: Sequence[Template] | None = None, @@ -149,6 +151,7 @@ class ProteinChain: sequence: The amino acid sequence of the chain. ptms: A list of tuples containing the post-translational modification type and the (1-based) residue index where the modification is applied. + description: An optional textual description of the protein chain. paired_msa: Paired A3M-formatted MSA for this chain. This MSA is not deduplicated and will be used to compute paired features. If None, this field is unset and must be filled in by the data pipeline before @@ -175,6 +178,7 @@ class ProteinChain: self._id = id self._sequence = sequence self._ptms = tuple(ptms) + self._description = description self._paired_msa = paired_msa self._unpaired_msa = unpaired_msa self._templates = tuple(templates) if templates is not None else None @@ -198,6 +202,10 @@ class ProteinChain: def ptms(self) -> Sequence[tuple[str, int]]: return self._ptms + @property + def description(self) -> str | None: + return self._description + @property def paired_msa(self) -> str | None: return self._paired_msa @@ -218,6 +226,7 @@ class ProteinChain: self._id == other._id and self._sequence == other._sequence and self._ptms == other._ptms + and self._description == other._description and self._paired_msa == other._paired_msa and self._unpaired_msa == other._unpaired_msa and self._templates == other._templates @@ -228,6 +237,7 @@ class ProteinChain: self._id, self._sequence, self._ptms, + self._description, self._paired_msa, self._unpaired_msa, self._templates, @@ -238,6 +248,7 @@ class ProteinChain: return hash(( self._sequence, self._ptms, + self._description, self._paired_msa, self._unpaired_msa, self._templates, @@ -298,6 +309,7 @@ class ProteinChain: 'id', 'sequence', 'modifications', + 'description', 'unpairedMsa', 'unpairedMsaPath', 'pairedMsa', @@ -368,6 +380,7 @@ class ProteinChain: id=seq_id or json_dict['id'], sequence=sequence, ptms=ptms, + description=json_dict.get('description', None), paired_msa=paired_msa, unpaired_msa=unpaired_msa, templates=templates, @@ -400,6 +413,8 @@ class ProteinChain: 'pairedMsa': self._paired_msa, 'templates': templates, } + if self._description is not None: + contents['description'] = self._description return {'protein': contents} def to_ccd_sequence(self) -> Sequence[str]: @@ -418,6 +433,7 @@ class ProteinChain: id=self.id, sequence=self._sequence, ptms=self._ptms, + description=self._description, unpaired_msa=self._unpaired_msa or '', paired_msa=self._paired_msa or '', templates=self._templates or [], @@ -427,7 +443,13 @@ class ProteinChain: class RnaChain: """RNA chain input.""" - __slots__ = ('_id', '_sequence', '_modifications', '_unpaired_msa') + __slots__ = ( + '_id', + '_sequence', + '_modifications', + '_description', + '_unpaired_msa', + ) def __init__( self, @@ -435,6 +457,7 @@ class RnaChain: id: str, # pylint: disable=redefined-builtin sequence: str, modifications: Sequence[tuple[str, int]], + description: str | None = None, unpaired_msa: str | None = None, ): """Initializes a single strand RNA chain input. @@ -444,6 +467,7 @@ class RnaChain: sequence: The RNA sequence of the chain. modifications: A list of tuples containing the modification type and the (1-based) residue index where the modification is applied. + description: An optional textual description of the RNA chain. unpaired_msa: Unpaired A3M-formatted MSA for this chain. This will be deduplicated and used to compute unpaired features. If None, this field is unset and must be filled in by the data pipeline before @@ -463,6 +487,7 @@ class RnaChain: self._sequence = sequence # Use hashable container for modifications. self._modifications = tuple(modifications) + self._description = description self._unpaired_msa = unpaired_msa @property @@ -484,6 +509,10 @@ class RnaChain: def modifications(self) -> Sequence[tuple[str, int]]: return self._modifications + @property + def description(self) -> str | None: + return self._description + @property def unpaired_msa(self) -> str | None: return self._unpaired_msa @@ -496,17 +525,27 @@ class RnaChain: self._id == other._id and self._sequence == other._sequence and self._modifications == other._modifications + and self._description == other._description and self._unpaired_msa == other._unpaired_msa ) def __hash__(self) -> int: - return hash( - (self._id, self._sequence, self._modifications, self._unpaired_msa) - ) + return hash(( + self._id, + self._sequence, + self._modifications, + self._description, + self._unpaired_msa, + )) def hash_without_id(self) -> int: """Returns a hash ignoring the ID - useful for deduplication.""" - return hash((self._sequence, self._modifications, self._unpaired_msa)) + return hash(( + self._sequence, + self._modifications, + self._description, + self._unpaired_msa, + )) @classmethod def from_alphafoldserver_dict( @@ -532,7 +571,14 @@ class RnaChain: json_dict = json_dict['rna'] _validate_keys( json_dict.keys(), - {'id', 'sequence', 'unpairedMsa', 'unpairedMsaPath', 'modifications'}, + { + 'id', + 'sequence', + 'modifications', + 'description', + 'unpairedMsa', + 'unpairedMsaPath', + }, ) sequence = json_dict['sequence'] modifications = [ @@ -559,6 +605,7 @@ class RnaChain: id=seq_id or json_dict['id'], sequence=sequence, modifications=modifications, + description=json_dict.get('description', None), unpaired_msa=unpaired_msa, ) @@ -575,6 +622,8 @@ class RnaChain: ], 'unpairedMsa': self._unpaired_msa, } + if self._description is not None: + contents['description'] = self._description return {'rna': contents} def to_ccd_sequence(self) -> Sequence[str]: @@ -600,7 +649,7 @@ class RnaChain: class DnaChain: """Single strand DNA chain input.""" - __slots__ = ('_id', '_sequence', '_modifications') + __slots__ = ('_id', '_sequence', '_modifications', '_description') def __init__( self, @@ -608,6 +657,7 @@ class DnaChain: id: str, # pylint: disable=redefined-builtin sequence: str, modifications: Sequence[tuple[str, int]], + description: str | None = None, ): """Initializes a single strand DNA chain input. @@ -616,6 +666,7 @@ class DnaChain: sequence: The DNA sequence of the chain. modifications: A list of tuples containing the modification type and the (1-based) residue index where the modification is applied. + description: An optional textual description of the DNA chain. """ if not all(res.isalpha() for res in sequence): raise ValueError(f'DNA must contain only letters, got "{sequence}"') @@ -630,6 +681,7 @@ class DnaChain: self._sequence = sequence # Use hashable container for modifications. self._modifications = tuple(modifications) + self._description = description @property def id(self) -> str: @@ -646,6 +698,10 @@ class DnaChain: for r in self.to_ccd_sequence() ]) + @property + def description(self) -> str | None: + return self._description + def __len__(self) -> int: return len(self._sequence) @@ -654,17 +710,20 @@ class DnaChain: self._id == other._id and self._sequence == other._sequence and self._modifications == other._modifications + and self._description == other._description ) def __hash__(self) -> int: - return hash((self._id, self._sequence, self._modifications)) + return hash( + (self._id, self._sequence, self._modifications, self._description) + ) def modifications(self) -> Sequence[tuple[str, int]]: return self._modifications def hash_without_id(self) -> int: """Returns a hash ignoring the ID - useful for deduplication.""" - return hash((self._sequence, self._modifications)) + return hash((self._sequence, self._modifications, self._description)) @classmethod def from_alphafoldserver_dict( @@ -685,7 +744,9 @@ class DnaChain: ) -> Self: """Constructs DnaChain from the AlphaFold JSON dict.""" json_dict = json_dict['dna'] - _validate_keys(json_dict.keys(), {'id', 'sequence', 'modifications'}) + _validate_keys( + json_dict.keys(), {'id', 'sequence', 'modifications', 'description'} + ) sequence = json_dict['sequence'] modifications = [ (mod['modificationType'], mod['basePosition']) @@ -695,6 +756,7 @@ class DnaChain: id=seq_id or json_dict['id'], sequence=sequence, modifications=modifications, + description=json_dict.get('description', None), ) def to_dict( @@ -709,6 +771,8 @@ class DnaChain: for mod in self._modifications ], } + if self._description is not None: + contents['description'] = self._description return {'dna': contents} def to_ccd_sequence(self) -> Sequence[str]: @@ -734,11 +798,13 @@ class Ligand: a bond linking these components should be added to the bonded_atom_pairs Input field. smiles: The SMILES representation of the ligand. + description: An optional textual description of the ligand. """ id: str ccd_ids: Sequence[str] | None = None smiles: str | None = None + description: str | None = None def __post_init__(self): if (self.ccd_ids is None) == (self.smiles is None): @@ -761,7 +827,7 @@ class Ligand: def hash_without_id(self) -> int: """Returns a hash ignoring the ID - useful for deduplication.""" - return hash((self.ccd_ids, self.smiles)) + return hash((self.ccd_ids, self.smiles, self.description)) @classmethod def from_alphafoldserver_dict( @@ -783,7 +849,9 @@ class Ligand: ) -> Self: """Constructs Ligand from the AlphaFold JSON dict.""" json_dict = json_dict['ligand'] - _validate_keys(json_dict.keys(), {'id', 'ccdCodes', 'smiles'}) + _validate_keys( + json_dict.keys(), {'id', 'ccdCodes', 'smiles', 'description'} + ) if json_dict.get('ccdCodes') and json_dict.get('smiles'): raise ValueError( 'Ligand cannot have both CCD code and SMILES set at the same time, ' @@ -797,9 +865,17 @@ class Ligand: 'CCD codes must be a list of strings, got ' f'{type(ccd_codes).__name__} instead: {ccd_codes}' ) - return cls(id=seq_id or json_dict['id'], ccd_ids=ccd_codes) + return cls( + id=seq_id or json_dict['id'], + ccd_ids=ccd_codes, + description=json_dict.get('description', None), + ) elif 'smiles' in json_dict: - return cls(id=seq_id or json_dict['id'], smiles=json_dict['smiles']) + return cls( + id=seq_id or json_dict['id'], + smiles=json_dict['smiles'], + description=json_dict.get('description', None), + ) else: raise ValueError(f'Unknown ligand type: {json_dict}') @@ -812,6 +888,8 @@ class Ligand: contents['ccdCodes'] = self.ccd_ids if self.smiles is not None: contents['smiles'] = self.smiles + if self.description is not None: + contents['description'] = self.description return {'ligand': contents}