diff --git a/models/rfd3/src/rfd3/inference/input_parsing.py b/models/rfd3/src/rfd3/inference/input_parsing.py index 17032e4..f309180 100644 --- a/models/rfd3/src/rfd3/inference/input_parsing.py +++ b/models/rfd3/src/rfd3/inference/input_parsing.py @@ -184,6 +184,7 @@ class DesignInputSpecification(BaseModel): symmetry: Optional[SymmetryConfig] = Field(None, description="Symmetry specification, see docs/symmetry.md") # Centering & COM guidance ori_token: Optional[list[float]] = Field(None, description="Origin coordinates") + ori_jitter: Optional[float] = Field(None, description="Jitter ori in a random direction and use ori_jitter to sample distance via exponential distribution") infer_ori_strategy: Optional[str] = Field(None, description="Strategy for inferring origin; `com` or `hotspots`") # Additional global conditioning plddt_enhanced: Optional[bool] = Field(True, description="Enable pLDDT enhancement") @@ -725,6 +726,13 @@ class DesignInputSpecification(BaseModel): ligand_array.set_annotation( annot, np.full(ligand_array.array_length(), default) ) + + chain_cand = 'X' + while chain_cand in atom_array.chain_id.tolist(): + chain_cand = chain_cand + chain_cand + ligand_chain = np.array([chain_cand]*len(ligand_array)) + ligand_array.chain_id = ligand_chain + atom_array = atom_array + ligand_array return atom_array @@ -749,8 +757,10 @@ class DesignInputSpecification(BaseModel): "Partial diffusion with symmetry: skipping COM centering to preserve chain spacing" ) else: + if not exists(self.ori_jitter): + self.ori_jitter = None atom_array = set_com( - atom_array, ori_token=None, infer_ori_strategy="com" + atom_array, ori_token=None, infer_ori_strategy="com", ori_jitter=self.ori_jitter ) else: # Standard: set ori token, zero out diffused atoms diff --git a/models/rfd3/src/rfd3/inference/legacy_input_parsing.py b/models/rfd3/src/rfd3/inference/legacy_input_parsing.py index aed3cc9..fbee346 100644 --- a/models/rfd3/src/rfd3/inference/legacy_input_parsing.py +++ b/models/rfd3/src/rfd3/inference/legacy_input_parsing.py @@ -736,6 +736,12 @@ def create_atom_array_from_design_specification_legacy( + np.max(atom_array.res_id) + 1 ) + chain_cand = 'X' + while chain_cand in atom_array.chain_id.tolist(): + chain_cand = chain_cand + chain_cand + ligand_chain = np.array([chain_cand]*len(ligand_array)) + ligand_array.chain_id = ligand_chain + atom_array = atom_array + ligand_array # ... Apply symmetry if it exists ahead of any other processing diff --git a/models/rfd3/src/rfd3/utils/inference.py b/models/rfd3/src/rfd3/utils/inference.py index 2d04e14..426eace 100644 --- a/models/rfd3/src/rfd3/utils/inference.py +++ b/models/rfd3/src/rfd3/utils/inference.py @@ -452,7 +452,7 @@ as input and return a three-element list or numpy array of floats. def set_com( - atom_array, ori_token: list | None = None, infer_ori_strategy: str | None = None + atom_array, ori_token: list | None = None, infer_ori_strategy: str | None = None, ori_jitter: float | None = None ): if exists(ori_token): center = np.array([float(x) for x in ori_token], dtype=atom_array.coord.dtype) @@ -505,6 +505,17 @@ def set_com( atom_array.coord = np.zeros_like( atom_array.coord, dtype=atom_array.coord.dtype ) + if ori_jitter is not None: + # randomly jitter ori with given scale + direction = np.random.normal(size=3) + direction /= np.linalg.norm(direction) + + # Random length (mean ~ scale) + length = np.random.exponential(scale=scale) + jittered_offset = direction*length + + atom_array.coord -= jittered_offset + return atom_array