mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2026-06-02 11:54:36 +08:00
Make everything past model_runner mandatory kwarg in process_fold_input
PiperOrigin-RevId: 856615431 Change-Id: I1e2a126e96e16dd2660f6e9fec9d9ffa02eeb7b2
This commit is contained in:
committed by
Copybara-Service
parent
3e2bc0ce35
commit
8d0a8db07e
@@ -681,6 +681,7 @@ def replace_db_dir(path_with_db_dir: str, db_dirs: Sequence[str]) -> str:
|
||||
def process_fold_input(
|
||||
fold_input: folding_input.Input,
|
||||
data_pipeline_config: pipeline.DataPipelineConfig | None,
|
||||
*,
|
||||
model_runner: None,
|
||||
output_dir: os.PathLike[str] | str,
|
||||
buckets: Sequence[int] | None = None,
|
||||
@@ -696,6 +697,7 @@ def process_fold_input(
|
||||
def process_fold_input(
|
||||
fold_input: folding_input.Input,
|
||||
data_pipeline_config: pipeline.DataPipelineConfig | None,
|
||||
*,
|
||||
model_runner: ModelRunner,
|
||||
output_dir: os.PathLike[str] | str,
|
||||
buckets: Sequence[int] | None = None,
|
||||
@@ -710,6 +712,7 @@ def process_fold_input(
|
||||
def process_fold_input(
|
||||
fold_input: folding_input.Input,
|
||||
data_pipeline_config: pipeline.DataPipelineConfig | None,
|
||||
*,
|
||||
model_runner: ModelRunner | None,
|
||||
output_dir: os.PathLike[str] | str,
|
||||
buckets: Sequence[int] | None = None,
|
||||
|
||||
@@ -208,7 +208,7 @@ class InferenceTest(parameterized.TestCase):
|
||||
actual = run_alphafold.process_fold_input(
|
||||
fold_input,
|
||||
self._data_pipeline_config,
|
||||
run_alphafold.ModelRunner(
|
||||
model_runner=run_alphafold.ModelRunner(
|
||||
config=self._model_config,
|
||||
device=jax.local_devices(backend='gpu')[0],
|
||||
model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value),
|
||||
|
||||
Reference in New Issue
Block a user