diff --git a/alphafold/model/config.py b/alphafold/model/config.py index 01d6e72..43b0aab 100644 --- a/alphafold/model/config.py +++ b/alphafold/model/config.py @@ -991,6 +991,15 @@ class AlphaFoldConfig(base_config.BaseConfig): data: Optional[Data] = None +def _set_num_ensembles(cfg: Any, name: str): + """Sets the number of ensembles based on the model name.""" + num_ensembles = 8 if name in MODEL_PRESETS['monomer_casp14'] else 1 + if 'multimer' in name: + cfg.model.num_ensemble_eval = num_ensembles + else: + cfg.data.eval.num_ensemble = num_ensembles + + def model_config(name: str) -> ml_collections.ConfigDict: """Get the ConfigDict of a CASP14 model.""" @@ -1001,6 +1010,7 @@ def model_config(name: str) -> ml_collections.ConfigDict: else: cfg = copy.deepcopy(CONFIG) cfg.update_from_flattened_dict(CONFIG_DIFFS[name]) + _set_num_ensembles(cfg, name) return cfg @@ -1016,6 +1026,7 @@ def get_model_config(name: str, frozen: bool = True) -> AlphaFoldConfig: ) apply_diff_op = CONFIG_DIFF_OPS[name] apply_diff_op(cfg) + _set_num_ensembles(cfg, name) if frozen: cfg.freeze() return cfg diff --git a/run_alphafold.py b/run_alphafold.py index 0f82d85..343883a 100644 --- a/run_alphafold.py +++ b/run_alphafold.py @@ -474,11 +474,6 @@ def main(argv): _check_flag('uniprot_database_path', 'model_preset', should_be_set=run_multimer_system) - if FLAGS.model_preset == 'monomer_casp14': - num_ensemble = 8 - else: - num_ensemble = 1 - # Check for duplicate FASTA file names. fasta_names = [pathlib.Path(p).stem for p in FLAGS.fasta_paths] if len(fasta_names) != len(set(fasta_names)): @@ -540,10 +535,6 @@ def main(argv): model_names = config.MODEL_PRESETS[FLAGS.model_preset] for model_name in model_names: model_config = config.model_config(model_name) - if run_multimer_system: - model_config.model.num_ensemble_eval = num_ensemble - else: - model_config.data.eval.num_ensemble = num_ensemble model_params = data.get_model_haiku_params( model_name=model_name, data_dir=FLAGS.data_dir) model_runner = model.RunModel(model_config, model_params)