From 76ed536c871bfceb847f91c8e16c05f91f9bb4e6 Mon Sep 17 00:00:00 2001 From: John Mayfield Date: Sun, 30 Jun 2019 03:49:24 +0100 Subject: [PATCH] Branch/vf2 optimisations (#2500) * Remove out_1 and out_2, only need for directed graphs. * Variable renaming, no other changes. * Fix a couple things that were forgotten during the rename * Only one of these counts is updated now. * Add ifdef to enable/disable VF2Pruning * Add degree bound check, could be pushed down to the generic VertexCheck predicate. * The stack knows at each point what was added, by passing these into the Backtrack call the clone allocation can be removed. * Indent and rename added_node1 to node1. * Now the clone is removed we can add a member function to handle the recursion cleanly. * Also make for MatchAll a member function. * Set lim=0 to mean infinity. * VF2Plus optimisation, when candidate is in the terminal set select from a mapped neighbours adjacency list. * Optimisation hint. * Use a struct for the Pair removing the need for double pointer and heap alloc/delete of the iterators. * disable pruning by default (it's a bit quicker to not do this) * update expected test results The updated VF2 code can return symmetric results in a different order than we were getting previously. The results are still correct, but they change some of the downstream things that don't do symmetrization * update expected java results --- Code/GraphMol/MolAlign/testMolAlign.cpp | 12 +- .../RGroupDecomposition/testRGroupDecomp.cpp | 33 +- .../ShapeHelpers/Wrap/testShapeHelpers.py | 6 +- .../ShapeHelpers/testShapeHelpers.cpp | 4 +- Code/GraphMol/Substruct/vf2.hpp | 470 +++++++++--------- .../src-test/org/RDKit/AlignTests.java | 2 +- 6 files changed, 265 insertions(+), 262 deletions(-) diff --git a/Code/GraphMol/MolAlign/testMolAlign.cpp b/Code/GraphMol/MolAlign/testMolAlign.cpp index 648b545e6..1692231d2 100644 --- a/Code/GraphMol/MolAlign/testMolAlign.cpp +++ b/Code/GraphMol/MolAlign/testMolAlign.cpp @@ -37,7 +37,7 @@ void test1MolAlign() { ROMol *m2 = MolFileToMol(fname2); double rmsd = MolAlign::alignMol(*m2, *m1); - TEST_ASSERT(RDKit::feq(rmsd, 0.6578)); + TEST_ASSERT(RDKit::feq(rmsd, 0.6578) || RDKit::feq(rmsd, 1.0345)); std::string fname3 = rdbase + "/Code/GraphMol/MolAlign/test_data/1oir_trans.mol"; @@ -54,11 +54,11 @@ void test1MolAlign() { } RDGeom::Transform3D trans; rmsd = MolAlign::getAlignmentTransform(*m1, *m2, trans); - TEST_ASSERT(RDKit::feq(rmsd, 0.6578)); + TEST_ASSERT(RDKit::feq(rmsd, 0.6578) || RDKit::feq(rmsd, 1.0345)); // specify conformations rmsd = MolAlign::alignMol(*m1, *m2, 0, 0); - TEST_ASSERT(RDKit::feq(rmsd, 0.6578)); + TEST_ASSERT(RDKit::feq(rmsd, 0.6578) || RDKit::feq(rmsd, 1.0345)); // provide an atom mapping delete m1; @@ -102,7 +102,7 @@ void test1MolWithQueryAlign() { m2->replaceAtom(19, a2); double rmsd = MolAlign::alignMol(*m2, *m1); - TEST_ASSERT(RDKit::feq(rmsd, 0.6578)); + TEST_ASSERT(RDKit::feq(rmsd, 0.6578) || RDKit::feq(rmsd, 1.0345)); std::string fname3 = rdbase + "/Code/GraphMol/MolAlign/test_data/1oir_trans.mol"; @@ -124,11 +124,11 @@ void test1MolWithQueryAlign() { RDGeom::Transform3D trans; rmsd = MolAlign::getAlignmentTransform(*m1, *m2, trans); - TEST_ASSERT(RDKit::feq(rmsd, 0.6578)); + TEST_ASSERT(RDKit::feq(rmsd, 0.6578) || RDKit::feq(rmsd, 1.0345)); // specify conformations rmsd = MolAlign::alignMol(*m1, *m2, 0, 0); - TEST_ASSERT(RDKit::feq(rmsd, 0.6578)); + TEST_ASSERT(RDKit::feq(rmsd, 0.6578) || RDKit::feq(rmsd, 1.0345)); // provide an atom mapping delete m1; diff --git a/Code/GraphMol/RGroupDecomposition/testRGroupDecomp.cpp b/Code/GraphMol/RGroupDecomposition/testRGroupDecomp.cpp index 80f0468eb..b2f5f545e 100644 --- a/Code/GraphMol/RGroupDecomposition/testRGroupDecomp.cpp +++ b/Code/GraphMol/RGroupDecomposition/testRGroupDecomp.cpp @@ -462,7 +462,7 @@ Cl[*:2] } } delete core; - // std::cerr<= 0.0, ""); dist = MolShapes::tanimotoDistance(*m, *m2); - CHECK_INVARIANT(RDKit::feq(dist, 0.3146), ""); + CHECK_INVARIANT(RDKit::feq(dist, 0.31, 0.01), ""); dist = MolShapes::tverskyIndex(*m, *m2, 1.0, 1.0); - CHECK_INVARIANT(RDKit::feq(dist, 0.6854), ""); + CHECK_INVARIANT(RDKit::feq(dist, 0.68, 0.01), ""); delete m2; m2 = MolFileToMol(fname2); diff --git a/Code/GraphMol/Substruct/vf2.hpp b/Code/GraphMol/Substruct/vf2.hpp index 31b7663d8..5aafe9c84 100644 --- a/Code/GraphMol/Substruct/vf2.hpp +++ b/Code/GraphMol/Substruct/vf2.hpp @@ -18,6 +18,8 @@ #ifndef __BGL_VF2_SUB_STATE_H__ #define __BGL_VF2_SUB_STATE_H__ +//#define RDK_VF2_PRUNING +#define RDK_ADJ_ITER typename Graph::adjacency_iterator namespace boost{ namespace detail { @@ -28,6 +30,16 @@ namespace boost{ node_id in; node_id out; }; + + template + struct Pair { + node_id n1, n2; + bool hasiter; + RDK_ADJ_ITER nbrbeg, nbrend; + + Pair() : n1(NULL_NODE), n2(NULL_NODE), hasiter(false) { + } + }; /** * The ordering by in/out degree @@ -127,17 +139,13 @@ namespace boost{ MatchChecking &mc; unsigned int n1, n2; - unsigned int core_len, orig_core_len; - unsigned int added_node1; - unsigned int t1both_len, t2both_len; - unsigned int t1in_len, t1out_len; - unsigned int t2in_len, t2out_len; // Core nodes are also counted by these... + unsigned int core_len; + unsigned int t1_len; + unsigned int t2_len; // Core nodes are also counted by these... node_id *core_1; node_id *core_2; - node_id *in_1; - node_id *in_2; - node_id *out_1; - node_id *out_2; + node_id *term_1; + node_id *term_2; node_id *order; @@ -157,29 +165,23 @@ namespace boost{ order = NULL; } - core_len=orig_core_len=0; - t1both_len=t1in_len=t1out_len=0; - t2both_len=t2in_len=t2out_len=0; - - added_node1=NULL_NODE; + core_len=0; + t1_len=0; + t2_len=0; core_1=new node_id[n1]; core_2=new node_id[n2]; - in_1=new node_id[n1]; - in_2=new node_id[n2]; - out_1=new node_id[n1]; - out_2=new node_id[n2]; + term_1=new node_id[n1]; + term_2=new node_id[n2]; share_count = new long; for(unsigned int i=0; in2 || - t1both_len>t2both_len || - t1out_len>t2out_len || - t1in_len>t2in_len; + t1_len>t2_len; }; unsigned int CoreLen() { return core_len; } Graph *GetGraph1() { return g1; } Graph *GetGraph2() { return g2; } - bool NextPair(node_id *pn1, node_id *pn2, - node_id prev_n1=NULL_NODE, node_id prev_n2=NULL_NODE){ - if (prev_n1==NULL_NODE) - prev_n1=0; - if (prev_n2==NULL_NODE) - prev_n2=0; + bool NextPair(Pair &pair){ + if (pair.n1==NULL_NODE) + pair.n1=0; + if (pair.n2==NULL_NODE) + pair.n2=0; else - prev_n2++; + pair.n2++; #if 0 std::cerr<<" **** np: "<< prev_n1<<","<core_len && t2both_len>core_len) { - while (prev_n1core_len && t2_len>core_len) { + while (pair.n1core_len && t2out_len>core_len) { - while (prev_n1core_len && t2in_len>core_len) { - while (prev_n1core_len && t2both_len>core_len) { - while (prev_n2core_len && t2out_len>core_len) { - while (prev_n2core_len && t2in_len>core_len) { - while (prev_n2core_len && t2_len>core_len) { + while (pair.n2boost::out_degree(node2,*g2)) + return false; if(!vc(node1,node2)) return false; unsigned int other1, other2; - unsigned int termout1 = 0, termout2 = 0, termin1 = 0, termin2 = 0; +#ifdef RDK_VF2_PRUNING + unsigned int term1 = 0, term2 = 0; unsigned int new1 = 0, new2 = 0; +#endif // Check the out edges of node1 typename Graph::out_edge_iterator bNbrs,eNbrs; @@ -371,14 +369,17 @@ namespace boost{ //std::cerr<<" short2"< pair; + while (NextPair(pair)) { + if (IsFeasiblePair(pair.n1, pair.n2)){ + AddPair(pair.n1, pair.n2); + if (Match(c1, c2)) // recurse + return true; + BackTrack(pair.n1, pair.n2); + } + } + return false; + } + + template + bool MatchAll(node_id c1[], node_id c2[], + DoubleBackInsertionSequence &res, unsigned int lim=0) + { + if (IsGoal()) { + GetCoreSet(c1, c2); + if(MatchChecks(c1,c2)) { + typename DoubleBackInsertionSequence::value_type newSeq; + for(unsigned int i=0;i(c1[i],c2[i])); + } + res.push_back(newSeq); + return lim && res.size() >= lim; + } + } + + if (IsDead()) + return false; + + Pair pair; + while (NextPair(pair)) { + if (IsFeasiblePair(pair.n1, pair.n2)){ + AddPair(pair.n1, pair.n2); + if (MatchAll(c1, c2, res, lim)) // recurse + return true; + BackTrack(pair.n1, pair.n2); + } + } + return false; + } }; /*------------------------------------------------------------- @@ -530,31 +561,12 @@ namespace boost{ template bool match(int *pn, node_id c1[], node_id c2[], SubState &s) { - if (s.IsGoal() ) { - s.GetCoreSet(c1, c2); - if(s.MatchChecks(c1,c2)) { - *pn=s.CoreLen(); - return true; - } + if (s.Match(c1, c2)) { + // not needed, pn = num query atoms (n1)... + *pn=s.CoreLen(); + return true; } - - if (s.IsDead()) - return false; - //std::cerr<<" > match: "<<*pn<<" "<<&s<AddPair(n1, n2); - found=match(pn, c1, c2, *s1); - s1->BackTrack(); - delete s1; - } - } - //std::cerr<<" < returning: "< bool match(node_id c1[], node_id c2[], SubState &s, DoubleBackInsertionSequence &res, unsigned int max_results) { - if (s.IsGoal()){ - s.GetCoreSet(c1, c2); - if(s.MatchChecks(c1,c2)) { - typename DoubleBackInsertionSequence::value_type newSeq; - for(unsigned int i=0;i(c1[i],c2[i])); - } - res.push_back(newSeq); - if(res.size()>=max_results) return true; - } - return false; - } - - if (s.IsDead()) - return false; - - node_id n1=NULL_NODE, n2=NULL_NODE; - while (s.NextPair(&n1, &n2, n1, n2)) { - if (s.IsFeasiblePair(n1, n2)){ - SubState *s1=s.Clone(); - s1->AddPair(n1, n2); - if (match(c1, c2, *s1,res,max_results)){ - s1->BackTrack(); - delete s1; - return true; - } - else { - s1->BackTrack(); - delete s1; - } - } - } - return false; + s.MatchAll(c1, c2, res, max_results); + return !res.empty(); } }; //end of namespace detail @@ -663,3 +645,5 @@ namespace boost{ } // end of namespace boost #endif +#undef RDK_VF2_PRUNING +#undef RDK_ADJ_ITER diff --git a/Code/JavaWrappers/gmwrapper/src-test/org/RDKit/AlignTests.java b/Code/JavaWrappers/gmwrapper/src-test/org/RDKit/AlignTests.java index 5be79b760..6866749b4 100644 --- a/Code/JavaWrappers/gmwrapper/src-test/org/RDKit/AlignTests.java +++ b/Code/JavaWrappers/gmwrapper/src-test/org/RDKit/AlignTests.java @@ -65,7 +65,7 @@ public class AlignTests extends GraphMolTest { ROMol m1 = RWMol.MolFromMolFile(fname1); Transform3D trans = new Transform3D(); double res = m0.getAlignmentTransform(m1, trans); - assertEquals(res, 0.6578, 0.001); + assertEquals(res, 1.0345, 0.001); m0.delete(); m1.delete(); }