From 8120d418fb009f11a684e9a25fc820e0e2acaf43 Mon Sep 17 00:00:00 2001 From: reza bagheri alashti Date: Thu, 14 May 2026 08:22:12 +0330 Subject: [PATCH] refactor: improve readability and maintainability of AAP similarity code (#9277) * refactor: clean up AAP similarity logic and add type hints * Refactor AAP similarity implementation for clarity and maintainability * Refactor AAP similarity implementation for clarity and maintainability * ran yapf over the code --- .../AtomAtomPathSimilarity.py | 82 +++++++++---------- 1 file changed, 37 insertions(+), 45 deletions(-) diff --git a/Contrib/AtomAtomSimilarity/AtomAtomPathSimilarity.py b/Contrib/AtomAtomSimilarity/AtomAtomPathSimilarity.py index 1f12aa9e4..fccad721c 100644 --- a/Contrib/AtomAtomSimilarity/AtomAtomPathSimilarity.py +++ b/Contrib/AtomAtomSimilarity/AtomAtomPathSimilarity.py @@ -9,11 +9,10 @@ import unittest import numpy from scipy.optimize import linear_sum_assignment -from rdkit import Chem, DataStructs -from rdkit.Chem import AllChem, rdmolops -from rdkit.Chem.Fingerprints import FingerprintMols +from rdkit import Chem +from rdkit.Chem import rdmolops -_BK_ = { +_bondTypeCodes = { Chem.rdchem.BondType.SINGLE: 1, Chem.rdchem.BondType.DOUBLE: 2, Chem.rdchem.BondType.TRIPLE: 3, @@ -21,9 +20,9 @@ _BK_ = { } _BONDSYMBOL_ = {1: '-', 2: '=', 3: '#', 4: ':'} -#_nAT_ = 217 # 108*2+1 -_nAT_ = 223 # Gobbi code actually uses the first prime higher than 217, not 217 itself -_nBT_ = 5 +#_atomHashModulus = 217 # 108*2+1 +_atomHashModulus = 223 # Gobbi code actually uses the first prime higher than 217, not 217 itself +_bondHashModulus = 5 #def FindAllPathsOfLengthN_Gobbi(mol, length, rootedAtAtom=-1, uniquepaths=True): # return FindAllPathsOfLengthMToN(mol, length, length, rootedAtAtom=rootedAtAtom, uniquepaths=uniquepaths) @@ -61,12 +60,12 @@ def _FindAllPathsOfLengthMToN_Gobbi(atom, path, minlength, maxlength, visited, p if len(path) >= minlength and len(path) <= maxlength: paths.append(tuple(path)) if len(path) < maxlength: - a1 = bond.GetBeginAtom() - a2 = bond.GetEndAtom() - if a1.GetIdx() == atom.GetIdx(): - nextatom = a2 + beginAtom = bond.GetBeginAtom() + endAtom = bond.GetEndAtom() + if beginAtom.GetIdx() == atom.GetIdx(): + nextatom = endAtom else: - nextatom = a1 + nextatom = beginAtom nextatomidx = nextatom.GetIdx() if nextatomidx not in visited: visited.add(nextatomidx) @@ -79,7 +78,7 @@ def getpathintegers(m1, uptolength=7): '''returns a list of integers describing the paths for molecule m1. This uses numpy 16 bit unsigned integers to reproduce the data in the Gobbi paper. The returned list is sorted''' bondtypelookup = {} for b in m1.GetBonds(): - bondtypelookup[b.GetIdx()] = _BK_[b.GetBondType()], b.GetBeginAtom(), b.GetEndAtom() + bondtypelookup[b.GetIdx()] = _bondTypeCodes[b.GetBondType()], b.GetBeginAtom(), b.GetEndAtom() pathintegers = {} for a in m1.GetAtoms(): idx = a.GetIdx() @@ -88,40 +87,32 @@ def getpathintegers(m1, uptolength=7): # for path in rdmolops.FindAllPathsOfLengthN(m1, pathlength, rootedAtAtom=idx): for ipath, path in enumerate( FindAllPathsOfLengthMToN_Gobbi(m1, 1, uptolength, rootedAtAtom=idx, uniquepaths=False)): - strpath = [] currentidx = idx res = [] for ip, p in enumerate(path): - bk, a1, a2 = bondtypelookup[p] - strpath.append(_BONDSYMBOL_[bk]) - if a1.GetIdx() == currentidx: - a = a2 + bondTypeCode, beginAtom, endAtom = bondtypelookup[p] + if beginAtom.GetIdx() == currentidx: + a = endAtom else: - a = a1 - ak = a.GetAtomicNum() + a = beginAtom + neighborAtomCode = a.GetAtomicNum() if a.GetIsAromatic(): - ak += 108 + neighborAtomCode += 108 #trying to get the same behaviour as the Gobbi test code - it looks like a circular path includes the bond, but not the closure atom - this fix works if a.GetIdx() == idx: - ak = None - if ak is not None: - astr = a.GetSymbol() - if a.GetIsAromatic(): - strpath.append(astr.lower()) - else: - strpath.append(astr) - res.append((bk, ak)) + neighborAtomCode = None + res.append((bondTypeCode, neighborAtomCode)) currentidx = a.GetIdx() pathuniqueint = numpy.ushort(0) # work with 16 bit unsigned integers and ignore overflow... - for ires, (bi, ai) in enumerate(res): + for ires, (bondValue, atomValue) in enumerate(res): #use 16 bit unsigned integer arithmetic to reproduce the Gobbi ints - # pathuniqueint = ((pathuniqueint+bi)*_nAT_+ai)*_nBT_ - val1 = pathuniqueint + numpy.ushort(bi) - val2 = val1 * numpy.ushort(_nAT_) + # pathuniqueint = ((pathuniqueint+bondValue)*_atomHashModulus+atomValue)*_bondHashModulus + val1 = pathuniqueint + numpy.ushort(bondValue) + val2 = val1 * numpy.ushort(_atomHashModulus) #trying to get the same behaviour as the Gobbi test code - it looks like a circular path includes the bond, but not the closure atom - this fix works - if ai is not None: - val3 = val2 + numpy.ushort(ai) - val4 = val3 * numpy.ushort(_nBT_) + if atomValue is not None: + val3 = val2 + numpy.ushort(atomValue) + val4 = val3 * numpy.ushort(_bondHashModulus) else: val4 = val2 pathuniqueint = val4 @@ -140,14 +131,14 @@ def getcommon(l1, ll1, l2, ll2): ix1 = 0 ix2 = 0 while (ix1 < ll1) and (ix2 < ll2): - a1 = l1[ix1] - a2 = l2[ix2] - #a1 is < or > more often that == - if a1 < a2: + beginAtom = l1[ix1] + endAtom = l2[ix2] + #beginAtom is < or > more often that == + if beginAtom < endAtom: ix1 += 1 - elif a1 > a2: + elif beginAtom > endAtom: ix2 += 1 - else: # a1 == a2: + else: # beginAtom == endAtom: ncommon += 1 ix1 += 1 ix2 += 1 @@ -206,19 +197,20 @@ def getsimab(mappings, simmatrixdict): def getsimmatrix(m1, m1pathintegers, m2, m2pathintegers): '''generate a matrix of atom atom similarities. See Figure 4''' - aidata = [((ai.GetAtomicNum(), ai.GetIsAromatic()), ai.GetIdx()) for ai in m1.GetAtoms()] + aidata = [((atomValue.GetAtomicNum(), atomValue.GetIsAromatic()), atomValue.GetIdx()) + for atomValue in m1.GetAtoms()] bjdata = [((bj.GetAtomicNum(), bj.GetIsAromatic()), bj.GetIdx()) for bj in m2.GetAtoms()] simmatrixarray = numpy.zeros((len(aidata), len(bjdata))) - for ai, (aitype, aiidx) in enumerate(aidata): + for atomValue, (aitype, aiidx) in enumerate(aidata): aipaths = m1pathintegers[aiidx] naipaths = len(aipaths) for bj, (bjtype, bjidx) in enumerate(bjdata): if aitype == bjtype: bjpaths = m2pathintegers[bjidx] nbjpaths = len(bjpaths) - simmatrixarray[ai][bj] = getsimaibj(aipaths, bjpaths, naipaths, nbjpaths) + simmatrixarray[atomValue][bj] = getsimaibj(aipaths, bjpaths, naipaths, nbjpaths) return simmatrixarray