ligand chain fix; add ori_jitter option in new dialect

This commit is contained in:
Raktim Mitra
2026-02-10 18:25:01 -08:00
committed by Raktim Mitra
parent cff03801ed
commit d7fe89b8f5
3 changed files with 29 additions and 2 deletions

View File

@@ -184,6 +184,7 @@ class DesignInputSpecification(BaseModel):
symmetry: Optional[SymmetryConfig] = Field(None, description="Symmetry specification, see docs/symmetry.md")
# Centering & COM guidance
ori_token: Optional[list[float]] = Field(None, description="Origin coordinates")
ori_jitter: Optional[float] = Field(None, description="Jitter ori in a random direction and use ori_jitter to sample distance via exponential distribution")
infer_ori_strategy: Optional[str] = Field(None, description="Strategy for inferring origin; `com` or `hotspots`")
# Additional global conditioning
plddt_enhanced: Optional[bool] = Field(True, description="Enable pLDDT enhancement")
@@ -725,6 +726,13 @@ class DesignInputSpecification(BaseModel):
ligand_array.set_annotation(
annot, np.full(ligand_array.array_length(), default)
)
chain_cand = 'X'
while chain_cand in atom_array.chain_id.tolist():
chain_cand = chain_cand + chain_cand
ligand_chain = np.array([chain_cand]*len(ligand_array))
ligand_array.chain_id = ligand_chain
atom_array = atom_array + ligand_array
return atom_array
@@ -749,8 +757,10 @@ class DesignInputSpecification(BaseModel):
"Partial diffusion with symmetry: skipping COM centering to preserve chain spacing"
)
else:
if not exists(self.ori_jitter):
self.ori_jitter = None
atom_array = set_com(
atom_array, ori_token=None, infer_ori_strategy="com"
atom_array, ori_token=None, infer_ori_strategy="com", ori_jitter=self.ori_jitter
)
else:
# Standard: set ori token, zero out diffused atoms

View File

@@ -736,6 +736,12 @@ def create_atom_array_from_design_specification_legacy(
+ np.max(atom_array.res_id)
+ 1
)
chain_cand = 'X'
while chain_cand in atom_array.chain_id.tolist():
chain_cand = chain_cand + chain_cand
ligand_chain = np.array([chain_cand]*len(ligand_array))
ligand_array.chain_id = ligand_chain
atom_array = atom_array + ligand_array
# ... Apply symmetry if it exists ahead of any other processing

View File

@@ -452,7 +452,7 @@ as input and return a three-element list or numpy array of floats.
def set_com(
atom_array, ori_token: list | None = None, infer_ori_strategy: str | None = None
atom_array, ori_token: list | None = None, infer_ori_strategy: str | None = None, ori_jitter: float | None = None
):
if exists(ori_token):
center = np.array([float(x) for x in ori_token], dtype=atom_array.coord.dtype)
@@ -505,6 +505,17 @@ def set_com(
atom_array.coord = np.zeros_like(
atom_array.coord, dtype=atom_array.coord.dtype
)
if ori_jitter is not None:
# randomly jitter ori with given scale
direction = np.random.normal(size=3)
direction /= np.linalg.norm(direction)
# Random length (mean ~ scale)
length = np.random.exponential(scale=scale)
jittered_offset = direction*length
atom_array.coord -= jittered_offset
return atom_array