refactor: improve readability and maintainability of AAP similarity code (#9277)

* refactor: clean up AAP similarity logic and add type hints

* Refactor AAP similarity implementation for clarity and maintainability

* Refactor AAP similarity implementation for clarity and maintainability

* ran yapf over the code
This commit is contained in:
reza bagheri alashti
2026-05-14 08:22:12 +03:30
committed by greg landrum
parent 3c81879dc9
commit 8120d418fb

View File

@@ -9,11 +9,10 @@ import unittest
import numpy
from scipy.optimize import linear_sum_assignment
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, rdmolops
from rdkit.Chem.Fingerprints import FingerprintMols
from rdkit import Chem
from rdkit.Chem import rdmolops
_BK_ = {
_bondTypeCodes = {
Chem.rdchem.BondType.SINGLE: 1,
Chem.rdchem.BondType.DOUBLE: 2,
Chem.rdchem.BondType.TRIPLE: 3,
@@ -21,9 +20,9 @@ _BK_ = {
}
_BONDSYMBOL_ = {1: '-', 2: '=', 3: '#', 4: ':'}
#_nAT_ = 217 # 108*2+1
_nAT_ = 223 # Gobbi code actually uses the first prime higher than 217, not 217 itself
_nBT_ = 5
#_atomHashModulus = 217 # 108*2+1
_atomHashModulus = 223 # Gobbi code actually uses the first prime higher than 217, not 217 itself
_bondHashModulus = 5
#def FindAllPathsOfLengthN_Gobbi(mol, length, rootedAtAtom=-1, uniquepaths=True):
# return FindAllPathsOfLengthMToN(mol, length, length, rootedAtAtom=rootedAtAtom, uniquepaths=uniquepaths)
@@ -61,12 +60,12 @@ def _FindAllPathsOfLengthMToN_Gobbi(atom, path, minlength, maxlength, visited, p
if len(path) >= minlength and len(path) <= maxlength:
paths.append(tuple(path))
if len(path) < maxlength:
a1 = bond.GetBeginAtom()
a2 = bond.GetEndAtom()
if a1.GetIdx() == atom.GetIdx():
nextatom = a2
beginAtom = bond.GetBeginAtom()
endAtom = bond.GetEndAtom()
if beginAtom.GetIdx() == atom.GetIdx():
nextatom = endAtom
else:
nextatom = a1
nextatom = beginAtom
nextatomidx = nextatom.GetIdx()
if nextatomidx not in visited:
visited.add(nextatomidx)
@@ -79,7 +78,7 @@ def getpathintegers(m1, uptolength=7):
'''returns a list of integers describing the paths for molecule m1. This uses numpy 16 bit unsigned integers to reproduce the data in the Gobbi paper. The returned list is sorted'''
bondtypelookup = {}
for b in m1.GetBonds():
bondtypelookup[b.GetIdx()] = _BK_[b.GetBondType()], b.GetBeginAtom(), b.GetEndAtom()
bondtypelookup[b.GetIdx()] = _bondTypeCodes[b.GetBondType()], b.GetBeginAtom(), b.GetEndAtom()
pathintegers = {}
for a in m1.GetAtoms():
idx = a.GetIdx()
@@ -88,40 +87,32 @@ def getpathintegers(m1, uptolength=7):
# for path in rdmolops.FindAllPathsOfLengthN(m1, pathlength, rootedAtAtom=idx):
for ipath, path in enumerate(
FindAllPathsOfLengthMToN_Gobbi(m1, 1, uptolength, rootedAtAtom=idx, uniquepaths=False)):
strpath = []
currentidx = idx
res = []
for ip, p in enumerate(path):
bk, a1, a2 = bondtypelookup[p]
strpath.append(_BONDSYMBOL_[bk])
if a1.GetIdx() == currentidx:
a = a2
bondTypeCode, beginAtom, endAtom = bondtypelookup[p]
if beginAtom.GetIdx() == currentidx:
a = endAtom
else:
a = a1
ak = a.GetAtomicNum()
a = beginAtom
neighborAtomCode = a.GetAtomicNum()
if a.GetIsAromatic():
ak += 108
neighborAtomCode += 108
#trying to get the same behaviour as the Gobbi test code - it looks like a circular path includes the bond, but not the closure atom - this fix works
if a.GetIdx() == idx:
ak = None
if ak is not None:
astr = a.GetSymbol()
if a.GetIsAromatic():
strpath.append(astr.lower())
else:
strpath.append(astr)
res.append((bk, ak))
neighborAtomCode = None
res.append((bondTypeCode, neighborAtomCode))
currentidx = a.GetIdx()
pathuniqueint = numpy.ushort(0) # work with 16 bit unsigned integers and ignore overflow...
for ires, (bi, ai) in enumerate(res):
for ires, (bondValue, atomValue) in enumerate(res):
#use 16 bit unsigned integer arithmetic to reproduce the Gobbi ints
# pathuniqueint = ((pathuniqueint+bi)*_nAT_+ai)*_nBT_
val1 = pathuniqueint + numpy.ushort(bi)
val2 = val1 * numpy.ushort(_nAT_)
# pathuniqueint = ((pathuniqueint+bondValue)*_atomHashModulus+atomValue)*_bondHashModulus
val1 = pathuniqueint + numpy.ushort(bondValue)
val2 = val1 * numpy.ushort(_atomHashModulus)
#trying to get the same behaviour as the Gobbi test code - it looks like a circular path includes the bond, but not the closure atom - this fix works
if ai is not None:
val3 = val2 + numpy.ushort(ai)
val4 = val3 * numpy.ushort(_nBT_)
if atomValue is not None:
val3 = val2 + numpy.ushort(atomValue)
val4 = val3 * numpy.ushort(_bondHashModulus)
else:
val4 = val2
pathuniqueint = val4
@@ -140,14 +131,14 @@ def getcommon(l1, ll1, l2, ll2):
ix1 = 0
ix2 = 0
while (ix1 < ll1) and (ix2 < ll2):
a1 = l1[ix1]
a2 = l2[ix2]
#a1 is < or > more often that ==
if a1 < a2:
beginAtom = l1[ix1]
endAtom = l2[ix2]
#beginAtom is < or > more often that ==
if beginAtom < endAtom:
ix1 += 1
elif a1 > a2:
elif beginAtom > endAtom:
ix2 += 1
else: # a1 == a2:
else: # beginAtom == endAtom:
ncommon += 1
ix1 += 1
ix2 += 1
@@ -206,19 +197,20 @@ def getsimab(mappings, simmatrixdict):
def getsimmatrix(m1, m1pathintegers, m2, m2pathintegers):
'''generate a matrix of atom atom similarities. See Figure 4'''
aidata = [((ai.GetAtomicNum(), ai.GetIsAromatic()), ai.GetIdx()) for ai in m1.GetAtoms()]
aidata = [((atomValue.GetAtomicNum(), atomValue.GetIsAromatic()), atomValue.GetIdx())
for atomValue in m1.GetAtoms()]
bjdata = [((bj.GetAtomicNum(), bj.GetIsAromatic()), bj.GetIdx()) for bj in m2.GetAtoms()]
simmatrixarray = numpy.zeros((len(aidata), len(bjdata)))
for ai, (aitype, aiidx) in enumerate(aidata):
for atomValue, (aitype, aiidx) in enumerate(aidata):
aipaths = m1pathintegers[aiidx]
naipaths = len(aipaths)
for bj, (bjtype, bjidx) in enumerate(bjdata):
if aitype == bjtype:
bjpaths = m2pathintegers[bjidx]
nbjpaths = len(bjpaths)
simmatrixarray[ai][bj] = getsimaibj(aipaths, bjpaths, naipaths, nbjpaths)
simmatrixarray[atomValue][bj] = getsimaibj(aipaths, bjpaths, naipaths, nbjpaths)
return simmatrixarray