Make everything past model_runner mandatory kwarg in process_fold_input

PiperOrigin-RevId: 856615431
Change-Id: I1e2a126e96e16dd2660f6e9fec9d9ffa02eeb7b2
This commit is contained in:
Augustin Zidek
2026-01-15 05:15:05 -08:00
committed by Copybara-Service
parent 3e2bc0ce35
commit 8d0a8db07e
2 changed files with 4 additions and 1 deletions

View File

@@ -681,6 +681,7 @@ def replace_db_dir(path_with_db_dir: str, db_dirs: Sequence[str]) -> str:
def process_fold_input(
fold_input: folding_input.Input,
data_pipeline_config: pipeline.DataPipelineConfig | None,
*,
model_runner: None,
output_dir: os.PathLike[str] | str,
buckets: Sequence[int] | None = None,
@@ -696,6 +697,7 @@ def process_fold_input(
def process_fold_input(
fold_input: folding_input.Input,
data_pipeline_config: pipeline.DataPipelineConfig | None,
*,
model_runner: ModelRunner,
output_dir: os.PathLike[str] | str,
buckets: Sequence[int] | None = None,
@@ -710,6 +712,7 @@ def process_fold_input(
def process_fold_input(
fold_input: folding_input.Input,
data_pipeline_config: pipeline.DataPipelineConfig | None,
*,
model_runner: ModelRunner | None,
output_dir: os.PathLike[str] | str,
buckets: Sequence[int] | None = None,

View File

@@ -208,7 +208,7 @@ class InferenceTest(parameterized.TestCase):
actual = run_alphafold.process_fold_input(
fold_input,
self._data_pipeline_config,
run_alphafold.ModelRunner(
model_runner=run_alphafold.ModelRunner(
config=self._model_config,
device=jax.local_devices(backend='gpu')[0],
model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value),