mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
feat: atom23 inference changes and training fixes
This commit is contained in:
committed by
Raktim Mitra
parent
4a7aaf8793
commit
ebec466e4f
@@ -9,6 +9,6 @@ weights:
|
||||
beta: 0.5
|
||||
alphas:
|
||||
a_prot: 3.0 # 3 for AF-3
|
||||
a_nuc: 0.0 # 3 for AF-3
|
||||
a_nuc: 3.0 # 3 for AF-3
|
||||
a_ligand: 1.0 # 1 for AF-3
|
||||
a_loi: 5.0 # 5 for AF-3
|
||||
|
||||
82
models/rfd3/configs/experiment/rfd3na.yaml
Normal file
82
models/rfd3/configs/experiment/rfd3na.yaml
Normal file
@@ -0,0 +1,82 @@
|
||||
# @package _global_
|
||||
# Training configuration for RFD3
|
||||
|
||||
defaults:
|
||||
- /debug/default
|
||||
- override /model: rfd3_base
|
||||
#- override /datasets: all
|
||||
- override /logger: csv
|
||||
#- override /logger: wandb
|
||||
- _self_
|
||||
|
||||
name: train-base
|
||||
tags: [print-model]
|
||||
ckpt_path: null
|
||||
|
||||
model:
|
||||
net:
|
||||
token_initializer:
|
||||
token_1d_features:
|
||||
ref_motif_token_type: 3
|
||||
restype: 32
|
||||
is_dna_token: 1
|
||||
is_rna_token: 1
|
||||
is_protein_token: 1
|
||||
atom_1d_features:
|
||||
ref_atom_name_chars: 256
|
||||
ref_element: 128
|
||||
ref_charge: 1
|
||||
ref_mask: 1
|
||||
ref_is_motif_atom_with_fixed_coord: 1
|
||||
ref_is_motif_atom_unindexed: 1
|
||||
has_zero_occupancy: 1
|
||||
ref_pos: 3
|
||||
|
||||
# Guided features
|
||||
ref_atomwise_rasa: 3
|
||||
active_donor: 1
|
||||
active_acceptor: 1
|
||||
is_atom_level_hotspot: 1
|
||||
diffusion_module:
|
||||
n_recycle: 2
|
||||
use_local_token_attention: True
|
||||
diffusion_transformer:
|
||||
n_local_tokens: 32
|
||||
n_keys: 128
|
||||
|
||||
inference_sampler:
|
||||
num_timesteps: 100
|
||||
|
||||
|
||||
datasets:
|
||||
diffusion_batch_size_train: 16
|
||||
crop_size: 256
|
||||
max_atoms_in_crop: 2560 # ~10x crop size.
|
||||
global_transform_args:
|
||||
association_scheme: atom23
|
||||
train_conditions:
|
||||
unconditional:
|
||||
frequency: 2.0
|
||||
island:
|
||||
frequency: 2.0
|
||||
sequence_design:
|
||||
frequency: 0.5
|
||||
tipatom:
|
||||
frequency: 5.0
|
||||
ppi:
|
||||
frequency: 0.0
|
||||
train:
|
||||
# These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data.
|
||||
#pdb:
|
||||
#probability: 0.10
|
||||
#monomer_distillation:
|
||||
#probability: 0.90
|
||||
pdb:
|
||||
probability: 1.0
|
||||
|
||||
trainer:
|
||||
devices_per_node: 1
|
||||
limit_train_batches: 10
|
||||
limit_val_batches: 1
|
||||
validate_every_n_epochs: 5
|
||||
prevalidate: false
|
||||
@@ -242,42 +242,171 @@ SELECTION_NONPROTEIN = [
|
||||
"POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID",
|
||||
]
|
||||
|
||||
backbone_atomscheme_DNA = [' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'"]#, None]
|
||||
backbone_atomscheme_DNA = [
|
||||
" P ",
|
||||
" OP1",
|
||||
" OP2",
|
||||
" O5'",
|
||||
" C5'",
|
||||
" C4'",
|
||||
" O4'",
|
||||
" C3'",
|
||||
" O3'",
|
||||
" C2'",
|
||||
" C1'",
|
||||
] # , None]
|
||||
|
||||
backbone_atomscheme_RNA = [' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'"]
|
||||
backbone_atomscheme_RNA = [
|
||||
" P ",
|
||||
" OP1",
|
||||
" OP2",
|
||||
" O5'",
|
||||
" C5'",
|
||||
" C4'",
|
||||
" O4'",
|
||||
" C3'",
|
||||
" O3'",
|
||||
" C2'",
|
||||
" O2'",
|
||||
" C1'",
|
||||
]
|
||||
|
||||
DNA_atoms = {
|
||||
'DA': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 '],
|
||||
'DC': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 '],
|
||||
'DG': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '],
|
||||
'DT': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C7 ', ' C6 ']}
|
||||
|
||||
RNA_atoms = {
|
||||
'A': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 '],
|
||||
'C': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 '],
|
||||
'G': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '],
|
||||
'U': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C6 ']
|
||||
"DA": [
|
||||
" N9 ",
|
||||
" C8 ",
|
||||
" N7 ",
|
||||
" C5 ",
|
||||
" C6 ",
|
||||
" N6 ",
|
||||
" N1 ",
|
||||
" C2 ",
|
||||
" N3 ",
|
||||
" C4 ",
|
||||
],
|
||||
"DC": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " N4 ", " C5 ", " C6 "],
|
||||
"DG": [
|
||||
" N9 ",
|
||||
" C8 ",
|
||||
" N7 ",
|
||||
" C5 ",
|
||||
" C6 ",
|
||||
" O6 ",
|
||||
" N1 ",
|
||||
" C2 ",
|
||||
" N2 ",
|
||||
" N3 ",
|
||||
" C4 ",
|
||||
],
|
||||
"DT": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " O4 ", " C5 ", " C7 ", " C6 "],
|
||||
}
|
||||
|
||||
association_schemes['atom23'] = {}
|
||||
RNA_atoms = {
|
||||
"A": [
|
||||
" N9 ",
|
||||
" C8 ",
|
||||
" N7 ",
|
||||
" C5 ",
|
||||
" C6 ",
|
||||
" N6 ",
|
||||
" N1 ",
|
||||
" C2 ",
|
||||
" N3 ",
|
||||
" C4 ",
|
||||
],
|
||||
"C": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " N4 ", " C5 ", " C6 "],
|
||||
"G": [
|
||||
" N9 ",
|
||||
" C8 ",
|
||||
" N7 ",
|
||||
" C5 ",
|
||||
" C6 ",
|
||||
" O6 ",
|
||||
" N1 ",
|
||||
" C2 ",
|
||||
" N2 ",
|
||||
" N3 ",
|
||||
" C4 ",
|
||||
],
|
||||
"U": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " O4 ", " C5 ", " C6 "],
|
||||
}
|
||||
|
||||
association_schemes["atom23"] = {}
|
||||
for item in DNA_atoms:
|
||||
association_schemes['atom23'][item] = tuple(backbone_atomscheme_DNA + DNA_atoms[item]+ [None]*(22 - len(DNA_atoms[item] + backbone_atomscheme_DNA)))
|
||||
association_schemes["atom23"][item] = tuple(
|
||||
backbone_atomscheme_DNA
|
||||
+ DNA_atoms[item]
|
||||
+ [None] * (22 - len(DNA_atoms[item] + backbone_atomscheme_DNA))
|
||||
)
|
||||
for item in RNA_atoms:
|
||||
association_schemes['atom23'][item] = tuple(backbone_atomscheme_RNA + RNA_atoms[item]+ [None]*(23 - len(RNA_atoms[item] + backbone_atomscheme_RNA)))
|
||||
association_schemes["atom23"][item] = tuple(
|
||||
backbone_atomscheme_RNA
|
||||
+ RNA_atoms[item]
|
||||
+ [None] * (23 - len(RNA_atoms[item] + backbone_atomscheme_RNA))
|
||||
)
|
||||
|
||||
for item in association_schemes['dense']:
|
||||
association_schemes['atom23'][item] = association_schemes['dense'][item]
|
||||
for item in association_schemes["dense"]:
|
||||
association_schemes["atom23"][item] = association_schemes["dense"][item]
|
||||
|
||||
association_schemes['atom23']['DX'] = (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'", None, None, None, None, None, None, None, None, None, None, None) #rna_mask
|
||||
association_schemes['atom23']['X'] = (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'", None, None, None, None, None, None, None, None, None, None, None)#rna mask
|
||||
association_schemes["atom23"]["DX"] = (
|
||||
" P ",
|
||||
" OP1",
|
||||
" OP2",
|
||||
" O5'",
|
||||
" C5'",
|
||||
" C4'",
|
||||
" O4'",
|
||||
" C3'",
|
||||
" O3'",
|
||||
" C2'",
|
||||
" C1'",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
) # rna_mask
|
||||
association_schemes["atom23"]["X"] = (
|
||||
" P ",
|
||||
" OP1",
|
||||
" OP2",
|
||||
" O5'",
|
||||
" C5'",
|
||||
" C4'",
|
||||
" O4'",
|
||||
" C3'",
|
||||
" O3'",
|
||||
" C2'",
|
||||
" O2'",
|
||||
" C1'",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
) # rna mask
|
||||
|
||||
ATOM23_ATOM_NAMES_RNA = np.array(
|
||||
[item.strip() for item in backbone_atomscheme_RNA] + [f"V{i}" for i in range(23 - len(backbone_atomscheme_RNA))]
|
||||
[item.strip() for item in backbone_atomscheme_RNA]
|
||||
+ [f"V{i}" for i in range(23 - len(backbone_atomscheme_RNA))]
|
||||
)
|
||||
"""Atom23 atom names (e.g. CA, V1)"""
|
||||
|
||||
ATOM23_ATOM_ELEMENTS_RNA = np.array(
|
||||
["P", "O", "O", "O", "C", "C", "O", "C","O", "C", "O", "C"] + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(23 - len(backbone_atomscheme_RNA))]
|
||||
["P", "O", "O", "O", "C", "C", "O", "C", "O", "C", "O", "C"]
|
||||
+ [VIRTUAL_ATOM_ELEMENT_NAME for i in range(23 - len(backbone_atomscheme_RNA))]
|
||||
)
|
||||
"""Atom23 element names (e.g. C, VX)"""
|
||||
|
||||
@@ -285,12 +414,14 @@ ATOM23_ATOM_NAME_TO_ELEMENT = {
|
||||
name: elem for name, elem in zip(ATOM23_ATOM_NAMES_RNA, ATOM23_ATOM_ELEMENTS_RNA)
|
||||
}
|
||||
ATOM23_ATOM_NAMES_DNA = np.array(
|
||||
[item.strip() for item in backbone_atomscheme_DNA] + [f"V{i}" for i in range(22 - len(backbone_atomscheme_DNA))]
|
||||
[item.strip() for item in backbone_atomscheme_DNA]
|
||||
+ [f"V{i}" for i in range(22 - len(backbone_atomscheme_DNA))]
|
||||
)
|
||||
"""Atom23 atom names (e.g. CA, V1)"""
|
||||
|
||||
ATOM23_ATOM_ELEMENTS_DNA = np.array(
|
||||
["P", "O", "O", "O", "C", "C", "O", "C","O", "C", "C"] + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(22 - len(backbone_atomscheme_DNA))]
|
||||
["P", "O", "O", "O", "C", "C", "O", "C", "O", "C", "C"]
|
||||
+ [VIRTUAL_ATOM_ELEMENT_NAME for i in range(22 - len(backbone_atomscheme_DNA))]
|
||||
)
|
||||
"""Atom23 element names (e.g. C, VX)"""
|
||||
|
||||
@@ -307,4 +438,3 @@ association_schemes_stripped = {
|
||||
|
||||
backbone_atoms_RNA = strip_list(backbone_atomscheme_RNA)
|
||||
backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA)
|
||||
|
||||
|
||||
@@ -30,9 +30,12 @@ from rfd3.constants import (
|
||||
OPTIONAL_CONDITIONING_VALUES,
|
||||
REQUIRED_CONDITIONING_ANNOTATION_VALUES,
|
||||
REQUIRED_INFERENCE_ANNOTATIONS,
|
||||
backbone_atoms_DNA,
|
||||
backbone_atoms_RNA,
|
||||
)
|
||||
from rfd3.inference.legacy_input_parsing import (
|
||||
create_atom_array_from_design_specification_legacy,
|
||||
reorder_atoms_per_residue,
|
||||
)
|
||||
from rfd3.inference.parsing import InputSelection
|
||||
from rfd3.inference.symmetry.symmetry_utils import (
|
||||
@@ -67,7 +70,6 @@ logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Custom infer_ori functions
|
||||
#################################################################################
|
||||
@@ -505,6 +507,21 @@ class DesignInputSpecification(BaseModel):
|
||||
def build(self, return_metadata=False):
|
||||
"""Main build pipeline."""
|
||||
atom_array_input_annotated = copy.deepcopy(self.atom_array_input)
|
||||
|
||||
########## reorder NA atoms ###########
|
||||
is_dna = np.isin(atom_array_input_annotated.res_name, ["DA", "DC", "DG", "DT"])
|
||||
is_rna = np.isin(atom_array_input_annotated.res_name, ["A", "C", "G", "U"])
|
||||
dna_array = atom_array_input_annotated[is_dna]
|
||||
rna_array = atom_array_input_annotated[is_rna]
|
||||
|
||||
atom_array_input_annotated[is_dna] = reorder_atoms_per_residue(
|
||||
dna_array, backbone_atoms_DNA
|
||||
)
|
||||
atom_array_input_annotated[is_rna] = reorder_atoms_per_residue(
|
||||
rna_array, backbone_atoms_RNA
|
||||
)
|
||||
#######################################
|
||||
|
||||
atom_array = self._build_init(atom_array_input_annotated)
|
||||
|
||||
# Apply post-processing
|
||||
@@ -894,31 +911,52 @@ def validator_context(validator_name: str, data: dict = None):
|
||||
raise e
|
||||
|
||||
|
||||
def create_diffused_residues(n, additional_annotations=None):
|
||||
def create_diffused_residues(n, additional_annotations=None, polymer_type="P"):
|
||||
from rfd3.constants import (
|
||||
ATOM23_ATOM_NAME_TO_ELEMENT,
|
||||
backbone_atoms_DNA,
|
||||
backbone_atoms_RNA,
|
||||
)
|
||||
|
||||
if n <= 0:
|
||||
raise ValueError(f"Negative/null residue count ({n}) not allowed.")
|
||||
|
||||
if polymer_type == "P":
|
||||
res_name = "ALA"
|
||||
bb_len = 5
|
||||
bb_atom_names = ["N", "CA", "C", "O", "CB"]
|
||||
elif polymer_type == "R":
|
||||
res_name = "A"
|
||||
bb_len = len(backbone_atoms_RNA)
|
||||
bb_atom_names = strip_list(backbone_atoms_RNA)
|
||||
elif polymer_type == "D":
|
||||
res_name = "DA"
|
||||
bb_len = len(backbone_atoms_DNA)
|
||||
bb_atom_names = strip_list(backbone_atoms_DNA)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"invalid polymer type detected: {polymer_type}, check contig!"
|
||||
)
|
||||
|
||||
bb_elements = [ATOM23_ATOM_NAME_TO_ELEMENT[item] for item in bb_atom_names]
|
||||
|
||||
atoms = []
|
||||
[
|
||||
atoms.extend(
|
||||
[
|
||||
struc.Atom(
|
||||
np.array([0.0, 0.0, 0.0], dtype=np.float32),
|
||||
res_name="ALA",
|
||||
res_name=res_name,
|
||||
res_id=idx,
|
||||
)
|
||||
for _ in range(5)
|
||||
for _ in range(bb_len)
|
||||
]
|
||||
)
|
||||
for idx in range(1, n + 1)
|
||||
]
|
||||
array = struc.array(atoms)
|
||||
array.set_annotation(
|
||||
"element", np.array(["N", "C", "C", "O", "C"] * n, dtype="<U2")
|
||||
)
|
||||
array.set_annotation(
|
||||
"atom_name", np.array(["N", "CA", "C", "O", "CB"] * n, dtype="<U2")
|
||||
)
|
||||
array.set_annotation("element", np.array(bb_elements * n, dtype="<U2"))
|
||||
array.set_annotation("atom_name", np.array(bb_atom_names * n, dtype="<U2"))
|
||||
array = set_default_conditioning_annotations(
|
||||
array, motif=False, additional=additional_annotations
|
||||
)
|
||||
@@ -1065,14 +1103,20 @@ def accumulate_components(
|
||||
np.zeros(token.shape[0], dtype=int),
|
||||
)
|
||||
else:
|
||||
n = int(component)
|
||||
if component[-1] in ["P", "R", "D"]: # if polymer type specified
|
||||
polymer_type = component[-1] # can be 'P'rotein, 'R'NA, 'D'NA
|
||||
n = int(component[:-1])
|
||||
else:
|
||||
polymer_type = "P"
|
||||
n = int(components)
|
||||
|
||||
# ... Skip if none or unindexed
|
||||
if n == 0 or unindexed_components_started:
|
||||
res_id += n
|
||||
continue
|
||||
|
||||
# ... Create diffused residues
|
||||
token = create_diffused_residues(n, all_annots)
|
||||
token = create_diffused_residues(n, all_annots, polymer_type)
|
||||
|
||||
# ... Set index of insertion
|
||||
token = set_indices(
|
||||
|
||||
@@ -5,16 +5,19 @@ from os import PathLike
|
||||
|
||||
import biotite.structure as struc
|
||||
import numpy as np
|
||||
from atomworks.constants import STANDARD_AA
|
||||
from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA
|
||||
from atomworks.io.utils.io_utils import to_cif_file
|
||||
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
||||
from atomworks.ml.utils.token import (
|
||||
get_token_starts,
|
||||
)
|
||||
from biotite.structure import AtomArray, concatenate, get_residue_starts
|
||||
from rfd3.constants import (
|
||||
INFERENCE_ANNOTATIONS,
|
||||
OPTIONAL_CONDITIONING_VALUES,
|
||||
REQUIRED_INFERENCE_ANNOTATIONS,
|
||||
backbone_atoms_DNA,
|
||||
backbone_atoms_RNA,
|
||||
)
|
||||
from rfd3.inference.symmetry.symmetry_utils import (
|
||||
center_symmetric_src_atom_array,
|
||||
@@ -183,6 +186,27 @@ def fetch_motif_residue_(
|
||||
subarray.set_annotation(
|
||||
"is_motif_atom_with_fixed_seq", np.zeros(subarray.shape[0], dtype=int)
|
||||
)
|
||||
elif redesign_motif_sidechains and res_name in (STANDARD_DNA + STANDARD_RNA):
|
||||
is_backbone = np.isin(subarray.atom_name, backbone_atoms_RNA)
|
||||
subarray.set_annotation("is_motif_atom", is_backbone)
|
||||
|
||||
subarray.set_annotation(
|
||||
"is_motif_atom", is_backbone & (subarray.atom_name != "C1'")
|
||||
)
|
||||
|
||||
subarray.set_annotation(
|
||||
"is_motif_atom_with_fixed_seq", np.zeros(subarray.shape[0], dtype=int)
|
||||
)
|
||||
if res_name in STANDARD_DNA:
|
||||
subarray.res_name = np.full_like(
|
||||
subarray.res_name, "DA", dtype=subarray.res_name.dtype
|
||||
)
|
||||
else:
|
||||
subarray.res_name = np.full_like(
|
||||
subarray.res_name, "A", dtype=subarray.res_name.dtype
|
||||
)
|
||||
subarray = subarray[is_backbone]
|
||||
|
||||
elif to_index or to_unindex:
|
||||
# If the residue is in the contig or unindexed components,
|
||||
# we set all atoms in the residue to be motif atoms
|
||||
@@ -195,6 +219,8 @@ def fetch_motif_residue_(
|
||||
f"{src_chain}{src_resid} is not found in fixed_atoms, contig or unindex contig."
|
||||
"Please check your input and contig specification."
|
||||
)
|
||||
|
||||
|
||||
if unfix_all or f"{src_chain}{src_resid}" in unfix_residues:
|
||||
subarray.set_annotation(
|
||||
"is_motif_atom_with_fixed_coord", np.zeros(subarray.shape[0], dtype=int)
|
||||
@@ -507,6 +533,17 @@ def create_atom_array_from_design_specification_legacy(
|
||||
fixed_atoms = {}
|
||||
|
||||
optional_conditions = []
|
||||
|
||||
########## reorder NA atoms ###########
|
||||
is_dna = np.isin(atom_array_input.res_name, ["DA", "DC", "DG", "DT"])
|
||||
is_rna = np.isin(atom_array_input.res_name, ["A", "C", "G", "U"])
|
||||
dna_array = atom_array_input[is_dna]
|
||||
rna_array = atom_array_input[is_rna]
|
||||
|
||||
atom_array_input[is_dna] = reorder_atoms_per_residue(dna_array, backbone_atoms_DNA)
|
||||
atom_array_input[is_rna] = reorder_atoms_per_residue(rna_array, backbone_atoms_RNA)
|
||||
#######################################
|
||||
|
||||
if exists(atomwise_rasa):
|
||||
set_atom_level_argument(atom_array_input, atomwise_rasa, "rasa_bin")
|
||||
optional_conditions.append("rasa_bin")
|
||||
@@ -731,3 +768,46 @@ def create_atom_array_from_design_specification_legacy(
|
||||
to_cif_file(atom_array, out_path, extra_fields=INFERENCE_ANNOTATIONS)
|
||||
|
||||
return atom_array
|
||||
|
||||
|
||||
def reorder_atoms_per_residue(
|
||||
atom_array: AtomArray, desired_order: list[str]
|
||||
) -> AtomArray:
|
||||
"""
|
||||
Reorder atoms within each residue of an AtomArray.
|
||||
Atoms in `desired_order` appear first (in that order), followed by all others
|
||||
in original order. Faster version using get_residue_starts().
|
||||
|
||||
Parameters:
|
||||
- atom_array: AtomArray to reorder.
|
||||
- desired_order: List of atom names in the desired per-residue order.
|
||||
|
||||
Returns:
|
||||
- AtomArray with reordered atoms per residue.
|
||||
"""
|
||||
if len(atom_array) == 0:
|
||||
return atom_array
|
||||
res_starts = get_residue_starts(atom_array)
|
||||
res_starts = np.append(res_starts, len(atom_array)) # add end index for slicing
|
||||
reordered_chunks = []
|
||||
order_dict = {name: i for i, name in enumerate(desired_order)}
|
||||
|
||||
for i in range(len(res_starts) - 1):
|
||||
start, end = res_starts[i], res_starts[i + 1]
|
||||
residue = atom_array[start:end]
|
||||
|
||||
# Boolean masks for matching and non-matching atom names
|
||||
in_order_mask = np.isin(residue.atom_name, desired_order)
|
||||
not_in_order_mask = ~in_order_mask
|
||||
|
||||
# Sort matching atoms by desired order
|
||||
atoms_in_order = residue[in_order_mask]
|
||||
sort_idx = np.argsort([order_dict[name] for name in atoms_in_order.atom_name])
|
||||
ordered_atoms = atoms_in_order[sort_idx]
|
||||
|
||||
# Remaining atoms as-is
|
||||
remaining_atoms = residue[not_in_order_mask]
|
||||
|
||||
# Concatenate reordered residue
|
||||
reordered_chunks.append(concatenate([ordered_atoms, remaining_atoms]))
|
||||
return concatenate(reordered_chunks)
|
||||
|
||||
@@ -428,9 +428,14 @@ class AADesignTrainer(FabricTrainer):
|
||||
|
||||
# ... Delete virtual atoms and assign atom names and elements
|
||||
if self.cleanup_virtual_atoms:
|
||||
atom_array = _cleanup_virtual_atoms_and_assign_atom_name_elements(
|
||||
atom_array, association_scheme=self.association_scheme
|
||||
)
|
||||
try:
|
||||
atom_array = _cleanup_virtual_atoms_and_assign_atom_name_elements(
|
||||
atom_array, association_scheme=self.association_scheme
|
||||
)
|
||||
except Exception as e:
|
||||
global_logger.warning(
|
||||
f"Failed to cleanup virtual atoms from diffusion output: {e}"
|
||||
)
|
||||
|
||||
# ... When cleaning up virtual atoms, we can also calculate native_array_metricsl
|
||||
metadata_dict[i]["metrics"] |= get_all_backbone_metrics(
|
||||
|
||||
@@ -11,6 +11,8 @@ from biotite.structure import concatenate, infer_elements
|
||||
from jaxtyping import Float, Int
|
||||
from rfd3.constants import (
|
||||
ATOM14_ATOM_NAMES,
|
||||
ATOM23_ATOM_NAMES_DNA,
|
||||
ATOM23_ATOM_NAMES_RNA,
|
||||
VIRTUAL_ATOM_ELEMENT_NAME,
|
||||
association_schemes,
|
||||
association_schemes_stripped,
|
||||
@@ -252,13 +254,19 @@ def _readout_seq_from_struc(
|
||||
continue
|
||||
|
||||
# ... Find the index of virtual atom names in the standard atom14 names
|
||||
ATOM_NAMES = ATOM14_ATOM_NAMES
|
||||
if restype in STANDARD_DNA:
|
||||
ATOM_NAMES = ATOM23_ATOM_NAMES_DNA
|
||||
if restype in STANDARD_RNA:
|
||||
ATOM_NAMES = ATOM23_ATOM_NAMES_RNA
|
||||
|
||||
atom_name_idx_in_atom14_scheme = np.array(
|
||||
[
|
||||
np.where(ATOM14_ATOM_NAMES == atom_name)[0][0]
|
||||
np.where(ATOM_NAMES == atom_name)[0][0]
|
||||
for atom_name in cur_pred_res_atom_names
|
||||
]
|
||||
) # five backbone atoms + some virtual atoms, returning e.g. [0, 1, 2, 3, 4, 11, 7]
|
||||
atom14_scheme_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool)
|
||||
atom14_scheme_mask = np.zeros_like(ATOM_NAMES, dtype=bool)
|
||||
atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True
|
||||
|
||||
# ... Find the matched restype by checking if all the non-None posititons and None positions match
|
||||
@@ -427,12 +435,30 @@ def process_unindexed_outputs(
|
||||
else:
|
||||
join_atom = None
|
||||
|
||||
if join_atom is None:
|
||||
pass
|
||||
else:
|
||||
dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
|
||||
|
||||
elif not np.any(
|
||||
np.isin(
|
||||
token.atom_name, [item.replace(" ", "") for item in backbone_atoms_RNA]
|
||||
)
|
||||
):
|
||||
if np.sum(token.atomize) == 1:
|
||||
join_atom = np.where(token.atomize)[0][0]
|
||||
elif "C1'" in token.atom_name:
|
||||
join_atom = np.where(token.atom_name == "C1'")[0][0]
|
||||
else:
|
||||
join_atom = None
|
||||
|
||||
if join_atom is None:
|
||||
global_logger.warning(
|
||||
f"Token {token_pdb_id} does not contain backbone atoms or CB, skipping join point distance calculation {token}."
|
||||
"Skipping joint point rmsd, neither protein or NA backbone"
|
||||
)
|
||||
else:
|
||||
dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
|
||||
|
||||
metadata["join_point_rmsd_by_token"][token_pdb_id] = dist
|
||||
|
||||
metadata["diffused_index_map"][token_pdb_id] = f"{chain_id}{res_id}"
|
||||
|
||||
@@ -281,7 +281,7 @@ class SampleConditioningType(Transform):
|
||||
cond = valid_conditions[i_cond]
|
||||
|
||||
cond.association_scheme = self.association_scheme
|
||||
|
||||
|
||||
data["sampled_condition"] = cond
|
||||
data["sampled_condition_name"] = cond.name
|
||||
data["sampled_condition_cls"] = cond.__class__
|
||||
@@ -299,6 +299,7 @@ class SampleConditioningFlags(Transform):
|
||||
"AssignTypes",
|
||||
"SampleConditioningType",
|
||||
] # We use is_protein in the PPI training condition
|
||||
|
||||
def __init__(self, association_scheme):
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
@@ -375,13 +376,15 @@ class UnindexFlaggedTokens(Transform):
|
||||
token.res_id = token.res_id + max_resid
|
||||
token.is_C_terminus[:] = False
|
||||
token.is_N_terminus[:] = False
|
||||
|
||||
if association_scheme is not 'atom23':
|
||||
|
||||
if not self.association_scheme == "atom23":
|
||||
assert token.is_protein.all(), f"Cannot unindex non-protein token: {token} unless using atom23 association scheme"
|
||||
token = add_representative_atom(token, central_atom=self.central_atom)
|
||||
else:
|
||||
if token.is_protein.all():
|
||||
token = add_representative_atom(token, central_atom=self.central_atom)
|
||||
token = add_representative_atom(
|
||||
token, central_atom=self.central_atom
|
||||
)
|
||||
else:
|
||||
token = add_representative_atom(token, central_atom="C1'")
|
||||
|
||||
|
||||
@@ -697,11 +697,12 @@ class AddAdditional1dFeaturesToFeats(Transform):
|
||||
token_1d_features,
|
||||
atom_1d_features,
|
||||
autofill_zeros_if_not_present_in_atomarray=False,
|
||||
association_scheme='atom14'
|
||||
association_scheme="atom14",
|
||||
):
|
||||
self.autofill = autofill_zeros_if_not_present_in_atomarray
|
||||
self.token_1d_features = token_1d_features
|
||||
self.atom_1d_features = atom_1d_features
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
def check_input(self, data) -> None:
|
||||
check_contains_keys(data, ["atom_array"])
|
||||
@@ -753,11 +754,13 @@ class AddAdditional1dFeaturesToFeats(Transform):
|
||||
"""
|
||||
if "feats" not in data.keys():
|
||||
data["feats"] = {}
|
||||
|
||||
if association_scheme == 'atom23':
|
||||
data['atom_array'].set_annotation('is_protein_token', data['atom_array'].is_protein)
|
||||
data['atom_array'].set_annotation('is_dna_token', data['atom_array'].is_dna)
|
||||
data['atom_array'].set_annotation('is_rna_token', data['atom_array'].is_rna)
|
||||
|
||||
if self.association_scheme == "atom23":
|
||||
data["atom_array"].set_annotation(
|
||||
"is_protein_token", data["atom_array"].is_protein
|
||||
)
|
||||
data["atom_array"].set_annotation("is_dna_token", data["atom_array"].is_dna)
|
||||
data["atom_array"].set_annotation("is_rna_token", data["atom_array"].is_rna)
|
||||
|
||||
for feature_name, n_dims in self.token_1d_features.items():
|
||||
data = self.generate_feature(feature_name, n_dims, data, "token")
|
||||
|
||||
@@ -383,7 +383,7 @@ def build_atom14_base_pipeline_(
|
||||
train_conditions=train_conditions,
|
||||
meta_conditioning_probabilities=meta_conditioning_probabilities,
|
||||
sequence_encoding=af3_sequence_encoding,
|
||||
association_scheme=association_scheme
|
||||
association_scheme=association_scheme,
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -423,7 +423,9 @@ def build_atom14_base_pipeline_(
|
||||
# ... Add global token features (since number of tokens is fixed after cropping)
|
||||
transforms.append(AddGlobalTokenIdAnnotation())
|
||||
# ... Create masks (NOTE: Modulates token count, and resets global token id if necessary)
|
||||
transforms.append(TrainingRoute(SampleConditioningFlags(association_scheme=association_scheme)))
|
||||
transforms.append(
|
||||
TrainingRoute(SampleConditioningFlags(association_scheme=association_scheme))
|
||||
)
|
||||
|
||||
# Post-crop transforms
|
||||
transforms.append(
|
||||
@@ -443,7 +445,9 @@ def build_atom14_base_pipeline_(
|
||||
sharding_depth=1,
|
||||
),
|
||||
# ... Fuse inference and training conditioning assignments
|
||||
UnindexFlaggedTokens(central_atom=central_atom),
|
||||
UnindexFlaggedTokens(
|
||||
central_atom=central_atom, association_scheme=association_scheme
|
||||
),
|
||||
# ... Virtual atom padding (NOTE: Last transform which modulates atom count)
|
||||
PadTokensWithVirtualAtoms(
|
||||
n_atoms_per_token=n_atoms_per_token,
|
||||
@@ -519,7 +523,7 @@ def build_atom14_base_pipeline_(
|
||||
autofill_zeros_if_not_present_in_atomarray=True,
|
||||
token_1d_features=token_1d_features,
|
||||
atom_1d_features=atom_1d_features,
|
||||
association_scheme=association_scheme
|
||||
association_scheme=association_scheme,
|
||||
),
|
||||
AddAF3TokenBondFeatures(),
|
||||
AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding),
|
||||
|
||||
@@ -58,6 +58,8 @@ class IslandCondition(TrainingCondition):
|
||||
Select islands as motif and assign conditioning strategies.
|
||||
"""
|
||||
|
||||
association_scheme = "atom14"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -70,11 +72,9 @@ class IslandCondition(TrainingCondition):
|
||||
p_fix_motif_coordinates,
|
||||
p_fix_motif_sequence,
|
||||
p_unindex_motif_tokens,
|
||||
association_scheme = 'atom14',
|
||||
):
|
||||
self.name = name
|
||||
self.frequency = frequency
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
# Token selection
|
||||
self.island_sampling_kwargs = island_sampling_kwargs
|
||||
@@ -89,15 +89,13 @@ class IslandCondition(TrainingCondition):
|
||||
self.p_fix_motif_coordinates = p_fix_motif_coordinates
|
||||
self.p_fix_motif_sequence = p_fix_motif_sequence
|
||||
self.p_unindex_motif_tokens = p_unindex_motif_tokens
|
||||
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
def is_valid_for_example(self, data) -> bool:
|
||||
is_protein = data["atom_array"].is_protein
|
||||
is_dna = data["atom_array"].is_dna
|
||||
is_rna = data["atom_array"].is_rna
|
||||
### updating this to allow other polymers
|
||||
if self.association_scheme is not 'atom23':
|
||||
if not self.association_scheme == "atom23":
|
||||
if not np.any(is_protein | is_dna | is_rna):
|
||||
return False
|
||||
else:
|
||||
@@ -113,8 +111,12 @@ class IslandCondition(TrainingCondition):
|
||||
token_level_array = atom_array[get_token_starts(atom_array)]
|
||||
|
||||
# initialize motif tokens as all non-protein tokens
|
||||
if self.association_scheme is 'atom23':
|
||||
polymer_mask = (token_level_array.is_protein | token_level_array.is_dna | token_level_array.is_rna)
|
||||
if self.association_scheme == "atom23":
|
||||
polymer_mask = (
|
||||
token_level_array.is_protein
|
||||
| token_level_array.is_dna
|
||||
| token_level_array.is_rna
|
||||
)
|
||||
is_motif_token = np.asarray(~polymer_mask, dtype=bool).copy()
|
||||
n_polymer_tokens = np.sum(polymer_mask)
|
||||
islands_mask = sample_island_tokens(
|
||||
@@ -123,13 +125,15 @@ class IslandCondition(TrainingCondition):
|
||||
)
|
||||
is_motif_token[polymer_mask] = islands_mask
|
||||
else:
|
||||
is_motif_token = np.asarray(~token_level_array.is_protein, dtype=bool).copy()
|
||||
is_motif_token = np.asarray(
|
||||
~token_level_array.is_protein, dtype=bool
|
||||
).copy()
|
||||
n_protein_tokens = np.sum(token_level_array.is_protein)
|
||||
|
||||
slands_mask = sample_island_tokens(
|
||||
_protein_tokens,
|
||||
*self.island_sampling_kwargs,
|
||||
|
||||
islands_mask = sample_island_tokens(
|
||||
n_protein_tokens,
|
||||
**self.island_sampling_kwargs,
|
||||
)
|
||||
is_motif_token[token_level_array.is_protein] = islands_mask
|
||||
|
||||
# TODO: Atoms with covalent bonds should be motif, needs FlagAndReassignCovalentModifications transform prior to this
|
||||
@@ -160,7 +164,7 @@ class IslandCondition(TrainingCondition):
|
||||
is_motif_atom = sample_motif_subgraphs(
|
||||
atom_array=atom_array,
|
||||
**self.subgraph_sampling_kwargs,
|
||||
association_scheme=self.association_scheme
|
||||
association_scheme=self.association_scheme,
|
||||
)
|
||||
|
||||
# We also only want resolved atoms to be motif
|
||||
@@ -182,7 +186,7 @@ class IslandCondition(TrainingCondition):
|
||||
p_fix_motif_sequence=self.p_fix_motif_sequence,
|
||||
p_fix_motif_coordinates=self.p_fix_motif_coordinates,
|
||||
p_unindex_motif_tokens=self.p_unindex_motif_tokens,
|
||||
association_scheme=self.association_scheme
|
||||
association_scheme=self.association_scheme,
|
||||
)
|
||||
|
||||
atom_array.set_annotation(
|
||||
@@ -202,7 +206,7 @@ class PPICondition(TrainingCondition):
|
||||
"""Get condition indicating what is motif and what is to be diffused for protein-protein interaction training."""
|
||||
|
||||
name = "ppi"
|
||||
association_scheme = 'atom14'
|
||||
association_scheme = "atom14"
|
||||
|
||||
def is_valid_for_example(self, data):
|
||||
# Extract relevant data
|
||||
@@ -301,7 +305,7 @@ class SubtypeCondition(TrainingCondition):
|
||||
"""
|
||||
|
||||
name = "subtype"
|
||||
association_scheme = 'atom14'
|
||||
association_scheme = "atom14"
|
||||
|
||||
def __init__(self, frequency: float, subtype: list[str], fix_pos: bool = False):
|
||||
self.frequency = frequency
|
||||
@@ -397,7 +401,7 @@ def sample_motif_subgraphs(
|
||||
hetatom_n_bond_expectation,
|
||||
residue_p_fix_all,
|
||||
hetatom_p_fix_all,
|
||||
association_scheme = 'atom14'
|
||||
association_scheme="atom14",
|
||||
):
|
||||
"""
|
||||
Returns a boolean mask over atoms, indicating which atoms are part of the sampled motif.
|
||||
@@ -431,10 +435,14 @@ def sample_motif_subgraphs(
|
||||
"p_fix_all": residue_p_fix_all,
|
||||
}
|
||||
|
||||
if association_scheme is 'atom23':
|
||||
clause = atom_array_subset.is_protein.all() | atom_array_subset.is_dna.all() | atom_array_subset.is_rna.all()
|
||||
if association_scheme == "atom23":
|
||||
clause = (
|
||||
atom_array_subset.is_protein.all()
|
||||
| atom_array_subset.is_dna.all()
|
||||
| atom_array_subset.is_rna.all()
|
||||
)
|
||||
else:
|
||||
clause = atom_array_subset.is_potein.all()
|
||||
clause = atom_array_subset.is_protein.all()
|
||||
|
||||
if not clause:
|
||||
args.update(
|
||||
@@ -465,12 +473,14 @@ def sample_conditioning_strategy(
|
||||
p_fix_motif_sequence,
|
||||
p_fix_motif_coordinates,
|
||||
p_unindex_motif_tokens,
|
||||
association_scheme
|
||||
association_scheme,
|
||||
):
|
||||
atom_array.set_annotation(
|
||||
"is_motif_atom_with_fixed_seq",
|
||||
sample_is_motif_atom_with_fixed_seq(
|
||||
atom_array, p_fix_motif_sequence=p_fix_motif_sequence, association_scheme=association_scheme
|
||||
atom_array,
|
||||
p_fix_motif_sequence=p_fix_motif_sequence,
|
||||
association_scheme=association_scheme,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -491,7 +501,9 @@ def sample_conditioning_strategy(
|
||||
return atom_array
|
||||
|
||||
|
||||
def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence, association_scheme):
|
||||
def sample_is_motif_atom_with_fixed_seq(
|
||||
atom_array, p_fix_motif_sequence, association_scheme
|
||||
):
|
||||
"""
|
||||
Samples what kind of conditioning to apply to motif tokens.
|
||||
|
||||
@@ -504,10 +516,11 @@ def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence, associ
|
||||
is_motif_atom_with_fixed_seq = np.zeros(atom_array.array_length(), dtype=bool)
|
||||
|
||||
# By default reveal sequence for non-protein
|
||||
|
||||
if association_scheme is not 'atom23':
|
||||
is_motif_atom_with_fixed_seq = is_motif_atom_with_fixed_seq | ~atom_array.is_protein
|
||||
|
||||
|
||||
if not association_scheme == "atom23":
|
||||
is_motif_atom_with_fixed_seq = (
|
||||
is_motif_atom_with_fixed_seq | ~atom_array.is_protein
|
||||
)
|
||||
|
||||
return is_motif_atom_with_fixed_seq
|
||||
|
||||
@@ -526,7 +539,9 @@ def sample_fix_motif_coordinates(atom_array, p_fix_motif_coordinates):
|
||||
return is_motif_atom_with_fixed_coord
|
||||
|
||||
|
||||
def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens, association_scheme='atom14'):
|
||||
def sample_unindexed_atoms(
|
||||
atom_array, p_unindex_motif_tokens, association_scheme="atom14"
|
||||
):
|
||||
"""
|
||||
Samples which atoms in motif tokens should be flagged for unindexing.
|
||||
|
||||
@@ -539,15 +554,15 @@ def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens, association_schem
|
||||
is_motif_atom_unindexed = np.zeros(atom_array.array_length(), dtype=bool)
|
||||
|
||||
# ensure non-residue atoms are not already flagged
|
||||
if association_scheme == 'atom23':
|
||||
if association_scheme == "atom23":
|
||||
is_motif_atom_unindexed = np.logical_and(
|
||||
is_motif_atom_unindexed, (atom_array.is_residue | atom_array.is_dna | atom_array.is_rna)
|
||||
) # is_residue refers to is_protein here
|
||||
is_motif_atom_unindexed,
|
||||
(atom_array.is_residue | atom_array.is_dna | atom_array.is_rna),
|
||||
) # is_residue refers to is_protein here
|
||||
else:
|
||||
is_motif_atom_unindexed = np.logical_and(
|
||||
is_motif_atom_unindexed, atom_array.is_residue
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
return is_motif_atom_unindexed
|
||||
|
||||
|
||||
@@ -10,11 +10,10 @@ from atomworks.ml.transforms.base import (
|
||||
)
|
||||
from atomworks.ml.utils.token import get_token_starts
|
||||
from rfd3.constants import (
|
||||
ATOM23_ATOM_NAME_TO_ELEMENT,
|
||||
ATOM14_ATOM_NAME_TO_ELEMENT,
|
||||
ATOM14_ATOM_NAMES,
|
||||
ATOM23_ATOM_NAMES_RNA,
|
||||
ATOM23_ATOM_NAME_TO_ELEMENT,
|
||||
ATOM23_ATOM_NAMES_DNA,
|
||||
ATOM23_ATOM_NAMES_RNA,
|
||||
VIRTUAL_ATOM_ELEMENT_NAME,
|
||||
association_schemes,
|
||||
association_schemes_stripped,
|
||||
@@ -31,7 +30,9 @@ from rfd3.transforms.util_transforms import (
|
||||
from foundry.common import exists
|
||||
|
||||
|
||||
def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="atom14", ATOM_NAMES=None):
|
||||
def map_to_association_scheme(
|
||||
atom_names: list | str, res_name: str, scheme="atom14", ATOM_NAMES=None
|
||||
):
|
||||
"""
|
||||
Maps a list of names to the atom14 naming scheme for that particular name (within a specific residue)
|
||||
NB this function is a bit more general since it is used to handle tipatoms too.
|
||||
@@ -52,6 +53,7 @@ def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="ato
|
||||
else:
|
||||
return ATOM_NAMES[idxs]
|
||||
|
||||
|
||||
def map_names_to_elements(
|
||||
atom_names: list | str, default=VIRTUAL_ATOM_ELEMENT_NAME
|
||||
) -> np.ndarray:
|
||||
@@ -61,7 +63,7 @@ def map_names_to_elements(
|
||||
then it returns the default value
|
||||
"""
|
||||
atom_names = [atom_names] if isinstance(atom_names, str) else atom_names
|
||||
elements = [ATOM14_ATOM_NAME_TO_ELEMENT.get(name, default) for name in atom_names]
|
||||
elements = [ATOM23_ATOM_NAME_TO_ELEMENT.get(name, default) for name in atom_names]
|
||||
return np.array(elements)
|
||||
|
||||
|
||||
@@ -72,6 +74,9 @@ def generate_atom_mappings_(scheme="atom14"):
|
||||
symmetry_mapping = {}
|
||||
|
||||
for aaa, atom_names in ccd_ordering_atomchar.items():
|
||||
if aaa not in scheme:
|
||||
continue
|
||||
|
||||
mapping = list(range(len(atom_names)))
|
||||
scheme_names = scheme[aaa]
|
||||
|
||||
@@ -126,10 +131,10 @@ def permute_symmetric_atom_names_(
|
||||
# With the structure-local atom attention it will not unless N_keys(n_attn_seq_neighbours) > n_atom_attn_queries.
|
||||
|
||||
## fail safe, no symmetry confusion in NA bases ##
|
||||
if (atom_names[0] == "P"):
|
||||
if atom_names[0] == "P":
|
||||
return atom_names
|
||||
##################################################
|
||||
|
||||
|
||||
if res_name in association_map:
|
||||
idx_to_swap = association_map[res_name]
|
||||
atom_names = atom_names[idx_to_swap]
|
||||
@@ -180,16 +185,17 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
token_ids = np.unique(atom_array.token_id)
|
||||
assert len(token_ids) == len(
|
||||
is_motif_atom_with_fixed_seq
|
||||
), "Token ids and token level array have different lengths!"
|
||||
), "Token ids and token level array have different lengths!"
|
||||
|
||||
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
|
||||
if self.association_scheme == 'atom23':
|
||||
is_residue = (
|
||||
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
|
||||
if self.association_scheme == "atom23":
|
||||
is_residue = (
|
||||
token_level_array.is_protein & ~token_level_array.atomize
|
||||
) | is_motif_token_unindexed
|
||||
|
||||
|
||||
is_residue_NA = (
|
||||
(token_level_array.is_dna | token_level_array.is_rna) & ~token_level_array.atomize
|
||||
(token_level_array.is_dna | token_level_array.is_rna)
|
||||
& ~token_level_array.atomize
|
||||
) | is_motif_token_unindexed
|
||||
|
||||
# Unindexed tokens are never padded, and so are treated as residues with fixed sequence.
|
||||
@@ -211,7 +217,6 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
is_non_paddable_residue = is_residue & (
|
||||
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
|
||||
)
|
||||
|
||||
|
||||
# Collect virtual atoms to insert (we will insert them all at once)
|
||||
virtual_atoms_to_insert = []
|
||||
@@ -221,7 +226,7 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
for token_id, (start, end) in enumerate(zip(starts[:-1], starts[1:])):
|
||||
if is_paddable[token_id]:
|
||||
token = atom_array[start:end]
|
||||
|
||||
|
||||
# First, pad with virtual atoms if needed
|
||||
if self.association_scheme == "atom23" and atom_array[start].is_dna:
|
||||
n_atoms_per_token = 22
|
||||
@@ -230,7 +235,7 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
else:
|
||||
n_atoms_per_token = self.n_atoms_per_token
|
||||
n_pad = n_atoms_per_token - len(token)
|
||||
|
||||
|
||||
if n_pad > 0:
|
||||
mask = get_af3_token_representative_masks(
|
||||
token, central_atom=self.atom_to_pad_from
|
||||
@@ -297,10 +302,10 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
|
||||
for token_id, (start, end) in enumerate(
|
||||
zip(starts_padded[:-1], starts_padded[1:])
|
||||
):
|
||||
if (atom_array_padded[start].is_dna):
|
||||
):
|
||||
if atom_array_padded[start].is_dna:
|
||||
ATOM_NAMES = ATOM23_ATOM_NAMES_DNA
|
||||
elif (atom_array_padded[start].is_rna):
|
||||
elif atom_array_padded[start].is_rna:
|
||||
ATOM_NAMES = ATOM23_ATOM_NAMES_RNA
|
||||
else:
|
||||
ATOM_NAMES = ATOM14_ATOM_NAMES
|
||||
@@ -328,7 +333,10 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
)
|
||||
atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names
|
||||
atom_names = map_to_association_scheme(
|
||||
atom_names, res_name, scheme=self.association_scheme, ATOM_NAMES=ATOM_NAMES
|
||||
atom_names,
|
||||
res_name,
|
||||
scheme=self.association_scheme,
|
||||
ATOM_NAMES=ATOM_NAMES,
|
||||
)
|
||||
atom_array_padded.atom_name[start:end] = atom_names
|
||||
else:
|
||||
|
||||
@@ -96,8 +96,21 @@ def get_design_pattern_with_constraints(contig, length=None):
|
||||
fixed_parts = []
|
||||
pos_to_put_motif = []
|
||||
|
||||
suff = [] # suffixes for diffused regions P(optional),R,D
|
||||
|
||||
for part in contig_parts:
|
||||
if any(c.isalpha() for c in part): # Detect parts containing letters as fixed
|
||||
## updating to include DNA and RNA generation
|
||||
if part[-1] in ["R", "D"]: ##Detect non-fixed RNA and DNA contig part
|
||||
suff.append(part[-1])
|
||||
part = part[:-1]
|
||||
if "-" in part:
|
||||
start, end = map(int, part.split("-"))
|
||||
else:
|
||||
start = end = int(part)
|
||||
variable_ranges.append([start, end])
|
||||
pos_to_put_motif.append(0)
|
||||
|
||||
elif any(c.isalpha() for c in part): # Detect parts containing letters as fixed
|
||||
pn_unit_id, pn_unit_start, pn_unit_end = extract_pn_unit_info(part)
|
||||
fixed_parts.append([pn_unit_id, pn_unit_start, pn_unit_end])
|
||||
pos_to_put_motif.append(1)
|
||||
@@ -110,6 +123,7 @@ def get_design_pattern_with_constraints(contig, length=None):
|
||||
start = end = int(part)
|
||||
variable_ranges.append([start, end])
|
||||
pos_to_put_motif.append(0)
|
||||
suff.append("P")
|
||||
|
||||
# adjust the total length to solely for free residues
|
||||
num_motif_residues = sum([i[2] - i[1] + 1 for i in fixed_parts])
|
||||
@@ -167,7 +181,7 @@ def get_design_pattern_with_constraints(contig, length=None):
|
||||
atoms_with_motif.append(f"{pn_unit_id}{index}")
|
||||
elif pos_to_put_motif[idx] == 0:
|
||||
free_atom = num_free_atoms.pop(0)
|
||||
atoms_with_motif.append(free_atom)
|
||||
atoms_with_motif.append(str(free_atom) + suff.pop(0))
|
||||
elif pos_to_put_motif[idx] == 2:
|
||||
atoms_with_motif.append("/0")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user