fix(#42): add RF annotation to skip-MSA stockholm

This commit is contained in:
Dima
2026-04-10 12:14:14 +02:00
parent 392d3f7ec2
commit 2c715001a9
2 changed files with 69 additions and 1 deletions

View File

@@ -34,7 +34,13 @@ def _query_only_a3m(sequence: str, query_id: str = "query") -> str:
def _query_only_stockholm(sequence: str, query_id: str = "query") -> str:
"""Return a single-sequence Stockholm alignment string."""
return f"# STOCKHOLM 1.0\n{query_id} {sequence}\n//\n"
rf_annotation = "x" * len(sequence)
return (
"# STOCKHOLM 1.0\n"
f"{query_id} {sequence}\n"
f"#=GC RF {rf_annotation}\n"
"//\n"
)
class MonomericObject:

View File

@@ -278,6 +278,68 @@ def test_make_features_skip_msa_builds_query_only_features_and_templates(
assert monomer.feature_dict["template_domain_names"].tolist() == [b"1abc_A"]
def test_make_features_skip_msa_builds_stockholm_with_rf_for_hmmsearch(
monkeypatch, tmp_path
):
monomer = MonomericObject("proteinA", "ACDE")
calls = {}
class FakeTemplateSearcher:
input_format = "sto"
output_format = "sto"
def query(self, alignment):
calls["template_query"] = alignment
return "template_hits"
def get_template_hits(self, output_string, input_sequence):
return ["hitA"]
class FakeTemplateFeaturizer:
def get_templates(self, query_sequence, hits):
return SimpleNamespace(
features={
"template_aatype": np.ones((1, 4, 22), dtype=np.float32),
"template_all_atom_masks": np.ones((1, 4, 37), dtype=np.float32),
"template_all_atom_positions": np.ones(
(1, 4, 37, 3), dtype=np.float32
),
"template_domain_names": np.asarray([b"1abc_A"], dtype=object),
"template_sequence": np.asarray([b"ACDE"], dtype=object),
"template_sum_probs": np.asarray([0.5], dtype=np.float32),
}
)
class FakePipeline:
template_searcher = FakeTemplateSearcher()
template_featurizer = FakeTemplateFeaturizer()
def process(self, *_args, **_kwargs):
raise AssertionError("skip_msa should bypass pipeline.process")
monkeypatch.setattr(
MonomericObject, "unzip_msa_files", staticmethod(lambda _path: False)
)
monkeypatch.setattr(
MonomericObject,
"remove_msa_files",
staticmethod(lambda msa_output_path=None, **_kwargs: None),
)
monkeypatch.setattr(
MonomericObject, "zip_msa_files", staticmethod(lambda _path: None)
)
monomer.make_features(
pipeline=FakePipeline(),
output_dir=str(tmp_path),
save_msa=False,
skip_msa=True,
)
assert "#=GC RF xxxx" in calls["template_query"]
assert calls["template_query"].startswith("# STOCKHOLM 1.0\nquery ACDE\n")
def test_make_mmseq_features_builds_all_seq_features_and_writes_a3m(
monkeypatch, tmp_path
):