fix(#293): skip full-db paths for reduced_dbs

This commit is contained in:
Dima
2026-04-09 14:09:37 +02:00
parent 894c4aa355
commit 34c04d1ade
2 changed files with 91 additions and 5 deletions

View File

@@ -68,7 +68,7 @@ AF3_DATABASES = {
"rna_central": "rnacentral_active_seq_id_90_cov_80_linclust.fasta",
}
AF2_DATABASE_FLAGS = {
AF2_FULL_DATABASE_FLAGS = {
"uniref90_database_path": "uniref90",
"uniref30_database_path": "uniref30",
"mgnify_database_path": "mgnify",
@@ -81,6 +81,16 @@ AF2_DATABASE_FLAGS = {
"obsolete_pdbs_path": "obsolete_pdbs",
}
AF2_REDUCED_DATABASE_FLAGS = {
"uniref90_database_path": "uniref90",
"mgnify_database_path": "mgnify",
"small_bfd_database_path": "small_bfd",
"uniprot_database_path": "uniprot",
"pdb_seqres_database_path": "pdb_seqres",
"template_mmcif_dir": "template_mmcif_dir",
"obsolete_pdbs_path": "obsolete_pdbs",
}
AF3_DATABASE_FLAGS = {
"uniref90_database_path": "uniref90",
"mgnify_database_path": "mgnify",
@@ -90,7 +100,11 @@ AF3_DATABASE_FLAGS = {
"template_mmcif_dir": "template_mmcif_dir",
}
DATABASE_PATH_FLAGS = frozenset(AF2_DATABASE_FLAGS) | frozenset(AF3_DATABASE_FLAGS)
DATABASE_PATH_FLAGS = (
frozenset(AF2_FULL_DATABASE_FLAGS)
| frozenset(AF2_REDUCED_DATABASE_FLAGS)
| frozenset(AF3_DATABASE_FLAGS)
)
# =================== Flags ===================
flags.DEFINE_enum(
@@ -180,9 +194,7 @@ def create_arguments(local_custom_template_db=None):
Optionally override template paths with a local custom template DB."""
validate_data_pipeline_flags()
required_database_flags = (
AF3_DATABASE_FLAGS if FLAGS.data_pipeline == 'alphafold3' else AF2_DATABASE_FLAGS
)
required_database_flags = get_required_database_flags()
# When using MMseqs2 (current implementation uses remote servers), database paths are not needed
# Note: Current MMseqs2 implementation uses remote servers via DEFAULT_API_SERVER
@@ -202,6 +214,20 @@ def create_arguments(local_custom_template_db=None):
FLAGS.template_mmcif_dir = os.path.join(local_custom_template_db, "pdb_mmcif", "mmcif_files")
FLAGS.obsolete_pdbs_path = os.path.join(local_custom_template_db, "pdb_mmcif", "obsolete.dat")
def get_required_database_flags():
"""Return the database flags required by the selected pipeline and preset."""
if FLAGS.data_pipeline == "alphafold3":
return AF3_DATABASE_FLAGS
if FLAGS.db_preset == "reduced_dbs":
required_flags = dict(AF2_REDUCED_DATABASE_FLAGS)
if FLAGS.use_hhsearch:
required_flags["pdb70_database_path"] = "pdb70"
return required_flags
return AF2_FULL_DATABASE_FLAGS
def check_template_date():
"""Check if the max_template_date is provided."""
if not FLAGS.max_template_date:

View File

@@ -760,6 +760,66 @@ class TestCreateIndividualFeaturesComprehensive:
assert FLAGS.obsolete_pdbs_path is None
logger.info("AF3 argument creation only kept AF3-relevant database paths")
def test_create_arguments_reduced_dbs_clears_unused_af2_databases(self):
"""Test that reduced_dbs only sets the AF2 paths it actually needs."""
logger.info("Testing reduced_dbs argument creation without full-db leftovers")
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test'])
FLAGS.use_mmseqs2 = False
FLAGS.data_pipeline = "alphafold2"
FLAGS.db_preset = "reduced_dbs"
FLAGS.use_hhsearch = False
FLAGS.data_dir = "/test/db"
FLAGS.uniref90_database_path = None
FLAGS.mgnify_database_path = None
FLAGS.small_bfd_database_path = None
FLAGS.uniprot_database_path = None
FLAGS.pdb_seqres_database_path = None
FLAGS.template_mmcif_dir = None
FLAGS.obsolete_pdbs_path = None
FLAGS.uniref30_database_path = "/stale/uniref30"
FLAGS.bfd_database_path = "/stale/bfd"
FLAGS.pdb70_database_path = "/stale/pdb70"
create_features.create_arguments()
assert FLAGS.uniref90_database_path == "/test/db/uniref90/uniref90.fasta"
assert FLAGS.mgnify_database_path == "/test/db/mgnify/mgy_clusters_2022_05.fa"
assert FLAGS.small_bfd_database_path == "/test/db/small_bfd/bfd-first_non_consensus_sequences.fasta"
assert FLAGS.uniprot_database_path == "/test/db/uniprot/uniprot.fasta"
assert FLAGS.pdb_seqres_database_path == "/test/db/pdb_seqres/pdb_seqres.txt"
assert FLAGS.template_mmcif_dir == "/test/db/pdb_mmcif/mmcif_files"
assert FLAGS.obsolete_pdbs_path == "/test/db/pdb_mmcif/obsolete.dat"
assert FLAGS.uniref30_database_path is None
assert FLAGS.bfd_database_path is None
assert FLAGS.pdb70_database_path is None
logger.info("Reduced-dbs argument creation cleared unused full-database paths")
def test_create_arguments_reduced_dbs_keeps_pdb70_for_hhsearch(self):
"""Test that reduced_dbs still sets pdb70 when HHsearch templates are requested."""
logger.info("Testing reduced_dbs HHsearch argument creation")
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test'])
FLAGS.use_mmseqs2 = False
FLAGS.data_pipeline = "alphafold2"
FLAGS.db_preset = "reduced_dbs"
FLAGS.use_hhsearch = True
FLAGS.data_dir = "/test/db"
FLAGS.pdb70_database_path = None
create_features.create_arguments()
assert FLAGS.pdb70_database_path == "/test/db/pdb70/pdb70"
assert FLAGS.bfd_database_path is None
assert FLAGS.uniref30_database_path is None
logger.info("Reduced-dbs HHsearch argument creation kept pdb70 without restoring full BFD")
def test_mmseqs2_without_data_dir(self):
"""Test that MMseqs2 works without data_dir flag."""
logger.info("Testing MMseqs2 without data_dir flag")