[CONTRIB]: Freewilson now keeps the coordinates passed in (#8868)

* Freewilson now keeps the coordinates passed in

* Add better 3D reconstruction test
This commit is contained in:
Brian Kelley
2025-11-01 01:28:52 -04:00
committed by GitHub
parent b9b6078137
commit c756aff4f9
3 changed files with 3016 additions and 29 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -151,7 +151,7 @@ from rdkit.Chem import rdRGroupDecomposition as rgd
logger = logging.getLogger("freewilson")
FreeWilsonPrediction = namedtuple("FreeWilsonPrediction", ['prediction', 'smiles', 'rgroups'])
FreeWilsonPrediction = namedtuple("FreeWilsonPrediction", ['prediction', 'smiles', 'rgroups', 'mol', 'is_training'])
# match dummy atoms in a smiles string to extract atom maps
dummypat = re.compile(r"\*:([0-9]+)")
@@ -159,6 +159,7 @@ dummypat = re.compile(r"\*:([0-9]+)")
# molzip doesn't handle some of the forms that the RGroupDecomposition
# returns, this solves these issues.
def molzip_smi(smiles):
"""Fix a rgroup smiles for molzip, note that the core MUST come first
in the smiles string, ala core.rgroup1.rgroup2 ...
@@ -209,6 +210,59 @@ def molzip_smi(smiles):
m.AddBond(oatom.GetIdx(), xatom.GetIdx(), Chem.BondType.SINGLE)
return Chem.molzip(m)
def molzip_mols(mols):
"""Fix a rgroup smiles for molzip, note that the core MUST come first
in the smiles string, ala core.rgroup1.rgroup2 ...
"""
if not mols: return Chem.RWMol()
m = Chem.RWMol(mols[0])
dupes = set()
for mol in mols[1:]:
s = Chem.MolToSmiles(mol)
if s.count("*") >= 1:
if s in dupes:
continue
else:
dupes.add(s)
m.InsertMol(mol)
frags = Chem.GetMolFrags(m)
core = frags[0]
atommaps = {}
counts = defaultdict(int)
for idx in core:
atommap = m.GetAtomWithIdx(idx).GetAtomMapNum()
if atommap:
atommaps[atommap] = idx
counts[atommap] += 1
next_atommap = max(atommaps) + 1
add_atommap = []
for fragment in frags[1:]:
for idx in fragment:
atommap = m.GetAtomWithIdx(idx).GetAtomMapNum()
if atommap:
count = counts[atommap] = counts[atommap] + 1
if count > 2:
m.GetAtomWithIdx(idx).SetAtomMapNum(next_atommap)
add_atommap.append((atommaps[atommap], next_atommap))
next_atommap += 1
for atomidx, atommap in add_atommap:
atom = m.GetAtomWithIdx(atomidx)
bonds = list(atom.GetBonds())
if len(bonds) == 1:
oatom = bonds[0].GetOtherAtom(atom)
xatom = Chem.Atom(0)
idx = m.AddAtom(xatom)
xatom = m.GetAtomWithIdx(idx)
xatom.SetAtomMapNum(atommap)
m.AddBond(oatom.GetIdx(), xatom.GetIdx(), Chem.BondType.SINGLE)
return Chem.molzip(m)
class RGroup:
"""FreeWilson RGroup
@@ -219,13 +273,14 @@ class RGroup:
idx - one-hot encoding for the rgroup
"""
def __init__(self, smiles, rgroup, count, coefficient, idx=None):
def __init__(self, smiles, rgroup, count, coefficient, idx=None, mol=None):
self.smiles = smiles # smiles for the sidechain (n.b. can be a core as well)
self.rgroup = rgroup # rgroup Core, R1, R2,...
self.count = count # num molecules with this rgruop
self.coefficient = coefficient # ridge coefficient
self.idx = idx # descriptor index
self.dummies = tuple([int(x) for x in sorted(dummypat.findall(smiles))])
self.mol = mol
# Assemble some additive properties
@@ -302,7 +357,11 @@ default_decomp_params.scoreMethod = rgd.RGroupScore.FingerprintVariance
# we need to keep hydrogens so molzip will work
default_decomp_params.removeHydrogensPostMatch = False
class DecompEntry:
def __init__(self, mol):
self.mol = mol
self.smiles = Chem.MolToSmiles(mol)
def FWDecompose(scaffolds, mols, scores,
decomp_params=default_decomp_params) -> FreeWilsonDecomposition:
"""
@@ -366,20 +425,24 @@ def FWDecompose(scaffolds, mols, scores,
logger.error("No scaffolds matched the input molecules")
return
decomposition = decomposer.GetRGroupsAsRows(asSmiles=True)
#d = decomposer.GetRGroupsAsRows(asSmiles=True)
d = decomposer.GetRGroupsAsRows()
decomposition = [ {rg: DecompEntry(m) for rg, m in row.items()} for row in d]
logger.info("Get unique rgroups...")
blocker = rdBase.BlockLogs()
rgroup_counts = defaultdict(int)
num_reconstructed = 0
for num_mols, (row, idx) in enumerate(zip(decomposition, matched_indices)):
row_smiles = []
for rgroup, smiles in row.items():
row_smiles.append(smiles)
rgroup_counts[smiles] += 1
if smiles not in rgroup_idx:
rgroup_idx[smiles] = len(rgroup_idx)
rgroups[rgroup].append(RGroup(smiles, rgroup, 0, 0))
for rgroup, de in row.items():
row_smiles.append(de.smiles)
rgroup_counts[de.smiles] += 1
if de.smiles not in rgroup_idx:
rgroup_idx[de.smiles] = len(rgroup_idx)
rgroups[rgroup].append(RGroup(de.smiles, rgroup, 0, 0, mol=de.mol))
row['original_idx'] = idx
reconstructed = ".".join(row_smiles)
try:
@@ -387,7 +450,8 @@ def FWDecompose(scaffolds, mols, scores,
mol = molzip_smi(reconstructed)
num_reconstructed += 1
except:
print("failed:", Chem.MolToSmiles(matched[num_mols]), reconstructed)
logging.error("failed reconstructing %s, %s", Chem.MolToSmiles(matched[num_mols]),
reconstructed)
logger.info(f"Descriptor size {len(rgroup_idx)}")
logger.info(f"Reconstructed {num_reconstructed} out of {num_mols}")
@@ -401,9 +465,10 @@ def FWDecompose(scaffolds, mols, scores,
row['molecule'] = mol
descriptor = [0] * len(rgroup_idx)
descriptors.append(descriptor)
for smiles in row.values():
if smiles in rgroup_idx:
descriptor[rgroup_idx[smiles]] = 1
for k, de in row.items():
if k == "original_idx" or k == 'molecule': continue
if de.smiles in rgroup_idx:
descriptor[rgroup_idx[de.smiles]] = 1
assert len(descriptors) == len(
matched_scores
@@ -428,7 +493,8 @@ def FWDecompose(scaffolds, mols, scores,
num_reconstructed)
def _enumerate(rgroups, fw, mw_filter=None, hvy_filter=None, pred_filter=None, mol_filter=None):
def _enumerate(rgroups, fw, mw_filter=None, hvy_filter=None, pred_filter=None, mol_filter=None,
keep_training_set=False):
N = fw.N
fitter = fw.fitter
num_products = 1
@@ -450,6 +516,7 @@ def _enumerate(rgroups, fw, mw_filter=None, hvy_filter=None, pred_filter=None, m
max_pred = -1e10
min_pred = 1e10
delta = num_products // 10 or 1
for i, groups in tqdm(enumerate(itertools.product(*rgroups)), total=num_products):
if i and i % delta == 0:
logging.debug(
@@ -466,7 +533,11 @@ def _enumerate(rgroups, fw, mw_filter=None, hvy_filter=None, pred_filter=None, m
if tuple(descriptors) in fw.descriptors:
in_training_set += 1
continue
is_training = True
if not keep_training_set:
continue
else:
is_training = False
min_mw = min(min_mw, mw)
max_mw = max(max_mw, mw)
@@ -487,13 +558,22 @@ def _enumerate(rgroups, fw, mw_filter=None, hvy_filter=None, pred_filter=None, m
rejected_pred += 1
continue
good_pred += 1
smiles = set([g.smiles for g in groups]) # remove dupes
smi = ".".join(set([g.smiles for g in groups]))
try:
mol = molzip_smi(smi)
except:
rejected_bad += 1
continue
mols = [g.mol for g in groups]
if None in mols:
smiles = set([g.smiles for g in groups]) # remove dupes
smi = ".".join(set([g.smiles for g in groups]))
try:
mol = molzip_smi(smi)
except:
rejected_bad += 1
continue
else:
try:
mol = molzip_mols(mols)
except:
rejected_bad += 1
continue
rejected = False
if mol_filter and not mol_filter(mol):
@@ -501,7 +581,7 @@ def _enumerate(rgroups, fw, mw_filter=None, hvy_filter=None, pred_filter=None, m
continue
out_smi = Chem.MolToSmiles(mol)
yield FreeWilsonPrediction(pred, out_smi, groups)
yield FreeWilsonPrediction(pred, out_smi, groups, mol, is_training)
wrote += 1
logging.info(
f"Wrote {wrote} results out of {num_products}\n\tIn Training set: {in_training_set}\n\tBad MW: {rejected_mw}\n\tBad Pred: {rejected_pred}\n\tBad Filters: {rejected_filters}\n\tBad smi: {rejected_bad}\n\tmin mw: {min_mw}\n\tmax mw: {max_mw}\n\tBad HVY: {rejected_hvy}\n\tBad Pred: {rejected_pred}\n\tBad Filters: {rejected_filters}\n\tBad smi: {rejected_bad}\n\tmin mw: {min_mw}\n\tmax mw: {max_mw}\n\tmin hvy: {min_hvy}\n\tmax hvy: {max_hvy}\n\t\n\tmin pred: {min_pred}\n\tmax pred: {max_pred}"
@@ -509,7 +589,7 @@ def _enumerate(rgroups, fw, mw_filter=None, hvy_filter=None, pred_filter=None, m
def FWBuild(fw: FreeWilsonDecomposition, pred_filter=None, mw_filter=None, hvy_filter=None,
mol_filter=None) -> Generator[FreeWilsonPrediction, None, None]:
mol_filter=None, keep_training_set=False) -> Generator[FreeWilsonPrediction, None, None]:
"""Enumerate the freewilson decomposition and return their predictions
:param fw: FreeWilsonDecomposition generated from FWDecompose
@@ -551,7 +631,8 @@ def FWBuild(fw: FreeWilsonDecomposition, pred_filter=None, mw_filter=None, hvy_f
rgroups = [rgroup for key, rgroup in sorted(rgroups_no_cycles.items())]
# core is always first
for res in _enumerate(rgroups, fw, pred_filter=pred_filter, mw_filter=mw_filter,
hvy_filter=hvy_filter, mol_filter=mol_filter):
hvy_filter=hvy_filter, mol_filter=mol_filter,
keep_training_set=keep_training_set):
yield res
# iterate on rgroups with cycles

View File

@@ -2,10 +2,20 @@ import csv
import io
import logging
import os
import sys
import freewilson as fw
try:
import freewilson as fw
except:
path = os.path.abspath(os.curdir)
sys.path.insert(0, os.path.join(path, ".."))
import freewilson as fw
import importlib
fw = importlib.reload(fw)
from rdkit import Chem, rdBase
from rdkit.Chem import rdFMCS, rdMolAlign
PATH = os.path.join(os.path.dirname(fw.__file__), 'data')
assert os.path.exists(PATH), PATH
@@ -33,7 +43,7 @@ def test_chembl():
with rdBase.BlockLogs():
free = fw.FWDecompose(scaffold, mols, scores)
# let's make sure the r squared is decent
assert free.r2 > 0.8
assert free.r2 > 0.8, str(free.r2)
# assert we get something
preds = list(fw.FWBuild(free))
@@ -68,3 +78,32 @@ def test_multicore():
decomp = fw.FWDecompose(scaffolds, mols, [1, 2, 3, 4, 5, 6])
s = io.StringIO()
fw.predictions_to_csv(s, decomp, fw.FWBuild(decomp))
def test_fep_benchmark_3D():
benchmark = os.path.join(PATH, "cmet_ligands.sdf")
mols = list(Chem.SDMolSupplier(benchmark, removeHs=False))
smis = {Chem.MolToSmiles(m):m for m in mols if m}
assert smis
match = rdFMCS.FindMCS(mols)
scores = [float(m.GetProp("r_exp_dg")) for m in mols]
with rdBase.BlockLogs():
free = fw.FWDecompose(match.queryMol, mols, scores)
preds2 = list(fw.FWBuild(free, keep_training_set=True))
assert preds2
count = 0
found = 0
# we use an rms of 0.7 because we are assembling rgroups from different
# ligands together which may not exactly align
for p in preds2:
if not p.is_training: continue
smi = Chem.MolToSmiles(p.mol)
count += 1
if smi not in smis: continue
rms = rdMolAlign.GetBestRMS(smis[smi], p.mol)
assert rms < 0.7