ruff format

This commit is contained in:
Raktim Mitra
2026-03-27 11:22:23 -07:00
parent fdf3f98a9f
commit b9e050c35b
21 changed files with 877 additions and 519 deletions

View File

@@ -438,114 +438,254 @@ backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA)
# Mapping from residue type to its backbone and sidechain atoms (for convenience)
ATOM_REGION_BY_RESI = {
'ALA': {'bb':('N','CA','C','O'),
'sc':('CB')},
'ARG': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD','NE','CZ','NH1','NH2')},
'ASN': {'bb':('N','CA','C','O'),
'sc':('CB','CG','OD1','ND2')},
'ASP': {'bb':('N','CA','C','O'),
'sc':('CB','CG','OD1','OD2')},
'CYS': {'bb':('N','CA','C','O'),
'sc':('CB','SG')},
'GLN': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD','OE1','NE2')},
'GLU': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD','OE1','OE2')},
'GLY': {'bb':('N','CA','C','O'),
'sc':()},
'HIS': {'bb':('N','CA','C','O'),
'sc':('CB','CG','ND1','CD2','CE1','NE2')},
'ILE': {'bb':('N','CA','C','O'),
'sc':('CB','CG1','CG2','CD1')},
'LEU': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD1','CD2')},
'LYS': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD','CE','NZ')},
'MET': {'bb':('N','CA','C','O'),
'sc':('CB','CG','SD','CE')},
'PHE': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD1','CD2','CE1','CE2','CZ')},
'PRO': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD')},
'SER': {'bb':('N','CA','C','O'),
'sc':('CB','OG')},
'THR': {'bb':('N','CA','C','O'),
'sc':('CB','OG1','CG2')},
'TRP': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD1','CD2','CE2','CE3','NE1','CZ2','CZ3','CH2')},
'TYR': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD1','CD2','CE1','CE2','CZ','OH')},
'VAL': {'bb':('N','CA','C','O'),
'sc':('CB','CG1','CG2')},
'UNK': {'bb':('N','CA','C','O'),
'sc':('CB')},
'MAS': {'bb':('N','CA','C','O'),
'sc':('CB')},
'DA': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'sc':('N9','C4','N3','C2','N1','C6','C5','N7','C8','N6')},
'DC': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'sc':('N1','C2','O2','N3','C4','N4','C5','C6')},
'DG': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'sc':('N9','C4','N3','C2','N1','C6','C5','N7','C8','N2','O6')},
'DT': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'sc':('N1','C2','O2','N3','C4','O4','C5','C7','C6')},
'DX': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'sc':()},
'A': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'sc':('N1','C2','N3','C4','C5','C6','N6','N7','C8','N9')},
'C': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'sc':('N1','C2','O2','N3','C4','N4','C5','C6')},
'G': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'sc':('N1','C2','N2','N3','C4','C5','C6','O6','N7','C8','N9')},
'U': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'sc':('N1','C2','O2','N3','C4','O4','C5','C6')},
'X': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'sc':()},
'HIS_D': {'bb':('N','CA','C','O'),
'sc':('CB','CG','NE2','CD2','CE1','ND1')},
"ALA": {"bb": ("N", "CA", "C", "O"), "sc": ("CB")},
"ARG": {
"bb": ("N", "CA", "C", "O"),
"sc": ("CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"),
},
"ASN": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "OD1", "ND2")},
"ASP": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "OD1", "OD2")},
"CYS": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "SG")},
"GLN": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD", "OE1", "NE2")},
"GLU": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD", "OE1", "OE2")},
"GLY": {"bb": ("N", "CA", "C", "O"), "sc": ()},
"HIS": {
"bb": ("N", "CA", "C", "O"),
"sc": ("CB", "CG", "ND1", "CD2", "CE1", "NE2"),
},
"ILE": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG1", "CG2", "CD1")},
"LEU": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD1", "CD2")},
"LYS": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD", "CE", "NZ")},
"MET": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "SD", "CE")},
"PHE": {
"bb": ("N", "CA", "C", "O"),
"sc": ("CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"),
},
"PRO": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD")},
"SER": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "OG")},
"THR": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "OG1", "CG2")},
"TRP": {
"bb": ("N", "CA", "C", "O"),
"sc": ("CB", "CG", "CD1", "CD2", "CE2", "CE3", "NE1", "CZ2", "CZ3", "CH2"),
},
"TYR": {
"bb": ("N", "CA", "C", "O"),
"sc": ("CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"),
},
"VAL": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG1", "CG2")},
"UNK": {"bb": ("N", "CA", "C", "O"), "sc": ("CB")},
"MAS": {"bb": ("N", "CA", "C", "O"), "sc": ("CB")},
"DA": {
"bb": (
"O4'",
"C1'",
"C2'",
"OP1",
"P",
"OP2",
"O5'",
"C5'",
"C4'",
"C3'",
"O3'",
),
"sc": ("N9", "C4", "N3", "C2", "N1", "C6", "C5", "N7", "C8", "N6"),
},
"DC": {
"bb": (
"O4'",
"C1'",
"C2'",
"OP1",
"P",
"OP2",
"O5'",
"C5'",
"C4'",
"C3'",
"O3'",
),
"sc": ("N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"),
},
"DG": {
"bb": (
"O4'",
"C1'",
"C2'",
"OP1",
"P",
"OP2",
"O5'",
"C5'",
"C4'",
"C3'",
"O3'",
),
"sc": ("N9", "C4", "N3", "C2", "N1", "C6", "C5", "N7", "C8", "N2", "O6"),
},
"DT": {
"bb": (
"O4'",
"C1'",
"C2'",
"OP1",
"P",
"OP2",
"O5'",
"C5'",
"C4'",
"C3'",
"O3'",
),
"sc": ("N1", "C2", "O2", "N3", "C4", "O4", "C5", "C7", "C6"),
},
"DX": {
"bb": (
"O4'",
"C1'",
"C2'",
"OP1",
"P",
"OP2",
"O5'",
"C5'",
"C4'",
"C3'",
"O3'",
),
"sc": (),
},
"A": {
"bb": (
"O4'",
"C1'",
"C2'",
"OP1",
"P",
"OP2",
"O5'",
"C5'",
"C4'",
"C3'",
"O3'",
"O2'",
),
"sc": ("N1", "C2", "N3", "C4", "C5", "C6", "N6", "N7", "C8", "N9"),
},
"C": {
"bb": (
"O4'",
"C1'",
"C2'",
"OP1",
"P",
"OP2",
"O5'",
"C5'",
"C4'",
"C3'",
"O3'",
"O2'",
),
"sc": ("N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"),
},
"G": {
"bb": (
"O4'",
"C1'",
"C2'",
"OP1",
"P",
"OP2",
"O5'",
"C5'",
"C4'",
"C3'",
"O3'",
"O2'",
),
"sc": ("N1", "C2", "N2", "N3", "C4", "C5", "C6", "O6", "N7", "C8", "N9"),
},
"U": {
"bb": (
"O4'",
"C1'",
"C2'",
"OP1",
"P",
"OP2",
"O5'",
"C5'",
"C4'",
"C3'",
"O3'",
"O2'",
),
"sc": ("N1", "C2", "O2", "N3", "C4", "O4", "C5", "C6"),
},
"X": {
"bb": (
"O4'",
"C1'",
"C2'",
"OP1",
"P",
"OP2",
"O5'",
"C5'",
"C4'",
"C3'",
"O3'",
"O2'",
),
"sc": (),
},
"HIS_D": {
"bb": ("N", "CA", "C", "O"),
"sc": ("CB", "CG", "NE2", "CD2", "CE1", "ND1"),
},
}
# Known planar sidechain atoms for each canonical residue type:
PLANAR_ATOMS_BY_RESI = {
'ALA': [],
'ARG': ['NH1', 'NH2', 'CZ', 'NE', 'CD'],
'ASN': ['OD1', 'ND2', 'CG', 'CB'],
'ASP': ['OD1', 'OD2', 'CG', 'CB'],
'CYS': [],
'GLN': ['OE1', 'NE2', 'CD', 'CG'],
'GLU': ['OE1', 'OE2', 'CD', 'CG'],
'GLY': [],
'HIS': ['ND1', 'CE1', 'NE2', 'CD2', 'CG', 'CB'],
'ILE': [],
'LEU': [],
'LYS': [],
'MET': [],
'PHE': ['CZ', 'CE1', 'CE2', 'CD1', 'CD2', 'CG', 'CB'],
'PRO': [],
'SER': [],
'THR': [],
'TRP': ['CH2', 'CZ3', 'CZ2', 'CE3', 'CE2', 'CD2', 'NE1', 'CD1', 'CG', 'CB'],
'TYR': ['OH', 'CZ', 'CE1', 'CE2', 'CD1', 'CD2', 'CG', 'CB'],
'VAL': [],
'UNK': [],
'MAS': [],
'DA': ['N6', 'C6', 'N1', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'],
'DC': ['N4', 'C4', 'N3', 'O2', 'C2', 'C5', 'C6', 'N1'],
'DG': ['O6', 'C6', 'N1', 'N2', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'],
'DT': ['O4', 'O2', 'N3', 'C4', 'C2', 'C5', 'C6', 'N1', 'C7'],
'DX': [],
'A': ['N6', 'C6', 'N1', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'],
'C': ['N4', 'C4', 'N3', 'O2', 'C2', 'C5', 'C6', 'N1'],
'G': ['O6', 'C6', 'N1', 'N2', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'],
'U': ['O4', 'O2', 'N3', 'C4', 'C2', 'C5', 'C6', 'N1'],
'X': [],
'HIS_D': ['ND1', 'CD2', 'CE1', 'NE2', 'CG', 'CB'],
}
"ALA": [],
"ARG": ["NH1", "NH2", "CZ", "NE", "CD"],
"ASN": ["OD1", "ND2", "CG", "CB"],
"ASP": ["OD1", "OD2", "CG", "CB"],
"CYS": [],
"GLN": ["OE1", "NE2", "CD", "CG"],
"GLU": ["OE1", "OE2", "CD", "CG"],
"GLY": [],
"HIS": ["ND1", "CE1", "NE2", "CD2", "CG", "CB"],
"ILE": [],
"LEU": [],
"LYS": [],
"MET": [],
"PHE": ["CZ", "CE1", "CE2", "CD1", "CD2", "CG", "CB"],
"PRO": [],
"SER": [],
"THR": [],
"TRP": ["CH2", "CZ3", "CZ2", "CE3", "CE2", "CD2", "NE1", "CD1", "CG", "CB"],
"TYR": ["OH", "CZ", "CE1", "CE2", "CD1", "CD2", "CG", "CB"],
"VAL": [],
"UNK": [],
"MAS": [],
"DA": ["N6", "C6", "N1", "C2", "N3", "C4", "C5", "N7", "C8", "N9"],
"DC": ["N4", "C4", "N3", "O2", "C2", "C5", "C6", "N1"],
"DG": ["O6", "C6", "N1", "N2", "C2", "N3", "C4", "C5", "N7", "C8", "N9"],
"DT": ["O4", "O2", "N3", "C4", "C2", "C5", "C6", "N1", "C7"],
"DX": [],
"A": ["N6", "C6", "N1", "C2", "N3", "C4", "C5", "N7", "C8", "N9"],
"C": ["N4", "C4", "N3", "O2", "C2", "C5", "C6", "N1"],
"G": ["O6", "C6", "N1", "N2", "C2", "N3", "C4", "C5", "N7", "C8", "N9"],
"U": ["O4", "O2", "N3", "C4", "C2", "C5", "C6", "N1"],
"X": [],
"HIS_D": ["ND1", "CD2", "CE1", "NE2", "CG", "CB"],
}
# fix C/U symmetry
temp = list(association_schemes['atom23']['U'])
temp = list(association_schemes["atom23"]["U"])
temp[19], temp[20] = temp[20], temp[19]
association_schemes['atom23']['U'] = tuple(temp)
association_schemes["atom23"]["U"] = tuple(temp)
association_schemes_stripped = {
name: {k: strip_list(v) for k, v in scheme.items()}
@@ -553,4 +693,6 @@ association_schemes_stripped = {
}
if __name__ == "__main__":
import pdb; pdb.set_trace()
import pdb
pdb.set_trace()

View File

@@ -119,8 +119,8 @@ class DesignInputSpecification(BaseModel):
validate_assignment=False,
str_strip_whitespace=True,
str_min_length=1,
#extra="forbid", ####################################################
extra="allow"
# extra="forbid", ####################################################
extra="allow",
## for now allowing extra for rfd3na-ss purposes, can decide later ##
)
# fmt: off
@@ -497,7 +497,10 @@ class DesignInputSpecification(BaseModel):
aa.is_motif_atom_with_fixed_seq[start:end] = np.full_like(
is_bkbn, False, dtype=int
)
elif aa.res_name[start] in (STANDARD_DNA + STANDARD_RNA) and self.redesign_motif_sidechains:
elif (
aa.res_name[start] in (STANDARD_DNA + STANDARD_RNA)
and self.redesign_motif_sidechains
):
is_bkbn = np.isin(aa.atom_name[start:end], backbone_atoms_RNA)
aa.is_motif_atom_with_fixed_coord[start:end] = is_bkbn.astype(int)
aa.is_motif_atom_with_fixed_seq[start:end] = np.full_like(
@@ -519,7 +522,9 @@ class DesignInputSpecification(BaseModel):
########## reorder NA atoms ###########
if exists(atom_array_input_annotated):
is_dna = np.isin(atom_array_input_annotated.res_name, ["DA", "DC", "DG", "DT"])
is_dna = np.isin(
atom_array_input_annotated.res_name, ["DA", "DC", "DG", "DT"]
)
is_rna = np.isin(atom_array_input_annotated.res_name, ["A", "C", "G", "U"])
dna_array = atom_array_input_annotated[is_dna]
rna_array = atom_array_input_annotated[is_rna]
@@ -726,11 +731,11 @@ class DesignInputSpecification(BaseModel):
ligand_array.set_annotation(
annot, np.full(ligand_array.array_length(), default)
)
chain_cand = 'X'
chain_cand = "X"
while chain_cand in atom_array.chain_id.tolist():
chain_cand = chain_cand + chain_cand
ligand_chain = np.array([chain_cand]*len(ligand_array))
ligand_chain = np.array([chain_cand] * len(ligand_array))
ligand_array.chain_id = ligand_chain
atom_array = atom_array + ligand_array
@@ -758,9 +763,12 @@ class DesignInputSpecification(BaseModel):
)
else:
if not exists(self.ori_jitter):
self.ori_jitter = None
self.ori_jitter = None
atom_array = set_com(
atom_array, ori_token=None, infer_ori_strategy="com", ori_jitter=self.ori_jitter
atom_array,
ori_token=None,
infer_ori_strategy="com",
ori_jitter=self.ori_jitter,
)
else:
# Standard: set ori token, zero out diffused atoms

View File

@@ -220,7 +220,6 @@ def fetch_motif_residue_(
"Please check your input and contig specification."
)
if unfix_all or f"{src_chain}{src_resid}" in unfix_residues:
subarray.set_annotation(
"is_motif_atom_with_fixed_coord", np.zeros(subarray.shape[0], dtype=int)
@@ -232,7 +231,7 @@ def fetch_motif_residue_(
)
else:
subarray.set_annotation(
"is_motif_atom_with_fixed_coord", np.array([True]*len(subarray))
"is_motif_atom_with_fixed_coord", np.array([True] * len(subarray))
)
if flexible_backbone:
@@ -255,9 +254,9 @@ def fetch_motif_residue_(
subarray = subarray[subarray.is_motif_atom.astype(bool)]
else:
subarray.set_annotation(
"is_motif_atom_unindexed", np.array([True]*len(subarray))
"is_motif_atom_unindexed", np.array([True] * len(subarray))
)
# ... Relax sequence constraint if provided
if (
exists(unfixed_sequence_components)
@@ -278,32 +277,35 @@ def fetch_motif_residue_(
return subarray
def create_diffused_residues_(n, polymer_type='p'):
def create_diffused_residues_(n, polymer_type="p"):
from rfd3.constants import (
ATOM23_ATOM_NAME_TO_ELEMENT,
backbone_atoms_DNA,
backbone_atoms_RNA,
)
if n <= 0:
raise ValueError(f"Negative/null residue count ({n}) not allowed.")
if polymer_type == 'P':
res_name = 'ALA'
if polymer_type == "P":
res_name = "ALA"
bb_len = 5
bb_atom_names = ["N", "CA", "C", "O", "CB"]
elif polymer_type == 'R':
res_name = 'A'
elif polymer_type == "R":
res_name = "A"
bb_len = len(backbone_atoms_RNA)
bb_atom_names = backbone_atoms_RNA
elif polymer_type == 'D':
res_name = 'DA'
elif polymer_type == "D":
res_name = "DA"
bb_len = len(backbone_atoms_DNA)
bb_atom_names = backbone_atoms_DNA
else:
raise ValueError(f"invalid polymer type detected: {polymer_type}, check contig!")
raise ValueError(
f"invalid polymer type detected: {polymer_type}, check contig!"
)
bb_elements = [ATOM23_ATOM_NAME_TO_ELEMENT[item] for item in bb_atom_names]
atoms = []
[
atoms.extend(
@@ -319,16 +321,13 @@ def create_diffused_residues_(n, polymer_type='p'):
for idx in range(1, n + 1)
]
array = struc.array(atoms)
array.set_annotation(
"element", np.array(bb_elements * n, dtype="<U2")
)
array.set_annotation(
"atom_name", np.array(bb_atom_names * n, dtype="<U3")
)
array.set_annotation("element", np.array(bb_elements * n, dtype="<U2"))
array.set_annotation("atom_name", np.array(bb_atom_names * n, dtype="<U3"))
array = set_default_conditioning_annotations(array, motif=False)
array = set_common_annotations(array)
return array
def accumulate_components(
components,
src_atom_array,
@@ -578,8 +577,12 @@ def create_atom_array_from_design_specification_legacy(
dna_array = atom_array_input[is_dna]
rna_array = atom_array_input[is_rna]
atom_array_input[is_dna] = reorder_atoms_per_residue(dna_array, backbone_atoms_DNA)
atom_array_input[is_rna] = reorder_atoms_per_residue(rna_array, backbone_atoms_RNA)
atom_array_input[is_dna] = reorder_atoms_per_residue(
dna_array, backbone_atoms_DNA
)
atom_array_input[is_rna] = reorder_atoms_per_residue(
rna_array, backbone_atoms_RNA
)
#######################################
if exists(atomwise_rasa):
@@ -736,10 +739,10 @@ def create_atom_array_from_design_specification_legacy(
+ np.max(atom_array.res_id)
+ 1
)
chain_cand = 'X'
chain_cand = "X"
while chain_cand in atom_array.chain_id.tolist():
chain_cand = chain_cand + chain_cand
ligand_chain = np.array([chain_cand]*len(ligand_array))
ligand_chain = np.array([chain_cand] * len(ligand_array))
ligand_array.chain_id = ligand_chain
atom_array = atom_array + ligand_array

View File

@@ -4,6 +4,7 @@ from atomworks.ml.utils.token import (
get_token_starts,
)
from beartype.typing import Any
from rfd3.constants import backbone_atoms_RNA
from rfd3.metrics.metrics_utils import (
_flatten_dict,
get_hotspot_contacts,
@@ -12,10 +13,10 @@ from rfd3.metrics.metrics_utils import (
from foundry.common import exists
from foundry.metrics.metric import Metric
from rfd3.constants import backbone_atoms_RNA
STANDARD_CACA_DIST = 3.8
STANDARD_P_P_DISTANCE = 6.4 ## average of B and A form 7 and 5.9
STANDARD_P_P_DISTANCE = 6.4 ## average of B and A form 7 and 5.9
def get_clash_metrics(
atom_array,
@@ -35,11 +36,11 @@ def get_clash_metrics(
elif "P" in atom_array.atom_name:
ca_atoms = atom_array[atom_array.atom_name == "P"]
cut_off = STANDARD_P_P_DISTANCE
xyz = ca_atoms.coord
xyz = torch.from_numpy(xyz)
ca_dists = torch.norm(xyz[1:] - xyz[:-1], dim=-1)
deviation = torch.abs(ca_dists - STANDARD_CACA_DIST)
deviation = torch.abs(ca_dists - cut_off)
# Allow leniency for expected chain breaks (e.g. PPI)
chain_breaks = ca_atoms.chain_iid[1:] != ca_atoms.chain_iid[:-1]
@@ -52,7 +53,9 @@ def get_clash_metrics(
}
def get_interresidue_clashes(backbone_only=False):
protein_array = atom_array[atom_array.is_protein | atom_array.is_dna | atom_array.is_rna]
protein_array = atom_array[
atom_array.is_protein | atom_array.is_dna | atom_array.is_rna
]
resid = protein_array.res_id - protein_array.res_id.min()
xyz = protein_array.coord
dists = np.linalg.norm(xyz[:, None] - xyz[None], axis=-1) # N_atoms x N_atoms
@@ -65,7 +68,9 @@ def get_clash_metrics(
if backbone_only:
# Block out non-backbone atoms
backbone_mask = np.isin(protein_array.atom_name, ["N", "CA", "C"] + backbone_atoms_RNA)
backbone_mask = np.isin(
protein_array.atom_name, ["N", "CA", "C"] + backbone_atoms_RNA
)
mask = backbone_mask[:, None] & backbone_mask[None, :]
dists[~mask] = 999
@@ -374,7 +379,7 @@ class BackboneMetrics(Metric):
is_protein = is_protein[diffused_region]
is_dna = is_dna[diffused_region]
is_rna = is_rna[diffused_region]
protein_idx_mask = is_ca & (is_protein)
protein_idx_mask = is_ca & (is_protein)
na_idx_mask = is_ca & (is_rna | is_dna)
if self.compute_for_diffused_region_only:
@@ -384,30 +389,43 @@ class BackboneMetrics(Metric):
xyz_protein = X_L.cpu()[:, protein_idx_mask]
xyz_na = X_L.cpu()[:, na_idx_mask]
ca_dists_protein = torch.norm(xyz_protein[:, 1:] - xyz_protein[:, :-1], dim=-1)
ca_dists_protein = torch.norm(
xyz_protein[:, 1:] - xyz_protein[:, :-1], dim=-1
)
ca_dists_na = torch.norm(xyz_na[:, 1:] - xyz_na[:, :-1], dim=-1)
deviation_protein = torch.abs(ca_dists_protein - self.standard_ca_dist) # B, (I-1)
deviation_protein = torch.abs(
ca_dists_protein - self.standard_ca_dist
) # B, (I-1)
deviation_na = torch.abs(ca_dists_na - self.standard_PP_dist) # B, (I-1)
is_chainbreak_protein = deviation_protein > 0.75
is_chainbreak_na = deviation_na > 1
try:
o["max_ca_deviation_protein"] = float(deviation_protein.max(-1).values.mean())
o["fraction_chainbreaks_protein"] = float(is_chainbreak_protein.float().mean(-1).mean())
o["n_chainbreaks_protein"] = float(is_chainbreak_protein.float().sum(-1).mean())
except:
o["max_ca_deviation_protein"] = float(
deviation_protein.max(-1).values.mean()
)
o["fraction_chainbreaks_protein"] = float(
is_chainbreak_protein.float().mean(-1).mean()
)
o["n_chainbreaks_protein"] = float(
is_chainbreak_protein.float().sum(-1).mean()
)
except Exception:
print("No protein in this example, skipping protein chainbreak metrics")
try:
o["max_ca_deviation_na"] = float(deviation_na.max(-1).values.mean())
o["fraction_chainbreaks_na"] = float(is_chainbreak_na.float().mean(-1).mean())
o["fraction_chainbreaks_na"] = float(
is_chainbreak_na.float().mean(-1).mean()
)
o["n_chainbreaks_na"] = float(is_chainbreak_na.float().sum(-1).mean())
except:
except Exception:
print("No NA in this example, skipping NA chainbreak metrics")
return o
class PPIMetrics(Metric):
"""PPI-specific metrics"""

View File

@@ -1,19 +1,19 @@
import logging
import bdb
import numpy as np
from biotite.structure import AtomArray
from atomworks.ml.utils.token import (
get_token_starts,
)
from biotite.structure import AtomArray
from rfd3.trainer.trainer_utils import (
_cleanup_virtual_atoms_and_assign_atom_name_elements,
_readout_seq_from_struc,
)
from rfd3.transforms.na_geom_utils import annotate_na_ss
from foundry.metrics.metric import Metric
from foundry.utils.ddp import RankedLogger
from rfd3.trainer.trainer_utils import _readout_seq_from_struc, _cleanup_virtual_atoms_and_assign_atom_name_elements
logging.basicConfig(level=logging.INFO)
global_logger = RankedLogger(__name__, rank_zero_only=False)
@@ -70,7 +70,9 @@ def _get_candidate_token_ids(
if hasattr(token_level_array, "is_dna")
else np.zeros(len(token_ids), dtype=bool)
)
token_mask &= (is_rna | is_dna) if (is_rna.any() or is_dna.any()) else token_mask
token_mask &= (
(is_rna | is_dna) if (is_rna.any() or is_dna.any()) else token_mask
)
if compute_for_diffused_region_only:
if hasattr(token_level_array, "is_motif_atom"):
@@ -175,7 +177,10 @@ def _extract_loop_and_paired_token_ids(
return loop_token_ids, paired_token_ids
def compute_from_two_arr(gt_arr, pred_arr, restrict_to_nucleic=True, compute_for_diffused_region_only = False):
def compute_from_two_arr(
gt_arr, pred_arr, restrict_to_nucleic=True, compute_for_diffused_region_only=False
):
gt_token_ids = _get_token_ids(gt_arr)
pred_token_ids = _get_token_ids(pred_arr)
if len(gt_token_ids) != len(pred_token_ids):
@@ -228,26 +233,25 @@ def compute_from_two_arr(gt_arr, pred_arr, restrict_to_nucleic=True, compute_for
(pair_weight * pair_f1 + loop_weight * loop_f1) / total_weight
)
return pair_f1, loop_f1, weighted_f1
def get_NA_SS_F1(pred_array):
## save the original bop_partner annotation
gt_array = pred_array.copy()
## replace by annotating again
pred_array = annotate_na_ss(
pred_array,
NA_only=True,
planar_only=True,
overwrite=True,
p_canonical_bp_filter=0.0,
)
pred_array,
NA_only=True,
planar_only=True,
overwrite=True,
p_canonical_bp_filter=0.0,
)
try:
pair_f1, loop_f1, weighted_f1 = compute_from_two_arr(gt_array, pred_array)
except:
except Exception:
# fails when returns None because expects three returns
return {}
@@ -261,13 +265,13 @@ def get_NA_SS_F1(pred_array):
class NucleicSSSimilarityMetrics(Metric):
"""Secondary-structure similarity for nucleic acids.
Reports:
- `pair_f1`: F1 over basepair edges from token-level bp-partner annotation.
- `loop_f1`: F1 over explicitly-unpaired loop tokens (`bp_partners == []`).
Unannotated tokens (`bp_partners is None`) are masked.
- `weighted_f1`: GT-weighted average of `pair_f1` and `loop_f1`, weighted by
the prevalence of paired vs loop tokens in the GT.
"""
Reports:
- `pair_f1`: F1 over basepair edges from token-level bp-partner annotation.
- `loop_f1`: F1 over explicitly-unpaired loop tokens (`bp_partners == []`).
Unannotated tokens (`bp_partners is None`) are masked.
- `weighted_f1`: GT-weighted average of `pair_f1` and `loop_f1`, weighted by
the prevalence of paired vs loop tokens in the GT.
"""
def __init__(
self,
@@ -302,7 +306,9 @@ class NucleicSSSimilarityMetrics(Metric):
n_valid = 0
for gt_arr, pred_arr in zip(ground_truth_atom_array_stack, predicted_atom_array_stack):
for gt_arr, pred_arr in zip(
ground_truth_atom_array_stack, predicted_atom_array_stack
):
gt_categories = gt_arr.get_annotation_categories()
if "bp_partners" not in gt_categories:
continue
@@ -326,14 +332,14 @@ class NucleicSSSimilarityMetrics(Metric):
pred_arr,
association_scheme="atom23",
)
except:
except Exception:
# this can fail early in training
print("could not cleanup virtuals for nucleic ss metric compute")
pass
# clear annotation to avoid potential info leak
if "bp_partners" in pred_arr.get_annotation_categories():
pred_arr.del_annotation("bp_partners")
# add nucleic-ss annotations
annotate_na_ss(
pred_arr,
@@ -348,8 +354,13 @@ class NucleicSSSimilarityMetrics(Metric):
# Basic sanity check: token counts should match for aligned comparisons
try:
pair_f1, loop_f1, weighted_f1 = compute_from_two_arr(gt_arr, pred_arr, restrict_to_nucleic=self.restrict_to_nucleic, compute_for_diffused_region_only = self.compute_for_diffused_region_only)
except:
pair_f1, loop_f1, weighted_f1 = compute_from_two_arr(
gt_arr,
pred_arr,
restrict_to_nucleic=self.restrict_to_nucleic,
compute_for_diffused_region_only=self.compute_for_diffused_region_only,
)
except Exception:
# fails when returns None because expects three returns
continue
@@ -360,7 +371,7 @@ class NucleicSSSimilarityMetrics(Metric):
if n_valid == 0:
return {}
return {
"pair_f1": float(np.mean(pair_f1_list)),
"loop_f1": float(np.mean(loop_f1_list)),

View File

@@ -1,5 +1,6 @@
import logging
import numpy as np
from foundry.metrics.metric import Metric
from foundry.utils.ddp import RankedLogger
@@ -8,9 +9,6 @@ logging.basicConfig(level=logging.INFO)
global_logger = RankedLogger(__name__, rank_zero_only=False)
import numpy as np
def calculate_ligand_contacts(
atom_array_stack,
cutoff_distance=4.0,
@@ -31,13 +29,12 @@ def calculate_ligand_contacts(
mean_contacts_per_model : float
"""
cutoff_sq = cutoff_distance ** 2
cutoff_sq = cutoff_distance**2
contacts_per_model = []
n_models = len(atom_array_stack)
for i in range(n_models):
atoms = atom_array_stack[i]
coords = atoms.coord
@@ -57,7 +54,7 @@ def calculate_ligand_contacts(
# Pairwise squared distances
diff = non_ligand_coords[:, None, :] - ligand_coords[None, :, :]
dist_sq = np.sum(diff ** 2, axis=-1)
dist_sq = np.sum(diff**2, axis=-1)
# Any ligand within cutoff
contact_mask = np.any(dist_sq < cutoff_sq, axis=1)
@@ -67,7 +64,11 @@ def calculate_ligand_contacts(
contacts_per_model = np.array(contacts_per_model)
return int(np.sum(contacts_per_model)), float(np.mean(contacts_per_model)), float(np.mean(contacts_per_model))/hetero_mask.sum()
return (
int(np.sum(contacts_per_model)),
float(np.mean(contacts_per_model)),
float(np.mean(contacts_per_model)) / hetero_mask.sum(),
)
class LigandContactMetrics(Metric):
@@ -89,12 +90,18 @@ class LigandContactMetrics(Metric):
def compute(self, *, predicted_atom_array_stack):
if self.restrict_to_nucleic:
if (predicted_atom_array_stack[0].is_rna.sum() + predicted_atom_array_stack[0].is_dna.sum()== 0):
if (
predicted_atom_array_stack[0].is_rna.sum()
+ predicted_atom_array_stack[0].is_dna.sum()
== 0
):
return {}
try:
total_contacts, mean_contacts, mean_contacts_per_atom = calculate_ligand_contacts(
atom_array_stack=predicted_atom_array_stack,
cutoff_distance=self.cutoff_distance,
total_contacts, mean_contacts, mean_contacts_per_atom = (
calculate_ligand_contacts(
atom_array_stack=predicted_atom_array_stack,
cutoff_distance=self.cutoff_distance,
)
)
except Exception as e:
global_logger.error(
@@ -106,4 +113,3 @@ class LigandContactMetrics(Metric):
"mean_ligand_contacts_per_model": float(mean_contacts),
"mean_ligand_contacts_per_atom": float(mean_contacts_per_atom),
}

View File

@@ -63,8 +63,8 @@ def strip_f(
)
else:
## for bp_partners default is a mask feature
v_cropped[:,:,0] = 1
v_cropped[:,:,1:] = 0
v_cropped[:, :, 0] = 1
v_cropped[:, :, 1:] = 0
# update the feature in the dictionary
f_stripped[k] = v_cropped

View File

@@ -210,7 +210,9 @@ def create_attention_indices(
chain_ids is not None and len(torch.unique(chain_ids)) > 3
): # Multi-chain structure
# Reserve 25% of attention keys for inter-chain interactions
k_inter_chain = min(max(32, k_actual // 4), k_actual) # At least 32 inter-chain keys
k_inter_chain = min(
max(32, k_actual // 4), k_actual
) # At least 32 inter-chain keys
k_intra_chain = k_actual - k_inter_chain
attn_indices = get_sparse_attention_indices_with_inter_chain(

View File

@@ -143,6 +143,7 @@ class OneDFeatureEmbedder(nn.Module):
)
)
class TwoDFeatureEmbedder(nn.Module):
"""
Embeds 2D features into a single vector.
@@ -164,18 +165,22 @@ class TwoDFeatureEmbedder(nn.Module):
for feature, n_channels in self.features.items()
}
)
def collapse2D(self, x, L):
return x.reshape((L, L, x.numel() // (L * L)))
def forward(self, f, collapse_length):
return sum(
tuple(
self.embedders[feature](self.collapse2D(f[feature].float(), collapse_length))
self.embedders[feature](
self.collapse2D(f[feature].float(), collapse_length)
)
for feature, n_channels in self.features.items()
if exists(n_channels)
)
)
class SinusoidalDistEmbed(nn.Module):
"""
Applies sinusoidal embedding to pairwise distances and projects to c_atompair.

View File

@@ -11,10 +11,10 @@ from rfd3.model.layers.blocks import (
Downcast,
LocalAtomTransformer,
OneDFeatureEmbedder,
TwoDFeatureEmbedder,
PositionPairDistEmbedder,
RelativePositionEncodingWithIndexRemoval,
SinusoidalDistEmbed,
TwoDFeatureEmbedder,
)
from rfd3.model.layers.chunked_pairwise import (
ChunkedPairwiseEmbedder,
@@ -64,7 +64,7 @@ class TokenInitializer(nn.Module):
self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s)
self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom)
self.token_1d_embedder = OneDFeatureEmbedder(token_1d_features, c_s)
if token_2d_features != None:
if token_2d_features is not None:
self.token_2d_embedder = TwoDFeatureEmbedder(token_2d_features, c_z)
else:
self.token_2d_embedder = None
@@ -209,7 +209,7 @@ class TokenInitializer(nn.Module):
f["ref_pos"][f["is_ca"]], valid_mask
)
# Add extra token pair features
if self.token_2d_embedder != None:
if self.token_2d_embedder is not None:
Z_init_II = Z_init_II + self.token_2d_embedder(f, I)
# Run a small transformer to provide position encodings to single.

View File

@@ -450,15 +450,12 @@ class AADesignTrainer(FabricTrainer):
):
metadata_dict[i]["metrics"] |= get_hbond_metrics(atom_array)
if (
"bp_partners" in atom_array.get_annotation_categories()
):
if not np.all(atom_array.bp_partners == None):
if "bp_partners" in atom_array.get_annotation_categories():
if not np.all(atom_array.bp_partners == None): # noqa: E711
try:
metadata_dict[i]["metrics"] |= get_NA_SS_F1(atom_array)
except:
except Exception:
pass
if "partial_t" in f:
# Try calcualte a CA RMSD to input:
aa_in = example["atom_array"]

View File

@@ -2,6 +2,7 @@ from collections import Counter, OrderedDict
import numpy as np
import torch
from atomworks.constants import STANDARD_DNA, STANDARD_RNA
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
from atomworks.ml.utils.token import (
get_token_starts,
@@ -13,10 +14,10 @@ from rfd3.constants import (
ATOM14_ATOM_NAMES,
ATOM23_ATOM_NAMES_DNA,
ATOM23_ATOM_NAMES_RNA,
backbone_atoms_RNA,
VIRTUAL_ATOM_ELEMENT_NAME,
association_schemes,
association_schemes_stripped,
backbone_atoms_RNA,
)
from rfd3.utils.io import (
build_stack_from_atom_array_and_batched_coords,
@@ -25,7 +26,6 @@ from scipy.optimize import linear_sum_assignment
from foundry.common import exists
from foundry.utils.ddp import RankedLogger
from atomworks.constants import STANDARD_DNA, STANDARD_RNA
global_logger = RankedLogger(__name__, rank_zero_only=False)
@@ -221,7 +221,7 @@ def _readout_seq_from_struc(
# There might be a better way to do this.
CA_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CA"]
CB_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CB"]
if cur_res_atom_array.is_dna[0] or cur_res_atom_array.is_rna[0]:
cur_central_atom = "C1'"
elif np.linalg.norm(CA_coord - CB_coord) < threshold:
@@ -229,7 +229,6 @@ def _readout_seq_from_struc(
else:
cur_central_atom = central_atom
central_mask = cur_res_atom_array.atom_name == cur_central_atom
# ... Calculate the distance to the central atom
@@ -269,7 +268,7 @@ def _readout_seq_from_struc(
if not cur_res_atom_array.is_rna[0]:
continue
else:
#ATOM_NAMES = ATOM23_ATOM_NAMES_RNA
# ATOM_NAMES = ATOM23_ATOM_NAMES_RNA
if not cur_res_atom_array.is_protein[0]:
continue
@@ -425,9 +424,11 @@ def process_unindexed_outputs(
try:
assert (res_id_ == res_id) & (chain_id_ == chain_id)
except:
global_logger.warning("Unindexed mapping did not work properly, res_id, chain_id")
except Exception:
global_logger.warning(
"Unindexed mapping did not work properly, res_id, chain_id"
)
inserted_mask = np.logical_or(inserted_mask, token_match)
# ... Compute metrics based on the new distances
@@ -456,11 +457,7 @@ def process_unindexed_outputs(
else:
dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
elif not np.any(
np.isin(
token.atom_name, backbone_atoms_RNA
)
):
elif not np.any(np.isin(token.atom_name, backbone_atoms_RNA)):
if np.sum(token.atomize) == 1:
join_atom = np.where(token.atomize)[0][0]
elif "C1'" in token.atom_name:
@@ -474,10 +471,10 @@ def process_unindexed_outputs(
)
else:
dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
try:
metadata["join_point_rmsd_by_token"][token_pdb_id] = dist
except:
except Exception:
pass
metadata["diffused_index_map"][token_pdb_id] = f"{chain_id}{res_id}"

View File

@@ -243,7 +243,7 @@ class SampleConditioningType(Transform):
)
self.meta_conditioning_probabilities = meta_conditioning_probabilities
self.train_conditions = train_conditions
for item in self.train_conditions:
self.train_conditions[item].association_scheme = association_scheme
@@ -265,15 +265,15 @@ class SampleConditioningType(Transform):
assert "conditions" in data, "Conditioning dict not initialized"
def forward(self, data):
#for item in self.train_conditions:
# for item in self.train_conditions:
# print(self.train_conditions[item].is_valid_for_example(data))
valid_conditions = [
cond
for cond in self.train_conditions.values()
if cond.is_valid_for_example(data) and cond.frequency > 0
if cond.is_valid_for_example(data) and cond.frequency > 0
]
if len(valid_conditions) == 0:
raise InvalidSampledConditionException("No valid condition was found.")
@@ -288,7 +288,7 @@ class SampleConditioningType(Transform):
cond = valid_conditions[i_cond]
cond.association_scheme = self.association_scheme
data["sampled_condition"] = cond
data["sampled_condition_name"] = cond.name
data["sampled_condition_cls"] = cond.__class__
@@ -306,11 +306,6 @@ class SampleConditioningFlags(Transform):
"AssignTypes",
"SampleConditioningType",
] # We use is_protein in the PPI training condition
def __init__(self, association_scheme):
self.association_scheme = association_scheme
def __init__(self, association_scheme):
self.association_scheme = association_scheme
def __init__(self, association_scheme):
self.association_scheme = association_scheme

View File

@@ -38,6 +38,7 @@ from rfd3.transforms.conditioning_base import (
UnindexFlaggedTokens,
get_motif_features,
)
from rfd3.transforms.na_geom import na_ss_feats_from_annotation
from rfd3.transforms.rasa import discretize_rasa
from rfd3.transforms.util_transforms import (
AssignTypes,
@@ -45,7 +46,6 @@ from rfd3.transforms.util_transforms import (
get_af3_token_representative_masks,
)
from rfd3.transforms.virtual_atoms import PadTokensWithVirtualAtoms
from rfd3.transforms.na_geom import na_ss_feats_from_annotation
from foundry.utils.ddp import RankedLogger # noqa
@@ -71,7 +71,7 @@ class SubsampleToTypes(Transform):
def __init__(
self,
allowed_types: list | str = ["is_protein"],
association_scheme: str = 'atom14'
association_scheme: str = "atom14",
):
self.allowed_types = allowed_types
self.association_scheme = association_scheme
@@ -106,7 +106,7 @@ class SubsampleToTypes(Transform):
)
)
if self.association_scheme != 'atom23' and atom_array.is_protein.sum() == 0:
if self.association_scheme != "atom23" and atom_array.is_protein.sum() == 0:
raise ValueError(
"No protein atoms found in the atom array. Example ID: {}".format(
data.get("example_id", "unknown")
@@ -757,7 +757,7 @@ class AddAdditional1dFeaturesToFeats(Transform):
"""
if "feats" not in data.keys():
data["feats"] = {}
if self.association_scheme == "atom23":
data["atom_array"].set_annotation(
"is_protein_token", data["atom_array"].is_protein
@@ -772,7 +772,6 @@ class AddAdditional1dFeaturesToFeats(Transform):
data = self.generate_feature(feature_name, n_dims, data, "atom")
return data
class AddAdditional2dFeaturesToFeats(Transform):
@@ -791,15 +790,15 @@ class AddAdditional2dFeaturesToFeats(Transform):
token_2d_features,
autofill_zeros_if_not_present_in_atomarray=False,
association_scheme="atom14",
):
):
self.autofill = autofill_zeros_if_not_present_in_atomarray
self.token_2d_features = token_2d_features
self.association_scheme = association_scheme
# Need to pre-define custom constructor functions
# Need to pre-define custom constructor functions
# to map from atomarray annotations to tensors.
self.constructor_functions = {
'bp_partners': na_ss_feats_from_annotation,
"bp_partners": na_ss_feats_from_annotation,
}
def check_input(self, data) -> None:
@@ -807,11 +806,10 @@ class AddAdditional2dFeaturesToFeats(Transform):
check_is_instance(data, "atom_array", AtomArray)
def generate_token_feature(self, feature_name, n_dims, data):
# Don't do this if we already have the feature
if feature_name in data["feats"].keys():
return data
# For these, we need to use a constructor function mapping,
# since pair features may require custom logic/conventions.
@@ -822,11 +820,11 @@ class AddAdditional2dFeaturesToFeats(Transform):
raise ValueError(
f"No constructor function found for 2d feature `{feature_name}`"
)
# We can fix shape issues here:
if len(feature_array.shape) == 2 and n_dims == 1:
feature_array = feature_array.unsqueeze(1)
# ensure that feature_array is a 3d array with third dim == n_dims:
if len(feature_array.shape) != 3:
raise ValueError(
@@ -854,7 +852,7 @@ class AddAdditional2dFeaturesToFeats(Transform):
if "feats" not in data.keys():
data["feats"] = {}
# Only apply for features that the model is expecting:
if self.token_2d_features == None:
if self.token_2d_features is None:
return data
for feature_name, n_dims in self.token_2d_features.items():
data = self.generate_token_feature(feature_name, n_dims, data)

View File

@@ -248,5 +248,3 @@ def subsample_one_hot_np(array, fraction):
new_array[i, j] = 1
return new_array

View File

@@ -1,48 +1,49 @@
from typing import Any
import numpy as np
from functools import partial
from biotite.structure import AtomArray
from atomworks.ml.transforms._checks import (
check_atom_array_annotation,
check_contains_keys,
check_is_instance,
)
from atomworks.ml.transforms.base import Transform
from atomworks.ml.utils.token import get_token_starts, spread_token_wise
from biotite.structure import AtomArray
from rfd3.transforms.conditioning_utils import sample_island_tokens
from rfd3.transforms.na_geom_utils import (
annotate_na_ss,
annotate_na_ss_from_data_specification,
DEFAULT_NA_SS_FEATURE_INFO,
annotate_na_ss,
annotate_na_ss_from_data_specification,
)
from atomworks.ml.utils.token import spread_token_wise, get_token_starts
def na_ss_feats_from_annotation(atom_array: AtomArray,
token_starts= None,
n_tokens = None,
return_as_onehot = True,
) -> np.ndarray:
def na_ss_feats_from_annotation(
atom_array: AtomArray,
token_starts=None,
n_tokens=None,
return_as_onehot=True,
) -> np.ndarray:
"""
Takes in atom array and constucts a base pair feature matrix from annotations,
according to to custom feature constuction + masking system.
This featurization utilizes info from BasePairEnum to assign int values
This featurization utilizes info from BasePairEnum to assign int values
to paired, unpaired, and masked positions in the matrix.
Args:
* atom_array: AtomArray with bp_partners annotation at atom level
* token_starts (optional): indices of token starts in the atom array
* n_tokens (optional): number of tokens (length of token_starts)
* return_as_onehot (optional): if False, return integer-encoded
* return_as_onehot (optional): if False, return integer-encoded
matrix instead of one-hot encoded matrix
returns:
* na_ss_matrix:
If ``return_as_onehot`` is True (default):
np.ndarray of shape (n_tokens, n_tokens, n_classes)
np.ndarray of shape (n_tokens, n_tokens, n_classes)
with one-hot encoded values according to BasePairEnum
If ``return_as_onehot`` is False :
np.ndarray of shape (n_tokens, n_tokens)
np.ndarray of shape (n_tokens, n_tokens)
with int values according to BasePairEnum
@@ -51,13 +52,16 @@ def na_ss_feats_from_annotation(atom_array: AtomArray,
if (token_starts is None) or (n_tokens is None):
token_starts = get_token_starts(atom_array)
n_tokens = len(token_starts)
# Collect token inds for paired or loop positions:
pair_inds = []
loop_inds = []
token_bp_partners = atom_array.get_annotation("bp_partners")[token_starts] # get bp_partners at token level
assert len(token_bp_partners) == n_tokens, "Length of token_bp_partners should match n_tokens"
token_bp_partners = atom_array.get_annotation("bp_partners")[
token_starts
] # get bp_partners at token level
assert (
len(token_bp_partners) == n_tokens
), "Length of token_bp_partners should match n_tokens"
for i, j_list in enumerate(token_bp_partners):
if j_list is not None:
if len(j_list) > 0:
@@ -68,46 +72,54 @@ def na_ss_feats_from_annotation(atom_array: AtomArray,
# The standard system for constructing meaningful base pair features:
# 0). Initialize with values of UNSPECIFIED (0): int matrix of shape (n_tokens, n_tokens)
na_ss_matrix = np.full((n_tokens, n_tokens), DEFAULT_NA_SS_FEATURE_INFO["NA_SS_MASK"], dtype=np.int64)
na_ss_matrix = np.full(
(n_tokens, n_tokens), DEFAULT_NA_SS_FEATURE_INFO["NA_SS_MASK"], dtype=np.int64
)
# 1). Fill in with values of PAIR (1) at positions that have bp_partners annotated as a non-empty list
for pair_i, pair_j in pair_inds:
na_ss_matrix[pair_i, pair_j] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_PAIR"]
na_ss_matrix[pair_j, pair_i] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_PAIR"] # ensure symmetry
na_ss_matrix[pair_j, pair_i] = DEFAULT_NA_SS_FEATURE_INFO[
"NA_SS_PAIR"
] # ensure symmetry
# 2). Fill in with values of LOOP (2) at positions that have bp_partners annotated as an empty list (explicitly unpaired)
# (we make full stripes across that position's row/col to indicate that NONE of those other positions are paired )
for loop_i in loop_inds:
na_ss_matrix[loop_i, :] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_LOOP"]
na_ss_matrix[:, loop_i] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_LOOP"] # ensure symmetry
na_ss_matrix[:, loop_i] = DEFAULT_NA_SS_FEATURE_INFO[
"NA_SS_LOOP"
] # ensure symmetry
# Optional: convert NA-SS matrix to one-hot encoding according for model input:
if return_as_onehot:
na_ss_matrix = np.eye(len(DEFAULT_NA_SS_FEATURE_INFO), dtype=np.int64)[na_ss_matrix]
na_ss_matrix = np.eye(len(DEFAULT_NA_SS_FEATURE_INFO), dtype=np.int64)[
na_ss_matrix
]
return na_ss_matrix
class CalculateNucleicAcidGeomFeats(Transform):
"""
Transform for constructing nucleic-acid conditioning features.
Transform for constructing nucleic-acid conditioning features.
This transform currently produces only nucleic-acid secondary-structure (NA-SS)
features as a 2D token-token matrix with 3 bins:
* 0: mask / unspecified
* 1: paired
* 2: loop / explicitly unpaired
This transform currently produces only nucleic-acid secondary-structure (NA-SS)
features as a 2D token-token matrix with 3 bins:
* 0: mask / unspecified
* 1: paired
* 2: loop / explicitly unpaired
Training:
- Computes geometry/H-bond-based base pairs and writes them onto the AtomArray
via the ``bp_partners`` annotation (annotation-first), then reconstructs the
matrix (and optionally masks parts of it) before one-hot encoding.
Training:
- Computes geometry/H-bond-based base pairs and writes them onto the AtomArray
via the ``bp_partners`` annotation (annotation-first), then reconstructs the
matrix (and optionally masks parts of it) before one-hot encoding.
Inference:
- Interprets user-provided secondary-structure specifications, writes the same
``bp_partners`` annotation, then follows the same matrix + one-hot path.
Inference:
- Interprets user-provided secondary-structure specifications, writes the same
``bp_partners`` annotation, then follows the same matrix + one-hot path.
Note: helical-parameter features are not implemented/used in this refactored path.
Note: helical-parameter features are not implemented/used in this refactored path.
"""
def __init__(
@@ -115,43 +127,46 @@ class CalculateNucleicAcidGeomFeats(Transform):
is_inference,
# Conditional sampling parameters all stored in this dict:
meta_conditioning_probabilities: dict[str, float] = None,
# Mask control paramerers:
nucleic_ss_min_shown: float = 0.2,
nucleic_ss_max_shown: float = 1.0,
n_islands_min: int = 1,
n_islands_max: int = 6,
# USE_RF2AA_NAMES: bool = False,
NA_only: bool = False,
planar_only : bool = True,
planar_only: bool = True,
):
# Critical, must always have to know how to handle
self.is_inference = is_inference
self.is_inference = is_inference
self.meta_conditioning_probabilities = meta_conditioning_probabilities or {}
# Control whether we show some nucleic SS or default to full 2D mask
self.p_is_nucleic_ss_example = self.meta_conditioning_probabilities.get("p_is_nucleic_ss_example", 0.0)
self.p_is_nucleic_ss_example = self.meta_conditioning_probabilities.get(
"p_is_nucleic_ss_example", 0.0
)
# Control whether we define full SS or just part of it (only applies if is NA SS example)
self.p_show_partial_feats = self.meta_conditioning_probabilities.get("p_nucleic_ss_show_partial_feats", 0.0)
self.p_show_partial_feats = self.meta_conditioning_probabilities.get(
"p_nucleic_ss_show_partial_feats", 0.0
)
# Some frac of time default to only showing canonical base pairs
self.p_canonical_bp_filter = self.meta_conditioning_probabilities.get("p_canonical_bp_filter", 0.5)
self.p_canonical_bp_filter = self.meta_conditioning_probabilities.get(
"p_canonical_bp_filter", 0.5
)
# mask patterning control to make things resemble design scenarios
self.nucleic_ss_min_shown = nucleic_ss_min_shown
self.nucleic_ss_max_shown = nucleic_ss_max_shown
self.n_islands_min = n_islands_min
self.n_islands_max = n_islands_max
self.nucleic_ss_min_shown = nucleic_ss_min_shown
self.nucleic_ss_max_shown = nucleic_ss_max_shown
self.n_islands_min = n_islands_min
self.n_islands_max = n_islands_max
# Filters for what can be considered a planar contact interaction
self.NA_only = NA_only # only annotate base-like interactions for nucleic acid residues
self.planar_only = planar_only # only consider planar atoms in sidechains for geometry calculations,
self.NA_only = (
NA_only # only annotate base-like interactions for nucleic acid residues
)
self.planar_only = planar_only # only consider planar atoms in sidechains for geometry calculations,
def check_input(self, data: dict[str, Any]) -> None:
check_contains_keys(data, ["atom_array"])
@@ -162,9 +177,7 @@ class CalculateNucleicAcidGeomFeats(Transform):
def _sample_training_flags(self) -> tuple[bool, bool]:
"""Sample booleans controlling whether/how features are shown in training."""
is_nucleic_ss_example = bool(np.random.rand() < self.p_is_nucleic_ss_example)
give_partial_feats = bool(
np.random.rand() < self.p_show_partial_feats
)
give_partial_feats = bool(np.random.rand() < self.p_show_partial_feats)
return is_nucleic_ss_example, give_partial_feats
def forward(self, data: dict) -> dict:
@@ -177,24 +190,25 @@ class CalculateNucleicAcidGeomFeats(Transform):
# Handle the training case with ground truth and masking
if not self.is_inference:
# First, annotate as usual
is_nucleic_ss_example, give_partial_feats = self._sample_training_flags()
if is_nucleic_ss_example:
atom_array = annotate_na_ss(atom_array,
NA_only=self.NA_only,
planar_only=self.planar_only,
p_canonical_bp_filter=self.p_canonical_bp_filter,
)
atom_array = annotate_na_ss(
atom_array,
NA_only=self.NA_only,
planar_only=self.planar_only,
p_canonical_bp_filter=self.p_canonical_bp_filter,
)
# Generate symmetric partner annotations at the token level for masking purposes.
# choice for object-consistency: if already masked/undefined: be a list mapping to self-index.
partner_sym_map = {
i: atom_array.bp_partners[ts_i] if atom_array.bp_partners[ts_i] is not None else [i]
i: atom_array.bp_partners[ts_i]
if atom_array.bp_partners[ts_i] is not None
else [i]
for i, ts_i in enumerate(token_starts)
}
}
# # Sample mask on token level:
token_mask_to_show = self._sample_where_to_show_ss(
@@ -206,17 +220,19 @@ class CalculateNucleicAcidGeomFeats(Transform):
# Spread mask to atom level
is_ss_shown = spread_token_wise(atom_array, token_mask_to_show)
# Extract the base pair annotations
bp_partners_atom = atom_array.get_annotation("bp_partners")
# Remove unshown positions from bp_partners annotation
bp_partners_atom[~is_ss_shown] = None
# Reset the annotation with newly hidden positions
atom_array.set_annotation("bp_partners", bp_partners_atom)
else:
atom_array.set_annotation("bp_partners", np.array([None]*len(atom_array)))
atom_array.set_annotation(
"bp_partners", np.array([None] * len(atom_array))
)
# Inference case: create from commandline args
else:
@@ -239,25 +255,26 @@ class CalculateNucleicAcidGeomFeats(Transform):
log_dict = data["log_dict"]
data["log_dict"] = log_dict
data["atom_array"] = atom_array
return data
def _sample_where_to_show_ss(self, n_tokens: int,
is_nucleic_ss_example: bool = True,
give_partial_feats: bool = True,
partner_sym_map: dict[int, list[int]] = None,
) -> np.ndarray:
def _sample_where_to_show_ss(
self,
n_tokens: int,
is_nucleic_ss_example: bool = True,
give_partial_feats: bool = True,
partner_sym_map: dict[int, list[int]] = None,
) -> np.ndarray:
"""Sample token-level islands indicating which SS rows/cols to reveal.
This custom function allows for enforcing symmetry in the shown features according
to the partner_sym_map, which encodes which tokens are partners in the SS
matrix and thus should be masked/unmasked together to maintain consistency.
This custom function allows for enforcing symmetry in the shown features according
to the partner_sym_map, which encodes which tokens are partners in the SS
matrix and thus should be masked/unmasked together to maintain consistency.
"""
# If NOT is_nucleic_ss_example, set is_shown to all False
if not is_nucleic_ss_example:
token_mask_to_show = np.zeros((n_tokens,), dtype=bool)
# If NOT give_partial_feats, set is_shown to all True
if not give_partial_feats:
token_mask_to_show = np.ones((n_tokens,), dtype=bool)
@@ -265,14 +282,19 @@ class CalculateNucleicAcidGeomFeats(Transform):
# Get numerical parameters for that govern the mask pattern
frac_shown = (
self.nucleic_ss_min_shown
+ (self.nucleic_ss_max_shown - self.nucleic_ss_min_shown) * np.random.rand()
+ (self.nucleic_ss_max_shown - self.nucleic_ss_min_shown)
* np.random.rand()
)
frac_shown = float(np.clip(frac_shown, 0.0, 1.0))
max_length = int(np.ceil(frac_shown * n_tokens))
if max_length <= 0:
token_mask_to_show = np.zeros((n_tokens,), dtype=bool)
island_len_min = max(1, int(frac_shown * n_tokens // max(int(self.n_islands_max), 1)))
island_len_max = max(1, int(frac_shown * n_tokens // max(int(self.n_islands_min), 1)))
island_len_min = max(
1, int(frac_shown * n_tokens // max(int(self.n_islands_max), 1))
)
island_len_max = max(
1, int(frac_shown * n_tokens // max(int(self.n_islands_min), 1))
)
island_len_min = min(island_len_min, n_tokens)
island_len_max = min(island_len_max, n_tokens)
island_len_max = max(island_len_max, island_len_min)
@@ -287,7 +309,7 @@ class CalculateNucleicAcidGeomFeats(Transform):
max_length=max_length,
)
# Handle symmetry by iterating through the partner_sym_map items and setting
# Handle symmetry by iterating through the partner_sym_map items and setting
# `partner_mask_to_show` at partner positions to match `token_mask_to_show`
# initialize as all shown so effect comes from hiding + logical AND condition
partner_mask_to_show = np.ones_like(token_mask_to_show)
@@ -299,4 +321,3 @@ class CalculateNucleicAcidGeomFeats(Transform):
token_mask_to_show = token_mask_to_show & partner_mask_to_show
return token_mask_to_show

View File

@@ -1,22 +1,22 @@
import math
import os
import subprocess
import tempfile
from datetime import datetime
from typing import Dict, Optional
import math
import numpy as np
import biotite.structure as struc
from biotite.structure import AtomArray
import biotite.structure as struc
import numpy as np
from atomworks.constants import (
STANDARD_AA,
STANDARD_AA,
STANDARD_DNA,
STANDARD_RNA,
)
from atomworks.io.utils.sequence import (
is_purine,
is_pyrimidine,
)
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
from atomworks.ml.utils.token import (
get_token_starts,
is_glycine,
@@ -24,15 +24,12 @@ from atomworks.ml.utils.token import (
is_standard_aa_not_glycine,
is_unknown_nucleotide,
)
from rfd3.transforms.hbonds_hbplus import save_atomarray_to_pdb
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
from biotite.structure import AtomArray
from rfd3.constants import (
ATOM_REGION_BY_RESI,
PLANAR_ATOMS_BY_RESI,
ATOM_REGION_BY_RESI,
PLANAR_ATOMS_BY_RESI,
)
import tempfile
from rfd3.transforms.hbonds_hbplus import save_atomarray_to_pdb
# Derived: True when the residue has any planar sidechain atoms
HAS_PLANAR_SC = {res: bool(atoms) for res, atoms in PLANAR_ATOMS_BY_RESI.items()}
@@ -43,15 +40,23 @@ DEFAULT_NA_SS_FEATURE_INFO: dict[str, int] = {
"NA_SS_LOOP": 2,
}
AA_PLANAR_ATOMS = sorted(set(
atom for res in STANDARD_AA if res in PLANAR_ATOMS_BY_RESI
for atom in PLANAR_ATOMS_BY_RESI[res]
))
AA_PLANAR_ATOMS = sorted(
set(
atom
for res in STANDARD_AA
if res in PLANAR_ATOMS_BY_RESI
for atom in PLANAR_ATOMS_BY_RESI[res]
)
)
NA_PLANAR_ATOMS = sorted(set(
atom for res in (*STANDARD_RNA, *STANDARD_DNA) if res in PLANAR_ATOMS_BY_RESI
for atom in PLANAR_ATOMS_BY_RESI[res]
))
NA_PLANAR_ATOMS = sorted(
set(
atom
for res in (*STANDARD_RNA, *STANDARD_DNA)
if res in PLANAR_ATOMS_BY_RESI
for atom in PLANAR_ATOMS_BY_RESI[res]
)
)
class NucMolInfo:
@@ -62,53 +67,126 @@ class NucMolInfo:
"""
def __init__(self) -> None:
# Hbond interaction-class indices of the `hbond_count`` array:
# `hbond_count`` array is (L, L, 3), where the last dimension
# `hbond_count`` array is (L, L, 3), where the last dimension
# encodes interaction type between tokens i & j
self.BB_BB = 0 # backbone-backbone hbond interactions
self.BB_SC = 1 # backbone-sidechain hbond interactions
self.SC_SC = 2 # sidechain-sidechain hbond interactions
self.BB_BB = 0 # backbone-backbone hbond interactions
self.BB_SC = 1 # backbone-sidechain hbond interactions
self.SC_SC = 2 # sidechain-sidechain hbond interactions
# We sum over the last dimension of the hbond_count array, scaling
# We sum over the last dimension of the hbond_count array, scaling
# count by the following weights to get the interaction score:
self.bp_weight_BB_BB = 0.0
self.bp_weight_BB_SC = 0.5
self.bp_weight_SC_SC = 1.0
self.bp_summation_weights = [self.bp_weight_BB_BB,
self.bp_weight_BB_SC,
self.bp_weight_SC_SC]
self.bp_summation_weights = [
self.bp_weight_BB_BB,
self.bp_weight_BB_SC,
self.bp_weight_SC_SC,
]
# Parameters fo sigmoid function that gives us a continuous step function for
# Parameters fo sigmoid function that gives us a continuous step function for
# meeting basepair interaction criteria based on hbond counts alone (1st filter).
# Calibrated such that:
# >= 2 base-base H-bonds -> ~1.0
# 1 base-base H-bond + 1 base-backbone H-bond -> ~0.5
self.min_hbonds_for_bp = 2.0
self.bp_hbond_coeff = 9.8 # determined heuristically
self.bp_val_cutoff = 0.5 # minimum basepairing score for binarizing basepairs when needed
self.bp_hbond_coeff = 9.8 # determined heuristically
self.bp_val_cutoff = (
0.5 # minimum basepairing score for binarizing basepairs when needed
)
self.base_geometry_limits = {}
self.base_geometry_limits['D_ij'] = 16.0
self.base_geometry_limits['H_ij'] = 1.5
self.base_geometry_limits['P_ij'] = math.pi/5
self.base_geometry_limits['B_ij'] = math.pi/5
self.base_geometry_limits["D_ij"] = 16.0
self.base_geometry_limits["H_ij"] = 1.5
self.base_geometry_limits["P_ij"] = math.pi / 5
self.base_geometry_limits["B_ij"] = math.pi / 5
self.rep_atom_dict={"protein": "CA", "rna": "C1'", "dna": "C1'"}
self.rep_atom_dict = {"protein": "CA", "rna": "C1'", "dna": "C1'"}
# go through self.vec_atom_dict and remove spaces from atom names (values in inner dicts), and remove spaces from keys + replace 'R' with '' in outer dict keys
self.vec_atom_dict = {
"DA": {"W_start":"N1", "W_stop":"N6", "H_start":"N7", "H_stop":"N6", "S_start":"C1'", "S_stop":"N3", "B_start":"C1'", "B_stop":"N9" },
"DG": {"W_start":"N1", "W_stop":"O6", "H_start":"N7", "H_stop":"O6", "S_start":"C1'", "S_stop":"N3", "B_start":"C1'", "B_stop":"N9" },
"DC": {"W_start":"N3", "W_stop":"N4", "H_start":"C5", "H_stop":"N4", "S_start":"C1'", "S_stop":"O2", "B_start":"C1'", "B_stop":"N1" },
"DT": {"W_start":"N3", "W_stop":"O4", "H_start":"C5", "H_stop":"O4", "S_start":"C1'", "S_stop":"O2", "B_start":"C1'", "B_stop":"N1" },
"A": {"W_start":"N1", "W_stop":"N6", "H_start":"N7", "H_stop":"N6", "S_start":"C1'", "S_stop":"N3", "B_start":"C1'", "B_stop":"N9" },
"G": {"W_start":"N1", "W_stop":"O6", "H_start":"N7", "H_stop":"O6", "S_start":"C1'", "S_stop":"N3", "B_start":"C1'", "B_stop":"N9" },
"C": {"W_start":"N3", "W_stop":"N4", "H_start":"C5", "H_stop":"N4", "S_start":"C1'", "S_stop":"O2", "B_start":"C1'", "B_stop":"N1" },
"U": {"W_start":"N3", "W_stop":"O4", "H_start":"C5", "H_stop":"O4", "S_start":"C1'", "S_stop":"O2", "B_start":"C1'", "B_stop":"N1" },
}
"DA": {
"W_start": "N1",
"W_stop": "N6",
"H_start": "N7",
"H_stop": "N6",
"S_start": "C1'",
"S_stop": "N3",
"B_start": "C1'",
"B_stop": "N9",
},
"DG": {
"W_start": "N1",
"W_stop": "O6",
"H_start": "N7",
"H_stop": "O6",
"S_start": "C1'",
"S_stop": "N3",
"B_start": "C1'",
"B_stop": "N9",
},
"DC": {
"W_start": "N3",
"W_stop": "N4",
"H_start": "C5",
"H_stop": "N4",
"S_start": "C1'",
"S_stop": "O2",
"B_start": "C1'",
"B_stop": "N1",
},
"DT": {
"W_start": "N3",
"W_stop": "O4",
"H_start": "C5",
"H_stop": "O4",
"S_start": "C1'",
"S_stop": "O2",
"B_start": "C1'",
"B_stop": "N1",
},
"A": {
"W_start": "N1",
"W_stop": "N6",
"H_start": "N7",
"H_stop": "N6",
"S_start": "C1'",
"S_stop": "N3",
"B_start": "C1'",
"B_stop": "N9",
},
"G": {
"W_start": "N1",
"W_stop": "O6",
"H_start": "N7",
"H_stop": "O6",
"S_start": "C1'",
"S_stop": "N3",
"B_start": "C1'",
"B_stop": "N9",
},
"C": {
"W_start": "N3",
"W_stop": "N4",
"H_start": "C5",
"H_stop": "N4",
"S_start": "C1'",
"S_stop": "O2",
"B_start": "C1'",
"B_stop": "N1",
},
"U": {
"W_start": "N3",
"W_stop": "O4",
"H_start": "C5",
"H_stop": "O4",
"S_start": "C1'",
"S_stop": "O2",
"B_start": "C1'",
"B_stop": "N1",
},
}
def calculate_hb_counts(
@@ -117,7 +195,7 @@ def calculate_hb_counts(
mol_info: NucMolInfo,
cutoff_HA_dist: float = 2.5,
cutoff_DA_dist: float = 3.9,
):
):
"""Count hydrogen bonds between residue pairs using HBPLUS.
Args:
@@ -195,14 +273,16 @@ def calculate_hb_counts(
)
# d_atm = atom_array[d_mask]
# d_idx = d_atm.token_id
d_idx = token_level_data["resi2index"].get(f"{d_chain_iid}__{d_resi}", None)
d_idx = token_level_data["resi2index"].get(
f"{d_chain_iid}__{d_resi}", None
)
if d_idx is None:
continue
# Handle standard polymer residues for donor atom:
if d_resn in ATOM_REGION_BY_RESI.keys():
d_is_sc = (d_atom_name in ATOM_REGION_BY_RESI[d_resn]['sc'])
d_is_bb = (d_atom_name in ATOM_REGION_BY_RESI[d_resn]['bb'])
d_is_sc = d_atom_name in ATOM_REGION_BY_RESI[d_resn]["sc"]
d_is_bb = d_atom_name in ATOM_REGION_BY_RESI[d_resn]["bb"]
else:
# If non-polymer, define any ligand HBonding atom as backbone:
if d_mask.sum() > 0:
@@ -219,45 +299,51 @@ def calculate_hb_counts(
& (atom_array.res_id == a_resi)
& (atom_array.chain_iid == a_chain_iid)
)
a_idx = token_level_data["resi2index"].get(f"{a_chain_iid}__{a_resi}", None)
a_idx = token_level_data["resi2index"].get(
f"{a_chain_iid}__{a_resi}", None
)
if a_idx is None:
continue
# Handle standard polymer residues for acceptor atom:
if a_resn in ATOM_REGION_BY_RESI.keys():
a_is_sc = (a_atom_name in ATOM_REGION_BY_RESI[a_resn]['sc'])
a_is_bb = (a_atom_name in ATOM_REGION_BY_RESI[a_resn]['bb'])
a_is_sc = a_atom_name in ATOM_REGION_BY_RESI[a_resn]["sc"]
a_is_bb = a_atom_name in ATOM_REGION_BY_RESI[a_resn]["bb"]
else:
# If non-polymer, define any ligand HBonding atom as backbone:
if a_mask.sum() > 0:
a_is_bb = atom_array[a_mask][0].is_ligand
# 0 -> both backbone (BB-BB)
hbond_count[a_idx, d_idx, 0] += (a_is_bb * d_is_bb)
hbond_count[d_idx, a_idx, 0] += (d_is_bb * a_is_bb)
hbond_count[a_idx, d_idx, 0] += a_is_bb * d_is_bb
hbond_count[d_idx, a_idx, 0] += d_is_bb * a_is_bb
# 1 -> one backbone, one sidechain (BB-SC)
hbond_count[a_idx, d_idx, 1] += (a_is_bb * d_is_sc) | (a_is_sc * d_is_bb)
hbond_count[d_idx, a_idx, 1] += (d_is_bb * a_is_sc) | (d_is_sc * a_is_bb)
hbond_count[a_idx, d_idx, 1] += (a_is_bb * d_is_sc) | (
a_is_sc * d_is_bb
)
hbond_count[d_idx, a_idx, 1] += (d_is_bb * a_is_sc) | (
d_is_sc * a_is_bb
)
# 2 -> both sidechain (SC-SC)
hbond_count[a_idx, d_idx, 2] += (a_is_sc * d_is_sc)
hbond_count[d_idx, a_idx, 2] += (d_is_sc * a_is_sc)
'''
hbond_count[a_idx, d_idx, 2] += a_is_sc * d_is_sc
hbond_count[d_idx, a_idx, 2] += d_is_sc * a_is_sc
"""
try:
os.remove(pdb_path)
os.remove(hb2_path)
except:
print("temp pdb/hb already removed or not created to begin with")
'''
"""
return hbond_count
def find_planar_positions(
atom_array: AtomArray,
mol_info: NucMolInfo,
tol: float = 1e-2,
) -> Dict:
atom_array: AtomArray,
mol_info: NucMolInfo,
tol: float = 1e-2,
) -> Dict:
"""Identify residues with planar sidechains via known atom lists or PCA plane-fitting.
For canonical residues the planar atoms are looked up from ``mol_info``;
@@ -285,11 +371,10 @@ def find_planar_positions(
# for chain_iid, res_id in unique_positions_list:
for chain_iid, res_id, res_name in unique_positions_list:
mask = (
(atom_array.chain_iid == chain_iid) &
(atom_array.res_id == res_id) &
(atom_array.res_name == res_name)
(atom_array.chain_iid == chain_iid)
& (atom_array.res_id == res_id)
& (atom_array.res_name == res_name)
)
res_atoms = atom_array[mask]
@@ -297,9 +382,9 @@ def find_planar_positions(
if res_name in PLANAR_ATOMS_BY_RESI.keys():
# Shared atoms between residue and known planar atoms for that residue type:
planar_atom_list = list(
set([atm.atom_name for atm in res_atoms]) &
set(PLANAR_ATOMS_BY_RESI[res_name])
)
set([atm.atom_name for atm in res_atoms])
& set(PLANAR_ATOMS_BY_RESI[res_name])
)
planar_atom_list_dict[(chain_iid, res_id)] = planar_atom_list
# If unknown or noncanonical residue, compute planar atoms geometrically:
@@ -337,7 +422,9 @@ def find_planar_positions(
all_quad_centered = coords - quad_center
quad_centered = quad_coords - quad_center
# covariance matrix
quad_cov = (quad_centered.T @ quad_centered) / max(quad_coords.shape[0] - 1, 1)
quad_cov = (quad_centered.T @ quad_centered) / max(
quad_coords.shape[0] - 1, 1
)
# eigen decomposition
_, quad_eigvecs = np.linalg.eigh(quad_cov)
quad_normal = quad_eigvecs[:, 0] # eigenvector with smallest eigenvalue
@@ -348,34 +435,39 @@ def find_planar_positions(
quad_valid_mask = quad_dists <= tol
# Filter for if we have a valid plane in the first place:
valid_plane_filter = (np.nanmax(quad_dists[:4]) < tol)
valid_plane_filter = np.nanmax(quad_dists[:4]) < tol
# Filter for if we have enough atoms in the plane:
plane_atom_filter = (int(np.sum(quad_valid_mask)) >= 4)
plane_atom_filter = int(np.sum(quad_valid_mask)) >= 4
if valid_plane_filter and plane_atom_filter:
# Set the planar atom list for this position to those that are within tol of the plane:
# Set the planar atom list for this position to those that are within tol of the plane:
# using quad_valid_mask and candidate_planar_atm_names:
planar_atom_list = [n for n, keep in zip(candidate_planar_atm_names, quad_valid_mask.tolist()) if keep]
planar_atom_list = [
n
for n, keep in zip(
candidate_planar_atm_names, quad_valid_mask.tolist()
)
if keep
]
# not enough atoms close to a common plane
else:
planar_atom_list = []
else:
# need at least 4 atoms to define a robust plane
planar_atom_list = []
planar_atom_list_dict[(chain_iid, res_id)] = planar_atom_list
planar_atom_list_dict[(chain_iid, res_id)] = planar_atom_list
return planar_atom_list_dict
def make_coord_list(atom_array: AtomArray,
residue_list: list[str],
chain_list: list[str],
atom_list: list[str],
) -> list[list[str]]:
def make_coord_list(
atom_array: AtomArray,
residue_list: list[str],
chain_list: list[str],
atom_list: list[str],
) -> list[list[str]]:
"""Extract per-residue representative coordinates from an AtomArray.
All three input lists must have the same length. Missing atoms are
@@ -393,24 +485,20 @@ def make_coord_list(atom_array: AtomArray,
"""
coord_list = []
for res_id, chain_id, atom_name in zip(residue_list, chain_list, atom_list):
# Check if the residue exists in the atom array
if atom_name == "atomized":
# Check for atomized residue, in which case we take the first atom of the residue
# full mask should be length-1 if atomized
mask = (
(atom_array.chain_id == chain_id) &
(atom_array.res_id == res_id)
)
mask = (atom_array.chain_id == chain_id) & (atom_array.res_id == res_id)
else:
# General case for non-atomized residues
# should have a unique solution, but we take the first entry either way.
mask = (
(atom_array.chain_id == chain_id) &
(atom_array.res_id == res_id) &
(atom_array.atom_name == atom_name)
)
(atom_array.chain_id == chain_id)
& (atom_array.res_id == res_id)
& (atom_array.atom_name == atom_name)
)
# Get the coordinates for the masked atoms
coords = atom_array.coord[mask][0:1]
@@ -428,8 +516,8 @@ def get_token_level_metadata(
*,
NA_only: bool = False,
planar_only: bool = True,
seq_cutoff = 2,
gap_length = 200
seq_cutoff=2,
gap_length=200,
) -> dict:
"""Build lightweight token-level metadata (no coordinate geometry).
@@ -468,9 +556,21 @@ def get_token_level_metadata(
# molecule type flags
# Instantiate encoding locally to avoid retaining large arrays at module scope.
sequence_encoding = AF3SequenceEncoding()
is_protein = np.isin(token_level_array.res_name, sequence_encoding.all_res_names[sequence_encoding.is_aa_like])
is_rna = np.isin(token_level_array.res_name, sequence_encoding.all_res_names[sequence_encoding.is_rna_like])
is_dna = np.isin(token_level_array.res_name, sequence_encoding.all_res_names[sequence_encoding.is_dna_like])
###################
# is_protein = np.isin(
# token_level_array.res_name,
# sequence_encoding.all_res_names[sequence_encoding.is_aa_like],
# )
##################
is_rna = np.isin(
token_level_array.res_name,
sequence_encoding.all_res_names[sequence_encoding.is_rna_like],
)
is_dna = np.isin(
token_level_array.res_name,
sequence_encoding.all_res_names[sequence_encoding.is_dna_like],
)
is_na_arr = (is_dna | is_rna).astype(bool)
@@ -500,7 +600,7 @@ def get_token_level_metadata(
sc_planarity_list.append(False)
# representative & sugar-edge atoms
if (is_glycine(atm.res_name) | is_protein_unknown(atm.res_name)):
if is_glycine(atm.res_name) | is_protein_unknown(atm.res_name):
rep_atom_i = "CA"
S_start_atom_i = None
S_stop_atom_i = None
@@ -534,7 +634,9 @@ def get_token_level_metadata(
S_stop_atom_list.append(S_stop_atom_i)
# residue index <-> token index map
resi2index = {f"{c}__{r}": i for c, r, i in zip(chain_iid_list, resi_list, ind_list)}
resi2index = {
f"{c}__{r}": i for c, r, i in zip(chain_iid_list, resi_list, ind_list)
}
# relative sequence positions w/ chain gaps
rel_pos_list: list[int] = []
@@ -547,9 +649,7 @@ def get_token_level_metadata(
rel_pos_list.append(int(r + chn_bias))
rel_pos = np.asarray(rel_pos_list, dtype=np.int64)
seq_neighbors = (
np.abs(rel_pos[:, None] - rel_pos[None, :]) <= int(seq_cutoff)
)
seq_neighbors = np.abs(rel_pos[:, None] - rel_pos[None, :]) <= int(seq_cutoff)
na_inds = np.nonzero(is_na_arr)[0].tolist()
na_tensor_inds = {na_i: i for i, na_i in enumerate(na_inds)}
@@ -647,12 +747,16 @@ def add_token_level_geometry_data(
S_start_atom_list: list[str | None] = token_level_data["S_start_atom_list"]
S_stop_atom_list: list[str | None] = token_level_data["S_stop_atom_list"]
planar_atom_list_dict = find_planar_positions(atom_array, mol_info) # {(chain_iid, res_id): [atom_name, ...]}
planar_atom_list_dict = find_planar_positions(
atom_array, mol_info
) # {(chain_iid, res_id): [atom_name, ...]}
has_planar_sc: list[bool] = []
xyz_planar: list[list[list[float]]] = [] # list[I] of [K_i, 3] (K_i varies per residue)
xyz_S_start: list[list[float]] = [] # list[I] of [3]
xyz_S_stop: list[list[float]] = [] # list[I] of [3]
xyz_planar: list[
list[list[float]]
] = [] # list[I] of [K_i, 3] (K_i varies per residue)
xyz_S_start: list[list[float]] = [] # list[I] of [3]
xyz_S_stop: list[list[float]] = [] # list[I] of [3]
for c, r, S_start_atm, S_stop_atm in zip(
chain_iid_list,
@@ -663,7 +767,9 @@ def add_token_level_geometry_data(
planar_atoms_i = planar_atom_list_dict[(c, r)]
has_planar_sc.append(bool(len(planar_atoms_i) >= 4))
atom_array_i = atom_array[(atom_array.chain_iid == c) & (atom_array.res_id == r)]
atom_array_i = atom_array[
(atom_array.chain_iid == c) & (atom_array.res_id == r)
]
planar_coords_i: list[list[float]] = []
for pl_atm_name_j in planar_atoms_i:
@@ -673,7 +779,9 @@ def add_token_level_geometry_data(
else:
planar_coords_i.append(pl_atom_array_ij[0].coord)
xyz_planar.append(planar_coords_i if len(planar_coords_i) > 3 else [[float("nan")] * 3])
xyz_planar.append(
planar_coords_i if len(planar_coords_i) > 3 else [[float("nan")] * 3]
)
if S_start_atm is None:
xyz_S_start.append([float("nan"), float("nan"), float("nan")])
@@ -698,21 +806,26 @@ def add_token_level_geometry_data(
del atom_array_i
# frame coordinates and backbone direction
frame_xyz = np.asarray( # [I, 3] representative-atom coordinates
frame_xyz = np.asarray( # [I, 3] representative-atom coordinates
make_coord_list(atom_array, resi_list, chain_list, rep_atom_list),
dtype=np.float32,
)
padded_centers = np.concatenate([frame_xyz[:1], frame_xyz, frame_xyz[-1:]], axis=0) # [I+2, 3]
M_i = ( # [I, 3] smoothed backbone-direction vectors
(padded_centers[1:-1] - padded_centers[:-2])
+ (padded_centers[2:] - padded_centers[1:-1])
) / 2.0
padded_centers = np.concatenate(
[frame_xyz[:1], frame_xyz, frame_xyz[-1:]], axis=0
) # [I+2, 3]
M_i = (
( # [I, 3] smoothed backbone-direction vectors
(padded_centers[1:-1] - padded_centers[:-2])
+ (padded_centers[2:] - padded_centers[1:-1])
)
/ 2.0
)
is_planar_arr = np.asarray(has_planar_sc, dtype=bool) # [I]
is_planar_arr = np.asarray(has_planar_sc, dtype=bool) # [I]
token_level_data["is_planar"] = is_planar_arr
is_na_arr = np.asarray(token_level_data["is_na"], dtype=bool) # [I]
is_na_arr = np.asarray(token_level_data["is_na"], dtype=bool) # [I]
if NA_only and planar_only:
filter_mask = is_na_arr & is_planar_arr
elif NA_only and (not planar_only):
@@ -721,7 +834,7 @@ def add_token_level_geometry_data(
filter_mask = is_planar_arr.copy()
else:
filter_mask = np.ones_like(is_na_arr, dtype=bool)
token_level_data["filter_mask"] = filter_mask # [I] bool
token_level_data["filter_mask"] = filter_mask # [I] bool
token_level_data.update(
{
@@ -777,7 +890,7 @@ def _compute_local_frames(
n_tokens = len(xyz_planar)
# Mean-centre the planar atoms per residue
centered_points = [ # list[I] of [K_i, 3]
centered_points = [ # list[I] of [K_i, 3]
np.asarray(xyz_i, dtype=np.float32) - cen_i
for xyz_i, cen_i in zip(xyz_planar, planar_centers)
]
@@ -857,19 +970,23 @@ def _compute_pairwise_geometry(
``"Z_ij"`` [I, I, 3], and optionally ``"O_ij"`` [I, I].
"""
# Orientation-selected pairwise Z-axis
Z_sum = Z_i[:, None, :] + Z_i[None, :, :] # [I, I, 3]
Z_diff = Z_i[:, None, :] - Z_i[None, :, :] # [I, I, 3]
Z_sum = Z_i[:, None, :] + Z_i[None, :, :] # [I, I, 3]
Z_diff = Z_i[:, None, :] - Z_i[None, :, :] # [I, I, 3]
Z_ij_oris = 0.5 * np.stack((Z_sum, Z_diff), axis=0) # [2, I, I, 3]
base_ori_ij = ( # [I, I] 0=parallel, 1=antiparallel
np.linalg.norm(Z_ij_oris[1], axis=-1) > np.linalg.norm(Z_ij_oris[0], axis=-1)
).astype(np.int64)
Z_ij = np.where(base_ori_ij[..., None] == 0, Z_ij_oris[0], Z_ij_oris[1]) # [I, I, 3]
Z_ij = np.where(
base_ori_ij[..., None] == 0, Z_ij_oris[0], Z_ij_oris[1]
) # [I, I, 3]
Z_ij = Z_ij / (np.linalg.norm(Z_ij, axis=-1, keepdims=True) + eps)
# Pairwise Y (inter-residue direction) and X axes
Y_ij = frame_D_ij_vec / (np.linalg.norm(frame_D_ij_vec, axis=-1, keepdims=True) + eps) # [I, I, 3]
Y_ij = frame_D_ij_vec / (
np.linalg.norm(frame_D_ij_vec, axis=-1, keepdims=True) + eps
) # [I, I, 3]
X_ij = np.cross(Z_ij, Y_ij) # [I, I, 3]
X_ij = X_ij / (np.linalg.norm(X_ij, axis=-1, keepdims=True) + eps)
@@ -882,22 +999,30 @@ def _compute_pairwise_geometry(
np.sum(Z_i[:, None, :] * Y_ij, axis=-1, keepdims=True) * Y_ij
+ np.sum(Z_i[:, None, :] * Z_ij, axis=-1, keepdims=True) * Z_ij
)
proj_Z_i_YZ_norm = proj_Z_i_YZ / (np.linalg.norm(proj_Z_i_YZ, axis=-1, keepdims=True) + eps)
cos_buckle = np.sum(proj_Z_i_YZ_norm * (-proj_Z_i_YZ_norm.swapaxes(0, 1)), axis=-1) # [I, I]
proj_Z_i_YZ_norm = proj_Z_i_YZ / (
np.linalg.norm(proj_Z_i_YZ, axis=-1, keepdims=True) + eps
)
cos_buckle = np.sum(
proj_Z_i_YZ_norm * (-proj_Z_i_YZ_norm.swapaxes(0, 1)), axis=-1
) # [I, I]
# Propeller (P_ij)
proj_Z_i_ZX = ( # [I, I, 3]
np.sum(Z_i[:, None, :] * Z_ij, axis=-1, keepdims=True) * Z_ij
+ np.sum(Z_i[:, None, :] * X_ij, axis=-1, keepdims=True) * X_ij
)
proj_Z_i_ZX_norm = proj_Z_i_ZX / (np.linalg.norm(proj_Z_i_ZX, axis=-1, keepdims=True) + eps)
cos_propeller = np.sum(proj_Z_i_ZX_norm * (-proj_Z_i_ZX_norm.swapaxes(0, 1)), axis=-1) # [I, I]
proj_Z_i_ZX_norm = proj_Z_i_ZX / (
np.linalg.norm(proj_Z_i_ZX, axis=-1, keepdims=True) + eps
)
cos_propeller = np.sum(
proj_Z_i_ZX_norm * (-proj_Z_i_ZX_norm.swapaxes(0, 1)), axis=-1
) # [I, I]
if clamp:
cos_buckle = np.clip(cos_buckle, -1.0, 1.0)
cos_propeller = np.clip(cos_propeller, -1.0, 1.0)
B_ij = np.arccos(cos_buckle) # [I, I]
B_ij = np.arccos(cos_buckle) # [I, I]
P_ij = np.arccos(cos_propeller) # [I, I]
result: dict[str, np.ndarray] = {
@@ -920,8 +1045,12 @@ def _compute_pairwise_geometry(
np.sum(X_i[:, None, :] * X_ij, axis=-1, keepdims=True) * X_ij
+ np.sum(X_i[:, None, :] * Y_ij, axis=-1, keepdims=True) * Y_ij
)
proj_X_i_XY_norm = proj_X_i_XY / (np.linalg.norm(proj_X_i_XY, axis=-1, keepdims=True) + eps)
cos_opening = np.sum(proj_X_i_XY_norm * proj_X_i_XY_norm.swapaxes(0, 1), axis=-1) # [I, I]
proj_X_i_XY_norm = proj_X_i_XY / (
np.linalg.norm(proj_X_i_XY, axis=-1, keepdims=True) + eps
)
cos_opening = np.sum(
proj_X_i_XY_norm * proj_X_i_XY_norm.swapaxes(0, 1), axis=-1
) # [I, I]
if clamp:
cos_opening = np.clip(cos_opening, -1.0, 1.0)
result["O_ij"] = np.arccos(cos_opening) # [I, I]
@@ -989,13 +1118,14 @@ def _compute_basepair_mask(
| (P_ij >= math.pi - mol_info.base_geometry_limits["P_ij"])
)
D_ij_filter = (D_ij <= mol_info.base_geometry_limits["D_ij"])
D_ij_filter = D_ij <= mol_info.base_geometry_limits["D_ij"]
bp_geom_filter = H_ij_filter & B_ij_filter & P_ij_filter & D_ij_filter # [I, I]
if bool_only:
basepairs_bool_ij = ( # [I, I]
(~seq_neighbors) & bp_geom_filter
(~seq_neighbors)
& bp_geom_filter
& (bp_preds >= float(mol_info.bp_val_cutoff))
)
return basepairs_bool_ij
@@ -1015,7 +1145,7 @@ def _compute_basepair_mask(
def compute_nucleic_ss(
mol_info,
mol_info,
token_level_data,
hbond_count,
clamp_pairwise_params=True,
@@ -1061,14 +1191,24 @@ def compute_nucleic_ss(
mask_1d = np.asarray(token_level_data["filter_mask"], dtype=bool) # [I_full]
# --- Unpack and filter token-level data ----------------------
M_i = np.asarray(token_level_data["M_i"], dtype=np.float32)[mask_1d] # [I, 3]
frame_xyz = np.asarray(token_level_data["frame_xyz"], dtype=np.float32)[mask_1d] # [I, 3]
xyz_S_start = [v for v, k in zip(token_level_data["xyz_S_start"], mask_1d) if k] # list[I] of [3]
xyz_S_stop = [v for v, k in zip(token_level_data["xyz_S_stop"], mask_1d) if k] # list[I] of [3]
xyz_planar = [v for v, k in zip(token_level_data["xyz_planar"], mask_1d) if k] # list[I] of [K_i, 3]
M_i = np.asarray(token_level_data["M_i"], dtype=np.float32)[mask_1d] # [I, 3]
frame_xyz = np.asarray(token_level_data["frame_xyz"], dtype=np.float32)[
mask_1d
] # [I, 3]
xyz_S_start = [
v for v, k in zip(token_level_data["xyz_S_start"], mask_1d) if k
] # list[I] of [3]
xyz_S_stop = [
v for v, k in zip(token_level_data["xyz_S_stop"], mask_1d) if k
] # list[I] of [3]
xyz_planar = [
v for v, k in zip(token_level_data["xyz_planar"], mask_1d) if k
] # list[I] of [K_i, 3]
hbond_count = np.asarray(hbond_count)[mask_1d, :][:, mask_1d] # [I, I, 3]
seq_neighbors = np.asarray(token_level_data["seq_neighbors"], dtype=bool)[mask_1d, :][:, mask_1d] # [I, I]
hbond_count = np.asarray(hbond_count)[mask_1d, :][:, mask_1d] # [I, I, 3]
seq_neighbors = np.asarray(token_level_data["seq_neighbors"], dtype=bool)[
mask_1d, :
][:, mask_1d] # [I, I]
# Nothing passed NA/planar filtering for this structure.
# Return empty outputs instead of failing downstream on np.stack([]).
@@ -1107,12 +1247,15 @@ def compute_nucleic_ss(
# --- Precompute centroids and displacement vectors -----------
planar_centers = np.stack( # [I, 3]
[np.nanmean(np.asarray(xyz_i, dtype=np.float32), axis=0) for xyz_i in xyz_planar],
[
np.nanmean(np.asarray(xyz_i, dtype=np.float32), axis=0)
for xyz_i in xyz_planar
],
axis=0,
).astype(np.float32)
frame_D_ij_vec = frame_xyz[None, :, :] - frame_xyz[:, None, :] # [I, I, 3]
sc_D_ij_vec = planar_centers[None, :, :] - planar_centers[:, None, :] # [I, I, 3]
frame_D_ij_vec = frame_xyz[None, :, :] - frame_xyz[:, None, :] # [I, I, 3]
sc_D_ij_vec = planar_centers[None, :, :] - planar_centers[:, None, :] # [I, I, 3]
# --- CALC I: per-residue local coordinate frames -------------
need_full_frame = return_local_params or return_opening_angle
@@ -1125,8 +1268,8 @@ def compute_nucleic_ss(
compute_full_frame=need_full_frame,
eps=eps,
)
Z_i = local_frames["Z_i"] # [I, 3]
X_i = local_frames.get("X_i") # [I, 3] or None
Z_i = local_frames["Z_i"] # [I, 3]
X_i = local_frames.get("X_i") # [I, 3] or None
# --- CALC II: pairwise base-step geometry --------------------
pw_geom = _compute_pairwise_geometry(
@@ -1187,7 +1330,6 @@ def compute_nucleic_ss(
return nucleic_ss_data
def annotate_na_ss(
atom_array: AtomArray,
*,
@@ -1258,10 +1400,10 @@ def annotate_na_ss(
NA_only=NA_only,
planar_only=planar_only,
)
# Note: this mask gives positions that are *chemically valid* for forming
# Note: this mask gives positions that are *chemically valid* for forming
# base pairs, which is different from custom mask-generation for features
mask_1d = np.asarray(token_level_data["filter_mask"], dtype=bool)
subset_idxs = np.nonzero(mask_1d)[0]
is_na_full = np.asarray(token_level_data["is_na"], dtype=bool)
@@ -1295,15 +1437,19 @@ def annotate_na_ss(
if planar_only:
n_tokens = bp_bool.shape[0]
has_planar_sc = np.asarray(
token_level_data.get("has_planar_sc", np.ones(n_tokens, dtype=bool)), dtype=bool
token_level_data.get("has_planar_sc", np.ones(n_tokens, dtype=bool)),
dtype=bool,
)
bp_bool &= has_planar_sc[:, None]
bp_bool &= has_planar_sc[None, :]
# Optional: filter to canonical Watson-Crick basepairs only.
# Sampled probabilistically to allow mixed supervision during training.
do_canonical_filter = bool(p_canonical_bp_filter and (np.random.rand() < float(p_canonical_bp_filter)))
do_canonical_filter = bool(
p_canonical_bp_filter and (np.random.rand() < float(p_canonical_bp_filter))
)
if do_canonical_filter:
def _base_letter(res_name: str) -> str | None:
rn = str(res_name).strip().upper()
if rn in STANDARD_RNA:
@@ -1313,11 +1459,16 @@ def annotate_na_ss(
return None
allowed_pairs = {
("A", "U"), ("U", "A"),
("A", "T"), ("T", "A"),
("G", "C"), ("C", "G"),
("A", "U"),
("U", "A"),
("A", "T"),
("T", "A"),
("G", "C"),
("C", "G"),
}
base_letters_full: list[str | None] = [_base_letter(rn) for rn in token_res_names]
base_letters_full: list[str | None] = [
_base_letter(rn) for rn in token_res_names
]
bp_bool = np.asarray(bp_bool, dtype=bool)
bp_rows_tmp, bp_cols_tmp = np.nonzero(bp_bool)
@@ -1401,7 +1552,9 @@ def annotate_na_ss(
continue
# A residue is treated as atomized if any atom in the residue carries atomize=True.
if "atomize" in atom_array.get_annotation_categories():
residue_is_atomized = bool(np.any(np.asarray(atom_array.atomize[int(start):stop], dtype=bool)))
residue_is_atomized = bool(
np.any(np.asarray(atom_array.atomize[int(start) : stop], dtype=bool))
)
else:
residue_is_atomized = False
if residue_is_atomized:

View File

@@ -216,7 +216,9 @@ def get_crop_transform(
), "Crop center cutoff distance must be greater than 0"
pre_crop_transforms = [
SubsampleToTypes(allowed_types=allowed_types, association_scheme=association_scheme),
SubsampleToTypes(
allowed_types=allowed_types, association_scheme=association_scheme
),
]
cropping_transform = RandomRoute(
@@ -361,11 +363,9 @@ def build_atom14_base_pipeline_(
max_ss_frac_to_provide: float,
min_ss_island_len: int,
max_ss_island_len: int,
## Nucleic acid features #####
#add_na_pair_features: bool,
# add_na_pair_features: bool,
## This should not be necessary, controlled through feature names in model, and meta conditioning probabilities, inference behavior handled in transform itself #####
**_, # dump additional kwargs (e.g. msa stuff)
):
"""
@@ -373,7 +373,7 @@ def build_atom14_base_pipeline_(
"""
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
# Add any data necessary for downstream transforms
transforms = [
AddData(
@@ -415,7 +415,7 @@ def build_atom14_base_pipeline_(
max_binder_length=max_binder_length,
max_atoms_in_crop=max_atoms_in_crop,
allowed_types=allowed_types,
association_scheme=association_scheme
association_scheme=association_scheme,
)
if zero_occ_on_exposure_after_cropping:
@@ -448,7 +448,7 @@ def build_atom14_base_pipeline_(
)
)
# Add nucleic acid geometry features
#if add_na_pair_features:
# if add_na_pair_features:
transforms.append(
CalculateNucleicAcidGeomFeats(
is_inference,
@@ -639,8 +639,8 @@ def build_atom14_base_pipeline(
kwargs.setdefault("min_ss_island_len", 0)
kwargs.setdefault("max_ss_island_len", 999)
kwargs.setdefault("max_binder_length", 999)
# This should not be necessary.
#kwargs.setdefault("add_na_pair_features", False)
# This should not be necessary.
# kwargs.setdefault("add_na_pair_features", False)
kwargs.setdefault("b_factor_min", None)
kwargs.setdefault("zero_occ_on_exposure_after_cropping", False)

View File

@@ -13,14 +13,13 @@ from atomworks.ml.utils.token import (
spread_token_wise,
)
from biotite.structure import AtomArray, get_residue_starts
from rfd3.constants import backbone_atoms_RNA
from rfd3.transforms.conditioning_utils import (
random_condition,
sample_island_tokens,
sample_subgraph_atoms,
)
from rfd3.constants import backbone_atoms_RNA
nx.from_numpy_matrix = nx.from_numpy_array
logger = logging.getLogger(__name__)
@@ -74,7 +73,7 @@ class IslandCondition(TrainingCondition):
p_fix_motif_coordinates,
p_fix_motif_sequence,
p_unindex_motif_tokens,
association_scheme = 'atom14',
association_scheme="atom14",
):
self.name = name
self.frequency = frequency
@@ -93,7 +92,7 @@ class IslandCondition(TrainingCondition):
self.p_fix_motif_coordinates = p_fix_motif_coordinates
self.p_fix_motif_sequence = p_fix_motif_sequence
self.p_unindex_motif_tokens = p_unindex_motif_tokens
self.association_scheme = association_scheme
def is_valid_for_example(self, data) -> bool:
@@ -107,7 +106,7 @@ class IslandCondition(TrainingCondition):
else:
if np.any(is_protein):
return True
return False
def sample_motif_tokens(self, atom_array):
@@ -162,7 +161,7 @@ class IslandCondition(TrainingCondition):
if random_condition(self.p_diffuse_motif_sidechains):
backbone_atoms = backbone_atoms_RNA.copy()
backbone_atoms.remove("C1'")
backbone_atoms = ["N", "C", "CA"] + backbone_atoms #covers DNA also
backbone_atoms = ["N", "C", "CA"] + backbone_atoms # covers DNA also
if random_condition(self.p_include_oxygen_in_backbone_mask):
backbone_atoms.append("O")
is_motif_atom = is_motif_atom & np.isin(
@@ -500,7 +499,9 @@ def sample_conditioning_strategy(
atom_array.set_annotation(
"is_motif_atom_unindexed",
sample_unindexed_atoms(
atom_array, p_unindex_motif_tokens=p_unindex_motif_tokens, association_scheme=association_scheme
atom_array,
p_unindex_motif_tokens=p_unindex_motif_tokens,
association_scheme=association_scheme,
),
)
return atom_array
@@ -526,7 +527,6 @@ def sample_is_motif_atom_with_fixed_seq(
is_motif_atom_with_fixed_seq = (
is_motif_atom_with_fixed_seq | ~atom_array.is_protein
)
return is_motif_atom_with_fixed_seq

View File

@@ -53,6 +53,7 @@ def map_to_association_scheme(
else:
return ATOM_NAMES[idxs]
def map_names_to_elements(
atom_names: list | str, default=VIRTUAL_ATOM_ELEMENT_NAME
) -> np.ndarray:
@@ -179,7 +180,7 @@ class PadTokensWithVirtualAtoms(Transform):
token_ids = np.unique(atom_array.token_id)
assert len(token_ids) == len(
is_motif_atom_with_fixed_seq
), "Token ids and token level array have different lengths!"
), "Token ids and token level array have different lengths!"
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
if self.association_scheme == "atom23":
@@ -230,14 +231,14 @@ class PadTokensWithVirtualAtoms(Transform):
else:
n_atoms_per_token = self.n_atoms_per_token
central_atom = self.atom_to_pad_from
n_pad = n_atoms_per_token - len(token)
if n_pad > 0:
mask = get_af3_token_representative_masks(
token, central_atom=central_atom
)
assert_single_representative(token, central_atom=central_atom)
# ... Create virtual atoms

View File

@@ -452,7 +452,10 @@ as input and return a three-element list or numpy array of floats.
def set_com(
atom_array, ori_token: list | None = None, infer_ori_strategy: str | None = None, ori_jitter: float | None = None
atom_array,
ori_token: list | None = None,
infer_ori_strategy: str | None = None,
ori_jitter: float | None = None,
):
if exists(ori_token):
center = np.array([float(x) for x in ori_token], dtype=atom_array.coord.dtype)
@@ -512,7 +515,7 @@ def set_com(
# Random length (mean ~ scale)
length = np.random.exponential(scale=scale)
jittered_offset = direction*length
jittered_offset = direction * length
atom_array.coord -= jittered_offset