Modernization of some substructure code (#8450)

* use std::span for substruct match callbacks

This removes a copy from every evaluation of potential matches

* some cleanup/modernization

* some modernization

* deprecate chiralAtomCompat

* small optimization

* remove naked pointers

* improve new_timings.py script

* changes suggested in review

* response to review

* response to review
This commit is contained in:
Greg Landrum
2025-05-12 06:33:25 +02:00
committed by GitHub
parent 123875aadd
commit a9477d2694
12 changed files with 61 additions and 66 deletions

View File

@@ -657,7 +657,7 @@ bool RAtomMatcher(const ROMol &mol, const Atom &atom,
} // namespace Matchers
bool genericAtomMatcher(const ROMol &mol, const ROMol &query,
const std::vector<unsigned int> &match) {
const std::span<const unsigned int> &match) {
boost::dynamic_bitset<> ignore(mol.getNumAtoms());
for (const auto idx : match) {
ignore.set(idx);

View File

@@ -16,6 +16,7 @@
#include <vector>
#include <functional>
#include <map>
#include <span>
#include <boost/dynamic_bitset.hpp>
namespace RDKit {
@@ -603,7 +604,7 @@ RDKIT_GENERICGROUPS_EXPORT ROMol *adjustQueryPropertiesWithGenericGroups(
/// the current match
RDKIT_GENERICGROUPS_EXPORT bool genericAtomMatcher(
const ROMol &mol, const ROMol &query,
const std::vector<unsigned int> &match);
const std::span<const unsigned int> &match);
//! sets the apropriate generic query tags based on atom labels and/or SGroups
/*

View File

@@ -172,12 +172,12 @@ M END
}
namespace {
bool no_match(const ROMol &mol, const std::vector<unsigned int> &ids) {
bool no_match(const ROMol &mol, const std::span<const unsigned int> ids) {
RDUNUSED_PARAM(mol);
RDUNUSED_PARAM(ids);
return false;
}
bool always_match(const ROMol &mol, const std::vector<unsigned int> &ids) {
bool always_match(const ROMol &mol, const std::span<const unsigned int> ids) {
RDUNUSED_PARAM(mol);
RDUNUSED_PARAM(ids);
return true;

View File

@@ -1,5 +1,5 @@
//
// Copyright (C) 2001-2021 Greg Landrum and other RDKit contributors
// Copyright (C) 2001-2025 Greg Landrum and other RDKit contributors
//
// @@ All Rights Reserved @@
// This file is part of the RDKit.
@@ -21,10 +21,7 @@
#include <GraphMol/GenericGroups/GenericGroups.h>
#include <boost/smart_ptr.hpp>
#include <map>
#if BOOST_VERSION == 106400
#include <boost/serialization/array_wrapper.hpp>
#endif
#include <span>
#ifdef RDK_BUILD_THREADSAFE_SSS
#include <mutex>
@@ -34,8 +31,6 @@
#include "vf2.hpp"
using boost::make_iterator_range;
namespace RDKit {
namespace detail {
@@ -57,14 +52,14 @@ bool enhancedStereoIsOK(
// If the query has stereo groups:
// * OR only matches AND or OR (not absolute)
// * AND only matches OR
for (auto &&sg : query.getStereoGroups()) {
for (const auto &sg : query.getStereoGroups()) {
if (sg.getGroupType() == StereoGroupType::STEREO_ABSOLUTE) {
continue;
}
// StereoGroup const* matched_mol_group = nullptr;
const bool is_and = sg.getGroupType() == StereoGroupType::STEREO_AND;
for (auto &&a : sg.getAtoms()) {
auto mol_group = molStereoGroups.find(q_to_mol[a->getIdx()]);
for (const auto a : sg.getAtoms()) {
const auto mol_group = molStereoGroups.find(q_to_mol[a->getIdx()]);
if (mol_group == molStereoGroups.end()) {
// group matching absolute. not ok.
return false;
@@ -81,15 +76,15 @@ bool enhancedStereoIsOK(
// If the mol has stereo groups:
// * All atoms must either be the same or opposite, you can't mix
// * Only one stereogroup must cover all matched atoms in the mol stereo group
for (auto &&sg : mol.getStereoGroups()) {
for (const auto &sg : mol.getStereoGroups()) {
if (sg.getGroupType() == StereoGroupType::STEREO_ABSOLUTE) {
continue;
}
bool doesMatch;
bool doesMatch = false;
bool seen = false;
StereoGroup const *QGroup = nullptr;
for (auto &&a : sg.getAtoms()) {
for (const auto &a : sg.getAtoms()) {
auto thisDoesMatch = matches.find(a->getIdx());
if (thisDoesMatch == matches.end()) {
// not matched
@@ -199,8 +194,7 @@ MolMatchFinalCheckFunctor::MolMatchFinalCheckFunctor(
bool MolMatchFinalCheckFunctor::operator()(const std::uint32_t q_c[],
const std::uint32_t m_c[]) {
if (d_params.extraFinalCheck || d_params.useGenericMatchers) {
// EFF: we can no-doubt do better than this
std::vector<unsigned int> aids(m_c, m_c + d_query.getNumAtoms());
const std::span<const std::uint32_t> aids(m_c, d_query.getNumAtoms());
if (d_params.useGenericMatchers &&
!GenericGroups::genericAtomMatcher(d_mol, d_query, aids)) {
return false;
@@ -276,8 +270,8 @@ bool MolMatchFinalCheckFunctor::operator()(const std::uint32_t q_c[],
mOrder.insert(mOrder.end(), unmatchedNeighbors, -1);
INT_LIST moOrder;
for (const auto &bond : make_iterator_range(d_mol.getAtomBonds(mAt))) {
int dbidx = d_mol[bond]->getIdx();
for (const auto &bond : d_mol.atomBonds(mAt)) {
const int dbidx = bond->getIdx();
if (std::find(mOrder.begin(), mOrder.end(), dbidx) != mOrder.end()) {
moOrder.push_back(dbidx);
} else {
@@ -285,7 +279,7 @@ bool MolMatchFinalCheckFunctor::operator()(const std::uint32_t q_c[],
}
}
int mPermCount =
const int mPermCount =
static_cast<int>(countSwapsToInterconvert(moOrder, mOrder));
const bool requireMatch = qPermCount % 2 == mPermCount % 2;
@@ -293,7 +287,7 @@ bool MolMatchFinalCheckFunctor::operator()(const std::uint32_t q_c[],
const bool matchOK = requireMatch == labelsMatch;
// if this is not part of a stereogroup and doesn't match, return false
auto msg = d_molStereoGroups.find(m_c[i]);
const auto msg = d_molStereoGroups.find(m_c[i]);
if (msg == d_molStereoGroups.end()) {
if (!matchOK) {
return false;
@@ -451,7 +445,7 @@ void ResSubstructMatchHelper_(const ResSubstructMatchHelperArgs_ &args,
unsigned int ei) {
for (unsigned int i = bi;
(matches->size() < args.params.maxMatches) && (i < ei); ++i) {
ROMol *mol = args.resMolSupplier[i];
std::unique_ptr<ROMol> mol{args.resMolSupplier[i]};
std::vector<MatchVectType> matchesTmp =
SubstructMatch(*mol, args.query, args.params);
for (const auto &match : matchesTmp) {
@@ -459,7 +453,6 @@ void ResSubstructMatchHelper_(const ResSubstructMatchHelperArgs_ &args,
break;
}
}
delete mol;
}
};
@@ -516,7 +509,7 @@ std::vector<MatchVectType> SubstructMatch(
boost::vf2_all(query.getTopology(), mol.getTopology(), atomLabeler,
bondLabeler, matchChecker, pms, params.maxMatches);
if (found) {
unsigned int nQueryAtoms = query.getNumAtoms();
const unsigned int nQueryAtoms = query.getNumAtoms();
matches.reserve(pms.size());
MatchVectType matchVect(nQueryAtoms);
for (const auto &pairs : pms) {
@@ -533,7 +526,7 @@ std::vector<MatchVectType> SubstructMatch(
const MolBundle &bundle, const ROMol &query,
const SubstructMatchParameters &params) {
std::vector<MatchVectType> res;
for (unsigned int i = 0; i < bundle.size() && !res.size(); ++i) {
for (unsigned int i = 0; i < bundle.size() && res.empty(); ++i) {
res = SubstructMatch(*bundle[i], query, params);
}
return res;
@@ -543,7 +536,7 @@ std::vector<MatchVectType> SubstructMatch(
const ROMol &mol, const MolBundle &query,
const SubstructMatchParameters &params) {
std::vector<MatchVectType> res;
for (unsigned int i = 0; i < query.size() && !res.size(); ++i) {
for (unsigned int i = 0; i < query.size() && res.empty(); ++i) {
res = SubstructMatch(mol, *query[i], params);
}
return res;
@@ -553,8 +546,8 @@ std::vector<MatchVectType> SubstructMatch(
const MolBundle &mol, const MolBundle &query,
const SubstructMatchParameters &params) {
std::vector<MatchVectType> res;
for (unsigned int i = 0; i < mol.size() && !res.size(); ++i) {
for (unsigned int j = 0; j < query.size() && !res.size(); ++j) {
for (unsigned int i = 0; i < mol.size() && res.empty(); ++i) {
for (unsigned int j = 0; j < query.size() && res.empty(); ++j) {
res = SubstructMatch(*mol[i], *query[j], params);
}
}
@@ -580,19 +573,19 @@ std::vector<MatchVectType> SubstructMatch(
#ifdef RDK_BUILD_THREADSAFE_SSS
else {
std::vector<std::future<void>> tg;
std::vector<std::set<MatchVectType> *> matchesThread(nt);
std::vector<std::unique_ptr<std::set<MatchVectType>>> matchesThread(nt);
unsigned int ei = 0;
double dpt =
static_cast<double>(resMolSupplier.length()) / static_cast<double>(nt);
double dc = 0.0;
for (unsigned int ti = 0; ti < nt; ++ti) {
matchesThread[ti] = new std::set<MatchVectType>();
matchesThread[ti] = std::make_unique<std::set<MatchVectType>>();
unsigned int bi = ei;
dc += dpt;
ei = static_cast<unsigned int>(floor(dc));
tg.emplace_back(std::async(std::launch::async,
detail::ResSubstructMatchHelper_, args,
matchesThread[ti], bi, ei));
matchesThread[ti].get(), bi, ei));
}
for (auto &fut : tg) {
fut.get();
@@ -604,7 +597,6 @@ std::vector<MatchVectType> SubstructMatch(
break;
}
}
delete matchesThread[ti];
}
}
#endif
@@ -673,8 +665,6 @@ void MatchSubqueries(const ROMol &mol, QueryAtom::QUERYATOM_QUERY *query,
SUBQUERY_MAP &subqueryMap,
std::vector<RecursiveStructureQuery *> &locked) {
PRECONDITION(query, "bad query");
// std::cout << "*-*-* MS: " << query << std::endl;
// std::cout << "\t\t" << typeid(*query).name() << std::endl;
if (query->getDescription() == "RecursiveStructure") {
auto *rsq = (RecursiveStructureQuery *)query;
#ifdef RDK_BUILD_THREADSAFE_SSS
@@ -694,8 +684,6 @@ void MatchSubqueries(const ROMol &mol, QueryAtom::QUERYATOM_QUERY *query,
++setIter) {
rsq->insert(*setIter);
}
// std::cerr<<" copying results for query serial number:
// "<<rsq->getSerialNumber()<<std::endl;
}
if (!matchDone) {
@@ -713,13 +701,8 @@ void MatchSubqueries(const ROMol &mol, QueryAtom::QUERYATOM_QUERY *query,
}
if (rsq->getSerialNumber()) {
subqueryMap[rsq->getSerialNumber()] = query;
// std::cerr << " storing results for query serial number: "
// << rsq->getSerialNumber() << " " << rsq->size() <<
// std::endl;
}
}
} else {
// std::cout << "\tmsq1: ";
}
// now recurse over our children (these things can be nested)

View File

@@ -1,5 +1,5 @@
//
// Copyright (C) 2001-2020 Greg Landrum and Rational Discovery LLC
// Copyright (C) 2001-2025 Greg Landrum and other RDKit contributors
//
// @@ All Rights Reserved @@
// This file is part of the RDKit.
@@ -19,6 +19,7 @@
#include <unordered_map>
#include <cstdint>
#include <string>
#include <span>
#include <boost/dynamic_bitset.hpp>
#if BOOST_VERSION >= 107100
@@ -62,7 +63,7 @@ struct RDKIT_SUBSTRUCTMATCH_EXPORT SubstructMatchParameters {
std::vector<std::string> bondProperties; //!< bond properties that must be
//!< equivalent in order to match
std::function<bool(const ROMol &mol,
const std::vector<unsigned int> &match)>
std::span<const unsigned int> match)>
extraFinalCheck; //!< a function to be called at the end to validate a
//!< match
unsigned int maxRecursiveMatches =

View File

@@ -1,6 +1,5 @@
//
// Copyright (C) 2003-2021 Greg Landrum and Rational Discovery LLC
//
// Copyright (C) 2003-2025 Greg Landrum and other RDKit contributors
// @@ All Rights Reserved @@
// This file is part of the RDKit.
// The contents are covered by the terms of the BSD license
@@ -146,6 +145,7 @@ bool atomCompat(const Atom *a1, const Atom *a2,
}
bool chiralAtomCompat(const Atom *&a1, const Atom *&a2) {
/// DEPRECATED
PRECONDITION(a1, "bad atom");
PRECONDITION(a2, "bad atom");
bool res = a1->Match(a2);
@@ -232,18 +232,18 @@ void removeDuplicates(std::vector<MatchVectType> &matches,
// that the 4 paths are equivalent in the semantics of the query.
// Also, OELib returns the same results
//
std::set<boost::dynamic_bitset<>> seen;
std::unordered_set<std::string> seen;
std::vector<MatchVectType> res;
res.reserve(matches.size());
for (auto &&match : matches) {
boost::dynamic_bitset<> val(nAtoms);
seen.reserve(matches.size());
for (const auto &match : matches) {
std::string val(nAtoms, '0');
for (const auto &ci : match) {
val.set(ci.second);
val[ci.second] = '1';
}
auto pos = seen.lower_bound(val);
if (pos == seen.end() || *pos != val) {
res.push_back(std::move(match));
seen.insert(pos, std::move(val));
const bool inserted = seen.insert(std::move(val)).second;
if (inserted) {
res.push_back(match);
}
}
res.shrink_to_fit();

View File

@@ -1,5 +1,5 @@
//
// Copyright (C) 2003-2019 Greg Landrum and Rational Discovery LLC
// Copyright (C) 2003-2025 Greg Landrum and other RDKit contributors
//
// @@ All Rights Reserved @@
// This file is part of the RDKit.
@@ -27,6 +27,7 @@ RDKIT_SUBSTRUCTMATCH_EXPORT bool propertyCompat(
const std::vector<std::string> &properties);
RDKIT_SUBSTRUCTMATCH_EXPORT bool atomCompat(const Atom *a1, const Atom *a2,
const SubstructMatchParameters &ps);
[[deprecated("chiralAtomCompat is deprecated and should not be used")]]
RDKIT_SUBSTRUCTMATCH_EXPORT bool chiralAtomCompat(const Atom *a1,
const Atom *a2);
RDKIT_SUBSTRUCTMATCH_EXPORT bool bondCompat(const Bond *b1, const Bond *b2,

View File

@@ -202,17 +202,17 @@ TEST_CASE("substructure parameters", "[substruct]") {
}
namespace {
bool no_match(const ROMol &mol, const std::vector<unsigned int> &ids) {
bool no_match(const ROMol &mol, const std::span<const unsigned int> &ids) {
RDUNUSED_PARAM(mol);
RDUNUSED_PARAM(ids);
return false;
}
bool always_match(const ROMol &mol, const std::vector<unsigned int> &ids) {
bool always_match(const ROMol &mol, const std::span<const unsigned int> &ids) {
RDUNUSED_PARAM(mol);
RDUNUSED_PARAM(ids);
return true;
}
bool bigger(const ROMol &mol, const std::vector<unsigned int> &ids) {
bool bigger(const ROMol &mol, const std::span<const unsigned int> &ids) {
RDUNUSED_PARAM(mol);
return std::accumulate(ids.begin(), ids.end(), 0) > 5;
}

View File

@@ -12,6 +12,8 @@
#include <functional>
#include <set>
#include <utility>
#include <span>
#include <GraphMol/RDKitBase.h>
#include <GraphMol/MolStandardize/Tautomer.h>
#include <GraphMol/Bond.h>
@@ -101,7 +103,7 @@ class TautomerQueryMatcher {
d_params(params),
d_matchingTautomers(matchingTautomers) {}
bool match(const ROMol &mol, const std::vector<unsigned int> &match) {
bool match(const ROMol &mol, const std::span<const unsigned int> &match) {
#ifdef VERBOSE
std::cout << "Checking template match" << std::endl;
#endif
@@ -192,7 +194,7 @@ TautomerQuery *TautomerQuery::fromMol(
bool TautomerQuery::matchTautomer(
const ROMol &mol, const ROMol &tautomer,
const std::vector<unsigned int> &match,
const std::span<const unsigned int> &match,
const SubstructMatchParameters &params) const {
for (auto idx : d_modifiedAtoms) {
const auto queryAtom = tautomer.getAtomWithIdx(idx);
@@ -255,7 +257,7 @@ std::vector<MatchVectType> TautomerQuery::substructOf(
// use this functor as a final check to see if any tautomer matches the target
auto checker = [&tautomerQueryMatcher](
const ROMol &mol,
const std::vector<unsigned int> &match) mutable {
const std::span<const unsigned int> &match) mutable {
return tautomerQueryMatcher.match(mol, match);
};
templateParams.extraFinalCheck = checker;

View File

@@ -16,6 +16,7 @@
#include <GraphMol/ROMol.h>
#include <GraphMol/MolPickler.h>
#include <vector>
#include <span>
#include <GraphMol/Substruct/SubstructMatch.h>
#include <DataStructs/ExplicitBitVect.h>
@@ -45,7 +46,7 @@ class RDKIT_TAUTOMERQUERY_EXPORT TautomerQuery {
// tests if a match to the template matches a specific tautomer
bool matchTautomer(const ROMol &mol, const ROMol &tautomer,
const std::vector<unsigned int> &match,
const std::span<const unsigned int> &match,
const SubstructMatchParameters &params) const;
public:

View File

@@ -1,5 +1,5 @@
// Copyright (C) 2003-2017 Greg Landrum and Rational Discovery LLC
// Copyright (C) 2003-2025 Greg Landrum and other RDKit contributors
//
// @@ All Rights Reserved @@
// This file is part of the RDKit.
@@ -10,6 +10,7 @@
#define NO_IMPORT_ARRAY
#include <RDBoost/python.h>
#include <string>
#include <span>
#include "props.hpp"
#include "rdchem.h"
@@ -156,8 +157,11 @@ class pyobjFunctor {
public:
pyobjFunctor(python::object obj) : dp_obj(std::move(obj)) {}
~pyobjFunctor() = default;
bool operator()(const ROMol &m, const std::vector<unsigned int> &match) {
return python::extract<bool>(dp_obj(boost::ref(m), boost::ref(match)));
bool operator()(const ROMol &m, std::span<const unsigned int> match) {
// boost::python doesn't handle std::span, so we need to convert the span to
// a vector before calling into python:
std::vector<unsigned int> matchVec(match.begin(), match.end());
return python::extract<bool>(dp_obj(boost::ref(m), boost::ref(matchVec)));
}
private:

View File

@@ -17,6 +17,7 @@ def data(fname):
logger = logger()
logger.setLevel(1)
tests = [1] * 1001
if len(sys.argv) > 1:
@@ -80,6 +81,7 @@ if tests[3] or tests[4] or tests[5]:
logger.info('patterns from smiles')
patts = []
nMols = 0
nBad = 0
t1 = time.time()
for line in pattData:
m = Chem.MolFromSmarts(line)