mirror of
https://github.com/rdkit/rdkit.git
synced 2026-06-07 22:44:25 +08:00
SubstructLibrary improvements (#4403)
* support MolBundles in the SubstructLibrary * support SubstructMatchParameters in SubstructLibrary * expose those methods to the python wrapper * fix non-threadsafe build problem * be explicit about sorting * add the option to set the search order * add example functions for setting the search order
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
// Copyright (c) 2017-2019, Novartis Institutes for BioMedical Research Inc.
|
||||
// Copyright (c) 2017-2021, Novartis Institutes for BioMedical Research Inc.
|
||||
// and other RDKit contributors
|
||||
//
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
@@ -36,6 +38,7 @@
|
||||
#endif
|
||||
|
||||
#include <GraphMol/Substruct/SubstructMatch.h>
|
||||
#include <boost/dynamic_bitset.hpp>
|
||||
|
||||
namespace RDKit {
|
||||
|
||||
@@ -50,16 +53,11 @@ bool SubstructLibraryCanSerialize() {
|
||||
struct Bits {
|
||||
const ExplicitBitVect *queryBits;
|
||||
const FPHolderBase *fps;
|
||||
bool recursionPossible;
|
||||
bool useChirality;
|
||||
bool useQueryQueryMatches;
|
||||
SubstructMatchParameters params;
|
||||
|
||||
Bits(const FPHolderBase *fps, const ROMol &m, bool recursionPossible,
|
||||
bool useChirality, bool useQueryQueryMatches)
|
||||
: fps(fps),
|
||||
recursionPossible(recursionPossible),
|
||||
useChirality(useChirality),
|
||||
useQueryQueryMatches(useQueryQueryMatches) {
|
||||
Bits(const FPHolderBase *fps, const ROMol &m,
|
||||
const SubstructMatchParameters &ssparams)
|
||||
: fps(fps), params(ssparams) {
|
||||
if (fps) {
|
||||
queryBits = fps->makeFingerprint(m);
|
||||
} else {
|
||||
@@ -68,11 +66,8 @@ struct Bits {
|
||||
}
|
||||
|
||||
Bits(const FPHolderBase *fingerprints, const TautomerQuery &m,
|
||||
bool recursionPossible, bool useChirality, bool useQueryQueryMatches)
|
||||
: fps(nullptr),
|
||||
recursionPossible(recursionPossible),
|
||||
useChirality(useChirality),
|
||||
useQueryQueryMatches(useQueryQueryMatches) {
|
||||
const SubstructMatchParameters &ssparams)
|
||||
: fps(nullptr), params(ssparams) {
|
||||
if (fingerprints) {
|
||||
const TautomerPatternHolder *tp =
|
||||
dynamic_cast<const TautomerPatternHolder *>(fingerprints);
|
||||
@@ -143,16 +138,23 @@ void SubSearcher(const Query &in_query, const Bits &bits,
|
||||
const MolHolderBase &mols, unsigned int start,
|
||||
unsigned int &end, unsigned int numThreads,
|
||||
const bool needs_rings, int &counter, const int maxResults,
|
||||
boost::dynamic_bitset<> &found,
|
||||
const std::vector<unsigned int> &searchOrder,
|
||||
std::vector<unsigned int> *idxs) {
|
||||
PRECONDITION(searchOrder.empty() || searchOrder.size() >= end,
|
||||
"bad searchOrder data");
|
||||
Query query(in_query);
|
||||
MatchVectType matchVect;
|
||||
for (unsigned int idx = start; idx < end; idx += numThreads) {
|
||||
if (!bits.check(idx)) {
|
||||
unsigned int sidx = idx;
|
||||
if (!searchOrder.empty()) {
|
||||
sidx = searchOrder[idx];
|
||||
}
|
||||
if (!bits.check(sidx) || found[sidx]) {
|
||||
continue;
|
||||
}
|
||||
// need shared_ptr as it (may) control the lifespan of the
|
||||
// returned molecule!
|
||||
const boost::shared_ptr<ROMol> &m = mols.getMol(idx);
|
||||
const boost::shared_ptr<ROMol> &m = mols.getMol(sidx);
|
||||
ROMol *mol = m.get();
|
||||
if (!mol) {
|
||||
continue;
|
||||
@@ -162,11 +164,11 @@ void SubSearcher(const Query &in_query, const Bits &bits,
|
||||
MolOps::symmetrizeSSSR(*mol);
|
||||
}
|
||||
|
||||
if (SubstructMatch(*mol, query, matchVect, bits.recursionPossible,
|
||||
bits.useChirality, bits.useQueryQueryMatches)) {
|
||||
if (!SubstructMatch(*mol, query, bits.params).empty()) {
|
||||
++counter;
|
||||
found.set(sidx);
|
||||
if (idxs) {
|
||||
idxs->push_back(idx);
|
||||
idxs->push_back(sidx);
|
||||
if (maxResults > 0 && counter == maxResults) {
|
||||
// if we reached maxResults, record the last idx we processed and bail
|
||||
// out
|
||||
@@ -181,11 +183,14 @@ void SubSearcher(const Query &in_query, const Bits &bits,
|
||||
template <class Query>
|
||||
int internalGetMatches(const Query &query, MolHolderBase &mols,
|
||||
const FPHolderBase *fps, unsigned int startIdx,
|
||||
unsigned int endIdx, bool recursionPossible,
|
||||
bool useChirality, bool useQueryQueryMatches,
|
||||
int numThreads = -1, int maxResults = 1000,
|
||||
std::vector<unsigned int> *idxs = nullptr) {
|
||||
unsigned int endIdx,
|
||||
const SubstructMatchParameters ¶ms, int numThreads,
|
||||
int maxResults, boost::dynamic_bitset<> &found,
|
||||
const std::vector<unsigned int> &searchOrder,
|
||||
std::vector<unsigned int> *idxs) {
|
||||
PRECONDITION(startIdx < mols.size(), "startIdx out of bounds");
|
||||
PRECONDITION(searchOrder.empty() || startIdx < searchOrder.size(),
|
||||
"startIdx out of bounds");
|
||||
PRECONDITION(endIdx > startIdx, "endIdx > startIdx");
|
||||
|
||||
// do not do any work if no results were requested
|
||||
@@ -194,12 +199,15 @@ int internalGetMatches(const Query &query, MolHolderBase &mols,
|
||||
}
|
||||
|
||||
endIdx = std::min(mols.size(), endIdx);
|
||||
if (!searchOrder.empty()) {
|
||||
endIdx = std::min(static_cast<unsigned int>(searchOrder.size()), endIdx);
|
||||
}
|
||||
|
||||
numThreads = static_cast<int>(getNumThreadsToUse(numThreads));
|
||||
numThreads = std::min(numThreads, static_cast<int>(endIdx));
|
||||
|
||||
bool needs_rings = query_needs_rings(query);
|
||||
Bits bits(fps, query, recursionPossible, useChirality, useQueryQueryMatches);
|
||||
Bits bits(fps, query, params);
|
||||
int counter = 0;
|
||||
|
||||
#ifdef RDK_THREADSAFE_SSS
|
||||
@@ -219,20 +227,23 @@ int internalGetMatches(const Query &query, MolHolderBase &mols,
|
||||
}
|
||||
std::vector<std::future<void>> thread_group;
|
||||
std::vector<std::vector<unsigned int>> internal_results;
|
||||
std::vector<boost::dynamic_bitset<>> internal_found(numThreads);
|
||||
if (idxs) {
|
||||
internal_results.resize(numThreads);
|
||||
}
|
||||
int thread_group_idx;
|
||||
for (thread_group_idx = 0; thread_group_idx < numThreads;
|
||||
++thread_group_idx) {
|
||||
internal_found[thread_group_idx] = found;
|
||||
// need to use boost::ref otherwise things are passed by value
|
||||
thread_group.emplace_back(
|
||||
std::async(std::launch::async, SubSearcher<Query>, std::ref(query),
|
||||
bits, std::ref(mols), startIdx + thread_group_idx,
|
||||
std::ref(endIdxVect[thread_group_idx]), numThreads,
|
||||
needs_rings, std::ref(counterVect[thread_group_idx]),
|
||||
maxResultsVect[thread_group_idx],
|
||||
idxs ? &internal_results[thread_group_idx] : nullptr));
|
||||
thread_group.emplace_back(std::async(
|
||||
std::launch::async, SubSearcher<Query>, std::ref(query), bits,
|
||||
std::ref(mols), startIdx + thread_group_idx,
|
||||
std::ref(endIdxVect[thread_group_idx]), numThreads, needs_rings,
|
||||
std::ref(counterVect[thread_group_idx]),
|
||||
maxResultsVect[thread_group_idx],
|
||||
std::ref(internal_found[thread_group_idx]), std::ref(searchOrder),
|
||||
idxs ? &internal_results[thread_group_idx] : nullptr));
|
||||
}
|
||||
unsigned int maxEndIdx;
|
||||
if (maxResults > 0) {
|
||||
@@ -259,12 +270,14 @@ int internalGetMatches(const Query &query, MolHolderBase &mols,
|
||||
if (endIdxVect[thread_group_idx] >= maxEndIdx) {
|
||||
continue;
|
||||
}
|
||||
internal_found[thread_group_idx] = found;
|
||||
// need to use boost::ref otherwise things are passed by value
|
||||
thread_group.emplace_back(std::async(
|
||||
std::launch::async, SubSearcher<Query>, std::ref(query), bits,
|
||||
std::ref(mols), endIdxVect[thread_group_idx] + numThreads,
|
||||
std::ref(maxEndIdx), numThreads, needs_rings,
|
||||
std::ref(counterVect[thread_group_idx]), -1,
|
||||
std::ref(internal_found[thread_group_idx]), std::ref(searchOrder),
|
||||
&internal_results[thread_group_idx]));
|
||||
}
|
||||
}
|
||||
@@ -277,6 +290,7 @@ int internalGetMatches(const Query &query, MolHolderBase &mols,
|
||||
idxs->insert(idxs->end(), internal_results[thread_group_idx].begin(),
|
||||
internal_results[thread_group_idx].end());
|
||||
}
|
||||
found |= internal_found[thread_group_idx];
|
||||
// If there was no maxResults, we still need to count, otherwise
|
||||
// this has already been done previously
|
||||
if (maxResults < 0) {
|
||||
@@ -287,11 +301,19 @@ int internalGetMatches(const Query &query, MolHolderBase &mols,
|
||||
// if this is running single-threaded, no need to suffer the overhead of
|
||||
// std::async
|
||||
SubSearcher(query, bits, mols, startIdx, endIdx, 1, needs_rings, counter,
|
||||
maxResults, idxs);
|
||||
maxResults, found, searchOrder, idxs);
|
||||
}
|
||||
if (idxs) {
|
||||
if (idxs && numThreads > 1) {
|
||||
// the sort is necessary to ensure consistency across runs with different
|
||||
// numbers of threads
|
||||
if (!searchOrder.empty()) {
|
||||
std::transform(idxs->begin(), idxs->end(), idxs->begin(),
|
||||
[searchOrder](unsigned int v) -> unsigned int {
|
||||
return std::find(searchOrder.begin(), searchOrder.end(),
|
||||
v) -
|
||||
searchOrder.begin();
|
||||
});
|
||||
}
|
||||
std::sort(idxs->begin(), idxs->end());
|
||||
// we may have actually accumulated more results than maxResults due
|
||||
// to the top up above, so trim the results down if that's the case
|
||||
@@ -299,10 +321,16 @@ int internalGetMatches(const Query &query, MolHolderBase &mols,
|
||||
idxs->size() > static_cast<unsigned int>(maxResults)) {
|
||||
idxs->resize(maxResults);
|
||||
}
|
||||
if (!searchOrder.empty()) {
|
||||
std::transform(idxs->begin(), idxs->end(), idxs->begin(),
|
||||
[searchOrder](unsigned int v) -> unsigned int {
|
||||
return searchOrder[v];
|
||||
});
|
||||
}
|
||||
}
|
||||
#else
|
||||
SubSearcher(query, bits, mols, startIdx, endIdx, 1, needs_rings, counter,
|
||||
maxResults, idxs);
|
||||
maxResults, found, searchOrder, idxs);
|
||||
#endif
|
||||
|
||||
delete bits.queryBits;
|
||||
@@ -310,66 +338,102 @@ int internalGetMatches(const Query &query, MolHolderBase &mols,
|
||||
return counter;
|
||||
}
|
||||
|
||||
int molbundleGetMatches(const MolBundle &query, MolHolderBase &mols,
|
||||
const FPHolderBase *fps, unsigned int startIdx,
|
||||
unsigned int endIdx,
|
||||
const SubstructMatchParameters ¶ms, int numThreads,
|
||||
int maxResults,
|
||||
const std::vector<unsigned int> &searchOrder,
|
||||
std::vector<unsigned int> *idxs) {
|
||||
int res = 0;
|
||||
boost::dynamic_bitset<> found(mols.size());
|
||||
for (const auto qmol : query.getMols()) {
|
||||
maxResults -= res;
|
||||
res += internalGetMatches(*qmol, mols, fps, startIdx, endIdx, params,
|
||||
numThreads, maxResults, found, searchOrder, idxs);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<unsigned int> SubstructLibrary::getMatches(
|
||||
const ROMol &query, unsigned int startIdx, unsigned int endIdx,
|
||||
bool recursionPossible, bool useChirality, bool useQueryQueryMatches,
|
||||
int numThreads, int maxResults) const {
|
||||
const SubstructMatchParameters ¶ms, int numThreads,
|
||||
int maxResults) const {
|
||||
std::vector<unsigned int> idxs;
|
||||
internalGetMatches(query, *mols, fps, startIdx, endIdx, recursionPossible,
|
||||
useChirality, useQueryQueryMatches, numThreads, maxResults,
|
||||
&idxs);
|
||||
boost::dynamic_bitset<> found(mols->size());
|
||||
internalGetMatches(query, *mols, fps, startIdx, endIdx, params, numThreads,
|
||||
maxResults, found, searchOrder, &idxs);
|
||||
return idxs;
|
||||
}
|
||||
|
||||
std::vector<unsigned int> SubstructLibrary::getMatches(
|
||||
const TautomerQuery &query, unsigned int startIdx, unsigned int endIdx,
|
||||
bool recursionPossible, bool useChirality, bool useQueryQueryMatches,
|
||||
int numThreads, int maxResults) const {
|
||||
const SubstructMatchParameters ¶ms, int numThreads,
|
||||
int maxResults) const {
|
||||
std::vector<unsigned int> idxs;
|
||||
internalGetMatches(query, *mols, fps, startIdx, endIdx, recursionPossible,
|
||||
useChirality, useQueryQueryMatches, numThreads, maxResults,
|
||||
&idxs);
|
||||
boost::dynamic_bitset<> found(mols->size());
|
||||
internalGetMatches(query, *mols, fps, startIdx, endIdx, params, numThreads,
|
||||
maxResults, found, searchOrder, &idxs);
|
||||
return idxs;
|
||||
}
|
||||
|
||||
std::vector<unsigned int> SubstructLibrary::getMatches(
|
||||
const MolBundle &query, unsigned int startIdx, unsigned int endIdx,
|
||||
const SubstructMatchParameters ¶ms, int numThreads,
|
||||
int maxResults) const {
|
||||
std::vector<unsigned int> idxs;
|
||||
molbundleGetMatches(query, *mols, fps, startIdx, endIdx, params, numThreads,
|
||||
maxResults, searchOrder, &idxs);
|
||||
return idxs;
|
||||
}
|
||||
|
||||
unsigned int SubstructLibrary::countMatches(
|
||||
const ROMol &query, unsigned int startIdx, unsigned int endIdx,
|
||||
bool recursionPossible, bool useChirality, bool useQueryQueryMatches,
|
||||
int numThreads) const {
|
||||
return internalGetMatches(query, *mols, fps, startIdx, endIdx,
|
||||
recursionPossible, useChirality,
|
||||
useQueryQueryMatches, numThreads, -1);
|
||||
const SubstructMatchParameters ¶ms, int numThreads) const {
|
||||
boost::dynamic_bitset<> found(mols->size());
|
||||
return internalGetMatches(query, *mols, fps, startIdx, endIdx, params,
|
||||
numThreads, -1, found, searchOrder, nullptr);
|
||||
}
|
||||
|
||||
unsigned int SubstructLibrary::countMatches(
|
||||
const TautomerQuery &query, unsigned int startIdx, unsigned int endIdx,
|
||||
bool recursionPossible, bool useChirality, bool useQueryQueryMatches,
|
||||
int numThreads) const {
|
||||
return internalGetMatches(query, *mols, fps, startIdx, endIdx,
|
||||
recursionPossible, useChirality,
|
||||
useQueryQueryMatches, numThreads, -1);
|
||||
const SubstructMatchParameters ¶ms, int numThreads) const {
|
||||
boost::dynamic_bitset<> found(mols->size());
|
||||
return internalGetMatches(query, *mols, fps, startIdx, endIdx, params,
|
||||
numThreads, -1, found, searchOrder, nullptr);
|
||||
}
|
||||
unsigned int SubstructLibrary::countMatches(
|
||||
const MolBundle &query, unsigned int startIdx, unsigned int endIdx,
|
||||
const SubstructMatchParameters ¶ms, int numThreads) const {
|
||||
return molbundleGetMatches(query, *mols, fps, startIdx, endIdx, params,
|
||||
numThreads, -1, searchOrder, nullptr);
|
||||
}
|
||||
|
||||
bool SubstructLibrary::hasMatch(const ROMol &query, unsigned int startIdx,
|
||||
unsigned int endIdx, bool recursionPossible,
|
||||
bool useChirality, bool useQueryQueryMatches,
|
||||
unsigned int endIdx,
|
||||
const SubstructMatchParameters ¶ms,
|
||||
int numThreads) const {
|
||||
const int maxResults = 1;
|
||||
return getMatches(query, startIdx, endIdx, recursionPossible, useChirality,
|
||||
useQueryQueryMatches, numThreads, maxResults)
|
||||
return getMatches(query, startIdx, endIdx, params, numThreads, maxResults)
|
||||
.size() > 0;
|
||||
}
|
||||
|
||||
bool SubstructLibrary::hasMatch(const TautomerQuery &query,
|
||||
unsigned int startIdx, unsigned int endIdx,
|
||||
bool recursionPossible, bool useChirality,
|
||||
bool useQueryQueryMatches,
|
||||
const SubstructMatchParameters ¶ms,
|
||||
int numThreads) const {
|
||||
const int maxResults = 1;
|
||||
return getMatches(query, startIdx, endIdx, recursionPossible, useChirality,
|
||||
useQueryQueryMatches, numThreads, maxResults)
|
||||
return getMatches(query, startIdx, endIdx, params, numThreads, maxResults)
|
||||
.size() > 0;
|
||||
}
|
||||
bool SubstructLibrary::hasMatch(const MolBundle &query, unsigned int startIdx,
|
||||
unsigned int endIdx,
|
||||
const SubstructMatchParameters ¶ms,
|
||||
int numThreads) const {
|
||||
const int maxResults = 1;
|
||||
return getMatches(query, startIdx, endIdx, params, numThreads, maxResults)
|
||||
.size() > 0;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user