Implement relevant improvements from #615

PiperOrigin-RevId: 868589660
Change-Id: Iac82ddf73f9f82118b935550afbe0ea13f6cd2eb
This commit is contained in:
Augustin Zidek
2026-02-11 03:27:10 -08:00
committed by Copybara-Service
parent ceb32296c7
commit e2b8ffd6a7

View File

@@ -39,10 +39,9 @@ import numpy as np
ModelResult: TypeAlias = Mapping[str, Any]
_ScalarNumberOrArray: TypeAlias = Mapping[str, float | int | np.ndarray]
@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, kw_only=True)
class InferenceResult:
"""Postprocessed model result.
@@ -58,8 +57,12 @@ class InferenceResult:
"""
predicted_structure: structure.Structure = dataclasses.field()
numerical_data: _ScalarNumberOrArray = dataclasses.field(default_factory=dict)
metadata: _ScalarNumberOrArray = dataclasses.field(default_factory=dict)
numerical_data: Mapping[str, float | int | np.ndarray] = dataclasses.field(
default_factory=dict
)
metadata: Mapping[str, float | int | np.ndarray] = dataclasses.field(
default_factory=dict
)
debug_outputs: Mapping[str, Any] = dataclasses.field(default_factory=dict)
model_id: bytes = b''
@@ -464,9 +467,8 @@ class Model(hk.Module):
# Computing solvent accessible area with dssp can be slow for large
# structures with lots of chains, so we parallelize the call.
pred_structures = pred_structure.unstack()
num_workers = len(pred_structures)
with concurrent.futures.ThreadPoolExecutor(
max_workers=num_workers
max_workers=min(len(pred_structures), 32)
) as executor:
has_clash = list(executor.map(confidences.has_clash, pred_structures))
fraction_disordered = list(