mirror of
https://github.com/AngxiaoYue/ReQFlow.git
synced 2026-06-04 12:14:23 +08:00
117 lines
4.3 KiB
Python
117 lines
4.3 KiB
Python
import numpy as np
|
|
import pandas as pd
|
|
import logging
|
|
import tree
|
|
import torch
|
|
import random
|
|
|
|
from torch.utils.data import Dataset
|
|
from data import utils as du
|
|
from openfold.data import data_transforms
|
|
from openfold.utils import rigid_utils
|
|
|
|
from openfold.utils.rigid_utils import rot_to_quat, quat_to_rot
|
|
|
|
class RectifyProtDataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dataset_cfg,
|
|
task,
|
|
is_training):
|
|
self._log = logging.getLogger(__name__)
|
|
self._dataset_cfg = dataset_cfg
|
|
self.is_training = is_training
|
|
data = pd.read_csv(self._dataset_cfg.rectify_csv_path)
|
|
self._create_split(data)
|
|
|
|
|
|
|
|
def _create_split(self, data_csv):
|
|
if self.is_training:
|
|
self.csv = data_csv
|
|
self._log.info(f'Loaded training dataset with {len(self.csv)} samples.')
|
|
else:
|
|
if self._dataset_cfg.max_eval_length is None and self._dataset_cfg.min_eval_length is None:
|
|
# min and max are empty
|
|
eval_lengths = data_csv.length
|
|
elif self._dataset_cfg.max_eval_length is None:
|
|
# min is not empty, max is empty
|
|
eval_lengths = data_csv.length[
|
|
data_csv.length >= self._dataset_cfg.min_eval_length
|
|
]
|
|
else:
|
|
# max is not empty, min is empty
|
|
eval_lengths = data_csv.length[
|
|
data_csv.length <= self._dataset_cfg.max_eval_length
|
|
]
|
|
if self._dataset_cfg.min_eval_length is not None:
|
|
eval_lengths = eval_lengths[eval_lengths >= self._dataset_cfg.min_eval_length]
|
|
|
|
all_lengths = np.sort(eval_lengths.unique())
|
|
length_indices = (len(all_lengths) - 1) * np.linspace(
|
|
0.0, 1.0, self._dataset_cfg.num_eval_lengths)
|
|
length_indices = length_indices.astype(int)
|
|
eval_lengths = all_lengths[length_indices]
|
|
eval_csv = data_csv[data_csv.length.isin(eval_lengths)]
|
|
eval_csv = eval_csv.groupby('length').sample(
|
|
self._dataset_cfg.samples_per_eval_length,
|
|
replace=True,
|
|
random_state=123
|
|
)
|
|
eval_csv = eval_csv.sort_values('length', ascending=False)
|
|
self.csv = eval_csv
|
|
self._log.info(
|
|
f'Validation: {len(self.csv)} examples with lengths {eval_lengths}')
|
|
|
|
self.csv['index'] = list(range(len(self.csv)))
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.csv)
|
|
|
|
def __getitem__(self, idx):
|
|
row = self.csv.iloc[idx]
|
|
prot_traj = torch.load(row['prot_path'], map_location='cpu')
|
|
num_res = length = row['length']
|
|
sample_batch = {}
|
|
noise_batch = {}
|
|
|
|
trans_0 = prot_traj[0][0]
|
|
rotmats_0 = prot_traj[0][1]
|
|
gt_trans_1 = prot_traj[1][0]
|
|
gt_rotmats_1 = prot_traj[1][1]
|
|
|
|
rotquats_0 = rot_to_quat(rotmats_0)
|
|
gt_rotquats_1 = rot_to_quat(gt_rotmats_1)
|
|
|
|
sample_batch['trans_1'] = gt_trans_1.squeeze(dim=0) # (1, num_res, 3) -> (num_res, 3)
|
|
sample_batch['rotmats_1'] = gt_rotmats_1.squeeze(dim=0) # (1, num_res, 3, 3) -> (num_res, 3, 3)
|
|
sample_batch['rotquats_1'] = gt_rotquats_1.squeeze(dim=0) # (1, num_res, 4) -> (num_res, 4)
|
|
sample_batch['res_mask'] = torch.ones(num_res, dtype=torch.int)
|
|
sample_batch['diffuse_mask'] = torch.ones(num_res, dtype=torch.int)
|
|
sample_batch['res_idx'] = torch.arange(num_res, dtype=torch.int)
|
|
sample_batch['csv_idx'] = torch.tensor([idx], dtype=torch.long)
|
|
sample_batch['chain_idx'] = torch.zeros(1, dtype=torch.int)
|
|
|
|
noise_batch['trans_1'] = trans_0.squeeze(dim=0)
|
|
noise_batch['rotmats_1'] = rotmats_0.squeeze(dim=0)
|
|
noise_batch['rotquats_1'] = rotquats_0.squeeze(dim=0)
|
|
noise_batch['res_mask'] = torch.ones(num_res, dtype=torch.int)
|
|
noise_batch['diffuse_mask'] = torch.ones(num_res, dtype=torch.int)
|
|
noise_batch['res_idx'] = torch.arange(num_res, dtype=torch.int)
|
|
noise_batch['csv_idx'] = torch.tensor([idx], dtype=torch.long)
|
|
noise_batch['chain_idx'] = torch.zeros(1, dtype=torch.int)
|
|
|
|
|
|
|
|
|
|
|
|
return {'sample': sample_batch, 'noise': noise_batch}
|
|
|
|
|
|
|
|
|