inference fixes for legacy_input_parsing

This commit is contained in:
Raktim Mitra
2026-02-03 12:54:16 -08:00
parent f0ab0fedae
commit b5beff039e

View File

@@ -268,36 +268,57 @@ def fetch_motif_residue_(
return subarray
def create_diffused_residues_(n):
def create_diffused_residues_(n, polymer_type='p'):
from rfd3.constants import (
ATOM23_ATOM_NAME_TO_ELEMENT,
backbone_atoms_DNA,
backbone_atoms_RNA,
)
if n <= 0:
raise ValueError(f"Negative/null residue count ({n}) not allowed.")
if polymer_type == 'P':
res_name = 'ALA'
bb_len = 5
bb_atom_names = ["N", "CA", "C", "O", "CB"]
elif polymer_type == 'R':
res_name = 'A'
bb_len = len(backbone_atoms_RNA)
bb_atom_names = backbone_atoms_RNA
elif polymer_type == 'D':
res_name = 'DA'
bb_len = len(backbone_atoms_DNA)
bb_atom_names = backbone_atoms_DNA
else:
raise ValueError(f"invalid polymer type detected: {polymer_type}, check contig!")
bb_elements = [ATOM23_ATOM_NAME_TO_ELEMENT[item] for item in bb_atom_names]
atoms = []
[
atoms.extend(
[
struc.Atom(
np.array([0.0, 0.0, 0.0], dtype=np.float32),
res_name="ALA",
res_name=res_name,
res_id=idx,
)
for _ in range(5)
for _ in range(bb_len)
]
)
for idx in range(1, n + 1)
]
array = struc.array(atoms)
array.set_annotation(
"element", np.array(["N", "C", "C", "O", "C"] * n, dtype="<U2")
"element", np.array(bb_elements * n, dtype="<U2")
)
array.set_annotation(
"atom_name", np.array(["N", "CA", "C", "O", "CB"] * n, dtype="<U2")
"atom_name", np.array(bb_atom_names * n, dtype="<U3")
)
array = set_default_conditioning_annotations(array, motif=False)
array = set_common_annotations(array)
return array
def accumulate_components(
components,
src_atom_array,
@@ -379,11 +400,16 @@ def accumulate_components(
np.ones(atom_array_insert.shape[0], dtype=int),
)
else:
n = int(component)
if component[-1] in ["P", "R", "D"]: # if polymer type specified
polymer_type = component[-1] # can be 'P'rotein, 'R'NA, 'D'NA
n = int(component[:-1])
else:
polymer_type = "P"
n = int(components)
if n == 0 or unindexed_components_started:
res_id += n
continue
atom_array_insert = create_diffused_residues(n)
atom_array_insert = create_diffused_residues(n, polymer_type)
for key in optional_conditions:
atom_array_insert.set_annotation(
key,