Bring model config creation in one place by simplifying num_ensemble setting

PiperOrigin-RevId: 811346971
Change-Id: Idbd6e84ab372dad81d8f3eb6e1a1af12daf4baeb
This commit is contained in:
Harsh Tiku
2025-09-25 07:42:28 -07:00
committed by Copybara-Service
parent 1ff6388c67
commit ecdc85103f
2 changed files with 11 additions and 9 deletions

View File

@@ -991,6 +991,15 @@ class AlphaFoldConfig(base_config.BaseConfig):
data: Optional[Data] = None
def _set_num_ensembles(cfg: Any, name: str):
"""Sets the number of ensembles based on the model name."""
num_ensembles = 8 if name in MODEL_PRESETS['monomer_casp14'] else 1
if 'multimer' in name:
cfg.model.num_ensemble_eval = num_ensembles
else:
cfg.data.eval.num_ensemble = num_ensembles
def model_config(name: str) -> ml_collections.ConfigDict:
"""Get the ConfigDict of a CASP14 model."""
@@ -1001,6 +1010,7 @@ def model_config(name: str) -> ml_collections.ConfigDict:
else:
cfg = copy.deepcopy(CONFIG)
cfg.update_from_flattened_dict(CONFIG_DIFFS[name])
_set_num_ensembles(cfg, name)
return cfg
@@ -1016,6 +1026,7 @@ def get_model_config(name: str, frozen: bool = True) -> AlphaFoldConfig:
)
apply_diff_op = CONFIG_DIFF_OPS[name]
apply_diff_op(cfg)
_set_num_ensembles(cfg, name)
if frozen:
cfg.freeze()
return cfg

View File

@@ -474,11 +474,6 @@ def main(argv):
_check_flag('uniprot_database_path', 'model_preset',
should_be_set=run_multimer_system)
if FLAGS.model_preset == 'monomer_casp14':
num_ensemble = 8
else:
num_ensemble = 1
# Check for duplicate FASTA file names.
fasta_names = [pathlib.Path(p).stem for p in FLAGS.fasta_paths]
if len(fasta_names) != len(set(fasta_names)):
@@ -540,10 +535,6 @@ def main(argv):
model_names = config.MODEL_PRESETS[FLAGS.model_preset]
for model_name in model_names:
model_config = config.model_config(model_name)
if run_multimer_system:
model_config.model.num_ensemble_eval = num_ensemble
else:
model_config.data.eval.num_ensemble = num_ensemble
model_params = data.get_model_haiku_params(
model_name=model_name, data_dir=FLAGS.data_dir)
model_runner = model.RunModel(model_config, model_params)