diff --git a/alphapulldown/scripts/create_individual_features.py b/alphapulldown/scripts/create_individual_features.py index a1cfd326..b8fed124 100644 --- a/alphapulldown/scripts/create_individual_features.py +++ b/alphapulldown/scripts/create_individual_features.py @@ -68,7 +68,7 @@ AF3_DATABASES = { "rna_central": "rnacentral_active_seq_id_90_cov_80_linclust.fasta", } -AF2_DATABASE_FLAGS = { +AF2_FULL_DATABASE_FLAGS = { "uniref90_database_path": "uniref90", "uniref30_database_path": "uniref30", "mgnify_database_path": "mgnify", @@ -81,6 +81,16 @@ AF2_DATABASE_FLAGS = { "obsolete_pdbs_path": "obsolete_pdbs", } +AF2_REDUCED_DATABASE_FLAGS = { + "uniref90_database_path": "uniref90", + "mgnify_database_path": "mgnify", + "small_bfd_database_path": "small_bfd", + "uniprot_database_path": "uniprot", + "pdb_seqres_database_path": "pdb_seqres", + "template_mmcif_dir": "template_mmcif_dir", + "obsolete_pdbs_path": "obsolete_pdbs", +} + AF3_DATABASE_FLAGS = { "uniref90_database_path": "uniref90", "mgnify_database_path": "mgnify", @@ -90,7 +100,11 @@ AF3_DATABASE_FLAGS = { "template_mmcif_dir": "template_mmcif_dir", } -DATABASE_PATH_FLAGS = frozenset(AF2_DATABASE_FLAGS) | frozenset(AF3_DATABASE_FLAGS) +DATABASE_PATH_FLAGS = ( + frozenset(AF2_FULL_DATABASE_FLAGS) + | frozenset(AF2_REDUCED_DATABASE_FLAGS) + | frozenset(AF3_DATABASE_FLAGS) +) # =================== Flags =================== flags.DEFINE_enum( @@ -180,9 +194,7 @@ def create_arguments(local_custom_template_db=None): Optionally override template paths with a local custom template DB.""" validate_data_pipeline_flags() - required_database_flags = ( - AF3_DATABASE_FLAGS if FLAGS.data_pipeline == 'alphafold3' else AF2_DATABASE_FLAGS - ) + required_database_flags = get_required_database_flags() # When using MMseqs2 (current implementation uses remote servers), database paths are not needed # Note: Current MMseqs2 implementation uses remote servers via DEFAULT_API_SERVER @@ -202,6 +214,20 @@ def create_arguments(local_custom_template_db=None): FLAGS.template_mmcif_dir = os.path.join(local_custom_template_db, "pdb_mmcif", "mmcif_files") FLAGS.obsolete_pdbs_path = os.path.join(local_custom_template_db, "pdb_mmcif", "obsolete.dat") + +def get_required_database_flags(): + """Return the database flags required by the selected pipeline and preset.""" + if FLAGS.data_pipeline == "alphafold3": + return AF3_DATABASE_FLAGS + + if FLAGS.db_preset == "reduced_dbs": + required_flags = dict(AF2_REDUCED_DATABASE_FLAGS) + if FLAGS.use_hhsearch: + required_flags["pdb70_database_path"] = "pdb70" + return required_flags + + return AF2_FULL_DATABASE_FLAGS + def check_template_date(): """Check if the max_template_date is provided.""" if not FLAGS.max_template_date: diff --git a/test/integration/test_create_individual_features.py b/test/integration/test_create_individual_features.py index 4f94778e..3f1297af 100644 --- a/test/integration/test_create_individual_features.py +++ b/test/integration/test_create_individual_features.py @@ -760,6 +760,66 @@ class TestCreateIndividualFeaturesComprehensive: assert FLAGS.obsolete_pdbs_path is None logger.info("AF3 argument creation only kept AF3-relevant database paths") + def test_create_arguments_reduced_dbs_clears_unused_af2_databases(self): + """Test that reduced_dbs only sets the AF2 paths it actually needs.""" + logger.info("Testing reduced_dbs argument creation without full-db leftovers") + + from absl import flags + FLAGS = flags.FLAGS + FLAGS(['test']) + + FLAGS.use_mmseqs2 = False + FLAGS.data_pipeline = "alphafold2" + FLAGS.db_preset = "reduced_dbs" + FLAGS.use_hhsearch = False + FLAGS.data_dir = "/test/db" + FLAGS.uniref90_database_path = None + FLAGS.mgnify_database_path = None + FLAGS.small_bfd_database_path = None + FLAGS.uniprot_database_path = None + FLAGS.pdb_seqres_database_path = None + FLAGS.template_mmcif_dir = None + FLAGS.obsolete_pdbs_path = None + FLAGS.uniref30_database_path = "/stale/uniref30" + FLAGS.bfd_database_path = "/stale/bfd" + FLAGS.pdb70_database_path = "/stale/pdb70" + + create_features.create_arguments() + + assert FLAGS.uniref90_database_path == "/test/db/uniref90/uniref90.fasta" + assert FLAGS.mgnify_database_path == "/test/db/mgnify/mgy_clusters_2022_05.fa" + assert FLAGS.small_bfd_database_path == "/test/db/small_bfd/bfd-first_non_consensus_sequences.fasta" + assert FLAGS.uniprot_database_path == "/test/db/uniprot/uniprot.fasta" + assert FLAGS.pdb_seqres_database_path == "/test/db/pdb_seqres/pdb_seqres.txt" + assert FLAGS.template_mmcif_dir == "/test/db/pdb_mmcif/mmcif_files" + assert FLAGS.obsolete_pdbs_path == "/test/db/pdb_mmcif/obsolete.dat" + assert FLAGS.uniref30_database_path is None + assert FLAGS.bfd_database_path is None + assert FLAGS.pdb70_database_path is None + logger.info("Reduced-dbs argument creation cleared unused full-database paths") + + def test_create_arguments_reduced_dbs_keeps_pdb70_for_hhsearch(self): + """Test that reduced_dbs still sets pdb70 when HHsearch templates are requested.""" + logger.info("Testing reduced_dbs HHsearch argument creation") + + from absl import flags + FLAGS = flags.FLAGS + FLAGS(['test']) + + FLAGS.use_mmseqs2 = False + FLAGS.data_pipeline = "alphafold2" + FLAGS.db_preset = "reduced_dbs" + FLAGS.use_hhsearch = True + FLAGS.data_dir = "/test/db" + FLAGS.pdb70_database_path = None + + create_features.create_arguments() + + assert FLAGS.pdb70_database_path == "/test/db/pdb70/pdb70" + assert FLAGS.bfd_database_path is None + assert FLAGS.uniref30_database_path is None + logger.info("Reduced-dbs HHsearch argument creation kept pdb70 without restoring full BFD") + def test_mmseqs2_without_data_dir(self): """Test that MMseqs2 works without data_dir flag.""" logger.info("Testing MMseqs2 without data_dir flag")