mirror of
https://github.com/rdkit/rdkit.git
synced 2026-06-03 21:44:30 +08:00
refactor: improve readability and maintainability of AAP similarity code (#9277)
* refactor: clean up AAP similarity logic and add type hints * Refactor AAP similarity implementation for clarity and maintainability * Refactor AAP similarity implementation for clarity and maintainability * ran yapf over the code
This commit is contained in:
committed by
GitHub
parent
2f6bbe03b0
commit
24f0007757
@@ -9,11 +9,10 @@ import unittest
|
||||
import numpy
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
from rdkit import Chem, DataStructs
|
||||
from rdkit.Chem import AllChem, rdmolops
|
||||
from rdkit.Chem.Fingerprints import FingerprintMols
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import rdmolops
|
||||
|
||||
_BK_ = {
|
||||
_bondTypeCodes = {
|
||||
Chem.rdchem.BondType.SINGLE: 1,
|
||||
Chem.rdchem.BondType.DOUBLE: 2,
|
||||
Chem.rdchem.BondType.TRIPLE: 3,
|
||||
@@ -21,9 +20,9 @@ _BK_ = {
|
||||
}
|
||||
_BONDSYMBOL_ = {1: '-', 2: '=', 3: '#', 4: ':'}
|
||||
|
||||
#_nAT_ = 217 # 108*2+1
|
||||
_nAT_ = 223 # Gobbi code actually uses the first prime higher than 217, not 217 itself
|
||||
_nBT_ = 5
|
||||
#_atomHashModulus = 217 # 108*2+1
|
||||
_atomHashModulus = 223 # Gobbi code actually uses the first prime higher than 217, not 217 itself
|
||||
_bondHashModulus = 5
|
||||
|
||||
#def FindAllPathsOfLengthN_Gobbi(mol, length, rootedAtAtom=-1, uniquepaths=True):
|
||||
# return FindAllPathsOfLengthMToN(mol, length, length, rootedAtAtom=rootedAtAtom, uniquepaths=uniquepaths)
|
||||
@@ -61,12 +60,12 @@ def _FindAllPathsOfLengthMToN_Gobbi(atom, path, minlength, maxlength, visited, p
|
||||
if len(path) >= minlength and len(path) <= maxlength:
|
||||
paths.append(tuple(path))
|
||||
if len(path) < maxlength:
|
||||
a1 = bond.GetBeginAtom()
|
||||
a2 = bond.GetEndAtom()
|
||||
if a1.GetIdx() == atom.GetIdx():
|
||||
nextatom = a2
|
||||
beginAtom = bond.GetBeginAtom()
|
||||
endAtom = bond.GetEndAtom()
|
||||
if beginAtom.GetIdx() == atom.GetIdx():
|
||||
nextatom = endAtom
|
||||
else:
|
||||
nextatom = a1
|
||||
nextatom = beginAtom
|
||||
nextatomidx = nextatom.GetIdx()
|
||||
if nextatomidx not in visited:
|
||||
visited.add(nextatomidx)
|
||||
@@ -79,7 +78,7 @@ def getpathintegers(m1, uptolength=7):
|
||||
'''returns a list of integers describing the paths for molecule m1. This uses numpy 16 bit unsigned integers to reproduce the data in the Gobbi paper. The returned list is sorted'''
|
||||
bondtypelookup = {}
|
||||
for b in m1.GetBonds():
|
||||
bondtypelookup[b.GetIdx()] = _BK_[b.GetBondType()], b.GetBeginAtom(), b.GetEndAtom()
|
||||
bondtypelookup[b.GetIdx()] = _bondTypeCodes[b.GetBondType()], b.GetBeginAtom(), b.GetEndAtom()
|
||||
pathintegers = {}
|
||||
for a in m1.GetAtoms():
|
||||
idx = a.GetIdx()
|
||||
@@ -88,40 +87,32 @@ def getpathintegers(m1, uptolength=7):
|
||||
# for path in rdmolops.FindAllPathsOfLengthN(m1, pathlength, rootedAtAtom=idx):
|
||||
for ipath, path in enumerate(
|
||||
FindAllPathsOfLengthMToN_Gobbi(m1, 1, uptolength, rootedAtAtom=idx, uniquepaths=False)):
|
||||
strpath = []
|
||||
currentidx = idx
|
||||
res = []
|
||||
for ip, p in enumerate(path):
|
||||
bk, a1, a2 = bondtypelookup[p]
|
||||
strpath.append(_BONDSYMBOL_[bk])
|
||||
if a1.GetIdx() == currentidx:
|
||||
a = a2
|
||||
bondTypeCode, beginAtom, endAtom = bondtypelookup[p]
|
||||
if beginAtom.GetIdx() == currentidx:
|
||||
a = endAtom
|
||||
else:
|
||||
a = a1
|
||||
ak = a.GetAtomicNum()
|
||||
a = beginAtom
|
||||
neighborAtomCode = a.GetAtomicNum()
|
||||
if a.GetIsAromatic():
|
||||
ak += 108
|
||||
neighborAtomCode += 108
|
||||
#trying to get the same behaviour as the Gobbi test code - it looks like a circular path includes the bond, but not the closure atom - this fix works
|
||||
if a.GetIdx() == idx:
|
||||
ak = None
|
||||
if ak is not None:
|
||||
astr = a.GetSymbol()
|
||||
if a.GetIsAromatic():
|
||||
strpath.append(astr.lower())
|
||||
else:
|
||||
strpath.append(astr)
|
||||
res.append((bk, ak))
|
||||
neighborAtomCode = None
|
||||
res.append((bondTypeCode, neighborAtomCode))
|
||||
currentidx = a.GetIdx()
|
||||
pathuniqueint = numpy.ushort(0) # work with 16 bit unsigned integers and ignore overflow...
|
||||
for ires, (bi, ai) in enumerate(res):
|
||||
for ires, (bondValue, atomValue) in enumerate(res):
|
||||
#use 16 bit unsigned integer arithmetic to reproduce the Gobbi ints
|
||||
# pathuniqueint = ((pathuniqueint+bi)*_nAT_+ai)*_nBT_
|
||||
val1 = pathuniqueint + numpy.ushort(bi)
|
||||
val2 = val1 * numpy.ushort(_nAT_)
|
||||
# pathuniqueint = ((pathuniqueint+bondValue)*_atomHashModulus+atomValue)*_bondHashModulus
|
||||
val1 = pathuniqueint + numpy.ushort(bondValue)
|
||||
val2 = val1 * numpy.ushort(_atomHashModulus)
|
||||
#trying to get the same behaviour as the Gobbi test code - it looks like a circular path includes the bond, but not the closure atom - this fix works
|
||||
if ai is not None:
|
||||
val3 = val2 + numpy.ushort(ai)
|
||||
val4 = val3 * numpy.ushort(_nBT_)
|
||||
if atomValue is not None:
|
||||
val3 = val2 + numpy.ushort(atomValue)
|
||||
val4 = val3 * numpy.ushort(_bondHashModulus)
|
||||
else:
|
||||
val4 = val2
|
||||
pathuniqueint = val4
|
||||
@@ -140,14 +131,14 @@ def getcommon(l1, ll1, l2, ll2):
|
||||
ix1 = 0
|
||||
ix2 = 0
|
||||
while (ix1 < ll1) and (ix2 < ll2):
|
||||
a1 = l1[ix1]
|
||||
a2 = l2[ix2]
|
||||
#a1 is < or > more often that ==
|
||||
if a1 < a2:
|
||||
beginAtom = l1[ix1]
|
||||
endAtom = l2[ix2]
|
||||
#beginAtom is < or > more often that ==
|
||||
if beginAtom < endAtom:
|
||||
ix1 += 1
|
||||
elif a1 > a2:
|
||||
elif beginAtom > endAtom:
|
||||
ix2 += 1
|
||||
else: # a1 == a2:
|
||||
else: # beginAtom == endAtom:
|
||||
ncommon += 1
|
||||
ix1 += 1
|
||||
ix2 += 1
|
||||
@@ -206,19 +197,20 @@ def getsimab(mappings, simmatrixdict):
|
||||
def getsimmatrix(m1, m1pathintegers, m2, m2pathintegers):
|
||||
'''generate a matrix of atom atom similarities. See Figure 4'''
|
||||
|
||||
aidata = [((ai.GetAtomicNum(), ai.GetIsAromatic()), ai.GetIdx()) for ai in m1.GetAtoms()]
|
||||
aidata = [((atomValue.GetAtomicNum(), atomValue.GetIsAromatic()), atomValue.GetIdx())
|
||||
for atomValue in m1.GetAtoms()]
|
||||
bjdata = [((bj.GetAtomicNum(), bj.GetIsAromatic()), bj.GetIdx()) for bj in m2.GetAtoms()]
|
||||
|
||||
simmatrixarray = numpy.zeros((len(aidata), len(bjdata)))
|
||||
|
||||
for ai, (aitype, aiidx) in enumerate(aidata):
|
||||
for atomValue, (aitype, aiidx) in enumerate(aidata):
|
||||
aipaths = m1pathintegers[aiidx]
|
||||
naipaths = len(aipaths)
|
||||
for bj, (bjtype, bjidx) in enumerate(bjdata):
|
||||
if aitype == bjtype:
|
||||
bjpaths = m2pathintegers[bjidx]
|
||||
nbjpaths = len(bjpaths)
|
||||
simmatrixarray[ai][bj] = getsimaibj(aipaths, bjpaths, naipaths, nbjpaths)
|
||||
simmatrixarray[atomValue][bj] = getsimaibj(aipaths, bjpaths, naipaths, nbjpaths)
|
||||
return simmatrixarray
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user