mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
128 lines
6.3 KiB
Python
128 lines
6.3 KiB
Python
import os
|
|
import logging
|
|
from absl.testing import parameterized
|
|
import shutil
|
|
import tempfile
|
|
from os.path import join
|
|
import gzip
|
|
import json
|
|
import pickle
|
|
from pathlib import Path
|
|
from alphapulldown.utils.post_modelling import post_prediction_process
|
|
|
|
class TestPostPrediction(parameterized.TestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
repo_root = Path(__file__).resolve().parents[2]
|
|
self.input_dir = join(repo_root, "test/test_data/predictions")
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
@parameterized.parameters(
|
|
('TEST', False, False, False),
|
|
('TEST', True, False, False),
|
|
('TEST', True, True, False),
|
|
('TEST', False, True, False),
|
|
('TEST_and_TEST', False, False, False),
|
|
('TEST_and_TEST', True, False, False),
|
|
('TEST_and_TEST', True, True, False),
|
|
('TEST_and_TEST', False, True, False),
|
|
('TEST', False, False, True),
|
|
('TEST', True, False, True),
|
|
('TEST', True, True, True),
|
|
('TEST', False, True, True),
|
|
('TEST_and_TEST', False, False, True),
|
|
('TEST_and_TEST', True, False, True),
|
|
('TEST_and_TEST', True, True, True),
|
|
('TEST_and_TEST', False, True, True)
|
|
)
|
|
def test_files(self, prediction_dir, compress_pickles, remove_pickles, remove_keys):
|
|
temp_dir = tempfile.TemporaryDirectory()
|
|
try:
|
|
logging.info(f"Running test for prediction_dir='{prediction_dir}', "
|
|
f"compress_pickles={compress_pickles}, remove_pickles={remove_pickles}, remove_keys={remove_keys}")
|
|
temp_dir_path = temp_dir.name
|
|
shutil.copytree(join(self.input_dir, prediction_dir), join(temp_dir_path, prediction_dir))
|
|
|
|
# Remove existing gz files
|
|
gz_files_existing = [f for f in os.listdir(join(temp_dir_path, prediction_dir)) if f.endswith('.gz')]
|
|
for f_ in gz_files_existing:
|
|
os.remove(join(temp_dir_path, prediction_dir, f_))
|
|
|
|
# Run the postprocessing
|
|
post_prediction_process(join(temp_dir_path, prediction_dir),
|
|
compress_pickles,
|
|
remove_pickles,
|
|
remove_keys)
|
|
|
|
# Identify the best model
|
|
with open(join(temp_dir_path, prediction_dir, 'ranking_debug.json')) as f:
|
|
best_model = json.load(f)['order'][0]
|
|
best_result_pickle = join(temp_dir_path, prediction_dir, f"result_{best_model}.pkl")
|
|
|
|
# Gather .pkl and .gz files
|
|
pickle_files = [f for f in os.listdir(join(temp_dir_path, prediction_dir)) if f.endswith('.pkl')]
|
|
gz_files = [f for f in os.listdir(join(temp_dir_path, prediction_dir)) if f.endswith('.gz')]
|
|
|
|
# Check if specified keys exist or were removed
|
|
if remove_keys:
|
|
for pf in pickle_files:
|
|
with open(join(temp_dir_path, prediction_dir, pf), 'rb') as f:
|
|
data = pickle.load(f)
|
|
for key in ['aligned_confidence_probs', 'distogram', 'masked_msa']:
|
|
self.assertNotIn(key, data, f"Key '{key}' was not removed from {pf}")
|
|
else:
|
|
# If we're not removing keys, verify they still exist in the pickle
|
|
for pf in pickle_files:
|
|
with open(join(temp_dir_path, prediction_dir, pf), 'rb') as f:
|
|
data = pickle.load(f)
|
|
for key in ['aligned_confidence_probs', 'distogram', 'masked_msa']:
|
|
self.assertIn(key, data, f"Key '{key}' was unexpectedly removed from {pf}")
|
|
|
|
# Now check file counts / compressions
|
|
if not compress_pickles and not remove_pickles:
|
|
# Expect all .pkl files (5 in your scenario), no .gz
|
|
self.assertEqual(len(pickle_files), 5,
|
|
f"Expected 5 pickle files, found {len(pickle_files)}.")
|
|
self.assertEqual(len(gz_files), 0,
|
|
f"Expected 0 gz files, found {len(gz_files)}.")
|
|
|
|
if compress_pickles and not remove_pickles:
|
|
# Expect 0 .pkl files, all compressed (5)
|
|
self.assertEqual(len(pickle_files), 0,
|
|
f"Expected 0 pickle files, found {len(pickle_files)}.")
|
|
self.assertEqual(len(gz_files), 5,
|
|
f"Expected 5 gz files, found {len(gz_files)}.")
|
|
# Validate that gz files are readable
|
|
for gz_file in gz_files:
|
|
with gzip.open(join(temp_dir_path, prediction_dir, gz_file), 'rb') as f:
|
|
f.read(1)
|
|
|
|
if not compress_pickles and remove_pickles:
|
|
# Only the best pickle remains
|
|
self.assertEqual(len(pickle_files), 1,
|
|
f"Expected 1 pickle file, found {len(pickle_files)}.")
|
|
self.assertEqual(len(gz_files), 0,
|
|
f"Expected 0 gz files, found {len(gz_files)}.")
|
|
self.assertTrue(os.path.exists(best_result_pickle),
|
|
f"Best result pickle file does not exist: {best_result_pickle}")
|
|
|
|
if compress_pickles and remove_pickles:
|
|
# Only the best pickle is compressed
|
|
self.assertEqual(len(pickle_files), 0,
|
|
f"Expected 0 pickle files, found {len(pickle_files)}.")
|
|
self.assertEqual(len(gz_files), 1,
|
|
f"Expected 1 gz file, found {len(gz_files)}.")
|
|
self.assertTrue(os.path.exists(best_result_pickle + ".gz"),
|
|
f"Best result pickle file not compressed: {best_result_pickle}.gz")
|
|
with gzip.open(join(temp_dir_path, prediction_dir, gz_files[0]), 'rb') as f:
|
|
f.read(1) # Check it's valid gzip
|
|
|
|
except AssertionError as e:
|
|
logging.error(f"AssertionError: {e}")
|
|
all_files = os.listdir(join(temp_dir_path, prediction_dir))
|
|
relevant_files = [f for f in all_files if f.endswith('.gz') or f.endswith('.pkl')]
|
|
logging.error(f".gz and .pkl files in {join(temp_dir_path, prediction_dir)}: {relevant_files}")
|
|
raise
|
|
finally:
|
|
temp_dir.cleanup()
|