mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
Added tests for squeeze_features.
This commit is contained in:
@@ -145,7 +145,10 @@ def squeeze_features(protein):
|
||||
if k in protein:
|
||||
final_dim = protein[k].shape[-1]
|
||||
if isinstance(final_dim, int) and final_dim == 1:
|
||||
protein[k] = torch.squeeze(protein[k], dim=-1)
|
||||
if torch.is_tensor(protein[k]):
|
||||
protein[k] = torch.squeeze(protein[k], dim=-1)
|
||||
else:
|
||||
protein[k] = np.squeeze(protein[k], axis=-1)
|
||||
|
||||
for k in ["seq_length", "num_alignments"]:
|
||||
if k in protein:
|
||||
|
||||
Binary file not shown.
@@ -5,12 +5,12 @@ import os
|
||||
|
||||
import pickle
|
||||
|
||||
import numpy
|
||||
import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \
|
||||
correct_msa_restypes
|
||||
correct_msa_restypes, squeeze_features
|
||||
from openfold.config import model_config
|
||||
|
||||
|
||||
@@ -65,6 +65,38 @@ class TestDataTransforms(unittest.TestCase):
|
||||
print(protein)
|
||||
assert torch.all(torch.eq(torch.tensor(features['msa'].shape), torch.tensor(protein['msa'].shape)))
|
||||
|
||||
def test_squeeze_features(self):
|
||||
with open("../test_data/features.pkl", "rb") as file:
|
||||
features = pickle.load(file)
|
||||
print(os.path.realpath(file.name), 'Keys: ', features.keys())
|
||||
|
||||
features_list = [
|
||||
'domain_name', 'msa', 'num_alignments', 'seq_length', 'sequence',
|
||||
'superfamily', 'deletion_matrix', 'resolution',
|
||||
'between_segment_residues', 'residue_index', 'template_all_atom_mask']
|
||||
|
||||
protein = {'aatype': torch.tensor(features['aatype'])}
|
||||
for k in features_list:
|
||||
if k in features:
|
||||
print(k, features[k].dtype)
|
||||
if k in ['domain_name', 'sequence']:
|
||||
protein[k] = np.expand_dims(features[k], -1)
|
||||
else:
|
||||
protein[k] = torch.tensor(features[k]).unsqueeze(-1)
|
||||
|
||||
for k in ['seq_length', 'num_alignments']:
|
||||
if k in protein:
|
||||
protein[k] = torch.tensor(protein[k]).unsqueeze(0)
|
||||
|
||||
protein_squeezed = squeeze_features(protein)
|
||||
print(protein)
|
||||
for k in features_list:
|
||||
if k in protein:
|
||||
print(k, protein_squeezed[k].shape, features[k].shape)
|
||||
assert protein_squeezed[k].shape == features[k].shape
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user