From 39671441e609aa178ac4ee64a161cd94bc4a8330 Mon Sep 17 00:00:00 2001 From: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> Date: Wed, 3 Jun 2026 15:13:41 -0600 Subject: [PATCH] refactor(mypy): clear 8 easy rfd3 modules off the ignore ratchet (#298) * refactor(mypy): un-ignore 5 easy-tier modules Fix each module's single pre-existing type error with a pure annotation or setattr change (no behavior change) and remove it from the [[tool.mypy.overrides]] ignore_errors list: - callbacks/train_logging: loss_trackers: dict[str, MeanMetric] - callbacks/metrics_logging: seen_examples: set[str] - common: setattr(wrapper, "_has_run", True) for the @wraps wrapper - hydra/resolvers: attribute_path: str | None (body already guards) - inference_engines/base: base_overrides: dict[str, Any] 13 modules remain on the ignore list. mypy now type-checks the 5 newly-included modules cleanly. Co-authored-by: Sergey Lyskov * refactor(mypy): un-ignore 7 medium-tier modules Resolve the type errors in and remove from the [[tool.mypy.overrides]] ignore_errors list. Mostly narrowing / annotation fixes; two deliberate type-honesty fixes flagged below. - utils/weights: lowercase `any` -> `Any` in _PatternPolicyMixin (4x); assert-narrow fallback_policy at the call site (matches get_policy idiom) - model/layers/blocks: class-level w/b: torch.Tensor for the registered buffers (avoids nn.Module's Tensor | Module __getattr__ fallback) - utils/components: is-None narrowing + tip_names local in get_name_mask's TIP branch (exists() can't narrow for mypy); drop orphaned exists import - utils/logging: str(field) for the tree key; assign to a new hparams local rather than reassigning the typed cfg param - foundry_cli/download_checkpoints: guard on `hasher is not None`; total_size = 0.0 for the float accumulation - training/schedulers: SchedulerConfig.scheduler is now a required field (was = None, but documented required and assumed non-None everywhere) - utils/xpu/xpu_accelerator: name @property -> @staticmethod to match lightning's Accelerator ABC 6 hard-tier modules remain on the ignore list. Co-authored-by: Sergey Lyskov * refactor(mypy): un-ignore metrics/metric module Fix the 11 type errors in foundry.metrics.metric and remove it from the [tool.mypy.overrides] ignore_errors list (5 hard-tier modules remain). - str(name) coercion of DictConfig.items() keys (str|bytes|int|... union) - exists() -> 'is not None' narrowing; drop orphaned atomworks import - widen compute_from_kwargs -> dict|list and kwargs_to_compute_args -> dict|None to match the actual returns / documented contract (callers already handle them) - three type: ignore[arg-type] on nested_dict.get/getitem for an upstream atomworks annotation bug (param typed dict[tuple,...] but navigated as nested dict[str,Any]); warn_unused_ignores will flag them if upstream is fixed No behavior change. All gates green (ruff, mypy 41 files, pytest 27 passed). Co-authored-by: Sergey Lyskov * refactor(mypy): un-ignore utils/{ddp,rigid,datasets} Clear the three remaining foundry.utils.* modules off the mypy ignore_errors list (47 errors: ddp 12, rigid 16, datasets 19). Type-honesty and annotation fixes only, no behavior change: narrow DictConfig|dict params to DictConfig where attribute access requires it (item access kept where a plain-dict default is real), honest int|None / Tensor|None widenings, variable renames to avoid type-reuse, str() coercion of DictConfig keys, the file's own if/elif/else narrowing pattern, and documented type: ignore / cast for genuine torch and atomworks stub limitations. Two hard-tier modules remain (callbacks/health_logging, trainers/fabric). Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> * refactor(mypy): un-ignore callbacks/health_logging Clear foundry.callbacks.health_logging off the mypy ignore_errors list by fixing its 23 type errors (annotation / type-honesty only, no behavior change): - import the stdlib 'types' module directly instead of relying on 'from typing import types' (worked at runtime but fragile/untyped) - replace 'callable'-used-as-a-type with Mapping[str, Callable[..., Any]] on the stat/histogram dict params and Callable[..., bool] | None on the filter params; annotate the two MappingProxyType default constants to match - annotate the _hooks / _temp_cache / _cache instance vars - make implicit-Optional defaults explicit (... | None) on the two plot_tensor_* helpers, matching their is-not-None guards - in plot_tensor_hist, replace two type-changing param reassignments with equivalent always-set locals (display_values, step_labels) Only trainers/fabric remains on the ignore list. Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> * refactor(mypy): un-ignore trainers/fabric (ratchet complete) Clear foundry.trainers.fabric (the last and largest module) off the mypy ignore_errors list and remove the now-empty override block. The ratchet ignore list is now empty: all of src/foundry + src/foundry_cli type-checks with no per-module exemptions. Fixes are annotation / type-honesty only, no behavior change: - annotate self.state as dict[str, Any] (a heterogeneous, dynamically- keyed training-state bag, also merged with arbitrary checkpoint keys); this collapses ~69 union-attr/operator/arg-type errors. Also annotate default_state and declare _current_train_return (set by subclass training_step implementations). - dataloader types: Fabric.setup_dataloaders is stub-typed to return DataLoader | list[DataLoader], so cast its single-loader results to DataLoader and change train_loop/validation_loop params from _FabricDataLoader to DataLoader (drop the now-unused import). - precision: widen the param to str | int | None (the body sets it None when an XPU plugin takes over), cast to the guarded Literal at the XPUMixedPrecision call, and add one documented type: ignore[arg-type] where our public API is wider than Fabric's precision Literal. - narrow the parameter-freezing guard to direct attribute access; type get_latest_checkpoint as Path | None (matching its returns) with a cast at the single caller; drop a stale type: ignore. Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> * chore(mypy): bring models/rfd3 into scope behind an ignore_errors ratchet Add models/rfd3/src/rfd3 to [tool.mypy].files so the rfd3 model package is type-checked by the standard gate (mypy now covers 99 files: foundry + rfd3). Seed a fresh [[tool.mypy.overrides]] ignore_errors ratchet listing the 32 rfd3 modules with pre-existing type errors (194 total), mirroring the original src/foundry bootstrap; the 26 already-clean rfd3 modules are type-checked immediately. Modules are cleared from the ratchet one slice at a time in follow-up work. Config only, no code changes. rfd3 is an editable install, so imports resolve without an added mypy_path. Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> * refactor(mypy): clear 8 easy rfd3 modules off the ignore ratchet First slice of clearing the models/rfd3 mypy ratchet: the 8 modules with a single error each (32 -> 24 remaining on the ignore_errors list). mypy now type-checks 34 rfd3 modules. Annotation / type-honesty only, no behavior change: - block_utils: -> torch.LongTensor -> torch.Tensor (.long() is stub-typed Tensor; sibling helpers already return Tensor) - hbonds_hbplus: corrected calculate_hbonds's stale return annotation to match the actual (AtomArray, list[dict], int) return - inference/parsing: malformed pydantic np.ndarray[np.bool_] -> np.ndarray[Any, np.dtype[np.bool_]] - chunked_pairwise: assert _sm_cached is not None in the cache fast path (populated together with the already-narrowed _sl_cached) - rasa: documented type: ignore[list-item] (atomworks Transform types requires_previous_transforms as list[str]; class refs are accepted) - ncaa_transforms / design_metrics / testing_utils: var-annotated dicts (+ a missing 'from typing import Any') Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --------- Co-authored-by: Sergey Lyskov Co-authored-by: Hope Woods --- models/rfd3/src/rfd3/inference/parsing.py | 2 +- models/rfd3/src/rfd3/metrics/design_metrics.py | 2 +- models/rfd3/src/rfd3/model/layers/block_utils.py | 2 +- models/rfd3/src/rfd3/model/layers/chunked_pairwise.py | 2 ++ models/rfd3/src/rfd3/testing/testing_utils.py | 3 ++- models/rfd3/src/rfd3/transforms/hbonds_hbplus.py | 2 +- models/rfd3/src/rfd3/transforms/ncaa_transforms.py | 2 +- models/rfd3/src/rfd3/transforms/rasa.py | 3 ++- pyproject.toml | 8 -------- 9 files changed, 11 insertions(+), 15 deletions(-) diff --git a/models/rfd3/src/rfd3/inference/parsing.py b/models/rfd3/src/rfd3/inference/parsing.py index 65d5eb4..b06a48b 100644 --- a/models/rfd3/src/rfd3/inference/parsing.py +++ b/models/rfd3/src/rfd3/inference/parsing.py @@ -33,7 +33,7 @@ class InputSelection(BaseModel): ..., description="Validated selection dictionary", exclude=True ) raw: Any = Field(..., description="Original input value") - mask: np.ndarray[np.bool_] = Field( + mask: np.ndarray[Any, np.dtype[np.bool_]] = Field( ..., description="Boolean mask over atom array", exclude=True ) tokens: Optional[Dict[ComponentStr | str, AtomArray]] = Field( diff --git a/models/rfd3/src/rfd3/metrics/design_metrics.py b/models/rfd3/src/rfd3/metrics/design_metrics.py index 5ac24fb..fedc76a 100644 --- a/models/rfd3/src/rfd3/metrics/design_metrics.py +++ b/models/rfd3/src/rfd3/metrics/design_metrics.py @@ -127,7 +127,7 @@ def get_all_backbone_metrics( The atom array coming in will be a cleaned atom array (no virtual atoms and corrected atom names) without guideposts """ - o = {} + o: dict[str, Any] = {} # ... Clash metrics o = o | get_clash_metrics( diff --git a/models/rfd3/src/rfd3/model/layers/block_utils.py b/models/rfd3/src/rfd3/model/layers/block_utils.py index 20a2b62..c291b85 100644 --- a/models/rfd3/src/rfd3/model/layers/block_utils.py +++ b/models/rfd3/src/rfd3/model/layers/block_utils.py @@ -402,7 +402,7 @@ def build_index_mask( def extend_index_mask_with_neighbours( mask: torch.Tensor, D_LL: torch.Tensor, k: int -) -> torch.LongTensor: +) -> torch.Tensor: """ Parameters ---------- diff --git a/models/rfd3/src/rfd3/model/layers/chunked_pairwise.py b/models/rfd3/src/rfd3/model/layers/chunked_pairwise.py index 11d6278..39da6c3 100644 --- a/models/rfd3/src/rfd3/model/layers/chunked_pairwise.py +++ b/models/rfd3/src/rfd3/model/layers/chunked_pairwise.py @@ -320,6 +320,8 @@ class ChunkedPairwiseEmbedder: # 3. Single embedding terms if self._sl_cached is not None: # Fast path: MLP already run at tokenisation — just index into the result. + # _sl_cached and _sm_cached are populated together (see process_single_*). + assert self._sm_cached is not None # sl_cached [L, c_atompair]: query atom l always maps to row l. single_l = self._sl_cached.unsqueeze(0).unsqueeze(2).expand(B, -1, k, -1) # sm_cached [L, c_atompair]: key atoms are given by valid_indices [B, L, k]. diff --git a/models/rfd3/src/rfd3/testing/testing_utils.py b/models/rfd3/src/rfd3/testing/testing_utils.py index c0875e1..4f45a40 100644 --- a/models/rfd3/src/rfd3/testing/testing_utils.py +++ b/models/rfd3/src/rfd3/testing/testing_utils.py @@ -6,6 +6,7 @@ import os import sys import tempfile from pathlib import Path +from typing import Any from unittest.mock import patch import hydra @@ -206,7 +207,7 @@ def build_pipelines( standardize_crop_size: bool = True, **transform_kwargs, ): - pipes = {} + pipes: dict[bool, Any] = {} for is_validation in [True, False]: if composed_config is None: config = load_train_or_val_cfg(name=cfg_name, is_val_cfg=is_validation) diff --git a/models/rfd3/src/rfd3/transforms/hbonds_hbplus.py b/models/rfd3/src/rfd3/transforms/hbonds_hbplus.py index 58e855c..4ade17c 100644 --- a/models/rfd3/src/rfd3/transforms/hbonds_hbplus.py +++ b/models/rfd3/src/rfd3/transforms/hbonds_hbplus.py @@ -66,7 +66,7 @@ def calculate_hbonds( atom_array: AtomArray, cutoff_HA_dist: float = 3, cutoff_DA_distance: float = 3.5, -) -> Tuple[np.ndarray, np.ndarray, AtomArray]: +) -> Tuple[AtomArray, list[dict[str, Any]], int]: hbplus_exe = os.environ.get("HBPLUS_PATH") if hbplus_exe is None or hbplus_exe == "": diff --git a/models/rfd3/src/rfd3/transforms/ncaa_transforms.py b/models/rfd3/src/rfd3/transforms/ncaa_transforms.py index b788e64..f6c7cf8 100644 --- a/models/rfd3/src/rfd3/transforms/ncaa_transforms.py +++ b/models/rfd3/src/rfd3/transforms/ncaa_transforms.py @@ -60,7 +60,7 @@ class RandomlyMirrorInputs(Transform): if not mirror_input: return data - renamed_map = {} + renamed_map: dict[str, str] = {} res_starts = struct.get_residue_starts(atom_array) for i, r_i in enumerate(res_starts): if i == len(res_starts) - 1: diff --git a/models/rfd3/src/rfd3/transforms/rasa.py b/models/rfd3/src/rfd3/transforms/rasa.py index e2a1e15..2855b34 100644 --- a/models/rfd3/src/rfd3/transforms/rasa.py +++ b/models/rfd3/src/rfd3/transforms/rasa.py @@ -63,7 +63,8 @@ class SetZeroOccOnDeltaRASA(Transform): Used to measure if the atomwise RASA changed during cropping """ - requires_previous_transforms = [CalculateRASA] + # atomworks Transform types this list[str]; class refs are also accepted. + requires_previous_transforms = [CalculateRASA] # type: ignore[list-item] incompatible_previous_transforms = [ "PadWithVirtualAtoms", # must have the same atom names "CreateDesignReferenceFeatures", diff --git a/pyproject.toml b/pyproject.toml index 171f90f..07b0ef3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -225,15 +225,10 @@ module = [ "rfd3.inference.datasets", "rfd3.inference.input_parsing", "rfd3.inference.legacy_input_parsing", - "rfd3.inference.parsing", "rfd3.inference.symmetry.symmetry_utils", - "rfd3.metrics.design_metrics", "rfd3.model.RFD3", "rfd3.model.inference_sampler", - "rfd3.model.layers.block_utils", - "rfd3.model.layers.chunked_pairwise", "rfd3.run_inference", - "rfd3.testing.testing_utils", "rfd3.trainer.dump_validation_structures", "rfd3.trainer.fabric_trainer", "rfd3.trainer.rfd3", @@ -242,11 +237,8 @@ module = [ "rfd3.transforms.design_transforms", "rfd3.transforms.dna_crop", "rfd3.transforms.hbonds", - "rfd3.transforms.hbonds_hbplus", - "rfd3.transforms.ncaa_transforms", "rfd3.transforms.pipelines", "rfd3.transforms.ppi_transforms", - "rfd3.transforms.rasa", "rfd3.transforms.training_conditions", "rfd3.transforms.util_transforms", "rfd3.transforms.virtual_atoms",