Files
Delete/utils/transforms.py
HaotianZhang 7652568ecc update
2024-11-25 09:59:13 -08:00

1241 lines
56 KiB
Python

import copy
import os
import sys
sys.path.append('.')
import random
import time
import uuid
from itertools import compress
import torch
import torch.nn.functional as F
import numpy as np
from torch_geometric.nn.pool import knn_graph
from torch_geometric.transforms import Compose
from torch_geometric.utils import subgraph
from torch_geometric.nn import knn, radius
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_scatter import scatter_add
from rdkit import Chem
from rdkit.Chem import rdMMPA
from scipy.spatial import distance_matrix
try:
from .data import ProteinLigandData
from .datasets import *
from .misc import *
from .train import inf_iterator
from .protein_ligand import ATOM_FAMILIES
from .chem import remove_dummys_mol, check_linkers, Murcko_decompose
except:
from utils.data import ProteinLigandData
from utils.datasets import *
from utils.misc import *
from utils.train import inf_iterator
from utils.protein_ligand import ATOM_FAMILIES
from utils.chem import remove_dummys_mol, check_linkers, Murcko_decompose
import argparse
import logging
def k_nearest_neighbors(query_pos, all_pos, constraint_id=None, k_query=10):
"""
Find the k-nearest neighbors for each query point in a subset of atoms.
- query_pos: (n_sample, 3) tensor of query positions
- all_pos: (n_constrained, 3) tensor of all atom positions
- constraint_id: (n_constrained) tensor of indices of constrained atoms
- k_query: number of neighbors to find
Returns:
- indices of k-nearest neighbors: (n_sample, k_query) tensor
- distances to k-nearest neighbors: (n_sample, k_query) tensor
"""
# Select the atoms of interest using the constraint_id
if constraint_id is None:
constraint_id = torch.arange(all_pos.size(0), device=all_pos.device)
constrained_pos = all_pos[constraint_id]
# Calculate squared distances using broadcasting
# (n_sample, 1, 3) - (1, n_constrained, 3) -> (n_sample, n_constrained, 3)
diff = query_pos.unsqueeze(1) - constrained_pos.unsqueeze(0)
dist_squared = (diff ** 2).sum(dim=2) # Sum over the coordinate dimension
# Get the k smallest distances and their indices for each query point
# We use k+1 here because topk includes the zero distance (self-neighbor) when query_pos is part of all_pos
distances, indices = torch.topk(dist_squared, k_query, largest=False, sorted=True)
# Return the indices within the constrained list and the square root of distances
# We need to map back the indices from the constrained subset to the original all_pos index
actual_indices = constraint_id[indices]
return actual_indices, torch.sqrt(distances)
def neighbors_within_distance(query_pos, all_pos, constraint_id=None, distance_threshold=5.0):
"""
Find the neighbors within a distance threshold for each query point in a subset of atoms.
- query_pos: (n_sample, 3) tensor of query positions
- all_pos: (n_constrained, 3) tensor of all atom positions
- constraint_id: (n_constrained) tensor of indices of constrained atoms
- distance_threshold: maximum distance to consider a neighbor
Returns:
- indices of neighbors within distance: list of (n_neighbors) tensors
- distances to neighbors within distance: list of (n_neighbors) tensors
"""
# Select the atoms of interest using the constraint_id
if constraint_id is None:
constraint_id = torch.arange(all_pos.size(0), device=all_pos.device)
constrained_pos = all_pos[constraint_id]
# Calculate squared distances using broadcasting
# (n_sample, 1, 3) - (1, n_constrained, 3) -> (n_sample, n_constrained, 3)
diff = query_pos.unsqueeze(1) - constrained_pos.unsqueeze(0)
dist_squared = (diff ** 2).sum(dim=2) # Sum over the coordinate dimension
# Apply the distance threshold
# Convert distance_threshold to squared distance to use with our squared distances
threshold_squared = distance_threshold ** 2
within_threshold = dist_squared <= threshold_squared
# Gather indices and distances for those within the threshold
indices = []
distances = []
for i in range(query_pos.size(0)):
mask = within_threshold[i]
indices.append(constraint_id[mask])
distances.append(torch.sqrt(dist_squared[i][mask]))
return indices, distances
def compress_relations(index_lists, distance_lists):
# Flatten the index and distance lists into single tensors
flat_indices = torch.cat(index_lists)
flat_distances = torch.cat(distance_lists)
# Prepare to create the mask indicating group number
lengths = [len(indices) for indices in index_lists]
mask = torch.zeros_like(flat_indices, dtype=torch.long) # Start with zeros
# Iterate over the lengths and assign group numbers
current_group = 0
idx = 0
for length in lengths:
mask[idx:idx + length] = current_group
idx += length
current_group += 1 # Increment the group number for the next group
return flat_indices, flat_distances, mask
class Protein_ligand_relation(object):
def __init__(self):
super().__init__()
self.cutoff = 8.0
# Hydrogen bond interaction: 3.5-4.0 A
# Van der Waals interaction: 4.0-6.0 A
# Electrostatic interaction: 6.0-8.0 A
def __call__(self, data):
protein_ligand_relation_idx, protein_ligand_relation_dist = neighbors_within_distance(data['pos_generate'], data['compose_pos'], data['idx_protein_in_compose'], self.cutoff)
data.pl_relation_idx, data.pl_relation_dist, data.pl_relation_mask = compress_relations(protein_ligand_relation_idx, protein_ligand_relation_dist)
return data
class RefineData(object):
def __init__(self):
super().__init__()
def __call__(self, data):
# delete H atom of pocket
protein_feature = data.protein_feature
# delete H atom of ligand
ligand_element = data.ligand_element
is_H_ligand = (ligand_element == 1)
if torch.sum(is_H_ligand) > 0:
not_H_ligand = ~is_H_ligand
data.ligand_atom_feature = data.ligand_atom_feature[not_H_ligand]
data.ligand_element = data.ligand_element[not_H_ligand]
data.ligand_pos = data.ligand_pos[not_H_ligand]
# nbh
index_atom_H = torch.nonzero(is_H_ligand)[:, 0]
index_changer = -np.ones(len(not_H_ligand), dtype=np.int64)
index_changer[not_H_ligand] = np.arange(torch.sum(not_H_ligand))
new_nbh_list = [value for ind_this, value in zip(not_H_ligand, data.ligand_nbh_list.values()) if ind_this]
data.ligand_nbh_list = {i:[index_changer[node] for node in neigh if node not in index_atom_H] for i, neigh in enumerate(new_nbh_list)}
# bond
ind_bond_with_H = np.array([(bond_i in index_atom_H) | (bond_j in index_atom_H) for bond_i, bond_j in zip(*data.ligand_bond_index)])
ind_bond_without_H = ~ind_bond_with_H
old_ligand_bond_index = data.ligand_bond_index[:, ind_bond_without_H]
data.ligand_bond_index = torch.tensor(index_changer)[old_ligand_bond_index]
data.ligand_bond_type = data.ligand_bond_type[ind_bond_without_H]
return data
class FeaturizeProteinAtom(object):
def __init__(self):
super().__init__()
# self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 16, 34]) # H, C, N, O, S, Se
self.atomic_numbers = torch.LongTensor([6, 7, 8, 16, 34]) # H, C, N, O, S, Se
self.max_num_aa = 20
@property
def feature_dim(self):
return 5
#return self.atomic_numbers.size(0) + self.max_num_aa + 1 + 1
def __call__(self, data:ProteinLigandData):
feature = data.protein_feature
is_mol_atom = torch.zeros(feature.shape[0], dtype=torch.long).unsqueeze(-1)
# x = torch.cat([element, amino_acid, is_backbone], dim=-1)
x = torch.cat([feature, is_mol_atom], dim=-1)
data.protein_surf_feature = x
# data.compose_index = torch.arange(len(element), dtype=torch.long)
return data
class FeaturizeLigandAtom(object):
def __init__(self):
super().__init__()
# self.atomic_numbers = torch.LongTensor([1,6,7,8,9,15,16,17]) # H C N O F P S Cl
self.atomic_numbers = torch.LongTensor([6,7,8,9,15,16,17]) # C N O F P S Cl
assert len(self.atomic_numbers) == 7, NotImplementedError('fix the staticmethod: chagne_bond')
# @property
# def num_properties(self):
# return len(ATOM_FAMILIES)
@property
def feature_dim(self):
return self.atomic_numbers.size(0) + (1 + 1 + 1) + 3
def __call__(self, data:ProteinLigandData):
element = data.ligand_element.view(-1, 1) == self.atomic_numbers.view(1, -1) # (N_atoms, N_elements)
# chem_feature = data.ligand_atom_feature
is_mol_atom = torch.ones([len(element), 1], dtype=torch.long)
n_neigh = data.ligand_num_neighbors.view(-1, 1)
n_valence = data.ligand_atom_valence.view(-1, 1)
ligand_atom_num_bonds = data.ligand_atom_num_bonds
# x = torch.cat([element, chem_feature, ], dim=-1)
x = torch.cat([element, is_mol_atom, n_neigh, n_valence, ligand_atom_num_bonds], dim=-1)
data.ligand_atom_feature_full = x
return data
@staticmethod
def change_features_of_neigh(ligand_feature_full, new_num_neigh, new_num_valence, ligand_atom_num_bonds):
idx_n_neigh = 7 + 1
idx_n_valence = idx_n_neigh + 1
idx_n_bonds = idx_n_valence + 1
ligand_feature_full[:, idx_n_neigh] = new_num_neigh.long()
ligand_feature_full[:, idx_n_valence] = new_num_valence.long()
ligand_feature_full[:, idx_n_bonds:idx_n_bonds+3] = ligand_atom_num_bonds.long()
return ligand_feature_full
class FeaturizeLigandBond(object):
def __init__(self):
super().__init__()
def __call__(self, data:ProteinLigandData):
data.ligand_bond_feature = F.one_hot(data.ligand_bond_type - 1 , num_classes=3) # (1,2,3) to (0,1,2)-onehot
return data
class LigandCountNeighbors(object):
@staticmethod
def count_neighbors(edge_index, symmetry, valence=None, num_nodes=None):
assert symmetry == True, 'Only support symmetrical edges.'
if num_nodes is None:
num_nodes = maybe_num_nodes(edge_index)
if valence is None:
valence = torch.ones([edge_index.size(1)], device=edge_index.device)
valence = valence.view(edge_index.size(1))
return scatter_add(valence, index=edge_index[0], dim=0, dim_size=num_nodes).long()
def __init__(self):
super().__init__()
def __call__(self, data):
data.ligand_num_neighbors = self.count_neighbors(
data.ligand_bond_index,
symmetry=True,
num_nodes=data.ligand_element.size(0),
)
data.ligand_atom_valence = self.count_neighbors(
data.ligand_bond_index,
symmetry=True,
valence=data.ligand_bond_type,
num_nodes=data.ligand_element.size(0),
)
data.ligand_atom_num_bonds = torch.stack([
self.count_neighbors(
data.ligand_bond_index,
symmetry=True,
valence=(data.ligand_bond_type == i).long(),
num_nodes=data.ligand_element.size(0),
) for i in [1, 2, 3]
], dim = -1)
return data
class LigandRandomMask(object):
def __init__(self, min_ratio=0.0, max_ratio=1.2, min_num_masked=1, min_num_unmasked=0):
super().__init__()
self.min_ratio = min_ratio
self.max_ratio = max_ratio
self.min_num_masked = min_num_masked
self.min_num_unmasked = min_num_unmasked
def __call__(self, data:ProteinLigandData):
ratio = np.clip(random.uniform(self.min_ratio, self.max_ratio), 0.0, 1.0)
num_atoms = data.ligand_element.size(0)
num_masked = int(num_atoms * ratio)
if num_masked < self.min_num_masked:
num_masked = self.min_num_masked
if (num_atoms - num_masked) < self.min_num_unmasked:
num_masked = num_atoms - self.min_num_unmasked
idx = np.arange(num_atoms)
np.random.shuffle(idx)
# if data.keep_frag is not None:
# data.context_keep_frag = []
# for kf in data.keep_frag:
# data.context_keep_frag.append(np.where(idx==kf)[0][0])
idx = torch.LongTensor(idx)
masked_idx = idx[:num_masked]
context_idx = idx[num_masked:]
data.context_idx = context_idx # for change bond index
data.masked_idx = masked_idx
# masked ligand atom element/feature/pos.
data.ligand_masked_element = data.ligand_element[masked_idx]
# data.ligand_masked_feature = data.ligand_atom_feature[masked_idx] # For Prediction. these features are chem properties
data.ligand_masked_pos = data.ligand_pos[masked_idx]
# context ligand atom elment/full features/pos. Note: num_neigh and num_valence features should be changed
data.ligand_context_element = data.ligand_element[context_idx]
data.ligand_context_feature_full = data.ligand_atom_feature_full[context_idx] # For Input
data.ligand_context_pos = data.ligand_pos[context_idx]
# new bond with ligand context atoms
if data.ligand_bond_index.size(1) != 0:
data.ligand_context_bond_index, data.ligand_context_bond_type = subgraph(
context_idx,
data.ligand_bond_index,
edge_attr = data.ligand_bond_type,
relabel_nodes = True,
)
else:
data.ligand_context_bond_index = torch.empty([2, 0], dtype=torch.long)
data.ligand_context_bond_type = torch.empty([0], dtype=torch.long)
# change context atom features that relate to bonds
data.ligand_context_num_neighbors = LigandCountNeighbors.count_neighbors(
data.ligand_context_bond_index,
symmetry=True,
num_nodes = context_idx.size(0),
)
data.ligand_context_valence = LigandCountNeighbors.count_neighbors(
data.ligand_context_bond_index,
symmetry=True,
valence=data.ligand_context_bond_type,
num_nodes=context_idx.size(0)
)
data.ligand_context_num_bonds = torch.stack([
LigandCountNeighbors.count_neighbors(
data.ligand_context_bond_index,
symmetry=True,
valence=(data.ligand_context_bond_type == i).long(),
num_nodes=context_idx.size(0),
) for i in [1, 2, 3]
], dim = -1)
# re-calculate ligand_context_featrure_full
data.ligand_context_feature_full = FeaturizeLigandAtom.change_features_of_neigh(
data.ligand_context_feature_full,
data.ligand_context_num_neighbors,
data.ligand_context_valence,
data.ligand_context_num_bonds
)
data.ligand_frontier = data.ligand_context_num_neighbors < data.ligand_num_neighbors[context_idx]
data._mask = 'random'
return data
class LigandBFSMask(object):
def __init__(self, min_ratio=0.0, max_ratio=1.2, min_num_masked=1, min_num_unmasked=0, inverse=False):
super().__init__()
self.min_ratio = min_ratio
self.max_ratio = max_ratio
self.min_num_masked = min_num_masked
self.min_num_unmasked = min_num_unmasked
self.inverse = inverse
@staticmethod
def get_bfs_perm(nbh_list):
num_nodes = len(nbh_list)
num_neighbors = torch.LongTensor([len(nbh_list[i]) for i in range(num_nodes)])
bfs_queue = [random.randint(0, num_nodes-1)]
bfs_perm = []
num_remains = [num_neighbors.clone()]
bfs_next_list = {}
visited = {bfs_queue[0]}
num_nbh_remain = num_neighbors.clone()
while len(bfs_queue) > 0:
current = bfs_queue.pop(0)
for nbh in nbh_list[current]:
num_nbh_remain[nbh] -= 1
bfs_perm.append(current)
num_remains.append(num_nbh_remain.clone())
next_candid = []
for nxt in nbh_list[current]:
if nxt in visited: continue
next_candid.append(nxt)
visited.add(nxt)
random.shuffle(next_candid)
bfs_queue += next_candid
bfs_next_list[current] = copy.copy(bfs_queue)
return torch.LongTensor(bfs_perm), bfs_next_list, num_remains
def __call__(self, data):
bfs_perm, bfs_next_list, num_remaining_nbs = self.get_bfs_perm(data.ligand_nbh_list)
ratio = np.clip(random.uniform(self.min_ratio, self.max_ratio), 0.0, 1.0)
num_atoms = data.ligand_element.size(0)
num_masked = int(num_atoms * ratio)
if num_masked < self.min_num_masked:
num_masked = self.min_num_masked
if (num_atoms - num_masked) < self.min_num_unmasked:
num_masked = num_atoms - self.min_num_unmasked
if self.inverse:
masked_idx = bfs_perm[:num_masked]
context_idx = bfs_perm[num_masked:]
else:
masked_idx = bfs_perm[-num_masked:]
context_idx = bfs_perm[:-num_masked]
data.context_idx = context_idx # for change bond index
data.masked_idx = masked_idx
# masked ligand atom element/feature/pos.
data.ligand_masked_element = data.ligand_element[masked_idx]
# data.ligand_masked_feature = data.ligand_atom_feature[masked_idx] # For Prediction. these features are chem properties
data.ligand_masked_pos = data.ligand_pos[masked_idx]
# context ligand atom elment/full features/pos. Note: num_neigh and num_valence features should be changed
data.ligand_context_element = data.ligand_element[context_idx]
data.ligand_context_feature_full = data.ligand_atom_feature_full[context_idx] # For Input
data.ligand_context_pos = data.ligand_pos[context_idx]
# new bond with ligand context atoms
if data.ligand_bond_index.size(1) != 0:
data.ligand_context_bond_index, data.ligand_context_bond_type = subgraph(
context_idx,
data.ligand_bond_index,
edge_attr = data.ligand_bond_type,
relabel_nodes = True,
)
else:
data.ligand_context_bond_index = torch.empty([2, 0], dtype=torch.long)
data.ligand_context_bond_type = torch.empty([0], dtype=torch.long)
# re-calculate atom features that relate to bond
data.ligand_context_num_neighbors = LigandCountNeighbors.count_neighbors(
data.ligand_context_bond_index,
symmetry=True,
num_nodes = context_idx.size(0),
)
data.ligand_context_valence = LigandCountNeighbors.count_neighbors(
data.ligand_context_bond_index,
symmetry=True,
valence=data.ligand_context_bond_type,
num_nodes=context_idx.size(0)
)
data.ligand_context_num_bonds = torch.stack([
LigandCountNeighbors.count_neighbors(
data.ligand_context_bond_index,
symmetry=True,
valence=data.ligand_context_bond_type == i,
num_nodes=context_idx.size(0),
) for i in [1, 2, 3]
], dim = -1)
# re-calculate ligand_context_featrure_full
data.ligand_context_feature_full = FeaturizeLigandAtom.change_features_of_neigh(
data.ligand_context_feature_full,
data.ligand_context_num_neighbors,
data.ligand_context_valence,
data.ligand_context_num_bonds
)
data.ligand_frontier = data.ligand_context_num_neighbors < data.ligand_num_neighbors[context_idx]
data._mask = 'invbfs' if self.inverse else 'bfs'
return data
class LigandMaskAll(LigandRandomMask):
def __init__(self):
super().__init__(min_ratio=1.0)
class LigandMaskZero(LigandRandomMask):
def __init__(self):
super().__init__(max_ratio=0.0, min_num_masked=0)
class LigandMaskSpatial(object):
def __init__(self, threshold=3, random_spatial=False, lower=2, upper=5):
super().__init__()
self.threshold = threshold
if random_spatial:
self.threshold = random.uniform(lower,upper)
def __call__(self, data):
# masking maker
mol = data.ligand_mol
Chem.SanitizeMol(mol)
num_atoms = mol.GetNumAtoms()
center_id = random.uniform(0, num_atoms)
coords = data.ligand_pos
dist_mat = distance_matrix(coords, coords, p=2)
context_id = dist_mat[center_id] < self.threshold
masked_id = ~context_id
context_id = np.nonzero(context_id)[0]
masked_id = np.nonzero(masked_id)[0]
context_idx = torch.LongTensor(masked_id)
masked_idx = torch.LongTensor(context_id)
data.context_idx = context_idx
data.masked_idx = masked_idx
# masked element and feature maker
data.ligand_masked_element = data.ligand_element[masked_idx]
data.ligand_masked_pos = data.ligand_pos[masked_idx]
data.ligand_context_element = data.ligand_element[context_idx]
data.ligand_context_feature_full = data.ligand_atom_feature_full[context_idx] # For Input
data.ligand_context_pos = data.ligand_pos[context_idx]
if data.ligand_bond_index.size(1) != 0:
data.ligand_context_bond_index, data.ligand_context_bond_type = subgraph(
context_idx,
data.ligand_bond_index,
edge_attr = data.ligand_bond_type,
relabel_nodes = True,
)
else:
data.ligand_context_bond_index = torch.empty([2, 0], dtype=torch.long)
data.ligand_context_bond_type = torch.empty([0], dtype=torch.long)
# change context atom features that relate to bonds
data.ligand_context_num_neighbors = LigandCountNeighbors.count_neighbors(
data.ligand_context_bond_index,
symmetry=True,
num_nodes = context_idx.size(0),
)
data.ligand_context_valence = LigandCountNeighbors.count_neighbors(
data.ligand_context_bond_index,
symmetry=True,
valence=data.ligand_context_bond_type,
num_nodes=context_idx.size(0)
)
data.ligand_context_num_bonds = torch.stack([
LigandCountNeighbors.count_neighbors(
data.ligand_context_bond_index,
symmetry=True,
valence=(data.ligand_context_bond_type == i).long(),
num_nodes=context_idx.size(0),
) for i in [1, 2, 3]
], dim = -1)
# re-calculate ligand_context_featrure_full
data.ligand_context_feature_full = FeaturizeLigandAtom.change_features_of_neigh(
data.ligand_context_feature_full,
data.ligand_context_num_neighbors,
data.ligand_context_valence,
data.ligand_context_num_bonds
)
data.ligand_frontier = data.ligand_context_num_neighbors < data.ligand_num_neighbors[context_idx]
data._mask = 'spatial'
return data
class LigandMaskFrag(object):
def __init__(self, masker, pattern="[#6+0;!$(*=,#[!#6])]!@!=!#[*]"):
super().__init__()
self.masker = masker
self.pattern = pattern
def __call__(self, data):
mol = data.ligand_mol
Chem.SanitizeMol(mol)
num_atoms = mol.GetNumAtoms()
try:
if self.masker == 'frag':
fragmentations = rdMMPA.FragmentMol(mol, minCuts=1, maxCuts=1, maxCutBonds=100, pattern=self.pattern, resultsAsMols=False)
fragmentation = random.choice(fragmentations)[1].replace('.',',').split(',') #no core
id = random.randint(0,1)
masked_frag = remove_dummys_mol(fragmentation[id])[0]
elif self.masker == 'linker':
fragmentations = rdMMPA.FragmentMol(mol, minCuts=2, maxCuts=2, maxCutBonds=100, pattern=self.pattern, resultsAsMols=False)
fragmentations = check_linkers(fragmentations)
fragmentation = random.choice(fragmentations)
core, chains = fragmentation
masked_frag = remove_dummys_mol(core)[0]
masked_frag = remove_dummys_mol(masked_frag)[0]
elif self.masker == 'linker_double_frag':
fragmentations = rdMMPA.FragmentMol(mol, minCuts=2, maxCuts=2, maxCutBonds=100, pattern=self.pattern, resultsAsMols=False)
fragmentations = check_linkers(fragmentations)
fragmentation = random.choice(fragmentations)
core, chains = fragmentation
masked_frag = remove_dummys_mol(chains)[0]
masked_frag = remove_dummys_mol(masked_frag)[0]
elif self.masker == 'linker_signle_frag':
fragmentations = rdMMPA.FragmentMol(mol, minCuts=2, maxCuts=2, maxCutBonds=100, pattern=self.pattern, resultsAsMols=False)
fragmentations = check_linkers(fragmentations)
fragmentation = random.choice(fragmentations)
core, chains = fragmentation
frag = chains.split('.')
id = random.randint(0,1)
masked_frag = remove_dummys_mol(frag[id])[0]
elif self.masker == 'scaffold':
scaffold, side_chains = Murcko_decompose(mol)
if len(side_chains) == 0:
raise ValueError('Side Chains decomposition is None')
masked_frag = scaffold
elif self.masker == 'side_chain':
scaffold, side_chains = Murcko_decompose(mol)
if len(side_chains) == 0:
raise ValueError('Side Chains decomposition is None')
masked_frag = None
kept_frag = scaffold
else:
raise NotImplementedError('Please choose the supported masker type')
if masked_frag is not None:
masked_id = mol.GetSubstructMatch(masked_frag)
context_id = list(set(list(range(num_atoms))) - set(masked_id))
context_idx = torch.LongTensor(masked_id)
masked_idx = torch.LongTensor(context_id)
else:
context_id = mol.GetSubstructMatch(kept_frag)
masked_id = list(set(list(range(num_atoms))) - set(context_id))
context_idx = torch.LongTensor(masked_id)
masked_idx = torch.LongTensor(context_id)
data.context_idx = context_idx
data.masked_idx = masked_idx
# masked ligand atom element/feature/pos.
data.ligand_masked_element = data.ligand_element[masked_idx]
# data.ligand_masked_feature = data.ligand_atom_feature[masked_idx] # For Prediction. these features are chem properties
data.ligand_masked_pos = data.ligand_pos[masked_idx]
# context ligand atom elment/full features/pos. Note: num_neigh and num_valence features should be changed
data.ligand_context_element = data.ligand_element[context_idx]
data.ligand_context_feature_full = data.ligand_atom_feature_full[context_idx] # For Input
data.ligand_context_pos = data.ligand_pos[context_idx]
# new bond with ligand context atoms
data.ligand_context_bond_index, data.ligand_context_bond_type = subgraph(
context_idx,
data.ligand_bond_index,
edge_attr = data.ligand_bond_type,
relabel_nodes = True,
)
# change context atom features that relate to bonds
data.ligand_context_num_neighbors = LigandCountNeighbors.count_neighbors(
data.ligand_context_bond_index,
symmetry=True,
num_nodes = context_idx.size(0),
)
data.ligand_context_valence = LigandCountNeighbors.count_neighbors(
data.ligand_context_bond_index,
symmetry=True,
valence=data.ligand_context_bond_type,
num_nodes=context_idx.size(0)
)
data.ligand_context_num_bonds = torch.stack([
LigandCountNeighbors.count_neighbors(
data.ligand_context_bond_index,
symmetry=True,
valence=(data.ligand_context_bond_type == i).long(),
num_nodes=context_idx.size(0),
) for i in [1, 2, 3]
], dim = -1)
# re-calculate ligand_context_featrure_full
data.ligand_context_feature_full = FeaturizeLigandAtom.change_features_of_neigh(
data.ligand_context_feature_full,
data.ligand_context_num_neighbors,
data.ligand_context_valence,
data.ligand_context_num_bonds
)
data.ligand_frontier = data.ligand_context_num_neighbors < data.ligand_num_neighbors[context_idx]
data._mask = self.masker
except Exception as e:
print(e)
masking = LigandRandomMask(min_ratio=0.0, max_ratio=1.1, min_num_masked=1, min_num_unmasked=0)
masking(data)
data._mask = 'frag_decom_random'
return data
class LigandMixedMask(object):
def __init__(self, min_ratio=0.0, max_ratio=1.2, min_num_masked=1, min_num_unmasked=0, p_random=0.5, p_bfs=0.25, p_invbfs=0.25):
super().__init__()
self.t = [
LigandRandomMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked),
LigandBFSMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked, inverse=False),
LigandBFSMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked, inverse=True),
]
self.p = [p_random, p_bfs, p_invbfs]
def __call__(self, data):
f = random.choices(self.t, k=1, weights=self.p)[0]
return f(data)
class LigandMixedMaskLinker(object):
def __init__(self, min_ratio=0.0, max_ratio=1.2, min_num_masked=1, min_num_unmasked=0, p_random=0.4, p_bfs=0.2, p_invbfs=0.2,p_linker=0.2):
super().__init__()
self.t = [
LigandRandomMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked),
LigandBFSMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked, inverse=False),
LigandBFSMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked, inverse=True),
LigandMaskFrag(masker='linker')
]
self.p = [p_random, p_bfs, p_invbfs,p_linker]
def __call__(self, data):
f = random.choices(self.t, k=1, weights=self.p)[0]
return f(data)
class LigandMixedMaskFrag(object):
def __init__(self, min_ratio=0.0, max_ratio=1.2, min_num_masked=1, min_num_unmasked=0, p_random=0.3, p_bfs=0.2, p_invbfs=0.2, p_fragment=0.3):
super().__init__()
self.t = [
LigandRandomMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked),
LigandBFSMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked, inverse=False),
LigandBFSMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked, inverse=True),
LigandMaskFrag(masker='frag')
]
self.p = [p_random, p_bfs, p_invbfs, p_fragment]
def __call__(self, data):
f = random.choices(self.t, k=1, weights=self.p)[0]
return f(data)
class LigandMixedMaskScaffold(object):
def __init__(self, min_ratio=0.0, max_ratio=1.2, min_num_masked=1, min_num_unmasked=0, p_random=0.3, p_bfs=0.2, p_invbfs=0.2, p_scaffold=0.3):
super().__init__()
self.t = [
LigandRandomMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked),
LigandBFSMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked, inverse=False),
LigandBFSMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked, inverse=True),
LigandMaskFrag(masker='scaffold')
]
self.p = [p_random, p_bfs, p_invbfs, p_scaffold]
def __call__(self, data):
f = random.choices(self.t, k=1, weights=self.p)[0]
return f(data)
class LigandMixedMaskSideChain(object):
def __init__(self, min_ratio=0.0, max_ratio=1.2, min_num_masked=1, min_num_unmasked=0, p_random=0.3, p_bfs=0.2, p_invbfs=0.2, p_side_chain=0.3):
super().__init__()
self.t = [
LigandRandomMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked),
LigandBFSMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked, inverse=False),
LigandBFSMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked, inverse=True),
LigandMaskFrag(masker='side_chain')
]
self.p = [p_random, p_bfs, p_invbfs, p_side_chain]
def __call__(self, data):
f = random.choices(self.t, k=1, weights=self.p)[0]
return f(data)
def get_mask(cfg):
if cfg.type == 'bfs':
return LigandBFSMask(
min_ratio=cfg.min_ratio,
max_ratio=cfg.max_ratio,
min_num_masked=cfg.min_num_masked,
min_num_unmasked=cfg.min_num_unmasked,
)
elif cfg.type == 'random':
return LigandRandomMask(
min_ratio=cfg.min_ratio,
max_ratio=cfg.max_ratio,
min_num_masked=cfg.min_num_masked,
min_num_unmasked=cfg.min_num_unmasked,
)
elif cfg.type == 'mixed':
return LigandMixedMask(
min_ratio=cfg.min_ratio,
max_ratio=cfg.max_ratio,
min_num_masked=cfg.min_num_masked,
min_num_unmasked=cfg.min_num_unmasked,
p_random = cfg.p_random,
p_bfs = cfg.p_bfs,
p_invbfs = cfg.p_invbfs,
)
elif cfg.type == 'all':
return LigandMaskAll()
elif cfg.type == 'linker':
return LigandMixedMaskLinker(
min_ratio=cfg.min_ratio,
max_ratio=cfg.max_ratio,
min_num_masked=cfg.min_num_masked,
min_num_unmasked=cfg.min_num_unmasked,
p_random = cfg.p_random,
p_bfs = cfg.p_bfs,
p_invbfs = cfg.p_invbfs,
p_linker=cfg.p_linker
)
elif cfg.type == 'fragmentation':
return LigandMixedMaskFrag(
min_ratio=cfg.min_ratio,
max_ratio=cfg.max_ratio,
min_num_masked=cfg.min_num_masked,
min_num_unmasked=cfg.min_num_unmasked,
p_random = cfg.p_random,
p_bfs = cfg.p_bfs,
p_invbfs = cfg.p_invbfs,
p_fragment=cfg.p_fragment
)
elif cfg.type == 'scaffold':
return LigandMixedMaskScaffold(
min_ratio=cfg.min_ratio,
max_ratio=cfg.max_ratio,
min_num_masked=cfg.min_num_masked,
min_num_unmasked=cfg.min_num_unmasked,
p_random = cfg.p_random,
p_bfs = cfg.p_bfs,
p_invbfs = cfg.p_invbfs,
p_scaffold=cfg.p_scaffold
)
elif cfg.type == 'side_chain':
return LigandMixedMaskSideChain(
min_ratio=cfg.min_ratio,
max_ratio=cfg.max_ratio,
min_num_masked=cfg.min_num_masked,
min_num_unmasked=cfg.min_num_unmasked,
p_random = cfg.p_random,
p_bfs = cfg.p_bfs,
p_invbfs = cfg.p_invbfs,
p_side_chain=cfg.p_side_chain
)
else:
raise NotImplementedError('Unknown mask: %s' % cfg.type)
class ContrastiveSample(object):
def __init__(self, num_real=50, num_fake=50, pos_real_std=0.05, pos_fake_std=2.0, knn=32, elements=None):
# def __init__(self, knn=32, elements=None):
super().__init__()
self.num_real = num_real
self.num_fake = num_fake
self.pos_real_std = pos_real_std
self.pos_fake_std = pos_fake_std
self.knn = knn
if elements is None:
elements = [6,7,8,9,15,16,17] #OCNOF PSCI
self.elements = torch.LongTensor(elements)
@property
def num_elements(self):
return self.elements.size(0)
def __call__(self, data:ProteinLigandData):
# Positive samples
pos_real_mode = data.ligand_masked_pos
element_real = data.ligand_masked_element
# ind_real = data.ligand_masked_feature
cls_real = data.ligand_masked_element.view(-1, 1) == self.elements.view(1, -1)
assert (cls_real.sum(-1) > 0).all(), 'Unexpected elements.'
p = np.zeros(len(pos_real_mode), dtype=np.float32)
p[data.idx_generated_in_ligand_masked] = 1.
real_sample_idx = np.random.choice(np.arange(pos_real_mode.size(0)), size=self.num_real, p=p/p.sum())
data.pos_real = pos_real_mode[real_sample_idx]
data.pos_real += torch.randn_like(data.pos_real) * self.pos_real_std
data.element_real = element_real[real_sample_idx]
data.cls_real = cls_real[real_sample_idx]
# data.ind_real = ind_real[real_sample_idx]
# data.num_neighbors_real = data.ligand_masked_num_neighbors[real_sample_idx]
mask_ctx_edge_index_0 = data.mask_ctx_edge_index_0
mask_ctx_edge_index_1 = data.mask_ctx_edge_index_1
mask_ctx_edge_type = data.mask_ctx_edge_type
real_ctx_edge_idx_0_list, real_ctx_edge_idx_1_list, real_ctx_edge_type_list = [], [], []
for new_idx, real_node in enumerate(real_sample_idx):
idx_edge = (mask_ctx_edge_index_0 == real_node)
# real_ctx_edge_idx_0 = mask_ctx_edge_index_0[idx_edge] # get edges related to this node
real_ctx_edge_idx_1 = mask_ctx_edge_index_1[idx_edge] # get edges related to this node
real_ctx_edge_type = mask_ctx_edge_type[idx_edge]
real_ctx_edge_idx_0 = new_idx * torch.ones(idx_edge.sum(), dtype=torch.long) # change to new node index
real_ctx_edge_idx_0_list.append(real_ctx_edge_idx_0)
real_ctx_edge_idx_1_list.append(real_ctx_edge_idx_1)
real_ctx_edge_type_list.append(real_ctx_edge_type)
data.real_ctx_edge_index_0 = torch.cat(real_ctx_edge_idx_0_list, dim=-1)
data.real_ctx_edge_index_1 = torch.cat(real_ctx_edge_idx_1_list, dim=-1)
data.real_ctx_edge_type = torch.cat(real_ctx_edge_type_list, dim=-1)
data.real_compose_edge_index_0 = data.real_ctx_edge_index_0
data.real_compose_edge_index_1 = data.idx_ligand_ctx_in_compose[data.real_ctx_edge_index_1] # actually are the same
data.real_compose_edge_type = data.real_ctx_edge_type
# the triangle edge of the mask-compose edge
row, col = data.real_compose_edge_index_0, data.real_compose_edge_index_1
acc_num_edges = 0
index_real_cps_edge_i_list, index_real_cps_edge_j_list = [], [] # index of real-ctx edge (for attention)
for node in torch.arange(data.pos_real.size(0)):
num_edges = (row == node).sum()
index_edge_i = torch.arange(num_edges, dtype=torch.long, ) + acc_num_edges
index_edge_i, index_edge_j = torch.meshgrid(index_edge_i, index_edge_i, indexing=None)
index_edge_i, index_edge_j = index_edge_i.flatten(), index_edge_j.flatten()
index_real_cps_edge_i_list.append(index_edge_i)
index_real_cps_edge_j_list.append(index_edge_j)
acc_num_edges += num_edges
index_real_cps_edge_i = torch.cat(index_real_cps_edge_i_list, dim=0) # add len(real_compose_edge_index) in the dataloader for batch
index_real_cps_edge_j = torch.cat(index_real_cps_edge_j_list, dim=0)
node_a_cps_tri_edge = col[index_real_cps_edge_i] # the node of tirangle edge for the edge attention (in the compose)
node_b_cps_tri_edge = col[index_real_cps_edge_j]
n_context = len(data.ligand_context_pos)
adj_mat = torch.zeros([n_context, n_context], dtype=torch.long) - torch.eye(n_context, dtype=torch.long)
adj_mat[data.ligand_context_bond_index[0], data.ligand_context_bond_index[1]] = data.ligand_context_bond_type
tri_edge_type = adj_mat[node_a_cps_tri_edge, node_b_cps_tri_edge]
tri_edge_feat = (tri_edge_type.view([-1, 1]) == torch.tensor([[-1, 0, 1, 2, 3]])).long()
data.index_real_cps_edge_for_atten = torch.stack([
index_real_cps_edge_i, index_real_cps_edge_j # plus len(real_compose_edge_index_0) for dataloader batch
], dim=0)
data.tri_edge_index = torch.stack([
node_a_cps_tri_edge, node_b_cps_tri_edge # plus len(compose_pos) for dataloader batch
], dim=0)
data.tri_edge_feat = tri_edge_feat
# Negative samples
if len(data.ligand_context_pos) != 0: # all mask
pos_fake_mode = data.ligand_context_pos[data.ligand_frontier]
else:
pos_fake_mode = data.protein_pos[data.y_protein_frontier]
fake_sample_idx = np.random.choice(np.arange(pos_fake_mode.size(0)), size=self.num_fake)
pos_fake = pos_fake_mode[fake_sample_idx]
data.pos_fake = pos_fake + torch.randn_like(pos_fake) * self.pos_fake_std / 2.
# knn of query nodes
real_compose_knn_edge_index = knn(x=data.compose_pos, y=data.pos_real, k=self.knn, num_workers=16)
data.real_compose_knn_edge_index_0, data.real_compose_knn_edge_index_1 = real_compose_knn_edge_index
fake_compose_knn_edge_index = knn(x=data.compose_pos, y=data.pos_fake, k=self.knn, num_workers=16)
data.fake_compose_knn_edge_index_0, data.fake_compose_knn_edge_index_1 =fake_compose_knn_edge_index
return data
# def get_contrastive_sampler(cfg):
# return ContrastiveSample(
# num_real = cfg.num_real,
# num_fake = cfg.num_fake,
# pos_real_std = cfg.pos_real_std,
# pos_fake_std = cfg.pos_fake_std,
# )
class AtomComposer(object):
def __init__(self, protein_dim, ligand_dim, knn):
super().__init__()
self.protein_dim = protein_dim
self.ligand_dim = ligand_dim
self.knn = knn # knn of compose atoms
def __call__(self, data:ProteinLigandData):
# fetch ligand context and protein from data
ligand_context_pos = data.ligand_context_pos
ligand_context_feature_full = data.ligand_context_feature_full
protein_pos = data.protein_pos
protein_surf_feature = data.protein_surf_feature
len_ligand_ctx = len(ligand_context_pos)
len_protein = len(protein_pos)
# compose ligand context and protein. save idx of them in compose
data.compose_pos = torch.cat([ligand_context_pos, protein_pos], dim=0)
len_compose = len_ligand_ctx + len_protein
protein_surf_feature_full_expand = torch.cat([
protein_surf_feature, torch.zeros([len_protein,self.ligand_dim- self.protein_dim], dtype=torch.long)
], dim=1)
# ligand_context_feature_full_expand = torch.cat([
# ligand_context_feature_full, torch.zeros([len_ligand_ctx, self.protein_dim - self.ligand_dim], dtype=torch.long)
# ], dim=1)
# data.compose_feature = torch.cat([ligand_context_feature_full_expand, protein_surf_feature], dim=0)
data.compose_feature = torch.cat([ligand_context_feature_full, protein_surf_feature_full_expand],dim=0)
data.idx_ligand_ctx_in_compose = torch.arange(len_ligand_ctx, dtype=torch.long) # can be delete
data.idx_protein_in_compose = torch.arange(len_protein, dtype=torch.long) + len_ligand_ctx # can be delete
# build knn graph and bond type
data = self.get_knn_graph(data, self.knn, len_ligand_ctx, len_compose, num_workers=16)
return data
@staticmethod
def get_knn_graph(data:ProteinLigandData, knn, len_ligand_ctx, len_compose, num_workers=1, ):
data.compose_knn_edge_index = knn_graph(data.compose_pos, knn, flow='target_to_source', num_workers=num_workers)
id_compose_edge = data.compose_knn_edge_index[0, :len_ligand_ctx*knn] * len_compose + data.compose_knn_edge_index[1, :len_ligand_ctx*knn]
id_ligand_ctx_edge = data.ligand_context_bond_index[0] * len_compose + data.ligand_context_bond_index[1]
idx_edge = [torch.nonzero(id_compose_edge == id_) for id_ in id_ligand_ctx_edge]
idx_edge = torch.tensor([a.squeeze() if len(a) > 0 else torch.tensor(-1) for a in idx_edge], dtype=torch.long)
data.compose_knn_edge_type = torch.zeros(len(data.compose_knn_edge_index[0]), dtype=torch.long) # for encoder edge embedding
data.compose_knn_edge_type[idx_edge[idx_edge>=0]] = data.ligand_context_bond_type[idx_edge>=0]
data.compose_knn_edge_feature = torch.cat([
torch.ones([len(data.compose_knn_edge_index[0]), 1], dtype=torch.long),
torch.zeros([len(data.compose_knn_edge_index[0]), 3], dtype=torch.long),
], dim=-1)
data.compose_knn_edge_feature[idx_edge[idx_edge>=0]] = F.one_hot(data.ligand_context_bond_type[idx_edge>=0], num_classes=4) # 0 (1,2,3)-onehot
return data
class FocalBuilder(object):
def __init__(self, close_threshold=0.8, max_bond_length=2.4):
self.close_threshold = close_threshold
self.max_bond_length = max_bond_length
super().__init__()
def __call__(self, data:ProteinLigandData):
# ligand_context_pos = data.ligand_context_pos
# ligand_pos = data.ligand_pos
ligand_masked_pos = data.ligand_masked_pos
protein_pos = data.protein_pos
context_idx = data.context_idx
masked_idx = data.masked_idx
old_bond_index = data.ligand_bond_index
# old_bond_types = data.ligand_bond_type # type: 0, 1, 2
has_unmask_atoms = context_idx.nelement() > 0
if has_unmask_atoms:
# # get bridge bond index (mask-context bond)
ind_edge_index_candidate = [
(context_node in context_idx) and (mask_node in masked_idx)
for mask_node, context_node in zip(*old_bond_index)
] # the mask-context order is right
bridge_bond_index = old_bond_index[:, ind_edge_index_candidate]
# candidate_bond_types = old_bond_types[idx_edge_index_candidate]
idx_generated_in_whole_ligand = bridge_bond_index[0]
idx_focal_in_whole_ligand = bridge_bond_index[1]
index_changer_masked = torch.zeros(masked_idx.max()+1, dtype=torch.int64)
index_changer_masked[masked_idx] = torch.arange(len(masked_idx))
idx_generated_in_ligand_masked = index_changer_masked[idx_generated_in_whole_ligand]
pos_generate = ligand_masked_pos[idx_generated_in_ligand_masked]
data.idx_generated_in_ligand_masked = idx_generated_in_ligand_masked
data.pos_generate = pos_generate
index_changer_context = torch.zeros(context_idx.max()+1, dtype=torch.int64)
index_changer_context[context_idx] = torch.arange(len(context_idx))
idx_focal_in_ligand_context = index_changer_context[idx_focal_in_whole_ligand]
idx_focal_in_compose = idx_focal_in_ligand_context # if ligand_context was not before protein in the compose, this was not correct
data.idx_focal_in_compose = idx_focal_in_compose
data.idx_protein_all_mask = torch.empty(0, dtype=torch.long) # no use if has context
data.y_protein_frontier = torch.empty(0, dtype=torch.bool) # no use if has context
else: # # the initial atom. surface atoms between ligand and protein
assign_index = radius(x=ligand_masked_pos, y=protein_pos, r=4., num_workers=16)
if assign_index.size(1) == 0:
dist = torch.norm(data.protein_pos.unsqueeze(1) - data.ligand_masked_pos.unsqueeze(0), p=2, dim=-1)
assign_index = torch.nonzero(dist <= torch.min(dist)+1e-5)[0:1].transpose(0, 1)
idx_focal_in_protein = assign_index[0]
data.idx_focal_in_compose = idx_focal_in_protein # no ligand context, so all composes are protein atoms
data.pos_generate = ligand_masked_pos[assign_index[1]]
data.idx_generated_in_ligand_masked = torch.unique(assign_index[1]) # for real of the contractive transform
data.idx_protein_all_mask = data.idx_protein_in_compose # for input of initial frontier prediction
y_protein_frontier = torch.zeros_like(data.idx_protein_all_mask, dtype=torch.bool) # for label of initial frontier prediction
y_protein_frontier[torch.unique(idx_focal_in_protein)] = True
data.y_protein_frontier = y_protein_frontier
# generate not positions: around pos_focal ( with `max_bond_length` distance) but not close to true generated within `close_threshold`
# pos_focal = ligand_context_pos[idx_focal_in_ligand_context]
# pos_notgenerate = pos_focal + torch.randn_like(pos_focal) * self.max_bond_length / 2.4
# dist = torch.norm(pos_generate - pos_notgenerate, p=2, dim=-1)
# ind_close = (dist < self.close_threshold)
# while ind_close.any():
# new_pos_notgenerate = pos_focal[ind_close] + torch.randn_like(pos_focal[ind_close]) * self.max_bond_length / 2.3
# dist[ind_close] = torch.norm(pos_generate[ind_close] - new_pos_notgenerate, p=2, dim=-1)
# pos_notgenerate[ind_close] = new_pos_notgenerate
# ind_close = (dist < self.close_threshold)
# data.pos_notgenerate = pos_notgenerate
return data
class EdgeSample(object):
def __init__(self, cfg, num_bond_types=3):
super().__init__()
# self.neg_pos_ratio = cfg.neg_pos_ratio
self.k = cfg.k
# self.r = cfg.r
self.num_bond_types = num_bond_types
def __call__(self, data:ProteinLigandData):
ligand_context_pos = data.ligand_context_pos
ligand_masked_pos = data.ligand_masked_pos
context_idx = data.context_idx
masked_idx = data.masked_idx
old_bond_index = data.ligand_bond_index
old_bond_types = data.ligand_bond_type
# candidate edge: mask-contex edge
idx_edge_index_candidate = [
(context_node in context_idx) and (mask_node in masked_idx)
for mask_node, context_node in zip(*old_bond_index)
] # the mask-context order is right
candidate_bond_index = old_bond_index[:, idx_edge_index_candidate]
candidate_bond_types = old_bond_types[idx_edge_index_candidate]
# index changer
index_changer_masked = torch.zeros(masked_idx.max()+1, dtype=torch.int64)
index_changer_masked[masked_idx] = torch.arange(len(masked_idx))
has_unmask_atoms = context_idx.nelement() > 0
if has_unmask_atoms:
index_changer_context = torch.zeros(context_idx.max()+1, dtype=torch.int64)
index_changer_context[context_idx] = torch.arange(len(context_idx))
# new edge index (positive)
new_edge_index_0 = index_changer_masked[candidate_bond_index[0]]
new_edge_index_1 = index_changer_context[candidate_bond_index[1]]
new_edge_index = torch.stack([new_edge_index_0, new_edge_index_1])
new_edge_type = candidate_bond_types
neg_version = 0
if neg_version == 1: # radiu + tri_edge
# negative edge index (types = 0)
id_edge_pos = new_edge_index[0] * len(context_idx) + new_edge_index[1]
# 1. radius all edges
edge_index_radius = radius(ligand_context_pos, ligand_masked_pos, r=self.r, num_workers=16) # r = 3
id_edge_radius = edge_index_radius[0] * len(context_idx) + edge_index_radius[1]
not_pos_in_radius = torch.tensor([id_ not in id_edge_pos for id_ in id_edge_radius])
# 2. pick true neg edges and random choice
if not_pos_in_radius.size(0) > 0:
edge_index_neg = edge_index_radius[:, not_pos_in_radius]
dist = torch.norm(ligand_masked_pos[edge_index_neg[0]] - ligand_context_pos[edge_index_neg[1]], p=2, dim=-1)
probs = torch.clip(0.8 * (dist ** 2) - 4.8 * dist + 7.3 + 0.4, min=0.5, max=0.95)
values = torch.rand(len(dist))
choice = values < probs
edge_index_neg = edge_index_neg[:, choice]
else:
edge_index_neg = torch.empty([2, 0], dtype=torch.long)
# 3. edges form ring should be choicen
bond_index_ctx = data.ligand_context_bond_index
edge_index_ring_candidate = [[], []]
for node_i, node_j in zip(*new_edge_index):
node_k_all = bond_index_ctx[1, bond_index_ctx[0] == node_j]
edge_index_ring_candidate[0].append( torch.ones_like(node_k_all) * node_i)
edge_index_ring_candidate[1].append(node_k_all)
edge_index_ring_candidate[0] = torch.cat(edge_index_ring_candidate[0], dim=0)
edge_index_ring_candidate[1] = torch.cat(edge_index_ring_candidate[1], dim=0)
id_ring_candidate = edge_index_ring_candidate[0] * len(context_idx) + edge_index_ring_candidate[1]
edge_index_ring_candidate = torch.stack(edge_index_ring_candidate, dim=0)
not_pos_in_ring = torch.tensor([id_ not in id_edge_pos for id_ in id_ring_candidate])
if not_pos_in_ring.size(0) > 0:
edge_index_ring = edge_index_ring_candidate[:, not_pos_in_ring]
dist = torch.norm(ligand_masked_pos[edge_index_ring[0]] - ligand_context_pos[edge_index_ring[1]], p=2, dim=-1)
edge_index_ring = edge_index_ring[:, dist < 4.0]
else:
edge_index_ring = torch.empty([2, 0], dtype=torch.long)
# 4.cat neg and ring
false_edge_index = torch.cat([
edge_index_neg, edge_index_ring
], dim=-1)
false_edge_types = torch.zeros(len(false_edge_index[0]), dtype=torch.int64)
elif neg_version == 0: # knn edge
edge_index_knn = knn(ligand_context_pos, ligand_masked_pos, k=self.k, num_workers=16)
dist = torch.norm(ligand_masked_pos[edge_index_knn[0]] - ligand_context_pos[edge_index_knn[1]], p=2, dim=-1)
idx_sort = torch.argsort(dist) # choose negative edges as short as possible
num_neg_edges = min(len(ligand_masked_pos) * (self.k // 2) + len(new_edge_index[0]), len(idx_sort))
idx_sort = torch.unique(
torch.cat([
idx_sort[:num_neg_edges],
torch.linspace(0, len(idx_sort), len(ligand_masked_pos)+1, dtype=torch.long)[:-1] # each mask pos at least has one negative edge
], dim=0)
)
edge_index_knn = edge_index_knn[:, idx_sort]
id_edge_knn = edge_index_knn[0] * len(context_idx) + edge_index_knn[1] # delete false negative edges
id_edge_new = new_edge_index[0] * len(context_idx) + new_edge_index[1]
idx_real_edge_index = torch.tensor([id_ in id_edge_new for id_ in id_edge_knn])
false_edge_index = edge_index_knn[:, ~idx_real_edge_index]
false_edge_types = torch.zeros(len(false_edge_index[0]), dtype=torch.int64)
# cat
# print('Num of pos : neg edge:', len(new_edge_type), len(false_edge_types), len(new_edge_type) / len(false_edge_types))
new_edge_index = torch.cat([new_edge_index, false_edge_index], dim=-1)
new_edge_type = torch.cat([new_edge_type, false_edge_types], dim=0)
data.mask_ctx_edge_index_0 = new_edge_index[0]
data.mask_ctx_edge_index_1 = new_edge_index[1]
data.mask_ctx_edge_type = new_edge_type
data.mask_compose_edge_index_0 = data.mask_ctx_edge_index_0
data.mask_compose_edge_index_1 = data.idx_ligand_ctx_in_compose[data.mask_ctx_edge_index_1] # actually are the same
data.mask_compose_edge_type = new_edge_type
else:
data.mask_ctx_edge_index_0 = torch.empty([0], dtype=torch.int64)
data.mask_ctx_edge_index_1 = torch.empty([0], dtype=torch.int64)
data.mask_ctx_edge_type = torch.empty([0], dtype=torch.int64)
data.mask_compose_edge_index_0 = torch.empty([0], dtype=torch.int64)
data.mask_compose_edge_index_1 = torch.empty([0], dtype=torch.int64)
data.mask_compose_edge_type = torch.empty([0], dtype=torch.int64)
return data