From b5beff039eecbb8f2a03c5eedbff2df64e908db9 Mon Sep 17 00:00:00 2001 From: Raktim Mitra Date: Tue, 3 Feb 2026 12:54:16 -0800 Subject: [PATCH] inference fixes for legacy_input_parsing --- .../rfd3/inference/legacy_input_parsing.py | 44 +++++++++++++++---- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/models/rfd3/src/rfd3/inference/legacy_input_parsing.py b/models/rfd3/src/rfd3/inference/legacy_input_parsing.py index 0d344fd..7bf1a50 100644 --- a/models/rfd3/src/rfd3/inference/legacy_input_parsing.py +++ b/models/rfd3/src/rfd3/inference/legacy_input_parsing.py @@ -268,36 +268,57 @@ def fetch_motif_residue_( return subarray -def create_diffused_residues_(n): +def create_diffused_residues_(n, polymer_type='p'): + from rfd3.constants import ( + ATOM23_ATOM_NAME_TO_ELEMENT, + backbone_atoms_DNA, + backbone_atoms_RNA, + ) if n <= 0: raise ValueError(f"Negative/null residue count ({n}) not allowed.") - + + if polymer_type == 'P': + res_name = 'ALA' + bb_len = 5 + bb_atom_names = ["N", "CA", "C", "O", "CB"] + elif polymer_type == 'R': + res_name = 'A' + bb_len = len(backbone_atoms_RNA) + bb_atom_names = backbone_atoms_RNA + elif polymer_type == 'D': + res_name = 'DA' + bb_len = len(backbone_atoms_DNA) + bb_atom_names = backbone_atoms_DNA + else: + raise ValueError(f"invalid polymer type detected: {polymer_type}, check contig!") + + bb_elements = [ATOM23_ATOM_NAME_TO_ELEMENT[item] for item in bb_atom_names] + atoms = [] [ atoms.extend( [ struc.Atom( np.array([0.0, 0.0, 0.0], dtype=np.float32), - res_name="ALA", + res_name=res_name, res_id=idx, ) - for _ in range(5) + for _ in range(bb_len) ] ) for idx in range(1, n + 1) ] array = struc.array(atoms) array.set_annotation( - "element", np.array(["N", "C", "C", "O", "C"] * n, dtype="