mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
ruff format
This commit is contained in:
@@ -438,114 +438,254 @@ backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA)
|
||||
|
||||
# Mapping from residue type to its backbone and sidechain atoms (for convenience)
|
||||
ATOM_REGION_BY_RESI = {
|
||||
'ALA': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB')},
|
||||
'ARG': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD','NE','CZ','NH1','NH2')},
|
||||
'ASN': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','OD1','ND2')},
|
||||
'ASP': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','OD1','OD2')},
|
||||
'CYS': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','SG')},
|
||||
'GLN': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD','OE1','NE2')},
|
||||
'GLU': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD','OE1','OE2')},
|
||||
'GLY': {'bb':('N','CA','C','O'),
|
||||
'sc':()},
|
||||
'HIS': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','ND1','CD2','CE1','NE2')},
|
||||
'ILE': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG1','CG2','CD1')},
|
||||
'LEU': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD1','CD2')},
|
||||
'LYS': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD','CE','NZ')},
|
||||
'MET': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','SD','CE')},
|
||||
'PHE': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD1','CD2','CE1','CE2','CZ')},
|
||||
'PRO': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD')},
|
||||
'SER': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','OG')},
|
||||
'THR': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','OG1','CG2')},
|
||||
'TRP': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD1','CD2','CE2','CE3','NE1','CZ2','CZ3','CH2')},
|
||||
'TYR': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD1','CD2','CE1','CE2','CZ','OH')},
|
||||
'VAL': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG1','CG2')},
|
||||
'UNK': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB')},
|
||||
'MAS': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB')},
|
||||
'DA': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'sc':('N9','C4','N3','C2','N1','C6','C5','N7','C8','N6')},
|
||||
'DC': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'sc':('N1','C2','O2','N3','C4','N4','C5','C6')},
|
||||
'DG': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'sc':('N9','C4','N3','C2','N1','C6','C5','N7','C8','N2','O6')},
|
||||
'DT': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'sc':('N1','C2','O2','N3','C4','O4','C5','C7','C6')},
|
||||
'DX': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'sc':()},
|
||||
'A': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'sc':('N1','C2','N3','C4','C5','C6','N6','N7','C8','N9')},
|
||||
'C': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'sc':('N1','C2','O2','N3','C4','N4','C5','C6')},
|
||||
'G': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'sc':('N1','C2','N2','N3','C4','C5','C6','O6','N7','C8','N9')},
|
||||
'U': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'sc':('N1','C2','O2','N3','C4','O4','C5','C6')},
|
||||
'X': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'sc':()},
|
||||
'HIS_D': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','NE2','CD2','CE1','ND1')},
|
||||
"ALA": {"bb": ("N", "CA", "C", "O"), "sc": ("CB")},
|
||||
"ARG": {
|
||||
"bb": ("N", "CA", "C", "O"),
|
||||
"sc": ("CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"),
|
||||
},
|
||||
"ASN": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "OD1", "ND2")},
|
||||
"ASP": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "OD1", "OD2")},
|
||||
"CYS": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "SG")},
|
||||
"GLN": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD", "OE1", "NE2")},
|
||||
"GLU": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD", "OE1", "OE2")},
|
||||
"GLY": {"bb": ("N", "CA", "C", "O"), "sc": ()},
|
||||
"HIS": {
|
||||
"bb": ("N", "CA", "C", "O"),
|
||||
"sc": ("CB", "CG", "ND1", "CD2", "CE1", "NE2"),
|
||||
},
|
||||
"ILE": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG1", "CG2", "CD1")},
|
||||
"LEU": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD1", "CD2")},
|
||||
"LYS": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD", "CE", "NZ")},
|
||||
"MET": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "SD", "CE")},
|
||||
"PHE": {
|
||||
"bb": ("N", "CA", "C", "O"),
|
||||
"sc": ("CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"),
|
||||
},
|
||||
"PRO": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG", "CD")},
|
||||
"SER": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "OG")},
|
||||
"THR": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "OG1", "CG2")},
|
||||
"TRP": {
|
||||
"bb": ("N", "CA", "C", "O"),
|
||||
"sc": ("CB", "CG", "CD1", "CD2", "CE2", "CE3", "NE1", "CZ2", "CZ3", "CH2"),
|
||||
},
|
||||
"TYR": {
|
||||
"bb": ("N", "CA", "C", "O"),
|
||||
"sc": ("CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"),
|
||||
},
|
||||
"VAL": {"bb": ("N", "CA", "C", "O"), "sc": ("CB", "CG1", "CG2")},
|
||||
"UNK": {"bb": ("N", "CA", "C", "O"), "sc": ("CB")},
|
||||
"MAS": {"bb": ("N", "CA", "C", "O"), "sc": ("CB")},
|
||||
"DA": {
|
||||
"bb": (
|
||||
"O4'",
|
||||
"C1'",
|
||||
"C2'",
|
||||
"OP1",
|
||||
"P",
|
||||
"OP2",
|
||||
"O5'",
|
||||
"C5'",
|
||||
"C4'",
|
||||
"C3'",
|
||||
"O3'",
|
||||
),
|
||||
"sc": ("N9", "C4", "N3", "C2", "N1", "C6", "C5", "N7", "C8", "N6"),
|
||||
},
|
||||
"DC": {
|
||||
"bb": (
|
||||
"O4'",
|
||||
"C1'",
|
||||
"C2'",
|
||||
"OP1",
|
||||
"P",
|
||||
"OP2",
|
||||
"O5'",
|
||||
"C5'",
|
||||
"C4'",
|
||||
"C3'",
|
||||
"O3'",
|
||||
),
|
||||
"sc": ("N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"),
|
||||
},
|
||||
"DG": {
|
||||
"bb": (
|
||||
"O4'",
|
||||
"C1'",
|
||||
"C2'",
|
||||
"OP1",
|
||||
"P",
|
||||
"OP2",
|
||||
"O5'",
|
||||
"C5'",
|
||||
"C4'",
|
||||
"C3'",
|
||||
"O3'",
|
||||
),
|
||||
"sc": ("N9", "C4", "N3", "C2", "N1", "C6", "C5", "N7", "C8", "N2", "O6"),
|
||||
},
|
||||
"DT": {
|
||||
"bb": (
|
||||
"O4'",
|
||||
"C1'",
|
||||
"C2'",
|
||||
"OP1",
|
||||
"P",
|
||||
"OP2",
|
||||
"O5'",
|
||||
"C5'",
|
||||
"C4'",
|
||||
"C3'",
|
||||
"O3'",
|
||||
),
|
||||
"sc": ("N1", "C2", "O2", "N3", "C4", "O4", "C5", "C7", "C6"),
|
||||
},
|
||||
"DX": {
|
||||
"bb": (
|
||||
"O4'",
|
||||
"C1'",
|
||||
"C2'",
|
||||
"OP1",
|
||||
"P",
|
||||
"OP2",
|
||||
"O5'",
|
||||
"C5'",
|
||||
"C4'",
|
||||
"C3'",
|
||||
"O3'",
|
||||
),
|
||||
"sc": (),
|
||||
},
|
||||
"A": {
|
||||
"bb": (
|
||||
"O4'",
|
||||
"C1'",
|
||||
"C2'",
|
||||
"OP1",
|
||||
"P",
|
||||
"OP2",
|
||||
"O5'",
|
||||
"C5'",
|
||||
"C4'",
|
||||
"C3'",
|
||||
"O3'",
|
||||
"O2'",
|
||||
),
|
||||
"sc": ("N1", "C2", "N3", "C4", "C5", "C6", "N6", "N7", "C8", "N9"),
|
||||
},
|
||||
"C": {
|
||||
"bb": (
|
||||
"O4'",
|
||||
"C1'",
|
||||
"C2'",
|
||||
"OP1",
|
||||
"P",
|
||||
"OP2",
|
||||
"O5'",
|
||||
"C5'",
|
||||
"C4'",
|
||||
"C3'",
|
||||
"O3'",
|
||||
"O2'",
|
||||
),
|
||||
"sc": ("N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"),
|
||||
},
|
||||
"G": {
|
||||
"bb": (
|
||||
"O4'",
|
||||
"C1'",
|
||||
"C2'",
|
||||
"OP1",
|
||||
"P",
|
||||
"OP2",
|
||||
"O5'",
|
||||
"C5'",
|
||||
"C4'",
|
||||
"C3'",
|
||||
"O3'",
|
||||
"O2'",
|
||||
),
|
||||
"sc": ("N1", "C2", "N2", "N3", "C4", "C5", "C6", "O6", "N7", "C8", "N9"),
|
||||
},
|
||||
"U": {
|
||||
"bb": (
|
||||
"O4'",
|
||||
"C1'",
|
||||
"C2'",
|
||||
"OP1",
|
||||
"P",
|
||||
"OP2",
|
||||
"O5'",
|
||||
"C5'",
|
||||
"C4'",
|
||||
"C3'",
|
||||
"O3'",
|
||||
"O2'",
|
||||
),
|
||||
"sc": ("N1", "C2", "O2", "N3", "C4", "O4", "C5", "C6"),
|
||||
},
|
||||
"X": {
|
||||
"bb": (
|
||||
"O4'",
|
||||
"C1'",
|
||||
"C2'",
|
||||
"OP1",
|
||||
"P",
|
||||
"OP2",
|
||||
"O5'",
|
||||
"C5'",
|
||||
"C4'",
|
||||
"C3'",
|
||||
"O3'",
|
||||
"O2'",
|
||||
),
|
||||
"sc": (),
|
||||
},
|
||||
"HIS_D": {
|
||||
"bb": ("N", "CA", "C", "O"),
|
||||
"sc": ("CB", "CG", "NE2", "CD2", "CE1", "ND1"),
|
||||
},
|
||||
}
|
||||
# Known planar sidechain atoms for each canonical residue type:
|
||||
PLANAR_ATOMS_BY_RESI = {
|
||||
'ALA': [],
|
||||
'ARG': ['NH1', 'NH2', 'CZ', 'NE', 'CD'],
|
||||
'ASN': ['OD1', 'ND2', 'CG', 'CB'],
|
||||
'ASP': ['OD1', 'OD2', 'CG', 'CB'],
|
||||
'CYS': [],
|
||||
'GLN': ['OE1', 'NE2', 'CD', 'CG'],
|
||||
'GLU': ['OE1', 'OE2', 'CD', 'CG'],
|
||||
'GLY': [],
|
||||
'HIS': ['ND1', 'CE1', 'NE2', 'CD2', 'CG', 'CB'],
|
||||
'ILE': [],
|
||||
'LEU': [],
|
||||
'LYS': [],
|
||||
'MET': [],
|
||||
'PHE': ['CZ', 'CE1', 'CE2', 'CD1', 'CD2', 'CG', 'CB'],
|
||||
'PRO': [],
|
||||
'SER': [],
|
||||
'THR': [],
|
||||
'TRP': ['CH2', 'CZ3', 'CZ2', 'CE3', 'CE2', 'CD2', 'NE1', 'CD1', 'CG', 'CB'],
|
||||
'TYR': ['OH', 'CZ', 'CE1', 'CE2', 'CD1', 'CD2', 'CG', 'CB'],
|
||||
'VAL': [],
|
||||
'UNK': [],
|
||||
'MAS': [],
|
||||
'DA': ['N6', 'C6', 'N1', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'],
|
||||
'DC': ['N4', 'C4', 'N3', 'O2', 'C2', 'C5', 'C6', 'N1'],
|
||||
'DG': ['O6', 'C6', 'N1', 'N2', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'],
|
||||
'DT': ['O4', 'O2', 'N3', 'C4', 'C2', 'C5', 'C6', 'N1', 'C7'],
|
||||
'DX': [],
|
||||
'A': ['N6', 'C6', 'N1', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'],
|
||||
'C': ['N4', 'C4', 'N3', 'O2', 'C2', 'C5', 'C6', 'N1'],
|
||||
'G': ['O6', 'C6', 'N1', 'N2', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'],
|
||||
'U': ['O4', 'O2', 'N3', 'C4', 'C2', 'C5', 'C6', 'N1'],
|
||||
'X': [],
|
||||
'HIS_D': ['ND1', 'CD2', 'CE1', 'NE2', 'CG', 'CB'],
|
||||
}
|
||||
"ALA": [],
|
||||
"ARG": ["NH1", "NH2", "CZ", "NE", "CD"],
|
||||
"ASN": ["OD1", "ND2", "CG", "CB"],
|
||||
"ASP": ["OD1", "OD2", "CG", "CB"],
|
||||
"CYS": [],
|
||||
"GLN": ["OE1", "NE2", "CD", "CG"],
|
||||
"GLU": ["OE1", "OE2", "CD", "CG"],
|
||||
"GLY": [],
|
||||
"HIS": ["ND1", "CE1", "NE2", "CD2", "CG", "CB"],
|
||||
"ILE": [],
|
||||
"LEU": [],
|
||||
"LYS": [],
|
||||
"MET": [],
|
||||
"PHE": ["CZ", "CE1", "CE2", "CD1", "CD2", "CG", "CB"],
|
||||
"PRO": [],
|
||||
"SER": [],
|
||||
"THR": [],
|
||||
"TRP": ["CH2", "CZ3", "CZ2", "CE3", "CE2", "CD2", "NE1", "CD1", "CG", "CB"],
|
||||
"TYR": ["OH", "CZ", "CE1", "CE2", "CD1", "CD2", "CG", "CB"],
|
||||
"VAL": [],
|
||||
"UNK": [],
|
||||
"MAS": [],
|
||||
"DA": ["N6", "C6", "N1", "C2", "N3", "C4", "C5", "N7", "C8", "N9"],
|
||||
"DC": ["N4", "C4", "N3", "O2", "C2", "C5", "C6", "N1"],
|
||||
"DG": ["O6", "C6", "N1", "N2", "C2", "N3", "C4", "C5", "N7", "C8", "N9"],
|
||||
"DT": ["O4", "O2", "N3", "C4", "C2", "C5", "C6", "N1", "C7"],
|
||||
"DX": [],
|
||||
"A": ["N6", "C6", "N1", "C2", "N3", "C4", "C5", "N7", "C8", "N9"],
|
||||
"C": ["N4", "C4", "N3", "O2", "C2", "C5", "C6", "N1"],
|
||||
"G": ["O6", "C6", "N1", "N2", "C2", "N3", "C4", "C5", "N7", "C8", "N9"],
|
||||
"U": ["O4", "O2", "N3", "C4", "C2", "C5", "C6", "N1"],
|
||||
"X": [],
|
||||
"HIS_D": ["ND1", "CD2", "CE1", "NE2", "CG", "CB"],
|
||||
}
|
||||
|
||||
# fix C/U symmetry
|
||||
temp = list(association_schemes['atom23']['U'])
|
||||
temp = list(association_schemes["atom23"]["U"])
|
||||
temp[19], temp[20] = temp[20], temp[19]
|
||||
association_schemes['atom23']['U'] = tuple(temp)
|
||||
association_schemes["atom23"]["U"] = tuple(temp)
|
||||
|
||||
association_schemes_stripped = {
|
||||
name: {k: strip_list(v) for k, v in scheme.items()}
|
||||
@@ -553,4 +693,6 @@ association_schemes_stripped = {
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pdb; pdb.set_trace()
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
|
||||
@@ -119,8 +119,8 @@ class DesignInputSpecification(BaseModel):
|
||||
validate_assignment=False,
|
||||
str_strip_whitespace=True,
|
||||
str_min_length=1,
|
||||
#extra="forbid", ####################################################
|
||||
extra="allow"
|
||||
# extra="forbid", ####################################################
|
||||
extra="allow",
|
||||
## for now allowing extra for rfd3na-ss purposes, can decide later ##
|
||||
)
|
||||
# fmt: off
|
||||
@@ -497,7 +497,10 @@ class DesignInputSpecification(BaseModel):
|
||||
aa.is_motif_atom_with_fixed_seq[start:end] = np.full_like(
|
||||
is_bkbn, False, dtype=int
|
||||
)
|
||||
elif aa.res_name[start] in (STANDARD_DNA + STANDARD_RNA) and self.redesign_motif_sidechains:
|
||||
elif (
|
||||
aa.res_name[start] in (STANDARD_DNA + STANDARD_RNA)
|
||||
and self.redesign_motif_sidechains
|
||||
):
|
||||
is_bkbn = np.isin(aa.atom_name[start:end], backbone_atoms_RNA)
|
||||
aa.is_motif_atom_with_fixed_coord[start:end] = is_bkbn.astype(int)
|
||||
aa.is_motif_atom_with_fixed_seq[start:end] = np.full_like(
|
||||
@@ -519,7 +522,9 @@ class DesignInputSpecification(BaseModel):
|
||||
|
||||
########## reorder NA atoms ###########
|
||||
if exists(atom_array_input_annotated):
|
||||
is_dna = np.isin(atom_array_input_annotated.res_name, ["DA", "DC", "DG", "DT"])
|
||||
is_dna = np.isin(
|
||||
atom_array_input_annotated.res_name, ["DA", "DC", "DG", "DT"]
|
||||
)
|
||||
is_rna = np.isin(atom_array_input_annotated.res_name, ["A", "C", "G", "U"])
|
||||
dna_array = atom_array_input_annotated[is_dna]
|
||||
rna_array = atom_array_input_annotated[is_rna]
|
||||
@@ -726,11 +731,11 @@ class DesignInputSpecification(BaseModel):
|
||||
ligand_array.set_annotation(
|
||||
annot, np.full(ligand_array.array_length(), default)
|
||||
)
|
||||
|
||||
chain_cand = 'X'
|
||||
|
||||
chain_cand = "X"
|
||||
while chain_cand in atom_array.chain_id.tolist():
|
||||
chain_cand = chain_cand + chain_cand
|
||||
ligand_chain = np.array([chain_cand]*len(ligand_array))
|
||||
ligand_chain = np.array([chain_cand] * len(ligand_array))
|
||||
ligand_array.chain_id = ligand_chain
|
||||
|
||||
atom_array = atom_array + ligand_array
|
||||
@@ -758,9 +763,12 @@ class DesignInputSpecification(BaseModel):
|
||||
)
|
||||
else:
|
||||
if not exists(self.ori_jitter):
|
||||
self.ori_jitter = None
|
||||
self.ori_jitter = None
|
||||
atom_array = set_com(
|
||||
atom_array, ori_token=None, infer_ori_strategy="com", ori_jitter=self.ori_jitter
|
||||
atom_array,
|
||||
ori_token=None,
|
||||
infer_ori_strategy="com",
|
||||
ori_jitter=self.ori_jitter,
|
||||
)
|
||||
else:
|
||||
# Standard: set ori token, zero out diffused atoms
|
||||
|
||||
@@ -220,7 +220,6 @@ def fetch_motif_residue_(
|
||||
"Please check your input and contig specification."
|
||||
)
|
||||
|
||||
|
||||
if unfix_all or f"{src_chain}{src_resid}" in unfix_residues:
|
||||
subarray.set_annotation(
|
||||
"is_motif_atom_with_fixed_coord", np.zeros(subarray.shape[0], dtype=int)
|
||||
@@ -232,7 +231,7 @@ def fetch_motif_residue_(
|
||||
)
|
||||
else:
|
||||
subarray.set_annotation(
|
||||
"is_motif_atom_with_fixed_coord", np.array([True]*len(subarray))
|
||||
"is_motif_atom_with_fixed_coord", np.array([True] * len(subarray))
|
||||
)
|
||||
|
||||
if flexible_backbone:
|
||||
@@ -255,9 +254,9 @@ def fetch_motif_residue_(
|
||||
subarray = subarray[subarray.is_motif_atom.astype(bool)]
|
||||
else:
|
||||
subarray.set_annotation(
|
||||
"is_motif_atom_unindexed", np.array([True]*len(subarray))
|
||||
"is_motif_atom_unindexed", np.array([True] * len(subarray))
|
||||
)
|
||||
|
||||
|
||||
# ... Relax sequence constraint if provided
|
||||
if (
|
||||
exists(unfixed_sequence_components)
|
||||
@@ -278,32 +277,35 @@ def fetch_motif_residue_(
|
||||
return subarray
|
||||
|
||||
|
||||
def create_diffused_residues_(n, polymer_type='p'):
|
||||
def create_diffused_residues_(n, polymer_type="p"):
|
||||
from rfd3.constants import (
|
||||
ATOM23_ATOM_NAME_TO_ELEMENT,
|
||||
backbone_atoms_DNA,
|
||||
backbone_atoms_RNA,
|
||||
)
|
||||
|
||||
if n <= 0:
|
||||
raise ValueError(f"Negative/null residue count ({n}) not allowed.")
|
||||
|
||||
if polymer_type == 'P':
|
||||
res_name = 'ALA'
|
||||
|
||||
if polymer_type == "P":
|
||||
res_name = "ALA"
|
||||
bb_len = 5
|
||||
bb_atom_names = ["N", "CA", "C", "O", "CB"]
|
||||
elif polymer_type == 'R':
|
||||
res_name = 'A'
|
||||
elif polymer_type == "R":
|
||||
res_name = "A"
|
||||
bb_len = len(backbone_atoms_RNA)
|
||||
bb_atom_names = backbone_atoms_RNA
|
||||
elif polymer_type == 'D':
|
||||
res_name = 'DA'
|
||||
elif polymer_type == "D":
|
||||
res_name = "DA"
|
||||
bb_len = len(backbone_atoms_DNA)
|
||||
bb_atom_names = backbone_atoms_DNA
|
||||
else:
|
||||
raise ValueError(f"invalid polymer type detected: {polymer_type}, check contig!")
|
||||
|
||||
raise ValueError(
|
||||
f"invalid polymer type detected: {polymer_type}, check contig!"
|
||||
)
|
||||
|
||||
bb_elements = [ATOM23_ATOM_NAME_TO_ELEMENT[item] for item in bb_atom_names]
|
||||
|
||||
|
||||
atoms = []
|
||||
[
|
||||
atoms.extend(
|
||||
@@ -319,16 +321,13 @@ def create_diffused_residues_(n, polymer_type='p'):
|
||||
for idx in range(1, n + 1)
|
||||
]
|
||||
array = struc.array(atoms)
|
||||
array.set_annotation(
|
||||
"element", np.array(bb_elements * n, dtype="<U2")
|
||||
)
|
||||
array.set_annotation(
|
||||
"atom_name", np.array(bb_atom_names * n, dtype="<U3")
|
||||
)
|
||||
array.set_annotation("element", np.array(bb_elements * n, dtype="<U2"))
|
||||
array.set_annotation("atom_name", np.array(bb_atom_names * n, dtype="<U3"))
|
||||
array = set_default_conditioning_annotations(array, motif=False)
|
||||
array = set_common_annotations(array)
|
||||
return array
|
||||
|
||||
|
||||
def accumulate_components(
|
||||
components,
|
||||
src_atom_array,
|
||||
@@ -578,8 +577,12 @@ def create_atom_array_from_design_specification_legacy(
|
||||
dna_array = atom_array_input[is_dna]
|
||||
rna_array = atom_array_input[is_rna]
|
||||
|
||||
atom_array_input[is_dna] = reorder_atoms_per_residue(dna_array, backbone_atoms_DNA)
|
||||
atom_array_input[is_rna] = reorder_atoms_per_residue(rna_array, backbone_atoms_RNA)
|
||||
atom_array_input[is_dna] = reorder_atoms_per_residue(
|
||||
dna_array, backbone_atoms_DNA
|
||||
)
|
||||
atom_array_input[is_rna] = reorder_atoms_per_residue(
|
||||
rna_array, backbone_atoms_RNA
|
||||
)
|
||||
#######################################
|
||||
|
||||
if exists(atomwise_rasa):
|
||||
@@ -736,10 +739,10 @@ def create_atom_array_from_design_specification_legacy(
|
||||
+ np.max(atom_array.res_id)
|
||||
+ 1
|
||||
)
|
||||
chain_cand = 'X'
|
||||
chain_cand = "X"
|
||||
while chain_cand in atom_array.chain_id.tolist():
|
||||
chain_cand = chain_cand + chain_cand
|
||||
ligand_chain = np.array([chain_cand]*len(ligand_array))
|
||||
ligand_chain = np.array([chain_cand] * len(ligand_array))
|
||||
ligand_array.chain_id = ligand_chain
|
||||
|
||||
atom_array = atom_array + ligand_array
|
||||
|
||||
@@ -4,6 +4,7 @@ from atomworks.ml.utils.token import (
|
||||
get_token_starts,
|
||||
)
|
||||
from beartype.typing import Any
|
||||
from rfd3.constants import backbone_atoms_RNA
|
||||
from rfd3.metrics.metrics_utils import (
|
||||
_flatten_dict,
|
||||
get_hotspot_contacts,
|
||||
@@ -12,10 +13,10 @@ from rfd3.metrics.metrics_utils import (
|
||||
|
||||
from foundry.common import exists
|
||||
from foundry.metrics.metric import Metric
|
||||
from rfd3.constants import backbone_atoms_RNA
|
||||
|
||||
STANDARD_CACA_DIST = 3.8
|
||||
STANDARD_P_P_DISTANCE = 6.4 ## average of B and A form 7 and 5.9
|
||||
STANDARD_P_P_DISTANCE = 6.4 ## average of B and A form 7 and 5.9
|
||||
|
||||
|
||||
def get_clash_metrics(
|
||||
atom_array,
|
||||
@@ -35,11 +36,11 @@ def get_clash_metrics(
|
||||
elif "P" in atom_array.atom_name:
|
||||
ca_atoms = atom_array[atom_array.atom_name == "P"]
|
||||
cut_off = STANDARD_P_P_DISTANCE
|
||||
|
||||
|
||||
xyz = ca_atoms.coord
|
||||
xyz = torch.from_numpy(xyz)
|
||||
ca_dists = torch.norm(xyz[1:] - xyz[:-1], dim=-1)
|
||||
deviation = torch.abs(ca_dists - STANDARD_CACA_DIST)
|
||||
deviation = torch.abs(ca_dists - cut_off)
|
||||
|
||||
# Allow leniency for expected chain breaks (e.g. PPI)
|
||||
chain_breaks = ca_atoms.chain_iid[1:] != ca_atoms.chain_iid[:-1]
|
||||
@@ -52,7 +53,9 @@ def get_clash_metrics(
|
||||
}
|
||||
|
||||
def get_interresidue_clashes(backbone_only=False):
|
||||
protein_array = atom_array[atom_array.is_protein | atom_array.is_dna | atom_array.is_rna]
|
||||
protein_array = atom_array[
|
||||
atom_array.is_protein | atom_array.is_dna | atom_array.is_rna
|
||||
]
|
||||
resid = protein_array.res_id - protein_array.res_id.min()
|
||||
xyz = protein_array.coord
|
||||
dists = np.linalg.norm(xyz[:, None] - xyz[None], axis=-1) # N_atoms x N_atoms
|
||||
@@ -65,7 +68,9 @@ def get_clash_metrics(
|
||||
|
||||
if backbone_only:
|
||||
# Block out non-backbone atoms
|
||||
backbone_mask = np.isin(protein_array.atom_name, ["N", "CA", "C"] + backbone_atoms_RNA)
|
||||
backbone_mask = np.isin(
|
||||
protein_array.atom_name, ["N", "CA", "C"] + backbone_atoms_RNA
|
||||
)
|
||||
mask = backbone_mask[:, None] & backbone_mask[None, :]
|
||||
dists[~mask] = 999
|
||||
|
||||
@@ -374,7 +379,7 @@ class BackboneMetrics(Metric):
|
||||
is_protein = is_protein[diffused_region]
|
||||
is_dna = is_dna[diffused_region]
|
||||
is_rna = is_rna[diffused_region]
|
||||
protein_idx_mask = is_ca & (is_protein)
|
||||
protein_idx_mask = is_ca & (is_protein)
|
||||
na_idx_mask = is_ca & (is_rna | is_dna)
|
||||
|
||||
if self.compute_for_diffused_region_only:
|
||||
@@ -384,30 +389,43 @@ class BackboneMetrics(Metric):
|
||||
xyz_protein = X_L.cpu()[:, protein_idx_mask]
|
||||
xyz_na = X_L.cpu()[:, na_idx_mask]
|
||||
|
||||
ca_dists_protein = torch.norm(xyz_protein[:, 1:] - xyz_protein[:, :-1], dim=-1)
|
||||
ca_dists_protein = torch.norm(
|
||||
xyz_protein[:, 1:] - xyz_protein[:, :-1], dim=-1
|
||||
)
|
||||
ca_dists_na = torch.norm(xyz_na[:, 1:] - xyz_na[:, :-1], dim=-1)
|
||||
|
||||
deviation_protein = torch.abs(ca_dists_protein - self.standard_ca_dist) # B, (I-1)
|
||||
|
||||
deviation_protein = torch.abs(
|
||||
ca_dists_protein - self.standard_ca_dist
|
||||
) # B, (I-1)
|
||||
deviation_na = torch.abs(ca_dists_na - self.standard_PP_dist) # B, (I-1)
|
||||
is_chainbreak_protein = deviation_protein > 0.75
|
||||
is_chainbreak_na = deviation_na > 1
|
||||
|
||||
try:
|
||||
o["max_ca_deviation_protein"] = float(deviation_protein.max(-1).values.mean())
|
||||
o["fraction_chainbreaks_protein"] = float(is_chainbreak_protein.float().mean(-1).mean())
|
||||
o["n_chainbreaks_protein"] = float(is_chainbreak_protein.float().sum(-1).mean())
|
||||
except:
|
||||
o["max_ca_deviation_protein"] = float(
|
||||
deviation_protein.max(-1).values.mean()
|
||||
)
|
||||
o["fraction_chainbreaks_protein"] = float(
|
||||
is_chainbreak_protein.float().mean(-1).mean()
|
||||
)
|
||||
o["n_chainbreaks_protein"] = float(
|
||||
is_chainbreak_protein.float().sum(-1).mean()
|
||||
)
|
||||
except Exception:
|
||||
print("No protein in this example, skipping protein chainbreak metrics")
|
||||
|
||||
try:
|
||||
o["max_ca_deviation_na"] = float(deviation_na.max(-1).values.mean())
|
||||
o["fraction_chainbreaks_na"] = float(is_chainbreak_na.float().mean(-1).mean())
|
||||
o["fraction_chainbreaks_na"] = float(
|
||||
is_chainbreak_na.float().mean(-1).mean()
|
||||
)
|
||||
o["n_chainbreaks_na"] = float(is_chainbreak_na.float().sum(-1).mean())
|
||||
except:
|
||||
except Exception:
|
||||
print("No NA in this example, skipping NA chainbreak metrics")
|
||||
|
||||
return o
|
||||
|
||||
|
||||
class PPIMetrics(Metric):
|
||||
"""PPI-specific metrics"""
|
||||
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import logging
|
||||
|
||||
import bdb
|
||||
import numpy as np
|
||||
from biotite.structure import AtomArray
|
||||
from atomworks.ml.utils.token import (
|
||||
get_token_starts,
|
||||
)
|
||||
|
||||
from biotite.structure import AtomArray
|
||||
from rfd3.trainer.trainer_utils import (
|
||||
_cleanup_virtual_atoms_and_assign_atom_name_elements,
|
||||
_readout_seq_from_struc,
|
||||
)
|
||||
from rfd3.transforms.na_geom_utils import annotate_na_ss
|
||||
|
||||
from foundry.metrics.metric import Metric
|
||||
from foundry.utils.ddp import RankedLogger
|
||||
|
||||
from rfd3.trainer.trainer_utils import _readout_seq_from_struc, _cleanup_virtual_atoms_and_assign_atom_name_elements
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
global_logger = RankedLogger(__name__, rank_zero_only=False)
|
||||
|
||||
@@ -70,7 +70,9 @@ def _get_candidate_token_ids(
|
||||
if hasattr(token_level_array, "is_dna")
|
||||
else np.zeros(len(token_ids), dtype=bool)
|
||||
)
|
||||
token_mask &= (is_rna | is_dna) if (is_rna.any() or is_dna.any()) else token_mask
|
||||
token_mask &= (
|
||||
(is_rna | is_dna) if (is_rna.any() or is_dna.any()) else token_mask
|
||||
)
|
||||
|
||||
if compute_for_diffused_region_only:
|
||||
if hasattr(token_level_array, "is_motif_atom"):
|
||||
@@ -175,7 +177,10 @@ def _extract_loop_and_paired_token_ids(
|
||||
|
||||
return loop_token_ids, paired_token_ids
|
||||
|
||||
def compute_from_two_arr(gt_arr, pred_arr, restrict_to_nucleic=True, compute_for_diffused_region_only = False):
|
||||
|
||||
def compute_from_two_arr(
|
||||
gt_arr, pred_arr, restrict_to_nucleic=True, compute_for_diffused_region_only=False
|
||||
):
|
||||
gt_token_ids = _get_token_ids(gt_arr)
|
||||
pred_token_ids = _get_token_ids(pred_arr)
|
||||
if len(gt_token_ids) != len(pred_token_ids):
|
||||
@@ -228,26 +233,25 @@ def compute_from_two_arr(gt_arr, pred_arr, restrict_to_nucleic=True, compute_for
|
||||
(pair_weight * pair_f1 + loop_weight * loop_f1) / total_weight
|
||||
)
|
||||
|
||||
|
||||
return pair_f1, loop_f1, weighted_f1
|
||||
|
||||
|
||||
def get_NA_SS_F1(pred_array):
|
||||
|
||||
## save the original bop_partner annotation
|
||||
gt_array = pred_array.copy()
|
||||
|
||||
## replace by annotating again
|
||||
pred_array = annotate_na_ss(
|
||||
pred_array,
|
||||
NA_only=True,
|
||||
planar_only=True,
|
||||
overwrite=True,
|
||||
p_canonical_bp_filter=0.0,
|
||||
)
|
||||
|
||||
pred_array,
|
||||
NA_only=True,
|
||||
planar_only=True,
|
||||
overwrite=True,
|
||||
p_canonical_bp_filter=0.0,
|
||||
)
|
||||
|
||||
try:
|
||||
pair_f1, loop_f1, weighted_f1 = compute_from_two_arr(gt_array, pred_array)
|
||||
except:
|
||||
except Exception:
|
||||
# fails when returns None because expects three returns
|
||||
return {}
|
||||
|
||||
@@ -261,13 +265,13 @@ def get_NA_SS_F1(pred_array):
|
||||
class NucleicSSSimilarityMetrics(Metric):
|
||||
"""Secondary-structure similarity for nucleic acids.
|
||||
|
||||
Reports:
|
||||
- `pair_f1`: F1 over basepair edges from token-level bp-partner annotation.
|
||||
- `loop_f1`: F1 over explicitly-unpaired loop tokens (`bp_partners == []`).
|
||||
Unannotated tokens (`bp_partners is None`) are masked.
|
||||
- `weighted_f1`: GT-weighted average of `pair_f1` and `loop_f1`, weighted by
|
||||
the prevalence of paired vs loop tokens in the GT.
|
||||
"""
|
||||
Reports:
|
||||
- `pair_f1`: F1 over basepair edges from token-level bp-partner annotation.
|
||||
- `loop_f1`: F1 over explicitly-unpaired loop tokens (`bp_partners == []`).
|
||||
Unannotated tokens (`bp_partners is None`) are masked.
|
||||
- `weighted_f1`: GT-weighted average of `pair_f1` and `loop_f1`, weighted by
|
||||
the prevalence of paired vs loop tokens in the GT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -302,7 +306,9 @@ class NucleicSSSimilarityMetrics(Metric):
|
||||
|
||||
n_valid = 0
|
||||
|
||||
for gt_arr, pred_arr in zip(ground_truth_atom_array_stack, predicted_atom_array_stack):
|
||||
for gt_arr, pred_arr in zip(
|
||||
ground_truth_atom_array_stack, predicted_atom_array_stack
|
||||
):
|
||||
gt_categories = gt_arr.get_annotation_categories()
|
||||
if "bp_partners" not in gt_categories:
|
||||
continue
|
||||
@@ -326,14 +332,14 @@ class NucleicSSSimilarityMetrics(Metric):
|
||||
pred_arr,
|
||||
association_scheme="atom23",
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
# this can fail early in training
|
||||
print("could not cleanup virtuals for nucleic ss metric compute")
|
||||
pass
|
||||
# clear annotation to avoid potential info leak
|
||||
if "bp_partners" in pred_arr.get_annotation_categories():
|
||||
pred_arr.del_annotation("bp_partners")
|
||||
|
||||
|
||||
# add nucleic-ss annotations
|
||||
annotate_na_ss(
|
||||
pred_arr,
|
||||
@@ -348,8 +354,13 @@ class NucleicSSSimilarityMetrics(Metric):
|
||||
|
||||
# Basic sanity check: token counts should match for aligned comparisons
|
||||
try:
|
||||
pair_f1, loop_f1, weighted_f1 = compute_from_two_arr(gt_arr, pred_arr, restrict_to_nucleic=self.restrict_to_nucleic, compute_for_diffused_region_only = self.compute_for_diffused_region_only)
|
||||
except:
|
||||
pair_f1, loop_f1, weighted_f1 = compute_from_two_arr(
|
||||
gt_arr,
|
||||
pred_arr,
|
||||
restrict_to_nucleic=self.restrict_to_nucleic,
|
||||
compute_for_diffused_region_only=self.compute_for_diffused_region_only,
|
||||
)
|
||||
except Exception:
|
||||
# fails when returns None because expects three returns
|
||||
continue
|
||||
|
||||
@@ -360,7 +371,7 @@ class NucleicSSSimilarityMetrics(Metric):
|
||||
|
||||
if n_valid == 0:
|
||||
return {}
|
||||
|
||||
|
||||
return {
|
||||
"pair_f1": float(np.mean(pair_f1_list)),
|
||||
"loop_f1": float(np.mean(loop_f1_list)),
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from foundry.metrics.metric import Metric
|
||||
from foundry.utils.ddp import RankedLogger
|
||||
@@ -8,9 +9,6 @@ logging.basicConfig(level=logging.INFO)
|
||||
global_logger = RankedLogger(__name__, rank_zero_only=False)
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def calculate_ligand_contacts(
|
||||
atom_array_stack,
|
||||
cutoff_distance=4.0,
|
||||
@@ -31,13 +29,12 @@ def calculate_ligand_contacts(
|
||||
mean_contacts_per_model : float
|
||||
"""
|
||||
|
||||
cutoff_sq = cutoff_distance ** 2
|
||||
cutoff_sq = cutoff_distance**2
|
||||
contacts_per_model = []
|
||||
|
||||
n_models = len(atom_array_stack)
|
||||
|
||||
for i in range(n_models):
|
||||
|
||||
atoms = atom_array_stack[i]
|
||||
|
||||
coords = atoms.coord
|
||||
@@ -57,7 +54,7 @@ def calculate_ligand_contacts(
|
||||
|
||||
# Pairwise squared distances
|
||||
diff = non_ligand_coords[:, None, :] - ligand_coords[None, :, :]
|
||||
dist_sq = np.sum(diff ** 2, axis=-1)
|
||||
dist_sq = np.sum(diff**2, axis=-1)
|
||||
|
||||
# Any ligand within cutoff
|
||||
contact_mask = np.any(dist_sq < cutoff_sq, axis=1)
|
||||
@@ -67,7 +64,11 @@ def calculate_ligand_contacts(
|
||||
|
||||
contacts_per_model = np.array(contacts_per_model)
|
||||
|
||||
return int(np.sum(contacts_per_model)), float(np.mean(contacts_per_model)), float(np.mean(contacts_per_model))/hetero_mask.sum()
|
||||
return (
|
||||
int(np.sum(contacts_per_model)),
|
||||
float(np.mean(contacts_per_model)),
|
||||
float(np.mean(contacts_per_model)) / hetero_mask.sum(),
|
||||
)
|
||||
|
||||
|
||||
class LigandContactMetrics(Metric):
|
||||
@@ -89,12 +90,18 @@ class LigandContactMetrics(Metric):
|
||||
|
||||
def compute(self, *, predicted_atom_array_stack):
|
||||
if self.restrict_to_nucleic:
|
||||
if (predicted_atom_array_stack[0].is_rna.sum() + predicted_atom_array_stack[0].is_dna.sum()== 0):
|
||||
if (
|
||||
predicted_atom_array_stack[0].is_rna.sum()
|
||||
+ predicted_atom_array_stack[0].is_dna.sum()
|
||||
== 0
|
||||
):
|
||||
return {}
|
||||
try:
|
||||
total_contacts, mean_contacts, mean_contacts_per_atom = calculate_ligand_contacts(
|
||||
atom_array_stack=predicted_atom_array_stack,
|
||||
cutoff_distance=self.cutoff_distance,
|
||||
total_contacts, mean_contacts, mean_contacts_per_atom = (
|
||||
calculate_ligand_contacts(
|
||||
atom_array_stack=predicted_atom_array_stack,
|
||||
cutoff_distance=self.cutoff_distance,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
global_logger.error(
|
||||
@@ -106,4 +113,3 @@ class LigandContactMetrics(Metric):
|
||||
"mean_ligand_contacts_per_model": float(mean_contacts),
|
||||
"mean_ligand_contacts_per_atom": float(mean_contacts_per_atom),
|
||||
}
|
||||
|
||||
|
||||
@@ -63,8 +63,8 @@ def strip_f(
|
||||
)
|
||||
else:
|
||||
## for bp_partners default is a mask feature
|
||||
v_cropped[:,:,0] = 1
|
||||
v_cropped[:,:,1:] = 0
|
||||
v_cropped[:, :, 0] = 1
|
||||
v_cropped[:, :, 1:] = 0
|
||||
|
||||
# update the feature in the dictionary
|
||||
f_stripped[k] = v_cropped
|
||||
|
||||
@@ -210,7 +210,9 @@ def create_attention_indices(
|
||||
chain_ids is not None and len(torch.unique(chain_ids)) > 3
|
||||
): # Multi-chain structure
|
||||
# Reserve 25% of attention keys for inter-chain interactions
|
||||
k_inter_chain = min(max(32, k_actual // 4), k_actual) # At least 32 inter-chain keys
|
||||
k_inter_chain = min(
|
||||
max(32, k_actual // 4), k_actual
|
||||
) # At least 32 inter-chain keys
|
||||
k_intra_chain = k_actual - k_inter_chain
|
||||
|
||||
attn_indices = get_sparse_attention_indices_with_inter_chain(
|
||||
|
||||
@@ -143,6 +143,7 @@ class OneDFeatureEmbedder(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TwoDFeatureEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds 2D features into a single vector.
|
||||
@@ -164,18 +165,22 @@ class TwoDFeatureEmbedder(nn.Module):
|
||||
for feature, n_channels in self.features.items()
|
||||
}
|
||||
)
|
||||
|
||||
def collapse2D(self, x, L):
|
||||
return x.reshape((L, L, x.numel() // (L * L)))
|
||||
|
||||
def forward(self, f, collapse_length):
|
||||
return sum(
|
||||
tuple(
|
||||
self.embedders[feature](self.collapse2D(f[feature].float(), collapse_length))
|
||||
self.embedders[feature](
|
||||
self.collapse2D(f[feature].float(), collapse_length)
|
||||
)
|
||||
for feature, n_channels in self.features.items()
|
||||
if exists(n_channels)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SinusoidalDistEmbed(nn.Module):
|
||||
"""
|
||||
Applies sinusoidal embedding to pairwise distances and projects to c_atompair.
|
||||
|
||||
@@ -11,10 +11,10 @@ from rfd3.model.layers.blocks import (
|
||||
Downcast,
|
||||
LocalAtomTransformer,
|
||||
OneDFeatureEmbedder,
|
||||
TwoDFeatureEmbedder,
|
||||
PositionPairDistEmbedder,
|
||||
RelativePositionEncodingWithIndexRemoval,
|
||||
SinusoidalDistEmbed,
|
||||
TwoDFeatureEmbedder,
|
||||
)
|
||||
from rfd3.model.layers.chunked_pairwise import (
|
||||
ChunkedPairwiseEmbedder,
|
||||
@@ -64,7 +64,7 @@ class TokenInitializer(nn.Module):
|
||||
self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s)
|
||||
self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom)
|
||||
self.token_1d_embedder = OneDFeatureEmbedder(token_1d_features, c_s)
|
||||
if token_2d_features != None:
|
||||
if token_2d_features is not None:
|
||||
self.token_2d_embedder = TwoDFeatureEmbedder(token_2d_features, c_z)
|
||||
else:
|
||||
self.token_2d_embedder = None
|
||||
@@ -209,7 +209,7 @@ class TokenInitializer(nn.Module):
|
||||
f["ref_pos"][f["is_ca"]], valid_mask
|
||||
)
|
||||
# Add extra token pair features
|
||||
if self.token_2d_embedder != None:
|
||||
if self.token_2d_embedder is not None:
|
||||
Z_init_II = Z_init_II + self.token_2d_embedder(f, I)
|
||||
|
||||
# Run a small transformer to provide position encodings to single.
|
||||
|
||||
@@ -450,15 +450,12 @@ class AADesignTrainer(FabricTrainer):
|
||||
):
|
||||
metadata_dict[i]["metrics"] |= get_hbond_metrics(atom_array)
|
||||
|
||||
if (
|
||||
"bp_partners" in atom_array.get_annotation_categories()
|
||||
):
|
||||
if not np.all(atom_array.bp_partners == None):
|
||||
if "bp_partners" in atom_array.get_annotation_categories():
|
||||
if not np.all(atom_array.bp_partners == None): # noqa: E711
|
||||
try:
|
||||
metadata_dict[i]["metrics"] |= get_NA_SS_F1(atom_array)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if "partial_t" in f:
|
||||
# Try calcualte a CA RMSD to input:
|
||||
aa_in = example["atom_array"]
|
||||
|
||||
@@ -2,6 +2,7 @@ from collections import Counter, OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from atomworks.constants import STANDARD_DNA, STANDARD_RNA
|
||||
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
||||
from atomworks.ml.utils.token import (
|
||||
get_token_starts,
|
||||
@@ -13,10 +14,10 @@ from rfd3.constants import (
|
||||
ATOM14_ATOM_NAMES,
|
||||
ATOM23_ATOM_NAMES_DNA,
|
||||
ATOM23_ATOM_NAMES_RNA,
|
||||
backbone_atoms_RNA,
|
||||
VIRTUAL_ATOM_ELEMENT_NAME,
|
||||
association_schemes,
|
||||
association_schemes_stripped,
|
||||
backbone_atoms_RNA,
|
||||
)
|
||||
from rfd3.utils.io import (
|
||||
build_stack_from_atom_array_and_batched_coords,
|
||||
@@ -25,7 +26,6 @@ from scipy.optimize import linear_sum_assignment
|
||||
|
||||
from foundry.common import exists
|
||||
from foundry.utils.ddp import RankedLogger
|
||||
from atomworks.constants import STANDARD_DNA, STANDARD_RNA
|
||||
|
||||
global_logger = RankedLogger(__name__, rank_zero_only=False)
|
||||
|
||||
@@ -221,7 +221,7 @@ def _readout_seq_from_struc(
|
||||
# There might be a better way to do this.
|
||||
CA_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CA"]
|
||||
CB_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CB"]
|
||||
|
||||
|
||||
if cur_res_atom_array.is_dna[0] or cur_res_atom_array.is_rna[0]:
|
||||
cur_central_atom = "C1'"
|
||||
elif np.linalg.norm(CA_coord - CB_coord) < threshold:
|
||||
@@ -229,7 +229,6 @@ def _readout_seq_from_struc(
|
||||
else:
|
||||
cur_central_atom = central_atom
|
||||
|
||||
|
||||
central_mask = cur_res_atom_array.atom_name == cur_central_atom
|
||||
|
||||
# ... Calculate the distance to the central atom
|
||||
@@ -269,7 +268,7 @@ def _readout_seq_from_struc(
|
||||
if not cur_res_atom_array.is_rna[0]:
|
||||
continue
|
||||
else:
|
||||
#ATOM_NAMES = ATOM23_ATOM_NAMES_RNA
|
||||
# ATOM_NAMES = ATOM23_ATOM_NAMES_RNA
|
||||
if not cur_res_atom_array.is_protein[0]:
|
||||
continue
|
||||
|
||||
@@ -425,9 +424,11 @@ def process_unindexed_outputs(
|
||||
|
||||
try:
|
||||
assert (res_id_ == res_id) & (chain_id_ == chain_id)
|
||||
except:
|
||||
global_logger.warning("Unindexed mapping did not work properly, res_id, chain_id")
|
||||
|
||||
except Exception:
|
||||
global_logger.warning(
|
||||
"Unindexed mapping did not work properly, res_id, chain_id"
|
||||
)
|
||||
|
||||
inserted_mask = np.logical_or(inserted_mask, token_match)
|
||||
|
||||
# ... Compute metrics based on the new distances
|
||||
@@ -456,11 +457,7 @@ def process_unindexed_outputs(
|
||||
else:
|
||||
dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
|
||||
|
||||
elif not np.any(
|
||||
np.isin(
|
||||
token.atom_name, backbone_atoms_RNA
|
||||
)
|
||||
):
|
||||
elif not np.any(np.isin(token.atom_name, backbone_atoms_RNA)):
|
||||
if np.sum(token.atomize) == 1:
|
||||
join_atom = np.where(token.atomize)[0][0]
|
||||
elif "C1'" in token.atom_name:
|
||||
@@ -474,10 +471,10 @@ def process_unindexed_outputs(
|
||||
)
|
||||
else:
|
||||
dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
|
||||
|
||||
|
||||
try:
|
||||
metadata["join_point_rmsd_by_token"][token_pdb_id] = dist
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
metadata["diffused_index_map"][token_pdb_id] = f"{chain_id}{res_id}"
|
||||
|
||||
@@ -243,7 +243,7 @@ class SampleConditioningType(Transform):
|
||||
)
|
||||
self.meta_conditioning_probabilities = meta_conditioning_probabilities
|
||||
self.train_conditions = train_conditions
|
||||
|
||||
|
||||
for item in self.train_conditions:
|
||||
self.train_conditions[item].association_scheme = association_scheme
|
||||
|
||||
@@ -265,15 +265,15 @@ class SampleConditioningType(Transform):
|
||||
assert "conditions" in data, "Conditioning dict not initialized"
|
||||
|
||||
def forward(self, data):
|
||||
#for item in self.train_conditions:
|
||||
# for item in self.train_conditions:
|
||||
# print(self.train_conditions[item].is_valid_for_example(data))
|
||||
|
||||
valid_conditions = [
|
||||
cond
|
||||
for cond in self.train_conditions.values()
|
||||
if cond.is_valid_for_example(data) and cond.frequency > 0
|
||||
if cond.is_valid_for_example(data) and cond.frequency > 0
|
||||
]
|
||||
|
||||
|
||||
if len(valid_conditions) == 0:
|
||||
raise InvalidSampledConditionException("No valid condition was found.")
|
||||
|
||||
@@ -288,7 +288,7 @@ class SampleConditioningType(Transform):
|
||||
cond = valid_conditions[i_cond]
|
||||
|
||||
cond.association_scheme = self.association_scheme
|
||||
|
||||
|
||||
data["sampled_condition"] = cond
|
||||
data["sampled_condition_name"] = cond.name
|
||||
data["sampled_condition_cls"] = cond.__class__
|
||||
@@ -306,11 +306,6 @@ class SampleConditioningFlags(Transform):
|
||||
"AssignTypes",
|
||||
"SampleConditioningType",
|
||||
] # We use is_protein in the PPI training condition
|
||||
def __init__(self, association_scheme):
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
def __init__(self, association_scheme):
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
def __init__(self, association_scheme):
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
@@ -38,6 +38,7 @@ from rfd3.transforms.conditioning_base import (
|
||||
UnindexFlaggedTokens,
|
||||
get_motif_features,
|
||||
)
|
||||
from rfd3.transforms.na_geom import na_ss_feats_from_annotation
|
||||
from rfd3.transforms.rasa import discretize_rasa
|
||||
from rfd3.transforms.util_transforms import (
|
||||
AssignTypes,
|
||||
@@ -45,7 +46,6 @@ from rfd3.transforms.util_transforms import (
|
||||
get_af3_token_representative_masks,
|
||||
)
|
||||
from rfd3.transforms.virtual_atoms import PadTokensWithVirtualAtoms
|
||||
from rfd3.transforms.na_geom import na_ss_feats_from_annotation
|
||||
|
||||
from foundry.utils.ddp import RankedLogger # noqa
|
||||
|
||||
@@ -71,7 +71,7 @@ class SubsampleToTypes(Transform):
|
||||
def __init__(
|
||||
self,
|
||||
allowed_types: list | str = ["is_protein"],
|
||||
association_scheme: str = 'atom14'
|
||||
association_scheme: str = "atom14",
|
||||
):
|
||||
self.allowed_types = allowed_types
|
||||
self.association_scheme = association_scheme
|
||||
@@ -106,7 +106,7 @@ class SubsampleToTypes(Transform):
|
||||
)
|
||||
)
|
||||
|
||||
if self.association_scheme != 'atom23' and atom_array.is_protein.sum() == 0:
|
||||
if self.association_scheme != "atom23" and atom_array.is_protein.sum() == 0:
|
||||
raise ValueError(
|
||||
"No protein atoms found in the atom array. Example ID: {}".format(
|
||||
data.get("example_id", "unknown")
|
||||
@@ -757,7 +757,7 @@ class AddAdditional1dFeaturesToFeats(Transform):
|
||||
"""
|
||||
if "feats" not in data.keys():
|
||||
data["feats"] = {}
|
||||
|
||||
|
||||
if self.association_scheme == "atom23":
|
||||
data["atom_array"].set_annotation(
|
||||
"is_protein_token", data["atom_array"].is_protein
|
||||
@@ -772,7 +772,6 @@ class AddAdditional1dFeaturesToFeats(Transform):
|
||||
data = self.generate_feature(feature_name, n_dims, data, "atom")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
|
||||
class AddAdditional2dFeaturesToFeats(Transform):
|
||||
@@ -791,15 +790,15 @@ class AddAdditional2dFeaturesToFeats(Transform):
|
||||
token_2d_features,
|
||||
autofill_zeros_if_not_present_in_atomarray=False,
|
||||
association_scheme="atom14",
|
||||
):
|
||||
):
|
||||
self.autofill = autofill_zeros_if_not_present_in_atomarray
|
||||
self.token_2d_features = token_2d_features
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
# Need to pre-define custom constructor functions
|
||||
# Need to pre-define custom constructor functions
|
||||
# to map from atomarray annotations to tensors.
|
||||
self.constructor_functions = {
|
||||
'bp_partners': na_ss_feats_from_annotation,
|
||||
"bp_partners": na_ss_feats_from_annotation,
|
||||
}
|
||||
|
||||
def check_input(self, data) -> None:
|
||||
@@ -807,11 +806,10 @@ class AddAdditional2dFeaturesToFeats(Transform):
|
||||
check_is_instance(data, "atom_array", AtomArray)
|
||||
|
||||
def generate_token_feature(self, feature_name, n_dims, data):
|
||||
|
||||
# Don't do this if we already have the feature
|
||||
if feature_name in data["feats"].keys():
|
||||
return data
|
||||
|
||||
|
||||
# For these, we need to use a constructor function mapping,
|
||||
# since pair features may require custom logic/conventions.
|
||||
|
||||
@@ -822,11 +820,11 @@ class AddAdditional2dFeaturesToFeats(Transform):
|
||||
raise ValueError(
|
||||
f"No constructor function found for 2d feature `{feature_name}`"
|
||||
)
|
||||
|
||||
|
||||
# We can fix shape issues here:
|
||||
if len(feature_array.shape) == 2 and n_dims == 1:
|
||||
feature_array = feature_array.unsqueeze(1)
|
||||
|
||||
|
||||
# ensure that feature_array is a 3d array with third dim == n_dims:
|
||||
if len(feature_array.shape) != 3:
|
||||
raise ValueError(
|
||||
@@ -854,7 +852,7 @@ class AddAdditional2dFeaturesToFeats(Transform):
|
||||
if "feats" not in data.keys():
|
||||
data["feats"] = {}
|
||||
# Only apply for features that the model is expecting:
|
||||
if self.token_2d_features == None:
|
||||
if self.token_2d_features is None:
|
||||
return data
|
||||
for feature_name, n_dims in self.token_2d_features.items():
|
||||
data = self.generate_token_feature(feature_name, n_dims, data)
|
||||
|
||||
@@ -248,5 +248,3 @@ def subsample_one_hot_np(array, fraction):
|
||||
new_array[i, j] = 1
|
||||
|
||||
return new_array
|
||||
|
||||
|
||||
|
||||
@@ -1,48 +1,49 @@
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from biotite.structure import AtomArray
|
||||
from atomworks.ml.transforms._checks import (
|
||||
check_atom_array_annotation,
|
||||
check_contains_keys,
|
||||
check_is_instance,
|
||||
)
|
||||
from atomworks.ml.transforms.base import Transform
|
||||
from atomworks.ml.utils.token import get_token_starts, spread_token_wise
|
||||
from biotite.structure import AtomArray
|
||||
from rfd3.transforms.conditioning_utils import sample_island_tokens
|
||||
from rfd3.transforms.na_geom_utils import (
|
||||
annotate_na_ss,
|
||||
annotate_na_ss_from_data_specification,
|
||||
DEFAULT_NA_SS_FEATURE_INFO,
|
||||
annotate_na_ss,
|
||||
annotate_na_ss_from_data_specification,
|
||||
)
|
||||
|
||||
from atomworks.ml.utils.token import spread_token_wise, get_token_starts
|
||||
|
||||
def na_ss_feats_from_annotation(atom_array: AtomArray,
|
||||
token_starts= None,
|
||||
n_tokens = None,
|
||||
return_as_onehot = True,
|
||||
) -> np.ndarray:
|
||||
def na_ss_feats_from_annotation(
|
||||
atom_array: AtomArray,
|
||||
token_starts=None,
|
||||
n_tokens=None,
|
||||
return_as_onehot=True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Takes in atom array and constucts a base pair feature matrix from annotations,
|
||||
according to to custom feature constuction + masking system.
|
||||
This featurization utilizes info from BasePairEnum to assign int values
|
||||
This featurization utilizes info from BasePairEnum to assign int values
|
||||
to paired, unpaired, and masked positions in the matrix.
|
||||
|
||||
Args:
|
||||
* atom_array: AtomArray with bp_partners annotation at atom level
|
||||
* token_starts (optional): indices of token starts in the atom array
|
||||
* n_tokens (optional): number of tokens (length of token_starts)
|
||||
* return_as_onehot (optional): if False, return integer-encoded
|
||||
* return_as_onehot (optional): if False, return integer-encoded
|
||||
matrix instead of one-hot encoded matrix
|
||||
|
||||
returns:
|
||||
* na_ss_matrix:
|
||||
If ``return_as_onehot`` is True (default):
|
||||
np.ndarray of shape (n_tokens, n_tokens, n_classes)
|
||||
np.ndarray of shape (n_tokens, n_tokens, n_classes)
|
||||
with one-hot encoded values according to BasePairEnum
|
||||
|
||||
If ``return_as_onehot`` is False :
|
||||
np.ndarray of shape (n_tokens, n_tokens)
|
||||
np.ndarray of shape (n_tokens, n_tokens)
|
||||
with int values according to BasePairEnum
|
||||
|
||||
|
||||
@@ -51,13 +52,16 @@ def na_ss_feats_from_annotation(atom_array: AtomArray,
|
||||
if (token_starts is None) or (n_tokens is None):
|
||||
token_starts = get_token_starts(atom_array)
|
||||
n_tokens = len(token_starts)
|
||||
|
||||
|
||||
# Collect token inds for paired or loop positions:
|
||||
pair_inds = []
|
||||
loop_inds = []
|
||||
token_bp_partners = atom_array.get_annotation("bp_partners")[token_starts] # get bp_partners at token level
|
||||
assert len(token_bp_partners) == n_tokens, "Length of token_bp_partners should match n_tokens"
|
||||
token_bp_partners = atom_array.get_annotation("bp_partners")[
|
||||
token_starts
|
||||
] # get bp_partners at token level
|
||||
assert (
|
||||
len(token_bp_partners) == n_tokens
|
||||
), "Length of token_bp_partners should match n_tokens"
|
||||
for i, j_list in enumerate(token_bp_partners):
|
||||
if j_list is not None:
|
||||
if len(j_list) > 0:
|
||||
@@ -68,46 +72,54 @@ def na_ss_feats_from_annotation(atom_array: AtomArray,
|
||||
|
||||
# The standard system for constructing meaningful base pair features:
|
||||
# 0). Initialize with values of UNSPECIFIED (0): int matrix of shape (n_tokens, n_tokens)
|
||||
na_ss_matrix = np.full((n_tokens, n_tokens), DEFAULT_NA_SS_FEATURE_INFO["NA_SS_MASK"], dtype=np.int64)
|
||||
na_ss_matrix = np.full(
|
||||
(n_tokens, n_tokens), DEFAULT_NA_SS_FEATURE_INFO["NA_SS_MASK"], dtype=np.int64
|
||||
)
|
||||
|
||||
# 1). Fill in with values of PAIR (1) at positions that have bp_partners annotated as a non-empty list
|
||||
for pair_i, pair_j in pair_inds:
|
||||
na_ss_matrix[pair_i, pair_j] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_PAIR"]
|
||||
na_ss_matrix[pair_j, pair_i] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_PAIR"] # ensure symmetry
|
||||
na_ss_matrix[pair_j, pair_i] = DEFAULT_NA_SS_FEATURE_INFO[
|
||||
"NA_SS_PAIR"
|
||||
] # ensure symmetry
|
||||
|
||||
# 2). Fill in with values of LOOP (2) at positions that have bp_partners annotated as an empty list (explicitly unpaired)
|
||||
# (we make full stripes across that position's row/col to indicate that NONE of those other positions are paired )
|
||||
for loop_i in loop_inds:
|
||||
na_ss_matrix[loop_i, :] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_LOOP"]
|
||||
na_ss_matrix[:, loop_i] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_LOOP"] # ensure symmetry
|
||||
|
||||
na_ss_matrix[:, loop_i] = DEFAULT_NA_SS_FEATURE_INFO[
|
||||
"NA_SS_LOOP"
|
||||
] # ensure symmetry
|
||||
|
||||
# Optional: convert NA-SS matrix to one-hot encoding according for model input:
|
||||
if return_as_onehot:
|
||||
na_ss_matrix = np.eye(len(DEFAULT_NA_SS_FEATURE_INFO), dtype=np.int64)[na_ss_matrix]
|
||||
|
||||
na_ss_matrix = np.eye(len(DEFAULT_NA_SS_FEATURE_INFO), dtype=np.int64)[
|
||||
na_ss_matrix
|
||||
]
|
||||
|
||||
return na_ss_matrix
|
||||
|
||||
|
||||
class CalculateNucleicAcidGeomFeats(Transform):
|
||||
"""
|
||||
Transform for constructing nucleic-acid conditioning features.
|
||||
Transform for constructing nucleic-acid conditioning features.
|
||||
|
||||
This transform currently produces only nucleic-acid secondary-structure (NA-SS)
|
||||
features as a 2D token-token matrix with 3 bins:
|
||||
* 0: mask / unspecified
|
||||
* 1: paired
|
||||
* 2: loop / explicitly unpaired
|
||||
This transform currently produces only nucleic-acid secondary-structure (NA-SS)
|
||||
features as a 2D token-token matrix with 3 bins:
|
||||
* 0: mask / unspecified
|
||||
* 1: paired
|
||||
* 2: loop / explicitly unpaired
|
||||
|
||||
Training:
|
||||
- Computes geometry/H-bond-based base pairs and writes them onto the AtomArray
|
||||
via the ``bp_partners`` annotation (annotation-first), then reconstructs the
|
||||
matrix (and optionally masks parts of it) before one-hot encoding.
|
||||
Training:
|
||||
- Computes geometry/H-bond-based base pairs and writes them onto the AtomArray
|
||||
via the ``bp_partners`` annotation (annotation-first), then reconstructs the
|
||||
matrix (and optionally masks parts of it) before one-hot encoding.
|
||||
|
||||
Inference:
|
||||
- Interprets user-provided secondary-structure specifications, writes the same
|
||||
``bp_partners`` annotation, then follows the same matrix + one-hot path.
|
||||
Inference:
|
||||
- Interprets user-provided secondary-structure specifications, writes the same
|
||||
``bp_partners`` annotation, then follows the same matrix + one-hot path.
|
||||
|
||||
Note: helical-parameter features are not implemented/used in this refactored path.
|
||||
Note: helical-parameter features are not implemented/used in this refactored path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -115,43 +127,46 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
is_inference,
|
||||
# Conditional sampling parameters all stored in this dict:
|
||||
meta_conditioning_probabilities: dict[str, float] = None,
|
||||
|
||||
# Mask control paramerers:
|
||||
nucleic_ss_min_shown: float = 0.2,
|
||||
nucleic_ss_max_shown: float = 1.0,
|
||||
n_islands_min: int = 1,
|
||||
n_islands_max: int = 6,
|
||||
|
||||
# USE_RF2AA_NAMES: bool = False,
|
||||
NA_only: bool = False,
|
||||
planar_only : bool = True,
|
||||
|
||||
planar_only: bool = True,
|
||||
):
|
||||
# Critical, must always have to know how to handle
|
||||
self.is_inference = is_inference
|
||||
|
||||
self.is_inference = is_inference
|
||||
|
||||
self.meta_conditioning_probabilities = meta_conditioning_probabilities or {}
|
||||
|
||||
|
||||
# Control whether we show some nucleic SS or default to full 2D mask
|
||||
self.p_is_nucleic_ss_example = self.meta_conditioning_probabilities.get("p_is_nucleic_ss_example", 0.0)
|
||||
self.p_is_nucleic_ss_example = self.meta_conditioning_probabilities.get(
|
||||
"p_is_nucleic_ss_example", 0.0
|
||||
)
|
||||
|
||||
# Control whether we define full SS or just part of it (only applies if is NA SS example)
|
||||
self.p_show_partial_feats = self.meta_conditioning_probabilities.get("p_nucleic_ss_show_partial_feats", 0.0)
|
||||
self.p_show_partial_feats = self.meta_conditioning_probabilities.get(
|
||||
"p_nucleic_ss_show_partial_feats", 0.0
|
||||
)
|
||||
|
||||
# Some frac of time default to only showing canonical base pairs
|
||||
self.p_canonical_bp_filter = self.meta_conditioning_probabilities.get("p_canonical_bp_filter", 0.5)
|
||||
|
||||
self.p_canonical_bp_filter = self.meta_conditioning_probabilities.get(
|
||||
"p_canonical_bp_filter", 0.5
|
||||
)
|
||||
|
||||
# mask patterning control to make things resemble design scenarios
|
||||
self.nucleic_ss_min_shown = nucleic_ss_min_shown
|
||||
self.nucleic_ss_max_shown = nucleic_ss_max_shown
|
||||
self.n_islands_min = n_islands_min
|
||||
self.n_islands_max = n_islands_max
|
||||
self.nucleic_ss_min_shown = nucleic_ss_min_shown
|
||||
self.nucleic_ss_max_shown = nucleic_ss_max_shown
|
||||
self.n_islands_min = n_islands_min
|
||||
self.n_islands_max = n_islands_max
|
||||
|
||||
# Filters for what can be considered a planar contact interaction
|
||||
self.NA_only = NA_only # only annotate base-like interactions for nucleic acid residues
|
||||
self.planar_only = planar_only # only consider planar atoms in sidechains for geometry calculations,
|
||||
|
||||
|
||||
self.NA_only = (
|
||||
NA_only # only annotate base-like interactions for nucleic acid residues
|
||||
)
|
||||
self.planar_only = planar_only # only consider planar atoms in sidechains for geometry calculations,
|
||||
|
||||
def check_input(self, data: dict[str, Any]) -> None:
|
||||
check_contains_keys(data, ["atom_array"])
|
||||
@@ -162,9 +177,7 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
def _sample_training_flags(self) -> tuple[bool, bool]:
|
||||
"""Sample booleans controlling whether/how features are shown in training."""
|
||||
is_nucleic_ss_example = bool(np.random.rand() < self.p_is_nucleic_ss_example)
|
||||
give_partial_feats = bool(
|
||||
np.random.rand() < self.p_show_partial_feats
|
||||
)
|
||||
give_partial_feats = bool(np.random.rand() < self.p_show_partial_feats)
|
||||
return is_nucleic_ss_example, give_partial_feats
|
||||
|
||||
def forward(self, data: dict) -> dict:
|
||||
@@ -177,24 +190,25 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
|
||||
# Handle the training case with ground truth and masking
|
||||
if not self.is_inference:
|
||||
|
||||
# First, annotate as usual
|
||||
is_nucleic_ss_example, give_partial_feats = self._sample_training_flags()
|
||||
|
||||
if is_nucleic_ss_example:
|
||||
|
||||
atom_array = annotate_na_ss(atom_array,
|
||||
NA_only=self.NA_only,
|
||||
planar_only=self.planar_only,
|
||||
p_canonical_bp_filter=self.p_canonical_bp_filter,
|
||||
)
|
||||
atom_array = annotate_na_ss(
|
||||
atom_array,
|
||||
NA_only=self.NA_only,
|
||||
planar_only=self.planar_only,
|
||||
p_canonical_bp_filter=self.p_canonical_bp_filter,
|
||||
)
|
||||
|
||||
# Generate symmetric partner annotations at the token level for masking purposes.
|
||||
# choice for object-consistency: if already masked/undefined: be a list mapping to self-index.
|
||||
partner_sym_map = {
|
||||
i: atom_array.bp_partners[ts_i] if atom_array.bp_partners[ts_i] is not None else [i]
|
||||
i: atom_array.bp_partners[ts_i]
|
||||
if atom_array.bp_partners[ts_i] is not None
|
||||
else [i]
|
||||
for i, ts_i in enumerate(token_starts)
|
||||
}
|
||||
}
|
||||
|
||||
# # Sample mask on token level:
|
||||
token_mask_to_show = self._sample_where_to_show_ss(
|
||||
@@ -206,17 +220,19 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
|
||||
# Spread mask to atom level
|
||||
is_ss_shown = spread_token_wise(atom_array, token_mask_to_show)
|
||||
|
||||
|
||||
# Extract the base pair annotations
|
||||
bp_partners_atom = atom_array.get_annotation("bp_partners")
|
||||
|
||||
# Remove unshown positions from bp_partners annotation
|
||||
bp_partners_atom[~is_ss_shown] = None
|
||||
|
||||
|
||||
# Reset the annotation with newly hidden positions
|
||||
atom_array.set_annotation("bp_partners", bp_partners_atom)
|
||||
else:
|
||||
atom_array.set_annotation("bp_partners", np.array([None]*len(atom_array)))
|
||||
atom_array.set_annotation(
|
||||
"bp_partners", np.array([None] * len(atom_array))
|
||||
)
|
||||
|
||||
# Inference case: create from commandline args
|
||||
else:
|
||||
@@ -239,25 +255,26 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
log_dict = data["log_dict"]
|
||||
data["log_dict"] = log_dict
|
||||
data["atom_array"] = atom_array
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _sample_where_to_show_ss(self, n_tokens: int,
|
||||
is_nucleic_ss_example: bool = True,
|
||||
give_partial_feats: bool = True,
|
||||
partner_sym_map: dict[int, list[int]] = None,
|
||||
) -> np.ndarray:
|
||||
def _sample_where_to_show_ss(
|
||||
self,
|
||||
n_tokens: int,
|
||||
is_nucleic_ss_example: bool = True,
|
||||
give_partial_feats: bool = True,
|
||||
partner_sym_map: dict[int, list[int]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Sample token-level islands indicating which SS rows/cols to reveal.
|
||||
This custom function allows for enforcing symmetry in the shown features according
|
||||
to the partner_sym_map, which encodes which tokens are partners in the SS
|
||||
matrix and thus should be masked/unmasked together to maintain consistency.
|
||||
|
||||
This custom function allows for enforcing symmetry in the shown features according
|
||||
to the partner_sym_map, which encodes which tokens are partners in the SS
|
||||
matrix and thus should be masked/unmasked together to maintain consistency.
|
||||
|
||||
"""
|
||||
# If NOT is_nucleic_ss_example, set is_shown to all False
|
||||
if not is_nucleic_ss_example:
|
||||
token_mask_to_show = np.zeros((n_tokens,), dtype=bool)
|
||||
|
||||
|
||||
# If NOT give_partial_feats, set is_shown to all True
|
||||
if not give_partial_feats:
|
||||
token_mask_to_show = np.ones((n_tokens,), dtype=bool)
|
||||
@@ -265,14 +282,19 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
# Get numerical parameters for that govern the mask pattern
|
||||
frac_shown = (
|
||||
self.nucleic_ss_min_shown
|
||||
+ (self.nucleic_ss_max_shown - self.nucleic_ss_min_shown) * np.random.rand()
|
||||
+ (self.nucleic_ss_max_shown - self.nucleic_ss_min_shown)
|
||||
* np.random.rand()
|
||||
)
|
||||
frac_shown = float(np.clip(frac_shown, 0.0, 1.0))
|
||||
max_length = int(np.ceil(frac_shown * n_tokens))
|
||||
if max_length <= 0:
|
||||
token_mask_to_show = np.zeros((n_tokens,), dtype=bool)
|
||||
island_len_min = max(1, int(frac_shown * n_tokens // max(int(self.n_islands_max), 1)))
|
||||
island_len_max = max(1, int(frac_shown * n_tokens // max(int(self.n_islands_min), 1)))
|
||||
island_len_min = max(
|
||||
1, int(frac_shown * n_tokens // max(int(self.n_islands_max), 1))
|
||||
)
|
||||
island_len_max = max(
|
||||
1, int(frac_shown * n_tokens // max(int(self.n_islands_min), 1))
|
||||
)
|
||||
island_len_min = min(island_len_min, n_tokens)
|
||||
island_len_max = min(island_len_max, n_tokens)
|
||||
island_len_max = max(island_len_max, island_len_min)
|
||||
@@ -287,7 +309,7 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
# Handle symmetry by iterating through the partner_sym_map items and setting
|
||||
# Handle symmetry by iterating through the partner_sym_map items and setting
|
||||
# `partner_mask_to_show` at partner positions to match `token_mask_to_show`
|
||||
# initialize as all shown so effect comes from hiding + logical AND condition
|
||||
partner_mask_to_show = np.ones_like(token_mask_to_show)
|
||||
@@ -299,4 +321,3 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
token_mask_to_show = token_mask_to_show & partner_mask_to_show
|
||||
|
||||
return token_mask_to_show
|
||||
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
import math
|
||||
import numpy as np
|
||||
import biotite.structure as struc
|
||||
from biotite.structure import AtomArray
|
||||
|
||||
import biotite.structure as struc
|
||||
import numpy as np
|
||||
from atomworks.constants import (
|
||||
STANDARD_AA,
|
||||
STANDARD_AA,
|
||||
STANDARD_DNA,
|
||||
STANDARD_RNA,
|
||||
)
|
||||
|
||||
from atomworks.io.utils.sequence import (
|
||||
is_purine,
|
||||
is_pyrimidine,
|
||||
)
|
||||
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
||||
from atomworks.ml.utils.token import (
|
||||
get_token_starts,
|
||||
is_glycine,
|
||||
@@ -24,15 +24,12 @@ from atomworks.ml.utils.token import (
|
||||
is_standard_aa_not_glycine,
|
||||
is_unknown_nucleotide,
|
||||
)
|
||||
from rfd3.transforms.hbonds_hbplus import save_atomarray_to_pdb
|
||||
|
||||
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
||||
|
||||
from biotite.structure import AtomArray
|
||||
from rfd3.constants import (
|
||||
ATOM_REGION_BY_RESI,
|
||||
PLANAR_ATOMS_BY_RESI,
|
||||
ATOM_REGION_BY_RESI,
|
||||
PLANAR_ATOMS_BY_RESI,
|
||||
)
|
||||
import tempfile
|
||||
from rfd3.transforms.hbonds_hbplus import save_atomarray_to_pdb
|
||||
|
||||
# Derived: True when the residue has any planar sidechain atoms
|
||||
HAS_PLANAR_SC = {res: bool(atoms) for res, atoms in PLANAR_ATOMS_BY_RESI.items()}
|
||||
@@ -43,15 +40,23 @@ DEFAULT_NA_SS_FEATURE_INFO: dict[str, int] = {
|
||||
"NA_SS_LOOP": 2,
|
||||
}
|
||||
|
||||
AA_PLANAR_ATOMS = sorted(set(
|
||||
atom for res in STANDARD_AA if res in PLANAR_ATOMS_BY_RESI
|
||||
for atom in PLANAR_ATOMS_BY_RESI[res]
|
||||
))
|
||||
AA_PLANAR_ATOMS = sorted(
|
||||
set(
|
||||
atom
|
||||
for res in STANDARD_AA
|
||||
if res in PLANAR_ATOMS_BY_RESI
|
||||
for atom in PLANAR_ATOMS_BY_RESI[res]
|
||||
)
|
||||
)
|
||||
|
||||
NA_PLANAR_ATOMS = sorted(set(
|
||||
atom for res in (*STANDARD_RNA, *STANDARD_DNA) if res in PLANAR_ATOMS_BY_RESI
|
||||
for atom in PLANAR_ATOMS_BY_RESI[res]
|
||||
))
|
||||
NA_PLANAR_ATOMS = sorted(
|
||||
set(
|
||||
atom
|
||||
for res in (*STANDARD_RNA, *STANDARD_DNA)
|
||||
if res in PLANAR_ATOMS_BY_RESI
|
||||
for atom in PLANAR_ATOMS_BY_RESI[res]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class NucMolInfo:
|
||||
@@ -62,53 +67,126 @@ class NucMolInfo:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
||||
|
||||
# Hbond interaction-class indices of the `hbond_count`` array:
|
||||
# `hbond_count`` array is (L, L, 3), where the last dimension
|
||||
# `hbond_count`` array is (L, L, 3), where the last dimension
|
||||
# encodes interaction type between tokens i & j
|
||||
self.BB_BB = 0 # backbone-backbone hbond interactions
|
||||
self.BB_SC = 1 # backbone-sidechain hbond interactions
|
||||
self.SC_SC = 2 # sidechain-sidechain hbond interactions
|
||||
self.BB_BB = 0 # backbone-backbone hbond interactions
|
||||
self.BB_SC = 1 # backbone-sidechain hbond interactions
|
||||
self.SC_SC = 2 # sidechain-sidechain hbond interactions
|
||||
|
||||
# We sum over the last dimension of the hbond_count array, scaling
|
||||
# We sum over the last dimension of the hbond_count array, scaling
|
||||
# count by the following weights to get the interaction score:
|
||||
self.bp_weight_BB_BB = 0.0
|
||||
self.bp_weight_BB_SC = 0.5
|
||||
self.bp_weight_SC_SC = 1.0
|
||||
self.bp_summation_weights = [self.bp_weight_BB_BB,
|
||||
self.bp_weight_BB_SC,
|
||||
self.bp_weight_SC_SC]
|
||||
self.bp_summation_weights = [
|
||||
self.bp_weight_BB_BB,
|
||||
self.bp_weight_BB_SC,
|
||||
self.bp_weight_SC_SC,
|
||||
]
|
||||
|
||||
# Parameters fo sigmoid function that gives us a continuous step function for
|
||||
# Parameters fo sigmoid function that gives us a continuous step function for
|
||||
# meeting basepair interaction criteria based on hbond counts alone (1st filter).
|
||||
# Calibrated such that:
|
||||
# >= 2 base-base H-bonds -> ~1.0
|
||||
# 1 base-base H-bond + 1 base-backbone H-bond -> ~0.5
|
||||
self.min_hbonds_for_bp = 2.0
|
||||
self.bp_hbond_coeff = 9.8 # determined heuristically
|
||||
self.bp_val_cutoff = 0.5 # minimum basepairing score for binarizing basepairs when needed
|
||||
self.bp_hbond_coeff = 9.8 # determined heuristically
|
||||
self.bp_val_cutoff = (
|
||||
0.5 # minimum basepairing score for binarizing basepairs when needed
|
||||
)
|
||||
|
||||
self.base_geometry_limits = {}
|
||||
self.base_geometry_limits['D_ij'] = 16.0
|
||||
self.base_geometry_limits['H_ij'] = 1.5
|
||||
self.base_geometry_limits['P_ij'] = math.pi/5
|
||||
self.base_geometry_limits['B_ij'] = math.pi/5
|
||||
self.base_geometry_limits["D_ij"] = 16.0
|
||||
self.base_geometry_limits["H_ij"] = 1.5
|
||||
self.base_geometry_limits["P_ij"] = math.pi / 5
|
||||
self.base_geometry_limits["B_ij"] = math.pi / 5
|
||||
|
||||
self.rep_atom_dict={"protein": "CA", "rna": "C1'", "dna": "C1'"}
|
||||
self.rep_atom_dict = {"protein": "CA", "rna": "C1'", "dna": "C1'"}
|
||||
|
||||
# go through self.vec_atom_dict and remove spaces from atom names (values in inner dicts), and remove spaces from keys + replace 'R' with '' in outer dict keys
|
||||
self.vec_atom_dict = {
|
||||
"DA": {"W_start":"N1", "W_stop":"N6", "H_start":"N7", "H_stop":"N6", "S_start":"C1'", "S_stop":"N3", "B_start":"C1'", "B_stop":"N9" },
|
||||
"DG": {"W_start":"N1", "W_stop":"O6", "H_start":"N7", "H_stop":"O6", "S_start":"C1'", "S_stop":"N3", "B_start":"C1'", "B_stop":"N9" },
|
||||
"DC": {"W_start":"N3", "W_stop":"N4", "H_start":"C5", "H_stop":"N4", "S_start":"C1'", "S_stop":"O2", "B_start":"C1'", "B_stop":"N1" },
|
||||
"DT": {"W_start":"N3", "W_stop":"O4", "H_start":"C5", "H_stop":"O4", "S_start":"C1'", "S_stop":"O2", "B_start":"C1'", "B_stop":"N1" },
|
||||
"A": {"W_start":"N1", "W_stop":"N6", "H_start":"N7", "H_stop":"N6", "S_start":"C1'", "S_stop":"N3", "B_start":"C1'", "B_stop":"N9" },
|
||||
"G": {"W_start":"N1", "W_stop":"O6", "H_start":"N7", "H_stop":"O6", "S_start":"C1'", "S_stop":"N3", "B_start":"C1'", "B_stop":"N9" },
|
||||
"C": {"W_start":"N3", "W_stop":"N4", "H_start":"C5", "H_stop":"N4", "S_start":"C1'", "S_stop":"O2", "B_start":"C1'", "B_stop":"N1" },
|
||||
"U": {"W_start":"N3", "W_stop":"O4", "H_start":"C5", "H_stop":"O4", "S_start":"C1'", "S_stop":"O2", "B_start":"C1'", "B_stop":"N1" },
|
||||
}
|
||||
|
||||
"DA": {
|
||||
"W_start": "N1",
|
||||
"W_stop": "N6",
|
||||
"H_start": "N7",
|
||||
"H_stop": "N6",
|
||||
"S_start": "C1'",
|
||||
"S_stop": "N3",
|
||||
"B_start": "C1'",
|
||||
"B_stop": "N9",
|
||||
},
|
||||
"DG": {
|
||||
"W_start": "N1",
|
||||
"W_stop": "O6",
|
||||
"H_start": "N7",
|
||||
"H_stop": "O6",
|
||||
"S_start": "C1'",
|
||||
"S_stop": "N3",
|
||||
"B_start": "C1'",
|
||||
"B_stop": "N9",
|
||||
},
|
||||
"DC": {
|
||||
"W_start": "N3",
|
||||
"W_stop": "N4",
|
||||
"H_start": "C5",
|
||||
"H_stop": "N4",
|
||||
"S_start": "C1'",
|
||||
"S_stop": "O2",
|
||||
"B_start": "C1'",
|
||||
"B_stop": "N1",
|
||||
},
|
||||
"DT": {
|
||||
"W_start": "N3",
|
||||
"W_stop": "O4",
|
||||
"H_start": "C5",
|
||||
"H_stop": "O4",
|
||||
"S_start": "C1'",
|
||||
"S_stop": "O2",
|
||||
"B_start": "C1'",
|
||||
"B_stop": "N1",
|
||||
},
|
||||
"A": {
|
||||
"W_start": "N1",
|
||||
"W_stop": "N6",
|
||||
"H_start": "N7",
|
||||
"H_stop": "N6",
|
||||
"S_start": "C1'",
|
||||
"S_stop": "N3",
|
||||
"B_start": "C1'",
|
||||
"B_stop": "N9",
|
||||
},
|
||||
"G": {
|
||||
"W_start": "N1",
|
||||
"W_stop": "O6",
|
||||
"H_start": "N7",
|
||||
"H_stop": "O6",
|
||||
"S_start": "C1'",
|
||||
"S_stop": "N3",
|
||||
"B_start": "C1'",
|
||||
"B_stop": "N9",
|
||||
},
|
||||
"C": {
|
||||
"W_start": "N3",
|
||||
"W_stop": "N4",
|
||||
"H_start": "C5",
|
||||
"H_stop": "N4",
|
||||
"S_start": "C1'",
|
||||
"S_stop": "O2",
|
||||
"B_start": "C1'",
|
||||
"B_stop": "N1",
|
||||
},
|
||||
"U": {
|
||||
"W_start": "N3",
|
||||
"W_stop": "O4",
|
||||
"H_start": "C5",
|
||||
"H_stop": "O4",
|
||||
"S_start": "C1'",
|
||||
"S_stop": "O2",
|
||||
"B_start": "C1'",
|
||||
"B_stop": "N1",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def calculate_hb_counts(
|
||||
@@ -117,7 +195,7 @@ def calculate_hb_counts(
|
||||
mol_info: NucMolInfo,
|
||||
cutoff_HA_dist: float = 2.5,
|
||||
cutoff_DA_dist: float = 3.9,
|
||||
):
|
||||
):
|
||||
"""Count hydrogen bonds between residue pairs using HBPLUS.
|
||||
|
||||
Args:
|
||||
@@ -195,14 +273,16 @@ def calculate_hb_counts(
|
||||
)
|
||||
# d_atm = atom_array[d_mask]
|
||||
# d_idx = d_atm.token_id
|
||||
d_idx = token_level_data["resi2index"].get(f"{d_chain_iid}__{d_resi}", None)
|
||||
d_idx = token_level_data["resi2index"].get(
|
||||
f"{d_chain_iid}__{d_resi}", None
|
||||
)
|
||||
if d_idx is None:
|
||||
continue
|
||||
|
||||
# Handle standard polymer residues for donor atom:
|
||||
if d_resn in ATOM_REGION_BY_RESI.keys():
|
||||
d_is_sc = (d_atom_name in ATOM_REGION_BY_RESI[d_resn]['sc'])
|
||||
d_is_bb = (d_atom_name in ATOM_REGION_BY_RESI[d_resn]['bb'])
|
||||
d_is_sc = d_atom_name in ATOM_REGION_BY_RESI[d_resn]["sc"]
|
||||
d_is_bb = d_atom_name in ATOM_REGION_BY_RESI[d_resn]["bb"]
|
||||
else:
|
||||
# If non-polymer, define any ligand HBonding atom as backbone:
|
||||
if d_mask.sum() > 0:
|
||||
@@ -219,45 +299,51 @@ def calculate_hb_counts(
|
||||
& (atom_array.res_id == a_resi)
|
||||
& (atom_array.chain_iid == a_chain_iid)
|
||||
)
|
||||
a_idx = token_level_data["resi2index"].get(f"{a_chain_iid}__{a_resi}", None)
|
||||
a_idx = token_level_data["resi2index"].get(
|
||||
f"{a_chain_iid}__{a_resi}", None
|
||||
)
|
||||
if a_idx is None:
|
||||
continue
|
||||
|
||||
# Handle standard polymer residues for acceptor atom:
|
||||
if a_resn in ATOM_REGION_BY_RESI.keys():
|
||||
a_is_sc = (a_atom_name in ATOM_REGION_BY_RESI[a_resn]['sc'])
|
||||
a_is_bb = (a_atom_name in ATOM_REGION_BY_RESI[a_resn]['bb'])
|
||||
a_is_sc = a_atom_name in ATOM_REGION_BY_RESI[a_resn]["sc"]
|
||||
a_is_bb = a_atom_name in ATOM_REGION_BY_RESI[a_resn]["bb"]
|
||||
else:
|
||||
# If non-polymer, define any ligand HBonding atom as backbone:
|
||||
if a_mask.sum() > 0:
|
||||
a_is_bb = atom_array[a_mask][0].is_ligand
|
||||
|
||||
# 0 -> both backbone (BB-BB)
|
||||
hbond_count[a_idx, d_idx, 0] += (a_is_bb * d_is_bb)
|
||||
hbond_count[d_idx, a_idx, 0] += (d_is_bb * a_is_bb)
|
||||
hbond_count[a_idx, d_idx, 0] += a_is_bb * d_is_bb
|
||||
hbond_count[d_idx, a_idx, 0] += d_is_bb * a_is_bb
|
||||
|
||||
# 1 -> one backbone, one sidechain (BB-SC)
|
||||
hbond_count[a_idx, d_idx, 1] += (a_is_bb * d_is_sc) | (a_is_sc * d_is_bb)
|
||||
hbond_count[d_idx, a_idx, 1] += (d_is_bb * a_is_sc) | (d_is_sc * a_is_bb)
|
||||
hbond_count[a_idx, d_idx, 1] += (a_is_bb * d_is_sc) | (
|
||||
a_is_sc * d_is_bb
|
||||
)
|
||||
hbond_count[d_idx, a_idx, 1] += (d_is_bb * a_is_sc) | (
|
||||
d_is_sc * a_is_bb
|
||||
)
|
||||
|
||||
# 2 -> both sidechain (SC-SC)
|
||||
hbond_count[a_idx, d_idx, 2] += (a_is_sc * d_is_sc)
|
||||
hbond_count[d_idx, a_idx, 2] += (d_is_sc * a_is_sc)
|
||||
'''
|
||||
hbond_count[a_idx, d_idx, 2] += a_is_sc * d_is_sc
|
||||
hbond_count[d_idx, a_idx, 2] += d_is_sc * a_is_sc
|
||||
"""
|
||||
try:
|
||||
os.remove(pdb_path)
|
||||
os.remove(hb2_path)
|
||||
except:
|
||||
print("temp pdb/hb already removed or not created to begin with")
|
||||
'''
|
||||
"""
|
||||
return hbond_count
|
||||
|
||||
|
||||
def find_planar_positions(
|
||||
atom_array: AtomArray,
|
||||
mol_info: NucMolInfo,
|
||||
tol: float = 1e-2,
|
||||
) -> Dict:
|
||||
atom_array: AtomArray,
|
||||
mol_info: NucMolInfo,
|
||||
tol: float = 1e-2,
|
||||
) -> Dict:
|
||||
"""Identify residues with planar sidechains via known atom lists or PCA plane-fitting.
|
||||
|
||||
For canonical residues the planar atoms are looked up from ``mol_info``;
|
||||
@@ -285,11 +371,10 @@ def find_planar_positions(
|
||||
|
||||
# for chain_iid, res_id in unique_positions_list:
|
||||
for chain_iid, res_id, res_name in unique_positions_list:
|
||||
|
||||
mask = (
|
||||
(atom_array.chain_iid == chain_iid) &
|
||||
(atom_array.res_id == res_id) &
|
||||
(atom_array.res_name == res_name)
|
||||
(atom_array.chain_iid == chain_iid)
|
||||
& (atom_array.res_id == res_id)
|
||||
& (atom_array.res_name == res_name)
|
||||
)
|
||||
res_atoms = atom_array[mask]
|
||||
|
||||
@@ -297,9 +382,9 @@ def find_planar_positions(
|
||||
if res_name in PLANAR_ATOMS_BY_RESI.keys():
|
||||
# Shared atoms between residue and known planar atoms for that residue type:
|
||||
planar_atom_list = list(
|
||||
set([atm.atom_name for atm in res_atoms]) &
|
||||
set(PLANAR_ATOMS_BY_RESI[res_name])
|
||||
)
|
||||
set([atm.atom_name for atm in res_atoms])
|
||||
& set(PLANAR_ATOMS_BY_RESI[res_name])
|
||||
)
|
||||
planar_atom_list_dict[(chain_iid, res_id)] = planar_atom_list
|
||||
|
||||
# If unknown or noncanonical residue, compute planar atoms geometrically:
|
||||
@@ -337,7 +422,9 @@ def find_planar_positions(
|
||||
all_quad_centered = coords - quad_center
|
||||
quad_centered = quad_coords - quad_center
|
||||
# covariance matrix
|
||||
quad_cov = (quad_centered.T @ quad_centered) / max(quad_coords.shape[0] - 1, 1)
|
||||
quad_cov = (quad_centered.T @ quad_centered) / max(
|
||||
quad_coords.shape[0] - 1, 1
|
||||
)
|
||||
# eigen decomposition
|
||||
_, quad_eigvecs = np.linalg.eigh(quad_cov)
|
||||
quad_normal = quad_eigvecs[:, 0] # eigenvector with smallest eigenvalue
|
||||
@@ -348,34 +435,39 @@ def find_planar_positions(
|
||||
quad_valid_mask = quad_dists <= tol
|
||||
|
||||
# Filter for if we have a valid plane in the first place:
|
||||
valid_plane_filter = (np.nanmax(quad_dists[:4]) < tol)
|
||||
valid_plane_filter = np.nanmax(quad_dists[:4]) < tol
|
||||
# Filter for if we have enough atoms in the plane:
|
||||
plane_atom_filter = (int(np.sum(quad_valid_mask)) >= 4)
|
||||
plane_atom_filter = int(np.sum(quad_valid_mask)) >= 4
|
||||
if valid_plane_filter and plane_atom_filter:
|
||||
# Set the planar atom list for this position to those that are within tol of the plane:
|
||||
# Set the planar atom list for this position to those that are within tol of the plane:
|
||||
# using quad_valid_mask and candidate_planar_atm_names:
|
||||
planar_atom_list = [n for n, keep in zip(candidate_planar_atm_names, quad_valid_mask.tolist()) if keep]
|
||||
|
||||
planar_atom_list = [
|
||||
n
|
||||
for n, keep in zip(
|
||||
candidate_planar_atm_names, quad_valid_mask.tolist()
|
||||
)
|
||||
if keep
|
||||
]
|
||||
|
||||
# not enough atoms close to a common plane
|
||||
else:
|
||||
planar_atom_list = []
|
||||
|
||||
else:
|
||||
|
||||
# need at least 4 atoms to define a robust plane
|
||||
planar_atom_list = []
|
||||
|
||||
planar_atom_list_dict[(chain_iid, res_id)] = planar_atom_list
|
||||
|
||||
planar_atom_list_dict[(chain_iid, res_id)] = planar_atom_list
|
||||
|
||||
return planar_atom_list_dict
|
||||
|
||||
|
||||
def make_coord_list(atom_array: AtomArray,
|
||||
residue_list: list[str],
|
||||
chain_list: list[str],
|
||||
atom_list: list[str],
|
||||
) -> list[list[str]]:
|
||||
def make_coord_list(
|
||||
atom_array: AtomArray,
|
||||
residue_list: list[str],
|
||||
chain_list: list[str],
|
||||
atom_list: list[str],
|
||||
) -> list[list[str]]:
|
||||
"""Extract per-residue representative coordinates from an AtomArray.
|
||||
|
||||
All three input lists must have the same length. Missing atoms are
|
||||
@@ -393,24 +485,20 @@ def make_coord_list(atom_array: AtomArray,
|
||||
"""
|
||||
coord_list = []
|
||||
for res_id, chain_id, atom_name in zip(residue_list, chain_list, atom_list):
|
||||
|
||||
# Check if the residue exists in the atom array
|
||||
if atom_name == "atomized":
|
||||
# Check for atomized residue, in which case we take the first atom of the residue
|
||||
# full mask should be length-1 if atomized
|
||||
mask = (
|
||||
(atom_array.chain_id == chain_id) &
|
||||
(atom_array.res_id == res_id)
|
||||
)
|
||||
mask = (atom_array.chain_id == chain_id) & (atom_array.res_id == res_id)
|
||||
else:
|
||||
# General case for non-atomized residues
|
||||
# should have a unique solution, but we take the first entry either way.
|
||||
mask = (
|
||||
(atom_array.chain_id == chain_id) &
|
||||
(atom_array.res_id == res_id) &
|
||||
(atom_array.atom_name == atom_name)
|
||||
)
|
||||
|
||||
(atom_array.chain_id == chain_id)
|
||||
& (atom_array.res_id == res_id)
|
||||
& (atom_array.atom_name == atom_name)
|
||||
)
|
||||
|
||||
# Get the coordinates for the masked atoms
|
||||
coords = atom_array.coord[mask][0:1]
|
||||
|
||||
@@ -428,8 +516,8 @@ def get_token_level_metadata(
|
||||
*,
|
||||
NA_only: bool = False,
|
||||
planar_only: bool = True,
|
||||
seq_cutoff = 2,
|
||||
gap_length = 200
|
||||
seq_cutoff=2,
|
||||
gap_length=200,
|
||||
) -> dict:
|
||||
"""Build lightweight token-level metadata (no coordinate geometry).
|
||||
|
||||
@@ -468,9 +556,21 @@ def get_token_level_metadata(
|
||||
# molecule type flags
|
||||
# Instantiate encoding locally to avoid retaining large arrays at module scope.
|
||||
sequence_encoding = AF3SequenceEncoding()
|
||||
is_protein = np.isin(token_level_array.res_name, sequence_encoding.all_res_names[sequence_encoding.is_aa_like])
|
||||
is_rna = np.isin(token_level_array.res_name, sequence_encoding.all_res_names[sequence_encoding.is_rna_like])
|
||||
is_dna = np.isin(token_level_array.res_name, sequence_encoding.all_res_names[sequence_encoding.is_dna_like])
|
||||
|
||||
###################
|
||||
# is_protein = np.isin(
|
||||
# token_level_array.res_name,
|
||||
# sequence_encoding.all_res_names[sequence_encoding.is_aa_like],
|
||||
# )
|
||||
##################
|
||||
is_rna = np.isin(
|
||||
token_level_array.res_name,
|
||||
sequence_encoding.all_res_names[sequence_encoding.is_rna_like],
|
||||
)
|
||||
is_dna = np.isin(
|
||||
token_level_array.res_name,
|
||||
sequence_encoding.all_res_names[sequence_encoding.is_dna_like],
|
||||
)
|
||||
|
||||
is_na_arr = (is_dna | is_rna).astype(bool)
|
||||
|
||||
@@ -500,7 +600,7 @@ def get_token_level_metadata(
|
||||
sc_planarity_list.append(False)
|
||||
|
||||
# representative & sugar-edge atoms
|
||||
if (is_glycine(atm.res_name) | is_protein_unknown(atm.res_name)):
|
||||
if is_glycine(atm.res_name) | is_protein_unknown(atm.res_name):
|
||||
rep_atom_i = "CA"
|
||||
S_start_atom_i = None
|
||||
S_stop_atom_i = None
|
||||
@@ -534,7 +634,9 @@ def get_token_level_metadata(
|
||||
S_stop_atom_list.append(S_stop_atom_i)
|
||||
|
||||
# residue index <-> token index map
|
||||
resi2index = {f"{c}__{r}": i for c, r, i in zip(chain_iid_list, resi_list, ind_list)}
|
||||
resi2index = {
|
||||
f"{c}__{r}": i for c, r, i in zip(chain_iid_list, resi_list, ind_list)
|
||||
}
|
||||
|
||||
# relative sequence positions w/ chain gaps
|
||||
rel_pos_list: list[int] = []
|
||||
@@ -547,9 +649,7 @@ def get_token_level_metadata(
|
||||
rel_pos_list.append(int(r + chn_bias))
|
||||
|
||||
rel_pos = np.asarray(rel_pos_list, dtype=np.int64)
|
||||
seq_neighbors = (
|
||||
np.abs(rel_pos[:, None] - rel_pos[None, :]) <= int(seq_cutoff)
|
||||
)
|
||||
seq_neighbors = np.abs(rel_pos[:, None] - rel_pos[None, :]) <= int(seq_cutoff)
|
||||
|
||||
na_inds = np.nonzero(is_na_arr)[0].tolist()
|
||||
na_tensor_inds = {na_i: i for i, na_i in enumerate(na_inds)}
|
||||
@@ -647,12 +747,16 @@ def add_token_level_geometry_data(
|
||||
S_start_atom_list: list[str | None] = token_level_data["S_start_atom_list"]
|
||||
S_stop_atom_list: list[str | None] = token_level_data["S_stop_atom_list"]
|
||||
|
||||
planar_atom_list_dict = find_planar_positions(atom_array, mol_info) # {(chain_iid, res_id): [atom_name, ...]}
|
||||
planar_atom_list_dict = find_planar_positions(
|
||||
atom_array, mol_info
|
||||
) # {(chain_iid, res_id): [atom_name, ...]}
|
||||
has_planar_sc: list[bool] = []
|
||||
|
||||
xyz_planar: list[list[list[float]]] = [] # list[I] of [K_i, 3] (K_i varies per residue)
|
||||
xyz_S_start: list[list[float]] = [] # list[I] of [3]
|
||||
xyz_S_stop: list[list[float]] = [] # list[I] of [3]
|
||||
xyz_planar: list[
|
||||
list[list[float]]
|
||||
] = [] # list[I] of [K_i, 3] (K_i varies per residue)
|
||||
xyz_S_start: list[list[float]] = [] # list[I] of [3]
|
||||
xyz_S_stop: list[list[float]] = [] # list[I] of [3]
|
||||
|
||||
for c, r, S_start_atm, S_stop_atm in zip(
|
||||
chain_iid_list,
|
||||
@@ -663,7 +767,9 @@ def add_token_level_geometry_data(
|
||||
planar_atoms_i = planar_atom_list_dict[(c, r)]
|
||||
has_planar_sc.append(bool(len(planar_atoms_i) >= 4))
|
||||
|
||||
atom_array_i = atom_array[(atom_array.chain_iid == c) & (atom_array.res_id == r)]
|
||||
atom_array_i = atom_array[
|
||||
(atom_array.chain_iid == c) & (atom_array.res_id == r)
|
||||
]
|
||||
|
||||
planar_coords_i: list[list[float]] = []
|
||||
for pl_atm_name_j in planar_atoms_i:
|
||||
@@ -673,7 +779,9 @@ def add_token_level_geometry_data(
|
||||
else:
|
||||
planar_coords_i.append(pl_atom_array_ij[0].coord)
|
||||
|
||||
xyz_planar.append(planar_coords_i if len(planar_coords_i) > 3 else [[float("nan")] * 3])
|
||||
xyz_planar.append(
|
||||
planar_coords_i if len(planar_coords_i) > 3 else [[float("nan")] * 3]
|
||||
)
|
||||
|
||||
if S_start_atm is None:
|
||||
xyz_S_start.append([float("nan"), float("nan"), float("nan")])
|
||||
@@ -698,21 +806,26 @@ def add_token_level_geometry_data(
|
||||
del atom_array_i
|
||||
|
||||
# frame coordinates and backbone direction
|
||||
frame_xyz = np.asarray( # [I, 3] representative-atom coordinates
|
||||
frame_xyz = np.asarray( # [I, 3] representative-atom coordinates
|
||||
make_coord_list(atom_array, resi_list, chain_list, rep_atom_list),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
padded_centers = np.concatenate([frame_xyz[:1], frame_xyz, frame_xyz[-1:]], axis=0) # [I+2, 3]
|
||||
M_i = ( # [I, 3] smoothed backbone-direction vectors
|
||||
(padded_centers[1:-1] - padded_centers[:-2])
|
||||
+ (padded_centers[2:] - padded_centers[1:-1])
|
||||
) / 2.0
|
||||
padded_centers = np.concatenate(
|
||||
[frame_xyz[:1], frame_xyz, frame_xyz[-1:]], axis=0
|
||||
) # [I+2, 3]
|
||||
M_i = (
|
||||
( # [I, 3] smoothed backbone-direction vectors
|
||||
(padded_centers[1:-1] - padded_centers[:-2])
|
||||
+ (padded_centers[2:] - padded_centers[1:-1])
|
||||
)
|
||||
/ 2.0
|
||||
)
|
||||
|
||||
is_planar_arr = np.asarray(has_planar_sc, dtype=bool) # [I]
|
||||
is_planar_arr = np.asarray(has_planar_sc, dtype=bool) # [I]
|
||||
token_level_data["is_planar"] = is_planar_arr
|
||||
|
||||
is_na_arr = np.asarray(token_level_data["is_na"], dtype=bool) # [I]
|
||||
is_na_arr = np.asarray(token_level_data["is_na"], dtype=bool) # [I]
|
||||
if NA_only and planar_only:
|
||||
filter_mask = is_na_arr & is_planar_arr
|
||||
elif NA_only and (not planar_only):
|
||||
@@ -721,7 +834,7 @@ def add_token_level_geometry_data(
|
||||
filter_mask = is_planar_arr.copy()
|
||||
else:
|
||||
filter_mask = np.ones_like(is_na_arr, dtype=bool)
|
||||
token_level_data["filter_mask"] = filter_mask # [I] bool
|
||||
token_level_data["filter_mask"] = filter_mask # [I] bool
|
||||
|
||||
token_level_data.update(
|
||||
{
|
||||
@@ -777,7 +890,7 @@ def _compute_local_frames(
|
||||
n_tokens = len(xyz_planar)
|
||||
|
||||
# Mean-centre the planar atoms per residue
|
||||
centered_points = [ # list[I] of [K_i, 3]
|
||||
centered_points = [ # list[I] of [K_i, 3]
|
||||
np.asarray(xyz_i, dtype=np.float32) - cen_i
|
||||
for xyz_i, cen_i in zip(xyz_planar, planar_centers)
|
||||
]
|
||||
@@ -857,19 +970,23 @@ def _compute_pairwise_geometry(
|
||||
``"Z_ij"`` [I, I, 3], and optionally ``"O_ij"`` [I, I].
|
||||
"""
|
||||
# Orientation-selected pairwise Z-axis
|
||||
Z_sum = Z_i[:, None, :] + Z_i[None, :, :] # [I, I, 3]
|
||||
Z_diff = Z_i[:, None, :] - Z_i[None, :, :] # [I, I, 3]
|
||||
Z_sum = Z_i[:, None, :] + Z_i[None, :, :] # [I, I, 3]
|
||||
Z_diff = Z_i[:, None, :] - Z_i[None, :, :] # [I, I, 3]
|
||||
Z_ij_oris = 0.5 * np.stack((Z_sum, Z_diff), axis=0) # [2, I, I, 3]
|
||||
|
||||
base_ori_ij = ( # [I, I] 0=parallel, 1=antiparallel
|
||||
np.linalg.norm(Z_ij_oris[1], axis=-1) > np.linalg.norm(Z_ij_oris[0], axis=-1)
|
||||
).astype(np.int64)
|
||||
|
||||
Z_ij = np.where(base_ori_ij[..., None] == 0, Z_ij_oris[0], Z_ij_oris[1]) # [I, I, 3]
|
||||
Z_ij = np.where(
|
||||
base_ori_ij[..., None] == 0, Z_ij_oris[0], Z_ij_oris[1]
|
||||
) # [I, I, 3]
|
||||
Z_ij = Z_ij / (np.linalg.norm(Z_ij, axis=-1, keepdims=True) + eps)
|
||||
|
||||
# Pairwise Y (inter-residue direction) and X axes
|
||||
Y_ij = frame_D_ij_vec / (np.linalg.norm(frame_D_ij_vec, axis=-1, keepdims=True) + eps) # [I, I, 3]
|
||||
Y_ij = frame_D_ij_vec / (
|
||||
np.linalg.norm(frame_D_ij_vec, axis=-1, keepdims=True) + eps
|
||||
) # [I, I, 3]
|
||||
X_ij = np.cross(Z_ij, Y_ij) # [I, I, 3]
|
||||
X_ij = X_ij / (np.linalg.norm(X_ij, axis=-1, keepdims=True) + eps)
|
||||
|
||||
@@ -882,22 +999,30 @@ def _compute_pairwise_geometry(
|
||||
np.sum(Z_i[:, None, :] * Y_ij, axis=-1, keepdims=True) * Y_ij
|
||||
+ np.sum(Z_i[:, None, :] * Z_ij, axis=-1, keepdims=True) * Z_ij
|
||||
)
|
||||
proj_Z_i_YZ_norm = proj_Z_i_YZ / (np.linalg.norm(proj_Z_i_YZ, axis=-1, keepdims=True) + eps)
|
||||
cos_buckle = np.sum(proj_Z_i_YZ_norm * (-proj_Z_i_YZ_norm.swapaxes(0, 1)), axis=-1) # [I, I]
|
||||
proj_Z_i_YZ_norm = proj_Z_i_YZ / (
|
||||
np.linalg.norm(proj_Z_i_YZ, axis=-1, keepdims=True) + eps
|
||||
)
|
||||
cos_buckle = np.sum(
|
||||
proj_Z_i_YZ_norm * (-proj_Z_i_YZ_norm.swapaxes(0, 1)), axis=-1
|
||||
) # [I, I]
|
||||
|
||||
# Propeller (P_ij)
|
||||
proj_Z_i_ZX = ( # [I, I, 3]
|
||||
np.sum(Z_i[:, None, :] * Z_ij, axis=-1, keepdims=True) * Z_ij
|
||||
+ np.sum(Z_i[:, None, :] * X_ij, axis=-1, keepdims=True) * X_ij
|
||||
)
|
||||
proj_Z_i_ZX_norm = proj_Z_i_ZX / (np.linalg.norm(proj_Z_i_ZX, axis=-1, keepdims=True) + eps)
|
||||
cos_propeller = np.sum(proj_Z_i_ZX_norm * (-proj_Z_i_ZX_norm.swapaxes(0, 1)), axis=-1) # [I, I]
|
||||
proj_Z_i_ZX_norm = proj_Z_i_ZX / (
|
||||
np.linalg.norm(proj_Z_i_ZX, axis=-1, keepdims=True) + eps
|
||||
)
|
||||
cos_propeller = np.sum(
|
||||
proj_Z_i_ZX_norm * (-proj_Z_i_ZX_norm.swapaxes(0, 1)), axis=-1
|
||||
) # [I, I]
|
||||
|
||||
if clamp:
|
||||
cos_buckle = np.clip(cos_buckle, -1.0, 1.0)
|
||||
cos_propeller = np.clip(cos_propeller, -1.0, 1.0)
|
||||
|
||||
B_ij = np.arccos(cos_buckle) # [I, I]
|
||||
B_ij = np.arccos(cos_buckle) # [I, I]
|
||||
P_ij = np.arccos(cos_propeller) # [I, I]
|
||||
|
||||
result: dict[str, np.ndarray] = {
|
||||
@@ -920,8 +1045,12 @@ def _compute_pairwise_geometry(
|
||||
np.sum(X_i[:, None, :] * X_ij, axis=-1, keepdims=True) * X_ij
|
||||
+ np.sum(X_i[:, None, :] * Y_ij, axis=-1, keepdims=True) * Y_ij
|
||||
)
|
||||
proj_X_i_XY_norm = proj_X_i_XY / (np.linalg.norm(proj_X_i_XY, axis=-1, keepdims=True) + eps)
|
||||
cos_opening = np.sum(proj_X_i_XY_norm * proj_X_i_XY_norm.swapaxes(0, 1), axis=-1) # [I, I]
|
||||
proj_X_i_XY_norm = proj_X_i_XY / (
|
||||
np.linalg.norm(proj_X_i_XY, axis=-1, keepdims=True) + eps
|
||||
)
|
||||
cos_opening = np.sum(
|
||||
proj_X_i_XY_norm * proj_X_i_XY_norm.swapaxes(0, 1), axis=-1
|
||||
) # [I, I]
|
||||
if clamp:
|
||||
cos_opening = np.clip(cos_opening, -1.0, 1.0)
|
||||
result["O_ij"] = np.arccos(cos_opening) # [I, I]
|
||||
@@ -989,13 +1118,14 @@ def _compute_basepair_mask(
|
||||
| (P_ij >= math.pi - mol_info.base_geometry_limits["P_ij"])
|
||||
)
|
||||
|
||||
D_ij_filter = (D_ij <= mol_info.base_geometry_limits["D_ij"])
|
||||
D_ij_filter = D_ij <= mol_info.base_geometry_limits["D_ij"]
|
||||
|
||||
bp_geom_filter = H_ij_filter & B_ij_filter & P_ij_filter & D_ij_filter # [I, I]
|
||||
|
||||
if bool_only:
|
||||
basepairs_bool_ij = ( # [I, I]
|
||||
(~seq_neighbors) & bp_geom_filter
|
||||
(~seq_neighbors)
|
||||
& bp_geom_filter
|
||||
& (bp_preds >= float(mol_info.bp_val_cutoff))
|
||||
)
|
||||
return basepairs_bool_ij
|
||||
@@ -1015,7 +1145,7 @@ def _compute_basepair_mask(
|
||||
|
||||
|
||||
def compute_nucleic_ss(
|
||||
mol_info,
|
||||
mol_info,
|
||||
token_level_data,
|
||||
hbond_count,
|
||||
clamp_pairwise_params=True,
|
||||
@@ -1061,14 +1191,24 @@ def compute_nucleic_ss(
|
||||
mask_1d = np.asarray(token_level_data["filter_mask"], dtype=bool) # [I_full]
|
||||
|
||||
# --- Unpack and filter token-level data ----------------------
|
||||
M_i = np.asarray(token_level_data["M_i"], dtype=np.float32)[mask_1d] # [I, 3]
|
||||
frame_xyz = np.asarray(token_level_data["frame_xyz"], dtype=np.float32)[mask_1d] # [I, 3]
|
||||
xyz_S_start = [v for v, k in zip(token_level_data["xyz_S_start"], mask_1d) if k] # list[I] of [3]
|
||||
xyz_S_stop = [v for v, k in zip(token_level_data["xyz_S_stop"], mask_1d) if k] # list[I] of [3]
|
||||
xyz_planar = [v for v, k in zip(token_level_data["xyz_planar"], mask_1d) if k] # list[I] of [K_i, 3]
|
||||
M_i = np.asarray(token_level_data["M_i"], dtype=np.float32)[mask_1d] # [I, 3]
|
||||
frame_xyz = np.asarray(token_level_data["frame_xyz"], dtype=np.float32)[
|
||||
mask_1d
|
||||
] # [I, 3]
|
||||
xyz_S_start = [
|
||||
v for v, k in zip(token_level_data["xyz_S_start"], mask_1d) if k
|
||||
] # list[I] of [3]
|
||||
xyz_S_stop = [
|
||||
v for v, k in zip(token_level_data["xyz_S_stop"], mask_1d) if k
|
||||
] # list[I] of [3]
|
||||
xyz_planar = [
|
||||
v for v, k in zip(token_level_data["xyz_planar"], mask_1d) if k
|
||||
] # list[I] of [K_i, 3]
|
||||
|
||||
hbond_count = np.asarray(hbond_count)[mask_1d, :][:, mask_1d] # [I, I, 3]
|
||||
seq_neighbors = np.asarray(token_level_data["seq_neighbors"], dtype=bool)[mask_1d, :][:, mask_1d] # [I, I]
|
||||
hbond_count = np.asarray(hbond_count)[mask_1d, :][:, mask_1d] # [I, I, 3]
|
||||
seq_neighbors = np.asarray(token_level_data["seq_neighbors"], dtype=bool)[
|
||||
mask_1d, :
|
||||
][:, mask_1d] # [I, I]
|
||||
|
||||
# Nothing passed NA/planar filtering for this structure.
|
||||
# Return empty outputs instead of failing downstream on np.stack([]).
|
||||
@@ -1107,12 +1247,15 @@ def compute_nucleic_ss(
|
||||
|
||||
# --- Precompute centroids and displacement vectors -----------
|
||||
planar_centers = np.stack( # [I, 3]
|
||||
[np.nanmean(np.asarray(xyz_i, dtype=np.float32), axis=0) for xyz_i in xyz_planar],
|
||||
[
|
||||
np.nanmean(np.asarray(xyz_i, dtype=np.float32), axis=0)
|
||||
for xyz_i in xyz_planar
|
||||
],
|
||||
axis=0,
|
||||
).astype(np.float32)
|
||||
|
||||
frame_D_ij_vec = frame_xyz[None, :, :] - frame_xyz[:, None, :] # [I, I, 3]
|
||||
sc_D_ij_vec = planar_centers[None, :, :] - planar_centers[:, None, :] # [I, I, 3]
|
||||
frame_D_ij_vec = frame_xyz[None, :, :] - frame_xyz[:, None, :] # [I, I, 3]
|
||||
sc_D_ij_vec = planar_centers[None, :, :] - planar_centers[:, None, :] # [I, I, 3]
|
||||
|
||||
# --- CALC I: per-residue local coordinate frames -------------
|
||||
need_full_frame = return_local_params or return_opening_angle
|
||||
@@ -1125,8 +1268,8 @@ def compute_nucleic_ss(
|
||||
compute_full_frame=need_full_frame,
|
||||
eps=eps,
|
||||
)
|
||||
Z_i = local_frames["Z_i"] # [I, 3]
|
||||
X_i = local_frames.get("X_i") # [I, 3] or None
|
||||
Z_i = local_frames["Z_i"] # [I, 3]
|
||||
X_i = local_frames.get("X_i") # [I, 3] or None
|
||||
|
||||
# --- CALC II: pairwise base-step geometry --------------------
|
||||
pw_geom = _compute_pairwise_geometry(
|
||||
@@ -1187,7 +1330,6 @@ def compute_nucleic_ss(
|
||||
return nucleic_ss_data
|
||||
|
||||
|
||||
|
||||
def annotate_na_ss(
|
||||
atom_array: AtomArray,
|
||||
*,
|
||||
@@ -1258,10 +1400,10 @@ def annotate_na_ss(
|
||||
NA_only=NA_only,
|
||||
planar_only=planar_only,
|
||||
)
|
||||
# Note: this mask gives positions that are *chemically valid* for forming
|
||||
# Note: this mask gives positions that are *chemically valid* for forming
|
||||
# base pairs, which is different from custom mask-generation for features
|
||||
mask_1d = np.asarray(token_level_data["filter_mask"], dtype=bool)
|
||||
|
||||
|
||||
subset_idxs = np.nonzero(mask_1d)[0]
|
||||
|
||||
is_na_full = np.asarray(token_level_data["is_na"], dtype=bool)
|
||||
@@ -1295,15 +1437,19 @@ def annotate_na_ss(
|
||||
if planar_only:
|
||||
n_tokens = bp_bool.shape[0]
|
||||
has_planar_sc = np.asarray(
|
||||
token_level_data.get("has_planar_sc", np.ones(n_tokens, dtype=bool)), dtype=bool
|
||||
token_level_data.get("has_planar_sc", np.ones(n_tokens, dtype=bool)),
|
||||
dtype=bool,
|
||||
)
|
||||
bp_bool &= has_planar_sc[:, None]
|
||||
bp_bool &= has_planar_sc[None, :]
|
||||
|
||||
# Optional: filter to canonical Watson-Crick basepairs only.
|
||||
# Sampled probabilistically to allow mixed supervision during training.
|
||||
do_canonical_filter = bool(p_canonical_bp_filter and (np.random.rand() < float(p_canonical_bp_filter)))
|
||||
do_canonical_filter = bool(
|
||||
p_canonical_bp_filter and (np.random.rand() < float(p_canonical_bp_filter))
|
||||
)
|
||||
if do_canonical_filter:
|
||||
|
||||
def _base_letter(res_name: str) -> str | None:
|
||||
rn = str(res_name).strip().upper()
|
||||
if rn in STANDARD_RNA:
|
||||
@@ -1313,11 +1459,16 @@ def annotate_na_ss(
|
||||
return None
|
||||
|
||||
allowed_pairs = {
|
||||
("A", "U"), ("U", "A"),
|
||||
("A", "T"), ("T", "A"),
|
||||
("G", "C"), ("C", "G"),
|
||||
("A", "U"),
|
||||
("U", "A"),
|
||||
("A", "T"),
|
||||
("T", "A"),
|
||||
("G", "C"),
|
||||
("C", "G"),
|
||||
}
|
||||
base_letters_full: list[str | None] = [_base_letter(rn) for rn in token_res_names]
|
||||
base_letters_full: list[str | None] = [
|
||||
_base_letter(rn) for rn in token_res_names
|
||||
]
|
||||
|
||||
bp_bool = np.asarray(bp_bool, dtype=bool)
|
||||
bp_rows_tmp, bp_cols_tmp = np.nonzero(bp_bool)
|
||||
@@ -1401,7 +1552,9 @@ def annotate_na_ss(
|
||||
continue
|
||||
# A residue is treated as atomized if any atom in the residue carries atomize=True.
|
||||
if "atomize" in atom_array.get_annotation_categories():
|
||||
residue_is_atomized = bool(np.any(np.asarray(atom_array.atomize[int(start):stop], dtype=bool)))
|
||||
residue_is_atomized = bool(
|
||||
np.any(np.asarray(atom_array.atomize[int(start) : stop], dtype=bool))
|
||||
)
|
||||
else:
|
||||
residue_is_atomized = False
|
||||
if residue_is_atomized:
|
||||
|
||||
@@ -216,7 +216,9 @@ def get_crop_transform(
|
||||
), "Crop center cutoff distance must be greater than 0"
|
||||
|
||||
pre_crop_transforms = [
|
||||
SubsampleToTypes(allowed_types=allowed_types, association_scheme=association_scheme),
|
||||
SubsampleToTypes(
|
||||
allowed_types=allowed_types, association_scheme=association_scheme
|
||||
),
|
||||
]
|
||||
|
||||
cropping_transform = RandomRoute(
|
||||
@@ -361,11 +363,9 @@ def build_atom14_base_pipeline_(
|
||||
max_ss_frac_to_provide: float,
|
||||
min_ss_island_len: int,
|
||||
max_ss_island_len: int,
|
||||
|
||||
## Nucleic acid features #####
|
||||
#add_na_pair_features: bool,
|
||||
# add_na_pair_features: bool,
|
||||
## This should not be necessary, controlled through feature names in model, and meta conditioning probabilities, inference behavior handled in transform itself #####
|
||||
|
||||
**_, # dump additional kwargs (e.g. msa stuff)
|
||||
):
|
||||
"""
|
||||
@@ -373,7 +373,7 @@ def build_atom14_base_pipeline_(
|
||||
"""
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
|
||||
|
||||
# Add any data necessary for downstream transforms
|
||||
transforms = [
|
||||
AddData(
|
||||
@@ -415,7 +415,7 @@ def build_atom14_base_pipeline_(
|
||||
max_binder_length=max_binder_length,
|
||||
max_atoms_in_crop=max_atoms_in_crop,
|
||||
allowed_types=allowed_types,
|
||||
association_scheme=association_scheme
|
||||
association_scheme=association_scheme,
|
||||
)
|
||||
|
||||
if zero_occ_on_exposure_after_cropping:
|
||||
@@ -448,7 +448,7 @@ def build_atom14_base_pipeline_(
|
||||
)
|
||||
)
|
||||
# Add nucleic acid geometry features
|
||||
#if add_na_pair_features:
|
||||
# if add_na_pair_features:
|
||||
transforms.append(
|
||||
CalculateNucleicAcidGeomFeats(
|
||||
is_inference,
|
||||
@@ -639,8 +639,8 @@ def build_atom14_base_pipeline(
|
||||
kwargs.setdefault("min_ss_island_len", 0)
|
||||
kwargs.setdefault("max_ss_island_len", 999)
|
||||
kwargs.setdefault("max_binder_length", 999)
|
||||
# This should not be necessary.
|
||||
#kwargs.setdefault("add_na_pair_features", False)
|
||||
# This should not be necessary.
|
||||
# kwargs.setdefault("add_na_pair_features", False)
|
||||
|
||||
kwargs.setdefault("b_factor_min", None)
|
||||
kwargs.setdefault("zero_occ_on_exposure_after_cropping", False)
|
||||
|
||||
@@ -13,14 +13,13 @@ from atomworks.ml.utils.token import (
|
||||
spread_token_wise,
|
||||
)
|
||||
from biotite.structure import AtomArray, get_residue_starts
|
||||
from rfd3.constants import backbone_atoms_RNA
|
||||
from rfd3.transforms.conditioning_utils import (
|
||||
random_condition,
|
||||
sample_island_tokens,
|
||||
sample_subgraph_atoms,
|
||||
)
|
||||
|
||||
from rfd3.constants import backbone_atoms_RNA
|
||||
|
||||
nx.from_numpy_matrix = nx.from_numpy_array
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -74,7 +73,7 @@ class IslandCondition(TrainingCondition):
|
||||
p_fix_motif_coordinates,
|
||||
p_fix_motif_sequence,
|
||||
p_unindex_motif_tokens,
|
||||
association_scheme = 'atom14',
|
||||
association_scheme="atom14",
|
||||
):
|
||||
self.name = name
|
||||
self.frequency = frequency
|
||||
@@ -93,7 +92,7 @@ class IslandCondition(TrainingCondition):
|
||||
self.p_fix_motif_coordinates = p_fix_motif_coordinates
|
||||
self.p_fix_motif_sequence = p_fix_motif_sequence
|
||||
self.p_unindex_motif_tokens = p_unindex_motif_tokens
|
||||
|
||||
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
def is_valid_for_example(self, data) -> bool:
|
||||
@@ -107,7 +106,7 @@ class IslandCondition(TrainingCondition):
|
||||
else:
|
||||
if np.any(is_protein):
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
def sample_motif_tokens(self, atom_array):
|
||||
@@ -162,7 +161,7 @@ class IslandCondition(TrainingCondition):
|
||||
if random_condition(self.p_diffuse_motif_sidechains):
|
||||
backbone_atoms = backbone_atoms_RNA.copy()
|
||||
backbone_atoms.remove("C1'")
|
||||
backbone_atoms = ["N", "C", "CA"] + backbone_atoms #covers DNA also
|
||||
backbone_atoms = ["N", "C", "CA"] + backbone_atoms # covers DNA also
|
||||
if random_condition(self.p_include_oxygen_in_backbone_mask):
|
||||
backbone_atoms.append("O")
|
||||
is_motif_atom = is_motif_atom & np.isin(
|
||||
@@ -500,7 +499,9 @@ def sample_conditioning_strategy(
|
||||
atom_array.set_annotation(
|
||||
"is_motif_atom_unindexed",
|
||||
sample_unindexed_atoms(
|
||||
atom_array, p_unindex_motif_tokens=p_unindex_motif_tokens, association_scheme=association_scheme
|
||||
atom_array,
|
||||
p_unindex_motif_tokens=p_unindex_motif_tokens,
|
||||
association_scheme=association_scheme,
|
||||
),
|
||||
)
|
||||
return atom_array
|
||||
@@ -526,7 +527,6 @@ def sample_is_motif_atom_with_fixed_seq(
|
||||
is_motif_atom_with_fixed_seq = (
|
||||
is_motif_atom_with_fixed_seq | ~atom_array.is_protein
|
||||
)
|
||||
|
||||
|
||||
return is_motif_atom_with_fixed_seq
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ def map_to_association_scheme(
|
||||
else:
|
||||
return ATOM_NAMES[idxs]
|
||||
|
||||
|
||||
def map_names_to_elements(
|
||||
atom_names: list | str, default=VIRTUAL_ATOM_ELEMENT_NAME
|
||||
) -> np.ndarray:
|
||||
@@ -179,7 +180,7 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
token_ids = np.unique(atom_array.token_id)
|
||||
assert len(token_ids) == len(
|
||||
is_motif_atom_with_fixed_seq
|
||||
), "Token ids and token level array have different lengths!"
|
||||
), "Token ids and token level array have different lengths!"
|
||||
|
||||
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
|
||||
if self.association_scheme == "atom23":
|
||||
@@ -230,14 +231,14 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
else:
|
||||
n_atoms_per_token = self.n_atoms_per_token
|
||||
central_atom = self.atom_to_pad_from
|
||||
|
||||
|
||||
n_pad = n_atoms_per_token - len(token)
|
||||
|
||||
if n_pad > 0:
|
||||
mask = get_af3_token_representative_masks(
|
||||
token, central_atom=central_atom
|
||||
)
|
||||
|
||||
|
||||
assert_single_representative(token, central_atom=central_atom)
|
||||
|
||||
# ... Create virtual atoms
|
||||
|
||||
@@ -452,7 +452,10 @@ as input and return a three-element list or numpy array of floats.
|
||||
|
||||
|
||||
def set_com(
|
||||
atom_array, ori_token: list | None = None, infer_ori_strategy: str | None = None, ori_jitter: float | None = None
|
||||
atom_array,
|
||||
ori_token: list | None = None,
|
||||
infer_ori_strategy: str | None = None,
|
||||
ori_jitter: float | None = None,
|
||||
):
|
||||
if exists(ori_token):
|
||||
center = np.array([float(x) for x in ori_token], dtype=atom_array.coord.dtype)
|
||||
@@ -512,7 +515,7 @@ def set_com(
|
||||
|
||||
# Random length (mean ~ scale)
|
||||
length = np.random.exponential(scale=scale)
|
||||
jittered_offset = direction*length
|
||||
jittered_offset = direction * length
|
||||
|
||||
atom_array.coord -= jittered_offset
|
||||
|
||||
|
||||
Reference in New Issue
Block a user