mirror of
https://github.com/AngxiaoYue/ReQFlow.git
synced 2026-06-04 12:14:23 +08:00
369 lines
14 KiB
Python
369 lines
14 KiB
Python
import numpy as np
|
|
import pandas as pd
|
|
import logging
|
|
import tree
|
|
import torch
|
|
import random
|
|
|
|
from torch.utils.data import Dataset
|
|
from data import utils as du
|
|
from openfold.data import data_transforms
|
|
from openfold.utils import rigid_utils
|
|
from sklearn.preprocessing import PolynomialFeatures
|
|
from sklearn.linear_model import LinearRegression
|
|
|
|
|
|
def _rog_filter(df, quantile):
|
|
y_quant = pd.pivot_table(
|
|
df,
|
|
values='radius_gyration',
|
|
index='modeled_seq_len',
|
|
aggfunc=lambda x: np.quantile(x, quantile)
|
|
)
|
|
x_quant = y_quant.index.to_numpy()
|
|
y_quant = y_quant.radius_gyration.to_numpy()
|
|
|
|
# Fit polynomial regressor
|
|
poly = PolynomialFeatures(degree=4, include_bias=True)
|
|
poly_features = poly.fit_transform(x_quant[:, None])
|
|
poly_reg_model = LinearRegression()
|
|
poly_reg_model.fit(poly_features, y_quant)
|
|
|
|
# Calculate cutoff for all sequence lengths
|
|
max_len = df.modeled_seq_len.max()
|
|
pred_poly_features = poly.fit_transform(np.arange(max_len)[:, None])
|
|
# Add a little more.
|
|
pred_y = poly_reg_model.predict(pred_poly_features) + 0.1
|
|
|
|
row_rog_cutoffs = df.modeled_seq_len.map(lambda x: pred_y[x-1])
|
|
return df[df.radius_gyration < row_rog_cutoffs]
|
|
|
|
|
|
def _length_filter(data_csv, min_res, max_res):
|
|
return data_csv[
|
|
(data_csv.modeled_seq_len >= min_res)
|
|
& (data_csv.modeled_seq_len <= max_res)
|
|
]
|
|
|
|
|
|
def _plddt_percent_filter(data_csv, min_plddt_percent):
|
|
return data_csv[data_csv.num_confident_plddt > min_plddt_percent]
|
|
|
|
|
|
def _max_coil_filter(data_csv, max_coil_percent):
|
|
return data_csv[data_csv.coil_percent <= max_coil_percent]
|
|
|
|
|
|
def _process_csv_row(processed_file_path):
|
|
processed_feats = du.read_pkl(processed_file_path)
|
|
processed_feats = du.parse_chain_feats(processed_feats)
|
|
|
|
# Only take modeled residues.
|
|
modeled_idx = processed_feats['modeled_idx']
|
|
min_idx = np.min(modeled_idx)
|
|
max_idx = np.max(modeled_idx)
|
|
del processed_feats['modeled_idx']
|
|
processed_feats = tree.map_structure(
|
|
lambda x: x[min_idx:(max_idx+1)], processed_feats)
|
|
|
|
# Run through OpenFold data transforms.
|
|
chain_feats = {
|
|
'aatype': torch.tensor(processed_feats['aatype']).long(),
|
|
'all_atom_positions': torch.tensor(processed_feats['atom_positions']).double(),
|
|
'all_atom_mask': torch.tensor(processed_feats['atom_mask']).double()
|
|
}
|
|
chain_feats = data_transforms.atom37_to_frames(chain_feats)
|
|
rigids_1 = rigid_utils.Rigid.from_tensor_4x4(chain_feats['rigidgroups_gt_frames'])[:, 0]
|
|
rotmats_1 = rigids_1.get_rots().get_rot_mats()
|
|
rotquats_1 = rigids_1.get_rots().get_quats()
|
|
trans_1 = rigids_1.get_trans()
|
|
res_plddt = processed_feats['b_factors'][:, 1]
|
|
res_mask = torch.tensor(processed_feats['bb_mask']).int()
|
|
|
|
# Re-number residue indices for each chain such that it starts from 1.
|
|
# Randomize chain indices.
|
|
chain_idx = processed_feats['chain_index']
|
|
res_idx = processed_feats['residue_index']
|
|
new_res_idx = np.zeros_like(res_idx)
|
|
new_chain_idx = np.zeros_like(res_idx)
|
|
all_chain_idx = np.unique(chain_idx).tolist()
|
|
shuffled_chain_idx = np.array(
|
|
random.sample(all_chain_idx, len(all_chain_idx))) - np.min(all_chain_idx) + 1
|
|
for i,chain_id in enumerate(all_chain_idx):
|
|
chain_mask = (chain_idx == chain_id).astype(int)
|
|
chain_min_idx = np.min(res_idx + (1 - chain_mask) * 1e3).astype(int)
|
|
new_res_idx = new_res_idx + (res_idx - chain_min_idx + 1) * chain_mask
|
|
|
|
# Shuffle chain_index
|
|
replacement_chain_id = shuffled_chain_idx[i]
|
|
new_chain_idx = new_chain_idx + replacement_chain_id * chain_mask
|
|
if torch.isnan(trans_1).any() or torch.isnan(rotmats_1).any():
|
|
raise ValueError(f'Found NaNs in {processed_file_path}')
|
|
return {
|
|
'res_plddt': res_plddt,
|
|
'aatype': chain_feats['aatype'],
|
|
'rotmats_1': rotmats_1,
|
|
'rotquats_1': rotquats_1,
|
|
'trans_1': trans_1,
|
|
'res_mask': res_mask,
|
|
'chain_idx': new_chain_idx,
|
|
'res_idx': new_res_idx,
|
|
}
|
|
|
|
|
|
def _add_plddt_mask(feats, plddt_threshold):
|
|
feats['plddt_mask'] = torch.tensor(
|
|
feats['res_plddt'] > plddt_threshold).int()
|
|
|
|
|
|
def _read_clusters(cluster_path):
|
|
pdb_to_cluster = {}
|
|
with open(cluster_path, "r") as f:
|
|
for i,line in enumerate(f):
|
|
for chain in line.split(' '):
|
|
pdb = chain.split('_')[0]
|
|
pdb_to_cluster[pdb.upper()] = i
|
|
return pdb_to_cluster
|
|
|
|
|
|
class RectifyDataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dataset_cfg,
|
|
is_training,
|
|
task,
|
|
):
|
|
self._log = logging.getLogger(__name__)
|
|
self._is_training = is_training
|
|
self._dataset_cfg = dataset_cfg
|
|
self.task = task
|
|
self.raw_csv = pd.read_csv(self._dataset_cfg.rectify_csv_path)
|
|
metadata_csv = self._filter_paired_metadata()
|
|
metadata_csv = self._filter_metadata(metadata_csv)
|
|
metadata_csv = metadata_csv.sort_values(
|
|
'modeled_seq_len', ascending=False)
|
|
self._create_split(metadata_csv) #* create self.csv
|
|
self.paired_pdb_names = self.csv['pdb_name'].unique()
|
|
self._cache = {}
|
|
self._rng = np.random.default_rng(seed=self._dataset_cfg.seed)
|
|
|
|
@property
|
|
def is_training(self):
|
|
return self._is_training
|
|
|
|
@property
|
|
def dataset_cfg(self):
|
|
return self._dataset_cfg
|
|
|
|
def _filter_paired_metadata(self):
|
|
"""
|
|
Only left keep pdb_name with sample and noise pair
|
|
"""
|
|
grouped = self.raw_csv.groupby('pdb_name')
|
|
paired_pdbs = grouped.filter(lambda x: set(x['type']) == {'sample', 'noise'})['pdb_name'].unique()
|
|
filtered_csv = self.raw_csv[self.raw_csv['pdb_name'].isin(paired_pdbs)].copy()
|
|
self._log.info(f'Found {len(paired_pdbs)} paired pdbs.')
|
|
return filtered_csv
|
|
|
|
def __len__(self):
|
|
return len(self.csv)
|
|
|
|
def _filter_metadata(self, meta_csv: pd.DataFrame) -> pd.DataFrame:
|
|
filter_cfg = self._dataset_cfg.filter
|
|
sample_csv = meta_csv[meta_csv.type == 'sample']
|
|
data_csv = _length_filter(
|
|
sample_csv,
|
|
filter_cfg.min_num_res,
|
|
filter_cfg.max_num_res
|
|
)
|
|
|
|
# data_csv = data_csv[
|
|
# data_csv.oligomeric_detail.isin(filter_cfg.oligomeric)]
|
|
data_csv = data_csv[
|
|
data_csv.num_chains.isin(filter_cfg.num_chains)]
|
|
data_csv = _max_coil_filter(data_csv, filter_cfg.max_coil_percent)
|
|
data_csv = _rog_filter(data_csv, filter_cfg.rog_quantile)
|
|
|
|
valid_pdb_names = data_csv['pdb_name'].unique()
|
|
filtered_csv = meta_csv[meta_csv['pdb_name'].isin(valid_pdb_names)]
|
|
filtered_csv['oligomeric_detail'] = 'monomeric'
|
|
return filtered_csv
|
|
|
|
def _create_split(self, data_csv):
|
|
# Training or validation specific logic.
|
|
if self.is_training:
|
|
self.csv = data_csv
|
|
self._log.info(
|
|
f'Training: {len(self.csv)} examples')
|
|
|
|
else: #* Validation
|
|
if self._dataset_cfg.max_eval_length is None and self._dataset_cfg.min_eval_length is None:
|
|
# min and max are empty
|
|
eval_lengths = data_csv.modeled_seq_len
|
|
elif self._dataset_cfg.max_eval_length is None:
|
|
# min is not empty, max is empty
|
|
eval_lengths = data_csv.modeled_seq_len[
|
|
data_csv.modeled_seq_len >= self._dataset_cfg.min_eval_length
|
|
]
|
|
else:
|
|
# max is not empty, min is empty
|
|
eval_lengths = data_csv.modeled_seq_len[
|
|
data_csv.modeled_seq_len <= self._dataset_cfg.max_eval_length
|
|
]
|
|
if self._dataset_cfg.min_eval_length is not None:
|
|
eval_lengths = eval_lengths[eval_lengths >= self._dataset_cfg.min_eval_length]
|
|
|
|
all_lengths = np.sort(eval_lengths.unique())
|
|
length_indices = (len(all_lengths) - 1) * np.linspace(
|
|
0.0, 1.0, self.dataset_cfg.num_eval_lengths)
|
|
length_indices = length_indices.astype(int)
|
|
eval_lengths = all_lengths[length_indices]
|
|
eval_csv = data_csv[data_csv.modeled_seq_len.isin(eval_lengths)]
|
|
|
|
# Fix a random seed to get the same split each time.
|
|
eval_csv = eval_csv.groupby('modeled_seq_len').sample(
|
|
self.dataset_cfg.samples_per_eval_length,
|
|
replace=True,
|
|
random_state=123
|
|
)
|
|
eval_csv = eval_csv.sort_values('modeled_seq_len', ascending=False)
|
|
self.csv = eval_csv
|
|
self._log.info(
|
|
f'Validation: {len(self.csv)} examples with lengths {eval_lengths}')
|
|
self.csv['index'] = list(range(len(self.csv)))
|
|
|
|
def process_csv_row(self, csv_row):
|
|
path = csv_row['processed_path']
|
|
seq_len = csv_row['modeled_seq_len']
|
|
# Large protein files are slow to read. Cache them.
|
|
use_cache = seq_len > self._dataset_cfg.cache_num_res
|
|
if use_cache and path in self._cache:
|
|
return self._cache[path]
|
|
processed_row = _process_csv_row(path)
|
|
if use_cache:
|
|
self._cache[path] = processed_row
|
|
return processed_row
|
|
|
|
def _sample_scaffold_mask(self, batch, rng):
|
|
trans_1 = batch['trans_1']
|
|
num_res = trans_1.shape[0]
|
|
min_motif_size = int(self._dataset_cfg.min_motif_percent * num_res)
|
|
max_motif_size = int(self._dataset_cfg.max_motif_percent * num_res)
|
|
|
|
# Sample the total number of residues that will be used as the motif.
|
|
total_motif_size = self._rng.integers(
|
|
low=min_motif_size,
|
|
high=max_motif_size
|
|
)
|
|
|
|
# Sample motifs at different locations.
|
|
num_motifs = rng.integers(low=1, high=total_motif_size)
|
|
|
|
# Attempt to sample
|
|
attempt = 0
|
|
while attempt < 100:
|
|
# Sample lengths of each motif.
|
|
motif_lengths = np.sort(
|
|
rng.integers(
|
|
low=1,
|
|
high=max_motif_size,
|
|
size=(num_motifs,)
|
|
)
|
|
)
|
|
|
|
# Truncate motifs to not go over the motif length.
|
|
cumulative_lengths = np.cumsum(motif_lengths)
|
|
motif_lengths = motif_lengths[cumulative_lengths < total_motif_size]
|
|
if len(motif_lengths) == 0:
|
|
attempt += 1
|
|
else:
|
|
break
|
|
if len(motif_lengths) == 0:
|
|
motif_lengths = [total_motif_size]
|
|
|
|
# Sample start location of each motif.
|
|
seed_residues = rng.integers(
|
|
low=0,
|
|
high=num_res-1,
|
|
size=(len(motif_lengths),)
|
|
)
|
|
|
|
# Construct the motif mask.
|
|
motif_mask = torch.zeros(num_res)
|
|
for motif_seed, motif_len in zip(seed_residues, motif_lengths):
|
|
motif_mask[motif_seed:min(motif_seed+motif_len, num_res)] = 1.0
|
|
scaffold_mask = 1 - motif_mask
|
|
return scaffold_mask * batch['res_mask']
|
|
|
|
def setup_inpainting(self, feats, rng):
|
|
diffuse_mask = self._sample_scaffold_mask(feats, rng)
|
|
if 'plddt_mask' in feats:
|
|
diffuse_mask = diffuse_mask * feats['plddt_mask']
|
|
if torch.sum(diffuse_mask) < 1:
|
|
# Should only happen rarely.
|
|
diffuse_mask = torch.ones_like(diffuse_mask)
|
|
feats['diffuse_mask'] = diffuse_mask
|
|
|
|
def __getitem__(self, idx):
|
|
# Process data example.
|
|
#* For training, idx is the index of pdb
|
|
#* For validation, idx is the index of the csv
|
|
|
|
if self.is_training:
|
|
pdb_name = self.paired_pdb_names[idx]
|
|
noise_pdb_name = self.paired_pdb_names[idx]
|
|
# noise_pdb_name = pdb_name
|
|
sample_row = self.csv[(self.csv['pdb_name'] == pdb_name) & (self.csv['type'] == 'sample')].iloc[0]
|
|
# noise_row = self.csv[(self.csv['pdb_name'] == pdb_name) & (self.csv['type'] == 'noise')].iloc[0]
|
|
noise_row = self.csv[(self.csv['pdb_name'] == noise_pdb_name) & (self.csv['type'] == 'noise')].iloc[0]
|
|
else:
|
|
pdb_name = self.csv.iloc[idx]['pdb_name']
|
|
sample_row = self.csv[(self.csv['pdb_name'] == pdb_name)].iloc[0]
|
|
noise_row = sample_row
|
|
# sample_row = self.csv[(self.csv['pdb_name'] == pdb_name) & (self.csv['type'] == 'sample')].iloc[0]
|
|
# noise_row = self.csv[(self.csv['pdb_name'] == pdb_name) & (self.csv['type'] == 'noise')].iloc[0]
|
|
if sample_row.empty or noise_row.empty:
|
|
raise ValueError(f"Missing sample or noise row for pdb_name: {pdb_name}")
|
|
sample_feats = self.process_csv_row(sample_row)
|
|
noise_feats = self.process_csv_row(noise_row)
|
|
|
|
sample_feats = self.process_feats(sample_feats)
|
|
noise_feats = self.process_feats(noise_feats)
|
|
|
|
# Storing the csv index is helpful for debugging.
|
|
if self.is_training:
|
|
sample_feats['csv_idx'] = torch.ones(1, dtype=torch.long) * (2 * idx)
|
|
noise_feats['csv_idx'] = torch.ones(1, dtype=torch.long) * (2 * idx + 1)
|
|
else:
|
|
sample_feats['csv_idx'] = torch.ones(1, dtype=torch.long) * idx
|
|
noise_feats['csv_idx'] = torch.ones(1, dtype=torch.long) * idx
|
|
|
|
return {'sample': sample_feats, 'noise': noise_feats}
|
|
|
|
def process_feats(self, feats):
|
|
if self._dataset_cfg.add_plddt_mask:
|
|
_add_plddt_mask(feats, self._dataset_cfg.min_plddt_threshold)
|
|
else:
|
|
feats['plddt_mask'] = torch.ones_like(feats['res_mask'])
|
|
|
|
if self.task == 'hallucination':
|
|
feats['diffuse_mask'] = torch.ones_like(feats['res_mask']).bool()
|
|
elif self.task == 'inpainting':
|
|
if self._dataset_cfg.inpainting_percent < random.random():
|
|
feats['diffuse_mask'] = torch.ones_like(feats['res_mask'])
|
|
else:
|
|
rng = self._rng if self.is_training else np.random.default_rng(seed=123)
|
|
self.setup_inpainting(feats, rng)
|
|
# Center based on motif locations
|
|
motif_mask = 1 - feats['diffuse_mask']
|
|
trans_1 = feats['trans_1']
|
|
motif_1 = trans_1 * motif_mask[:, None]
|
|
motif_com = torch.sum(motif_1, dim=0) / (torch.sum(motif_mask) + 1)
|
|
trans_1 -= motif_com[None, :]
|
|
feats['trans_1'] = trans_1
|
|
else:
|
|
raise ValueError(f'Unknown task {self.task}')
|
|
feats['diffuse_mask'] = feats['diffuse_mask'].int()
|
|
return feats
|