mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
inference fixes for legacy_input_parsing
This commit is contained in:
@@ -268,36 +268,57 @@ def fetch_motif_residue_(
|
||||
return subarray
|
||||
|
||||
|
||||
def create_diffused_residues_(n):
|
||||
def create_diffused_residues_(n, polymer_type='p'):
|
||||
from rfd3.constants import (
|
||||
ATOM23_ATOM_NAME_TO_ELEMENT,
|
||||
backbone_atoms_DNA,
|
||||
backbone_atoms_RNA,
|
||||
)
|
||||
if n <= 0:
|
||||
raise ValueError(f"Negative/null residue count ({n}) not allowed.")
|
||||
|
||||
|
||||
if polymer_type == 'P':
|
||||
res_name = 'ALA'
|
||||
bb_len = 5
|
||||
bb_atom_names = ["N", "CA", "C", "O", "CB"]
|
||||
elif polymer_type == 'R':
|
||||
res_name = 'A'
|
||||
bb_len = len(backbone_atoms_RNA)
|
||||
bb_atom_names = backbone_atoms_RNA
|
||||
elif polymer_type == 'D':
|
||||
res_name = 'DA'
|
||||
bb_len = len(backbone_atoms_DNA)
|
||||
bb_atom_names = backbone_atoms_DNA
|
||||
else:
|
||||
raise ValueError(f"invalid polymer type detected: {polymer_type}, check contig!")
|
||||
|
||||
bb_elements = [ATOM23_ATOM_NAME_TO_ELEMENT[item] for item in bb_atom_names]
|
||||
|
||||
atoms = []
|
||||
[
|
||||
atoms.extend(
|
||||
[
|
||||
struc.Atom(
|
||||
np.array([0.0, 0.0, 0.0], dtype=np.float32),
|
||||
res_name="ALA",
|
||||
res_name=res_name,
|
||||
res_id=idx,
|
||||
)
|
||||
for _ in range(5)
|
||||
for _ in range(bb_len)
|
||||
]
|
||||
)
|
||||
for idx in range(1, n + 1)
|
||||
]
|
||||
array = struc.array(atoms)
|
||||
array.set_annotation(
|
||||
"element", np.array(["N", "C", "C", "O", "C"] * n, dtype="<U2")
|
||||
"element", np.array(bb_elements * n, dtype="<U2")
|
||||
)
|
||||
array.set_annotation(
|
||||
"atom_name", np.array(["N", "CA", "C", "O", "CB"] * n, dtype="<U2")
|
||||
"atom_name", np.array(bb_atom_names * n, dtype="<U3")
|
||||
)
|
||||
array = set_default_conditioning_annotations(array, motif=False)
|
||||
array = set_common_annotations(array)
|
||||
return array
|
||||
|
||||
|
||||
def accumulate_components(
|
||||
components,
|
||||
src_atom_array,
|
||||
@@ -379,11 +400,16 @@ def accumulate_components(
|
||||
np.ones(atom_array_insert.shape[0], dtype=int),
|
||||
)
|
||||
else:
|
||||
n = int(component)
|
||||
if component[-1] in ["P", "R", "D"]: # if polymer type specified
|
||||
polymer_type = component[-1] # can be 'P'rotein, 'R'NA, 'D'NA
|
||||
n = int(component[:-1])
|
||||
else:
|
||||
polymer_type = "P"
|
||||
n = int(components)
|
||||
if n == 0 or unindexed_components_started:
|
||||
res_id += n
|
||||
continue
|
||||
atom_array_insert = create_diffused_residues(n)
|
||||
atom_array_insert = create_diffused_residues(n, polymer_type)
|
||||
for key in optional_conditions:
|
||||
atom_array_insert.set_annotation(
|
||||
key,
|
||||
|
||||
Reference in New Issue
Block a user