mirror of
https://github.com/rdkit/rdkit.git
synced 2026-06-03 21:44:30 +08:00
[CONTRIB]: Freewilson now keeps the coordinates passed in (#8868)
* Freewilson now keeps the coordinates passed in * Add better 3D reconstruction test
This commit is contained in:
committed by
greg landrum
parent
037b2b227c
commit
1b94273ba9
2867
Contrib/FreeWilson/data/cmet_ligands.sdf
Normal file
2867
Contrib/FreeWilson/data/cmet_ligands.sdf
Normal file
File diff suppressed because it is too large
Load Diff
@@ -151,7 +151,7 @@ from rdkit.Chem import rdRGroupDecomposition as rgd
|
||||
|
||||
logger = logging.getLogger("freewilson")
|
||||
|
||||
FreeWilsonPrediction = namedtuple("FreeWilsonPrediction", ['prediction', 'smiles', 'rgroups'])
|
||||
FreeWilsonPrediction = namedtuple("FreeWilsonPrediction", ['prediction', 'smiles', 'rgroups', 'mol', 'is_training'])
|
||||
|
||||
# match dummy atoms in a smiles string to extract atom maps
|
||||
dummypat = re.compile(r"\*:([0-9]+)")
|
||||
@@ -159,6 +159,7 @@ dummypat = re.compile(r"\*:([0-9]+)")
|
||||
|
||||
# molzip doesn't handle some of the forms that the RGroupDecomposition
|
||||
# returns, this solves these issues.
|
||||
|
||||
def molzip_smi(smiles):
|
||||
"""Fix a rgroup smiles for molzip, note that the core MUST come first
|
||||
in the smiles string, ala core.rgroup1.rgroup2 ...
|
||||
@@ -209,6 +210,59 @@ def molzip_smi(smiles):
|
||||
m.AddBond(oatom.GetIdx(), xatom.GetIdx(), Chem.BondType.SINGLE)
|
||||
return Chem.molzip(m)
|
||||
|
||||
def molzip_mols(mols):
|
||||
"""Fix a rgroup smiles for molzip, note that the core MUST come first
|
||||
in the smiles string, ala core.rgroup1.rgroup2 ...
|
||||
"""
|
||||
if not mols: return Chem.RWMol()
|
||||
|
||||
m = Chem.RWMol(mols[0])
|
||||
dupes = set()
|
||||
for mol in mols[1:]:
|
||||
s = Chem.MolToSmiles(mol)
|
||||
if s.count("*") >= 1:
|
||||
if s in dupes:
|
||||
continue
|
||||
else:
|
||||
dupes.add(s)
|
||||
|
||||
m.InsertMol(mol)
|
||||
|
||||
frags = Chem.GetMolFrags(m)
|
||||
core = frags[0]
|
||||
atommaps = {}
|
||||
counts = defaultdict(int)
|
||||
for idx in core:
|
||||
atommap = m.GetAtomWithIdx(idx).GetAtomMapNum()
|
||||
if atommap:
|
||||
atommaps[atommap] = idx
|
||||
counts[atommap] += 1
|
||||
|
||||
next_atommap = max(atommaps) + 1
|
||||
add_atommap = []
|
||||
for fragment in frags[1:]:
|
||||
for idx in fragment:
|
||||
atommap = m.GetAtomWithIdx(idx).GetAtomMapNum()
|
||||
if atommap:
|
||||
count = counts[atommap] = counts[atommap] + 1
|
||||
if count > 2:
|
||||
m.GetAtomWithIdx(idx).SetAtomMapNum(next_atommap)
|
||||
add_atommap.append((atommaps[atommap], next_atommap))
|
||||
next_atommap += 1
|
||||
|
||||
for atomidx, atommap in add_atommap:
|
||||
atom = m.GetAtomWithIdx(atomidx)
|
||||
bonds = list(atom.GetBonds())
|
||||
if len(bonds) == 1:
|
||||
oatom = bonds[0].GetOtherAtom(atom)
|
||||
xatom = Chem.Atom(0)
|
||||
idx = m.AddAtom(xatom)
|
||||
xatom = m.GetAtomWithIdx(idx)
|
||||
xatom.SetAtomMapNum(atommap)
|
||||
m.AddBond(oatom.GetIdx(), xatom.GetIdx(), Chem.BondType.SINGLE)
|
||||
|
||||
return Chem.molzip(m)
|
||||
|
||||
|
||||
class RGroup:
|
||||
"""FreeWilson RGroup
|
||||
@@ -219,13 +273,14 @@ class RGroup:
|
||||
idx - one-hot encoding for the rgroup
|
||||
"""
|
||||
|
||||
def __init__(self, smiles, rgroup, count, coefficient, idx=None):
|
||||
def __init__(self, smiles, rgroup, count, coefficient, idx=None, mol=None):
|
||||
self.smiles = smiles # smiles for the sidechain (n.b. can be a core as well)
|
||||
self.rgroup = rgroup # rgroup Core, R1, R2,...
|
||||
self.count = count # num molecules with this rgruop
|
||||
self.coefficient = coefficient # ridge coefficient
|
||||
self.idx = idx # descriptor index
|
||||
self.dummies = tuple([int(x) for x in sorted(dummypat.findall(smiles))])
|
||||
self.mol = mol
|
||||
|
||||
# Assemble some additive properties
|
||||
|
||||
@@ -302,7 +357,11 @@ default_decomp_params.scoreMethod = rgd.RGroupScore.FingerprintVariance
|
||||
# we need to keep hydrogens so molzip will work
|
||||
default_decomp_params.removeHydrogensPostMatch = False
|
||||
|
||||
|
||||
class DecompEntry:
|
||||
def __init__(self, mol):
|
||||
self.mol = mol
|
||||
self.smiles = Chem.MolToSmiles(mol)
|
||||
|
||||
def FWDecompose(scaffolds, mols, scores,
|
||||
decomp_params=default_decomp_params) -> FreeWilsonDecomposition:
|
||||
"""
|
||||
@@ -366,20 +425,24 @@ def FWDecompose(scaffolds, mols, scores,
|
||||
logger.error("No scaffolds matched the input molecules")
|
||||
return
|
||||
|
||||
decomposition = decomposer.GetRGroupsAsRows(asSmiles=True)
|
||||
|
||||
#d = decomposer.GetRGroupsAsRows(asSmiles=True)
|
||||
d = decomposer.GetRGroupsAsRows()
|
||||
|
||||
decomposition = [ {rg: DecompEntry(m) for rg, m in row.items()} for row in d]
|
||||
|
||||
logger.info("Get unique rgroups...")
|
||||
blocker = rdBase.BlockLogs()
|
||||
rgroup_counts = defaultdict(int)
|
||||
num_reconstructed = 0
|
||||
|
||||
for num_mols, (row, idx) in enumerate(zip(decomposition, matched_indices)):
|
||||
row_smiles = []
|
||||
for rgroup, smiles in row.items():
|
||||
row_smiles.append(smiles)
|
||||
rgroup_counts[smiles] += 1
|
||||
if smiles not in rgroup_idx:
|
||||
rgroup_idx[smiles] = len(rgroup_idx)
|
||||
rgroups[rgroup].append(RGroup(smiles, rgroup, 0, 0))
|
||||
for rgroup, de in row.items():
|
||||
row_smiles.append(de.smiles)
|
||||
rgroup_counts[de.smiles] += 1
|
||||
if de.smiles not in rgroup_idx:
|
||||
rgroup_idx[de.smiles] = len(rgroup_idx)
|
||||
rgroups[rgroup].append(RGroup(de.smiles, rgroup, 0, 0, mol=de.mol))
|
||||
row['original_idx'] = idx
|
||||
reconstructed = ".".join(row_smiles)
|
||||
try:
|
||||
@@ -387,7 +450,8 @@ def FWDecompose(scaffolds, mols, scores,
|
||||
mol = molzip_smi(reconstructed)
|
||||
num_reconstructed += 1
|
||||
except:
|
||||
print("failed:", Chem.MolToSmiles(matched[num_mols]), reconstructed)
|
||||
logging.error("failed reconstructing %s, %s", Chem.MolToSmiles(matched[num_mols]),
|
||||
reconstructed)
|
||||
|
||||
logger.info(f"Descriptor size {len(rgroup_idx)}")
|
||||
logger.info(f"Reconstructed {num_reconstructed} out of {num_mols}")
|
||||
@@ -401,9 +465,10 @@ def FWDecompose(scaffolds, mols, scores,
|
||||
row['molecule'] = mol
|
||||
descriptor = [0] * len(rgroup_idx)
|
||||
descriptors.append(descriptor)
|
||||
for smiles in row.values():
|
||||
if smiles in rgroup_idx:
|
||||
descriptor[rgroup_idx[smiles]] = 1
|
||||
for k, de in row.items():
|
||||
if k == "original_idx" or k == 'molecule': continue
|
||||
if de.smiles in rgroup_idx:
|
||||
descriptor[rgroup_idx[de.smiles]] = 1
|
||||
|
||||
assert len(descriptors) == len(
|
||||
matched_scores
|
||||
@@ -428,7 +493,8 @@ def FWDecompose(scaffolds, mols, scores,
|
||||
num_reconstructed)
|
||||
|
||||
|
||||
def _enumerate(rgroups, fw, mw_filter=None, hvy_filter=None, pred_filter=None, mol_filter=None):
|
||||
def _enumerate(rgroups, fw, mw_filter=None, hvy_filter=None, pred_filter=None, mol_filter=None,
|
||||
keep_training_set=False):
|
||||
N = fw.N
|
||||
fitter = fw.fitter
|
||||
num_products = 1
|
||||
@@ -450,6 +516,7 @@ def _enumerate(rgroups, fw, mw_filter=None, hvy_filter=None, pred_filter=None, m
|
||||
max_pred = -1e10
|
||||
min_pred = 1e10
|
||||
delta = num_products // 10 or 1
|
||||
|
||||
for i, groups in tqdm(enumerate(itertools.product(*rgroups)), total=num_products):
|
||||
if i and i % delta == 0:
|
||||
logging.debug(
|
||||
@@ -466,7 +533,11 @@ def _enumerate(rgroups, fw, mw_filter=None, hvy_filter=None, pred_filter=None, m
|
||||
|
||||
if tuple(descriptors) in fw.descriptors:
|
||||
in_training_set += 1
|
||||
continue
|
||||
is_training = True
|
||||
if not keep_training_set:
|
||||
continue
|
||||
else:
|
||||
is_training = False
|
||||
|
||||
min_mw = min(min_mw, mw)
|
||||
max_mw = max(max_mw, mw)
|
||||
@@ -487,13 +558,22 @@ def _enumerate(rgroups, fw, mw_filter=None, hvy_filter=None, pred_filter=None, m
|
||||
rejected_pred += 1
|
||||
continue
|
||||
good_pred += 1
|
||||
smiles = set([g.smiles for g in groups]) # remove dupes
|
||||
smi = ".".join(set([g.smiles for g in groups]))
|
||||
try:
|
||||
mol = molzip_smi(smi)
|
||||
except:
|
||||
rejected_bad += 1
|
||||
continue
|
||||
|
||||
mols = [g.mol for g in groups]
|
||||
if None in mols:
|
||||
smiles = set([g.smiles for g in groups]) # remove dupes
|
||||
smi = ".".join(set([g.smiles for g in groups]))
|
||||
try:
|
||||
mol = molzip_smi(smi)
|
||||
except:
|
||||
rejected_bad += 1
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
mol = molzip_mols(mols)
|
||||
except:
|
||||
rejected_bad += 1
|
||||
continue
|
||||
|
||||
rejected = False
|
||||
if mol_filter and not mol_filter(mol):
|
||||
@@ -501,7 +581,7 @@ def _enumerate(rgroups, fw, mw_filter=None, hvy_filter=None, pred_filter=None, m
|
||||
continue
|
||||
|
||||
out_smi = Chem.MolToSmiles(mol)
|
||||
yield FreeWilsonPrediction(pred, out_smi, groups)
|
||||
yield FreeWilsonPrediction(pred, out_smi, groups, mol, is_training)
|
||||
wrote += 1
|
||||
logging.info(
|
||||
f"Wrote {wrote} results out of {num_products}\n\tIn Training set: {in_training_set}\n\tBad MW: {rejected_mw}\n\tBad Pred: {rejected_pred}\n\tBad Filters: {rejected_filters}\n\tBad smi: {rejected_bad}\n\tmin mw: {min_mw}\n\tmax mw: {max_mw}\n\tBad HVY: {rejected_hvy}\n\tBad Pred: {rejected_pred}\n\tBad Filters: {rejected_filters}\n\tBad smi: {rejected_bad}\n\tmin mw: {min_mw}\n\tmax mw: {max_mw}\n\tmin hvy: {min_hvy}\n\tmax hvy: {max_hvy}\n\t\n\tmin pred: {min_pred}\n\tmax pred: {max_pred}"
|
||||
@@ -509,7 +589,7 @@ def _enumerate(rgroups, fw, mw_filter=None, hvy_filter=None, pred_filter=None, m
|
||||
|
||||
|
||||
def FWBuild(fw: FreeWilsonDecomposition, pred_filter=None, mw_filter=None, hvy_filter=None,
|
||||
mol_filter=None) -> Generator[FreeWilsonPrediction, None, None]:
|
||||
mol_filter=None, keep_training_set=False) -> Generator[FreeWilsonPrediction, None, None]:
|
||||
"""Enumerate the freewilson decomposition and return their predictions
|
||||
|
||||
:param fw: FreeWilsonDecomposition generated from FWDecompose
|
||||
@@ -551,7 +631,8 @@ def FWBuild(fw: FreeWilsonDecomposition, pred_filter=None, mw_filter=None, hvy_f
|
||||
rgroups = [rgroup for key, rgroup in sorted(rgroups_no_cycles.items())]
|
||||
# core is always first
|
||||
for res in _enumerate(rgroups, fw, pred_filter=pred_filter, mw_filter=mw_filter,
|
||||
hvy_filter=hvy_filter, mol_filter=mol_filter):
|
||||
hvy_filter=hvy_filter, mol_filter=mol_filter,
|
||||
keep_training_set=keep_training_set):
|
||||
yield res
|
||||
|
||||
# iterate on rgroups with cycles
|
||||
|
||||
@@ -2,10 +2,20 @@ import csv
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import freewilson as fw
|
||||
try:
|
||||
import freewilson as fw
|
||||
except:
|
||||
path = os.path.abspath(os.curdir)
|
||||
sys.path.insert(0, os.path.join(path, ".."))
|
||||
import freewilson as fw
|
||||
|
||||
import importlib
|
||||
fw = importlib.reload(fw)
|
||||
|
||||
from rdkit import Chem, rdBase
|
||||
from rdkit.Chem import rdFMCS, rdMolAlign
|
||||
|
||||
PATH = os.path.join(os.path.dirname(fw.__file__), 'data')
|
||||
assert os.path.exists(PATH), PATH
|
||||
@@ -33,7 +43,7 @@ def test_chembl():
|
||||
with rdBase.BlockLogs():
|
||||
free = fw.FWDecompose(scaffold, mols, scores)
|
||||
# let's make sure the r squared is decent
|
||||
assert free.r2 > 0.8
|
||||
assert free.r2 > 0.8, str(free.r2)
|
||||
|
||||
# assert we get something
|
||||
preds = list(fw.FWBuild(free))
|
||||
@@ -68,3 +78,32 @@ def test_multicore():
|
||||
decomp = fw.FWDecompose(scaffolds, mols, [1, 2, 3, 4, 5, 6])
|
||||
s = io.StringIO()
|
||||
fw.predictions_to_csv(s, decomp, fw.FWBuild(decomp))
|
||||
|
||||
def test_fep_benchmark_3D():
|
||||
benchmark = os.path.join(PATH, "cmet_ligands.sdf")
|
||||
mols = list(Chem.SDMolSupplier(benchmark, removeHs=False))
|
||||
smis = {Chem.MolToSmiles(m):m for m in mols if m}
|
||||
assert smis
|
||||
match = rdFMCS.FindMCS(mols)
|
||||
|
||||
scores = [float(m.GetProp("r_exp_dg")) for m in mols]
|
||||
with rdBase.BlockLogs():
|
||||
free = fw.FWDecompose(match.queryMol, mols, scores)
|
||||
preds2 = list(fw.FWBuild(free, keep_training_set=True))
|
||||
assert preds2
|
||||
|
||||
count = 0
|
||||
found = 0
|
||||
# we use an rms of 0.7 because we are assembling rgroups from different
|
||||
# ligands together which may not exactly align
|
||||
for p in preds2:
|
||||
if not p.is_training: continue
|
||||
|
||||
smi = Chem.MolToSmiles(p.mol)
|
||||
|
||||
count += 1
|
||||
if smi not in smis: continue
|
||||
rms = rdMolAlign.GetBestRMS(smis[smi], p.mol)
|
||||
assert rms < 0.7
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user