import numpy as np import pandas as pd import logging import tree import torch import random from torch.utils.data import Dataset from data import utils as du from openfold.data import data_transforms from openfold.utils import rigid_utils from sklearn.preprocessing import PolynomialFeatures from sklearn.linear_model import LinearRegression def _rog_filter(df, quantile): y_quant = pd.pivot_table( df, values='radius_gyration', index='modeled_seq_len', aggfunc=lambda x: np.quantile(x, quantile) ) x_quant = y_quant.index.to_numpy() y_quant = y_quant.radius_gyration.to_numpy() # Fit polynomial regressor poly = PolynomialFeatures(degree=4, include_bias=True) poly_features = poly.fit_transform(x_quant[:, None]) poly_reg_model = LinearRegression() poly_reg_model.fit(poly_features, y_quant) # Calculate cutoff for all sequence lengths max_len = df.modeled_seq_len.max() pred_poly_features = poly.fit_transform(np.arange(max_len)[:, None]) # Add a little more. pred_y = poly_reg_model.predict(pred_poly_features) + 0.1 row_rog_cutoffs = df.modeled_seq_len.map(lambda x: pred_y[x-1]) return df[df.radius_gyration < row_rog_cutoffs] def _length_filter(data_csv, min_res, max_res): return data_csv[ (data_csv.modeled_seq_len >= min_res) & (data_csv.modeled_seq_len <= max_res) ] def _plddt_percent_filter(data_csv, min_plddt_percent): return data_csv[data_csv.num_confident_plddt > min_plddt_percent] def _max_coil_filter(data_csv, max_coil_percent): return data_csv[data_csv.coil_percent <= max_coil_percent] def _process_csv_row(processed_file_path): processed_feats = du.read_pkl(processed_file_path) processed_feats = du.parse_chain_feats(processed_feats) # Only take modeled residues. modeled_idx = processed_feats['modeled_idx'] min_idx = np.min(modeled_idx) max_idx = np.max(modeled_idx) del processed_feats['modeled_idx'] processed_feats = tree.map_structure( lambda x: x[min_idx:(max_idx+1)], processed_feats) # Run through OpenFold data transforms. chain_feats = { 'aatype': torch.tensor(processed_feats['aatype']).long(), 'all_atom_positions': torch.tensor(processed_feats['atom_positions']).double(), 'all_atom_mask': torch.tensor(processed_feats['atom_mask']).double() } chain_feats = data_transforms.atom37_to_frames(chain_feats) rigids_1 = rigid_utils.Rigid.from_tensor_4x4(chain_feats['rigidgroups_gt_frames'])[:, 0] rotmats_1 = rigids_1.get_rots().get_rot_mats() rotquats_1 = rigids_1.get_rots().get_quats() trans_1 = rigids_1.get_trans() res_plddt = processed_feats['b_factors'][:, 1] res_mask = torch.tensor(processed_feats['bb_mask']).int() # Re-number residue indices for each chain such that it starts from 1. # Randomize chain indices. chain_idx = processed_feats['chain_index'] res_idx = processed_feats['residue_index'] new_res_idx = np.zeros_like(res_idx) new_chain_idx = np.zeros_like(res_idx) all_chain_idx = np.unique(chain_idx).tolist() shuffled_chain_idx = np.array( random.sample(all_chain_idx, len(all_chain_idx))) - np.min(all_chain_idx) + 1 for i,chain_id in enumerate(all_chain_idx): chain_mask = (chain_idx == chain_id).astype(int) chain_min_idx = np.min(res_idx + (1 - chain_mask) * 1e3).astype(int) new_res_idx = new_res_idx + (res_idx - chain_min_idx + 1) * chain_mask # Shuffle chain_index replacement_chain_id = shuffled_chain_idx[i] new_chain_idx = new_chain_idx + replacement_chain_id * chain_mask if torch.isnan(trans_1).any() or torch.isnan(rotmats_1).any(): raise ValueError(f'Found NaNs in {processed_file_path}') return { 'res_plddt': res_plddt, 'aatype': chain_feats['aatype'], 'rotmats_1': rotmats_1, 'rotquats_1': rotquats_1, 'trans_1': trans_1, 'res_mask': res_mask, 'chain_idx': new_chain_idx, 'res_idx': new_res_idx, } def _add_plddt_mask(feats, plddt_threshold): feats['plddt_mask'] = torch.tensor( feats['res_plddt'] > plddt_threshold).int() def _read_clusters(cluster_path): pdb_to_cluster = {} with open(cluster_path, "r") as f: for i,line in enumerate(f): for chain in line.split(' '): pdb = chain.split('_')[0] pdb_to_cluster[pdb.upper()] = i return pdb_to_cluster class RectifyDataset(Dataset): def __init__( self, *, dataset_cfg, is_training, task, ): self._log = logging.getLogger(__name__) self._is_training = is_training self._dataset_cfg = dataset_cfg self.task = task self.raw_csv = pd.read_csv(self._dataset_cfg.rectify_csv_path) metadata_csv = self._filter_paired_metadata() metadata_csv = self._filter_metadata(metadata_csv) metadata_csv = metadata_csv.sort_values( 'modeled_seq_len', ascending=False) self._create_split(metadata_csv) #* create self.csv self.paired_pdb_names = self.csv['pdb_name'].unique() self._cache = {} self._rng = np.random.default_rng(seed=self._dataset_cfg.seed) @property def is_training(self): return self._is_training @property def dataset_cfg(self): return self._dataset_cfg def _filter_paired_metadata(self): """ Only left keep pdb_name with sample and noise pair """ grouped = self.raw_csv.groupby('pdb_name') paired_pdbs = grouped.filter(lambda x: set(x['type']) == {'sample', 'noise'})['pdb_name'].unique() filtered_csv = self.raw_csv[self.raw_csv['pdb_name'].isin(paired_pdbs)].copy() self._log.info(f'Found {len(paired_pdbs)} paired pdbs.') return filtered_csv def __len__(self): return len(self.csv) def _filter_metadata(self, meta_csv: pd.DataFrame) -> pd.DataFrame: filter_cfg = self._dataset_cfg.filter sample_csv = meta_csv[meta_csv.type == 'sample'] data_csv = _length_filter( sample_csv, filter_cfg.min_num_res, filter_cfg.max_num_res ) # data_csv = data_csv[ # data_csv.oligomeric_detail.isin(filter_cfg.oligomeric)] data_csv = data_csv[ data_csv.num_chains.isin(filter_cfg.num_chains)] data_csv = _max_coil_filter(data_csv, filter_cfg.max_coil_percent) data_csv = _rog_filter(data_csv, filter_cfg.rog_quantile) valid_pdb_names = data_csv['pdb_name'].unique() filtered_csv = meta_csv[meta_csv['pdb_name'].isin(valid_pdb_names)] filtered_csv['oligomeric_detail'] = 'monomeric' return filtered_csv def _create_split(self, data_csv): # Training or validation specific logic. if self.is_training: self.csv = data_csv self._log.info( f'Training: {len(self.csv)} examples') else: #* Validation if self._dataset_cfg.max_eval_length is None and self._dataset_cfg.min_eval_length is None: # min and max are empty eval_lengths = data_csv.modeled_seq_len elif self._dataset_cfg.max_eval_length is None: # min is not empty, max is empty eval_lengths = data_csv.modeled_seq_len[ data_csv.modeled_seq_len >= self._dataset_cfg.min_eval_length ] else: # max is not empty, min is empty eval_lengths = data_csv.modeled_seq_len[ data_csv.modeled_seq_len <= self._dataset_cfg.max_eval_length ] if self._dataset_cfg.min_eval_length is not None: eval_lengths = eval_lengths[eval_lengths >= self._dataset_cfg.min_eval_length] all_lengths = np.sort(eval_lengths.unique()) length_indices = (len(all_lengths) - 1) * np.linspace( 0.0, 1.0, self.dataset_cfg.num_eval_lengths) length_indices = length_indices.astype(int) eval_lengths = all_lengths[length_indices] eval_csv = data_csv[data_csv.modeled_seq_len.isin(eval_lengths)] # Fix a random seed to get the same split each time. eval_csv = eval_csv.groupby('modeled_seq_len').sample( self.dataset_cfg.samples_per_eval_length, replace=True, random_state=123 ) eval_csv = eval_csv.sort_values('modeled_seq_len', ascending=False) self.csv = eval_csv self._log.info( f'Validation: {len(self.csv)} examples with lengths {eval_lengths}') self.csv['index'] = list(range(len(self.csv))) def process_csv_row(self, csv_row): path = csv_row['processed_path'] seq_len = csv_row['modeled_seq_len'] # Large protein files are slow to read. Cache them. use_cache = seq_len > self._dataset_cfg.cache_num_res if use_cache and path in self._cache: return self._cache[path] processed_row = _process_csv_row(path) if use_cache: self._cache[path] = processed_row return processed_row def _sample_scaffold_mask(self, batch, rng): trans_1 = batch['trans_1'] num_res = trans_1.shape[0] min_motif_size = int(self._dataset_cfg.min_motif_percent * num_res) max_motif_size = int(self._dataset_cfg.max_motif_percent * num_res) # Sample the total number of residues that will be used as the motif. total_motif_size = self._rng.integers( low=min_motif_size, high=max_motif_size ) # Sample motifs at different locations. num_motifs = rng.integers(low=1, high=total_motif_size) # Attempt to sample attempt = 0 while attempt < 100: # Sample lengths of each motif. motif_lengths = np.sort( rng.integers( low=1, high=max_motif_size, size=(num_motifs,) ) ) # Truncate motifs to not go over the motif length. cumulative_lengths = np.cumsum(motif_lengths) motif_lengths = motif_lengths[cumulative_lengths < total_motif_size] if len(motif_lengths) == 0: attempt += 1 else: break if len(motif_lengths) == 0: motif_lengths = [total_motif_size] # Sample start location of each motif. seed_residues = rng.integers( low=0, high=num_res-1, size=(len(motif_lengths),) ) # Construct the motif mask. motif_mask = torch.zeros(num_res) for motif_seed, motif_len in zip(seed_residues, motif_lengths): motif_mask[motif_seed:min(motif_seed+motif_len, num_res)] = 1.0 scaffold_mask = 1 - motif_mask return scaffold_mask * batch['res_mask'] def setup_inpainting(self, feats, rng): diffuse_mask = self._sample_scaffold_mask(feats, rng) if 'plddt_mask' in feats: diffuse_mask = diffuse_mask * feats['plddt_mask'] if torch.sum(diffuse_mask) < 1: # Should only happen rarely. diffuse_mask = torch.ones_like(diffuse_mask) feats['diffuse_mask'] = diffuse_mask def __getitem__(self, idx): # Process data example. #* For training, idx is the index of pdb #* For validation, idx is the index of the csv if self.is_training: pdb_name = self.paired_pdb_names[idx] noise_pdb_name = self.paired_pdb_names[idx] # noise_pdb_name = pdb_name sample_row = self.csv[(self.csv['pdb_name'] == pdb_name) & (self.csv['type'] == 'sample')].iloc[0] # noise_row = self.csv[(self.csv['pdb_name'] == pdb_name) & (self.csv['type'] == 'noise')].iloc[0] noise_row = self.csv[(self.csv['pdb_name'] == noise_pdb_name) & (self.csv['type'] == 'noise')].iloc[0] else: pdb_name = self.csv.iloc[idx]['pdb_name'] sample_row = self.csv[(self.csv['pdb_name'] == pdb_name)].iloc[0] noise_row = sample_row # sample_row = self.csv[(self.csv['pdb_name'] == pdb_name) & (self.csv['type'] == 'sample')].iloc[0] # noise_row = self.csv[(self.csv['pdb_name'] == pdb_name) & (self.csv['type'] == 'noise')].iloc[0] if sample_row.empty or noise_row.empty: raise ValueError(f"Missing sample or noise row for pdb_name: {pdb_name}") sample_feats = self.process_csv_row(sample_row) noise_feats = self.process_csv_row(noise_row) sample_feats = self.process_feats(sample_feats) noise_feats = self.process_feats(noise_feats) # Storing the csv index is helpful for debugging. if self.is_training: sample_feats['csv_idx'] = torch.ones(1, dtype=torch.long) * (2 * idx) noise_feats['csv_idx'] = torch.ones(1, dtype=torch.long) * (2 * idx + 1) else: sample_feats['csv_idx'] = torch.ones(1, dtype=torch.long) * idx noise_feats['csv_idx'] = torch.ones(1, dtype=torch.long) * idx return {'sample': sample_feats, 'noise': noise_feats} def process_feats(self, feats): if self._dataset_cfg.add_plddt_mask: _add_plddt_mask(feats, self._dataset_cfg.min_plddt_threshold) else: feats['plddt_mask'] = torch.ones_like(feats['res_mask']) if self.task == 'hallucination': feats['diffuse_mask'] = torch.ones_like(feats['res_mask']).bool() elif self.task == 'inpainting': if self._dataset_cfg.inpainting_percent < random.random(): feats['diffuse_mask'] = torch.ones_like(feats['res_mask']) else: rng = self._rng if self.is_training else np.random.default_rng(seed=123) self.setup_inpainting(feats, rng) # Center based on motif locations motif_mask = 1 - feats['diffuse_mask'] trans_1 = feats['trans_1'] motif_1 = trans_1 * motif_mask[:, None] motif_com = torch.sum(motif_1, dim=0) / (torch.sum(motif_mask) + 1) trans_1 -= motif_com[None, :] feats['trans_1'] = trans_1 else: raise ValueError(f'Unknown task {self.task}') feats['diffuse_mask'] = feats['diffuse_mask'].int() return feats