diff --git a/models/rfd3/src/rfd3/inference/parsing.py b/models/rfd3/src/rfd3/inference/parsing.py index 65d5eb4..b06a48b 100644 --- a/models/rfd3/src/rfd3/inference/parsing.py +++ b/models/rfd3/src/rfd3/inference/parsing.py @@ -33,7 +33,7 @@ class InputSelection(BaseModel): ..., description="Validated selection dictionary", exclude=True ) raw: Any = Field(..., description="Original input value") - mask: np.ndarray[np.bool_] = Field( + mask: np.ndarray[Any, np.dtype[np.bool_]] = Field( ..., description="Boolean mask over atom array", exclude=True ) tokens: Optional[Dict[ComponentStr | str, AtomArray]] = Field( diff --git a/models/rfd3/src/rfd3/metrics/design_metrics.py b/models/rfd3/src/rfd3/metrics/design_metrics.py index 5ac24fb..fedc76a 100644 --- a/models/rfd3/src/rfd3/metrics/design_metrics.py +++ b/models/rfd3/src/rfd3/metrics/design_metrics.py @@ -127,7 +127,7 @@ def get_all_backbone_metrics( The atom array coming in will be a cleaned atom array (no virtual atoms and corrected atom names) without guideposts """ - o = {} + o: dict[str, Any] = {} # ... Clash metrics o = o | get_clash_metrics( diff --git a/models/rfd3/src/rfd3/model/layers/block_utils.py b/models/rfd3/src/rfd3/model/layers/block_utils.py index 20a2b62..c291b85 100644 --- a/models/rfd3/src/rfd3/model/layers/block_utils.py +++ b/models/rfd3/src/rfd3/model/layers/block_utils.py @@ -402,7 +402,7 @@ def build_index_mask( def extend_index_mask_with_neighbours( mask: torch.Tensor, D_LL: torch.Tensor, k: int -) -> torch.LongTensor: +) -> torch.Tensor: """ Parameters ---------- diff --git a/models/rfd3/src/rfd3/model/layers/chunked_pairwise.py b/models/rfd3/src/rfd3/model/layers/chunked_pairwise.py index 11d6278..39da6c3 100644 --- a/models/rfd3/src/rfd3/model/layers/chunked_pairwise.py +++ b/models/rfd3/src/rfd3/model/layers/chunked_pairwise.py @@ -320,6 +320,8 @@ class ChunkedPairwiseEmbedder: # 3. Single embedding terms if self._sl_cached is not None: # Fast path: MLP already run at tokenisation — just index into the result. + # _sl_cached and _sm_cached are populated together (see process_single_*). + assert self._sm_cached is not None # sl_cached [L, c_atompair]: query atom l always maps to row l. single_l = self._sl_cached.unsqueeze(0).unsqueeze(2).expand(B, -1, k, -1) # sm_cached [L, c_atompair]: key atoms are given by valid_indices [B, L, k]. diff --git a/models/rfd3/src/rfd3/testing/testing_utils.py b/models/rfd3/src/rfd3/testing/testing_utils.py index c0875e1..4f45a40 100644 --- a/models/rfd3/src/rfd3/testing/testing_utils.py +++ b/models/rfd3/src/rfd3/testing/testing_utils.py @@ -6,6 +6,7 @@ import os import sys import tempfile from pathlib import Path +from typing import Any from unittest.mock import patch import hydra @@ -206,7 +207,7 @@ def build_pipelines( standardize_crop_size: bool = True, **transform_kwargs, ): - pipes = {} + pipes: dict[bool, Any] = {} for is_validation in [True, False]: if composed_config is None: config = load_train_or_val_cfg(name=cfg_name, is_val_cfg=is_validation) diff --git a/models/rfd3/src/rfd3/transforms/hbonds_hbplus.py b/models/rfd3/src/rfd3/transforms/hbonds_hbplus.py index 58e855c..4ade17c 100644 --- a/models/rfd3/src/rfd3/transforms/hbonds_hbplus.py +++ b/models/rfd3/src/rfd3/transforms/hbonds_hbplus.py @@ -66,7 +66,7 @@ def calculate_hbonds( atom_array: AtomArray, cutoff_HA_dist: float = 3, cutoff_DA_distance: float = 3.5, -) -> Tuple[np.ndarray, np.ndarray, AtomArray]: +) -> Tuple[AtomArray, list[dict[str, Any]], int]: hbplus_exe = os.environ.get("HBPLUS_PATH") if hbplus_exe is None or hbplus_exe == "": diff --git a/models/rfd3/src/rfd3/transforms/ncaa_transforms.py b/models/rfd3/src/rfd3/transforms/ncaa_transforms.py index b788e64..f6c7cf8 100644 --- a/models/rfd3/src/rfd3/transforms/ncaa_transforms.py +++ b/models/rfd3/src/rfd3/transforms/ncaa_transforms.py @@ -60,7 +60,7 @@ class RandomlyMirrorInputs(Transform): if not mirror_input: return data - renamed_map = {} + renamed_map: dict[str, str] = {} res_starts = struct.get_residue_starts(atom_array) for i, r_i in enumerate(res_starts): if i == len(res_starts) - 1: diff --git a/models/rfd3/src/rfd3/transforms/rasa.py b/models/rfd3/src/rfd3/transforms/rasa.py index e2a1e15..2855b34 100644 --- a/models/rfd3/src/rfd3/transforms/rasa.py +++ b/models/rfd3/src/rfd3/transforms/rasa.py @@ -63,7 +63,8 @@ class SetZeroOccOnDeltaRASA(Transform): Used to measure if the atomwise RASA changed during cropping """ - requires_previous_transforms = [CalculateRASA] + # atomworks Transform types this list[str]; class refs are also accepted. + requires_previous_transforms = [CalculateRASA] # type: ignore[list-item] incompatible_previous_transforms = [ "PadWithVirtualAtoms", # must have the same atom names "CreateDesignReferenceFeatures", diff --git a/pyproject.toml b/pyproject.toml index 171f90f..07b0ef3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -225,15 +225,10 @@ module = [ "rfd3.inference.datasets", "rfd3.inference.input_parsing", "rfd3.inference.legacy_input_parsing", - "rfd3.inference.parsing", "rfd3.inference.symmetry.symmetry_utils", - "rfd3.metrics.design_metrics", "rfd3.model.RFD3", "rfd3.model.inference_sampler", - "rfd3.model.layers.block_utils", - "rfd3.model.layers.chunked_pairwise", "rfd3.run_inference", - "rfd3.testing.testing_utils", "rfd3.trainer.dump_validation_structures", "rfd3.trainer.fabric_trainer", "rfd3.trainer.rfd3", @@ -242,11 +237,8 @@ module = [ "rfd3.transforms.design_transforms", "rfd3.transforms.dna_crop", "rfd3.transforms.hbonds", - "rfd3.transforms.hbonds_hbplus", - "rfd3.transforms.ncaa_transforms", "rfd3.transforms.pipelines", "rfd3.transforms.ppi_transforms", - "rfd3.transforms.rasa", "rfd3.transforms.training_conditions", "rfd3.transforms.util_transforms", "rfd3.transforms.virtual_atoms",