mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
ligand chain fix; add ori_jitter option in new dialect
This commit is contained in:
committed by
Raktim Mitra
parent
cff03801ed
commit
d7fe89b8f5
@@ -184,6 +184,7 @@ class DesignInputSpecification(BaseModel):
|
||||
symmetry: Optional[SymmetryConfig] = Field(None, description="Symmetry specification, see docs/symmetry.md")
|
||||
# Centering & COM guidance
|
||||
ori_token: Optional[list[float]] = Field(None, description="Origin coordinates")
|
||||
ori_jitter: Optional[float] = Field(None, description="Jitter ori in a random direction and use ori_jitter to sample distance via exponential distribution")
|
||||
infer_ori_strategy: Optional[str] = Field(None, description="Strategy for inferring origin; `com` or `hotspots`")
|
||||
# Additional global conditioning
|
||||
plddt_enhanced: Optional[bool] = Field(True, description="Enable pLDDT enhancement")
|
||||
@@ -725,6 +726,13 @@ class DesignInputSpecification(BaseModel):
|
||||
ligand_array.set_annotation(
|
||||
annot, np.full(ligand_array.array_length(), default)
|
||||
)
|
||||
|
||||
chain_cand = 'X'
|
||||
while chain_cand in atom_array.chain_id.tolist():
|
||||
chain_cand = chain_cand + chain_cand
|
||||
ligand_chain = np.array([chain_cand]*len(ligand_array))
|
||||
ligand_array.chain_id = ligand_chain
|
||||
|
||||
atom_array = atom_array + ligand_array
|
||||
return atom_array
|
||||
|
||||
@@ -749,8 +757,10 @@ class DesignInputSpecification(BaseModel):
|
||||
"Partial diffusion with symmetry: skipping COM centering to preserve chain spacing"
|
||||
)
|
||||
else:
|
||||
if not exists(self.ori_jitter):
|
||||
self.ori_jitter = None
|
||||
atom_array = set_com(
|
||||
atom_array, ori_token=None, infer_ori_strategy="com"
|
||||
atom_array, ori_token=None, infer_ori_strategy="com", ori_jitter=self.ori_jitter
|
||||
)
|
||||
else:
|
||||
# Standard: set ori token, zero out diffused atoms
|
||||
|
||||
@@ -736,6 +736,12 @@ def create_atom_array_from_design_specification_legacy(
|
||||
+ np.max(atom_array.res_id)
|
||||
+ 1
|
||||
)
|
||||
chain_cand = 'X'
|
||||
while chain_cand in atom_array.chain_id.tolist():
|
||||
chain_cand = chain_cand + chain_cand
|
||||
ligand_chain = np.array([chain_cand]*len(ligand_array))
|
||||
ligand_array.chain_id = ligand_chain
|
||||
|
||||
atom_array = atom_array + ligand_array
|
||||
|
||||
# ... Apply symmetry if it exists ahead of any other processing
|
||||
|
||||
@@ -452,7 +452,7 @@ as input and return a three-element list or numpy array of floats.
|
||||
|
||||
|
||||
def set_com(
|
||||
atom_array, ori_token: list | None = None, infer_ori_strategy: str | None = None
|
||||
atom_array, ori_token: list | None = None, infer_ori_strategy: str | None = None, ori_jitter: float | None = None
|
||||
):
|
||||
if exists(ori_token):
|
||||
center = np.array([float(x) for x in ori_token], dtype=atom_array.coord.dtype)
|
||||
@@ -505,6 +505,17 @@ def set_com(
|
||||
atom_array.coord = np.zeros_like(
|
||||
atom_array.coord, dtype=atom_array.coord.dtype
|
||||
)
|
||||
if ori_jitter is not None:
|
||||
# randomly jitter ori with given scale
|
||||
direction = np.random.normal(size=3)
|
||||
direction /= np.linalg.norm(direction)
|
||||
|
||||
# Random length (mean ~ scale)
|
||||
length = np.random.exponential(scale=scale)
|
||||
jittered_offset = direction*length
|
||||
|
||||
atom_array.coord -= jittered_offset
|
||||
|
||||
return atom_array
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user