mirror of
https://github.com/rdkit/rdkit.git
synced 2026-06-04 21:54:27 +08:00
A few improvements to SubstructLibrary (#3557)
* - enable SubstructLibrary to build without threading support - enforce consistency across single- and multi-threaded runs - improve performance on single-threaded runs avoiding overhead of spawning threads - consolidate internalCountMatches and internalGetMatches into one function to reduce code duplication - fix a bug in Python CountMatches whjich would run on 1000 threads * reverted Code/GraphMol/SubstructLibrary/Wrap/SubstructLibraryWrap.cpp as it is now in its own PR (#3558) * - added missing cast - removed unnecessary std::make_move_iterator - added Brian's test Co-authored-by: Greg Landrum <greg.landrum@gmail.com>
This commit is contained in:
@@ -35,7 +35,6 @@
|
||||
#include <future>
|
||||
#endif
|
||||
|
||||
#include <atomic>
|
||||
#include <GraphMol/Substruct/SubstructMatch.h>
|
||||
|
||||
|
||||
@@ -94,30 +93,28 @@ bool query_needs_rings(const ROMol &in_query) {
|
||||
for (auto &atom: in_query.atoms()) {
|
||||
if(atom->hasQuery()) {
|
||||
if (describeQuery(atom).find("Ring") != std::string::npos) {
|
||||
return true;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto &bond: in_query.bonds()) {
|
||||
if(bond->hasQuery()) {
|
||||
if (describeQuery(bond).find("Ring") != std::string::npos) {
|
||||
return true;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// end is exclusive here
|
||||
void SubSearcher(const ROMol &in_query, const Bits &bits,
|
||||
const MolHolderBase &mols, std::vector<unsigned int> &idxs,
|
||||
unsigned int start, unsigned int end, unsigned int numThreads,
|
||||
std::atomic<int> &counter, const int maxResults, const bool needs_rings) {
|
||||
const MolHolderBase &mols, unsigned int start,
|
||||
unsigned int &end, unsigned int numThreads,
|
||||
const bool needs_rings, int &counter, const int maxResults,
|
||||
std::vector<unsigned int> *idxs) {
|
||||
ROMol query(in_query);
|
||||
MatchVectType matchVect;
|
||||
for (unsigned int idx = start;
|
||||
idx < end && (maxResults == -1 || counter < maxResults);
|
||||
idx += numThreads) {
|
||||
for (unsigned int idx = start; idx < end; idx += numThreads) {
|
||||
if (!bits.check(idx)) {
|
||||
continue;
|
||||
}
|
||||
@@ -126,144 +123,155 @@ void SubSearcher(const ROMol &in_query, const Bits &bits,
|
||||
const boost::shared_ptr<ROMol> &m = mols.getMol(idx);
|
||||
ROMol *mol = m.get();
|
||||
if (needs_rings && (!mol->getRingInfo() || !mol->getRingInfo()->isInitialized())) {
|
||||
// I have no idea what happens when symmetrizeSSSR gets called
|
||||
// on the same molecule twice in two threads.
|
||||
// This most likely WILL NOT HAPPEN since only one molholder
|
||||
// likely needs ring info.
|
||||
MolOps::symmetrizeSSSR(*mol);
|
||||
}
|
||||
|
||||
if (SubstructMatch(*mol, query, matchVect, bits.recursionPossible,
|
||||
bits.useChirality, bits.useQueryQueryMatches)) {
|
||||
// this is squishy when updating the counter. While incrementing is
|
||||
// atomic
|
||||
// several substructure runs can update the counter beyond the maxResults
|
||||
// This okay: if we get one or two extra, we can fix it on the way out
|
||||
if (maxResults != -1 && counter >= maxResults) {
|
||||
break;
|
||||
}
|
||||
idxs.push_back(idx);
|
||||
if (maxResults != -1) {
|
||||
counter++;
|
||||
++counter;
|
||||
if (idxs) {
|
||||
idxs->push_back(idx);
|
||||
if (maxResults > 0 && counter == maxResults) {
|
||||
// if we reached maxResults, record the last idx we processed and bail
|
||||
// out
|
||||
end = idx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// end is inclusive here
|
||||
void SubSearchMatchCounter(const ROMol &in_query, const Bits &bits,
|
||||
const MolHolderBase &mols, unsigned int start,
|
||||
unsigned int end, int numThreads,
|
||||
std::atomic<int> &counter, bool needs_rings) {
|
||||
ROMol query(in_query);
|
||||
MatchVectType matchVect;
|
||||
for (unsigned int idx = start; idx < end; idx += numThreads) {
|
||||
if (!bits.check(idx)) {
|
||||
continue;
|
||||
}
|
||||
// need shared_ptr as it (may) controls the lifespan of the
|
||||
// returned molecule!
|
||||
const boost::shared_ptr<ROMol> &m = mols.getMol(idx);
|
||||
ROMol *mol = m.get();
|
||||
if (needs_rings && (!mol->getRingInfo() || !mol->getRingInfo()->isInitialized())) {
|
||||
// I have no idea what happens when symmetrizeSSSR gets called
|
||||
// on the same molecule twice in two threads.
|
||||
// This most likely WILL NOT HAPPEN since only one molholder
|
||||
// likely needs ring info.
|
||||
MolOps::symmetrizeSSSR(*mol);
|
||||
}
|
||||
if (SubstructMatch(*mol, query, matchVect, bits.recursionPossible,
|
||||
bits.useChirality, bits.useQueryQueryMatches)) {
|
||||
counter++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<unsigned int> internalGetMatches(
|
||||
int internalGetMatches(
|
||||
const ROMol &query, MolHolderBase &mols, const FPHolderBase *fps,
|
||||
unsigned int startIdx, unsigned int endIdx, bool recursionPossible,
|
||||
bool useChirality, bool useQueryQueryMatches, int numThreads = -1,
|
||||
int maxResults = 1000) {
|
||||
PRECONDITION(startIdx < mols.size(), "startIdx out of bounds");
|
||||
PRECONDITION(endIdx > startIdx, "endIdx > startIdx");
|
||||
numThreads = (int)getNumThreadsToUse(numThreads);
|
||||
|
||||
endIdx = std::min(mols.size(), endIdx);
|
||||
if (endIdx < static_cast<unsigned int>(numThreads)) {
|
||||
numThreads = endIdx;
|
||||
}
|
||||
|
||||
std::vector<std::future<void>> thread_group;
|
||||
std::atomic<int> counter(0);
|
||||
std::vector<std::vector<unsigned int>> internal_results(numThreads);
|
||||
|
||||
bool needs_rings = query_needs_rings(query);
|
||||
Bits bits(fps, query, recursionPossible, useChirality, useQueryQueryMatches);
|
||||
|
||||
for (int thread_group_idx = 0; thread_group_idx < numThreads;
|
||||
++thread_group_idx) {
|
||||
// need to use boost::ref otherwise things are passed by value
|
||||
thread_group.emplace_back(
|
||||
std::async(std::launch::async, SubSearcher, std::ref(query), bits,
|
||||
std::ref(mols), std::ref(internal_results[thread_group_idx]),
|
||||
startIdx + thread_group_idx, endIdx, numThreads,
|
||||
std::ref(counter), maxResults, needs_rings));
|
||||
}
|
||||
for (auto &fut : thread_group) {
|
||||
fut.get();
|
||||
}
|
||||
delete bits.queryBits;
|
||||
|
||||
std::vector<unsigned int> results;
|
||||
for (int thread_group_idx = 0; thread_group_idx < numThreads;
|
||||
++thread_group_idx) {
|
||||
results.insert(results.end(), internal_results[thread_group_idx].begin(),
|
||||
internal_results[thread_group_idx].end());
|
||||
}
|
||||
|
||||
// this is so we don't really have to do locking on the atomic counter...
|
||||
if (maxResults != -1 && rdcast<int>(results.size()) > maxResults) {
|
||||
results.resize(maxResults);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
int internalMatchCounter(const ROMol &query, MolHolderBase &mols,
|
||||
const FPHolderBase *fps, unsigned int startIdx,
|
||||
unsigned int endIdx, bool recursionPossible,
|
||||
bool useChirality, bool useQueryQueryMatches,
|
||||
int numThreads = -1) {
|
||||
bool useChirality, bool useQueryQueryMatches,
|
||||
int numThreads = -1, int maxResults = 1000,
|
||||
std::vector<unsigned int> *idxs = nullptr) {
|
||||
PRECONDITION(startIdx < mols.size(), "startIdx out of bounds");
|
||||
PRECONDITION(endIdx > startIdx, "endIdx > startIdx");
|
||||
|
||||
// do not do any work if no results were requested
|
||||
if (maxResults == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
endIdx = std::min(mols.size(), endIdx);
|
||||
|
||||
numThreads = (int)getNumThreadsToUse(numThreads);
|
||||
numThreads = static_cast<int>(getNumThreadsToUse(numThreads));
|
||||
numThreads = std::min(numThreads, static_cast<int>(endIdx));
|
||||
|
||||
if (endIdx < static_cast<unsigned int>(numThreads)) {
|
||||
numThreads = endIdx;
|
||||
}
|
||||
|
||||
std::vector<std::future<void>> thread_group;
|
||||
std::atomic<int> counter(0);
|
||||
bool needs_rings = query_needs_rings(query);
|
||||
|
||||
Bits bits(fps, query, recursionPossible, useChirality, useQueryQueryMatches);
|
||||
for (int thread_group_idx = 0; thread_group_idx < numThreads;
|
||||
++thread_group_idx) {
|
||||
// need to use boost::ref otherwise things are passed by value
|
||||
thread_group.emplace_back(
|
||||
std::async(std::launch::async, SubSearchMatchCounter, std::ref(query),
|
||||
bits, std::ref(mols), startIdx + thread_group_idx, endIdx,
|
||||
numThreads, std::ref(counter), needs_rings));
|
||||
int counter = 0;
|
||||
|
||||
#ifdef RDK_THREADSAFE_SSS
|
||||
if (numThreads > 1) {
|
||||
std::vector<int> counterVect(numThreads, 0);
|
||||
int maxResultsPerThread = maxResults;
|
||||
if (maxResults > 0) {
|
||||
maxResultsPerThread /= numThreads;
|
||||
}
|
||||
std::vector<int> maxResultsVect(numThreads, maxResultsPerThread);
|
||||
std::vector<unsigned int> endIdxVect(numThreads, endIdx);
|
||||
if (maxResults > 0) {
|
||||
int excess = maxResults % numThreads;
|
||||
for (int i = 0; i < excess; ++i) {
|
||||
++maxResultsVect[i];
|
||||
}
|
||||
}
|
||||
std::vector<std::future<void>> thread_group;
|
||||
std::vector<std::vector<unsigned int>> internal_results;
|
||||
if (idxs) {
|
||||
internal_results.resize(numThreads);
|
||||
}
|
||||
int thread_group_idx;
|
||||
for (thread_group_idx = 0; thread_group_idx < numThreads;
|
||||
++thread_group_idx) {
|
||||
// need to use boost::ref otherwise things are passed by value
|
||||
thread_group.emplace_back(
|
||||
std::async(std::launch::async, SubSearcher, std::ref(query), bits,
|
||||
std::ref(mols), startIdx + thread_group_idx,
|
||||
std::ref(endIdxVect[thread_group_idx]), numThreads,
|
||||
needs_rings, std::ref(counterVect[thread_group_idx]),
|
||||
maxResultsVect[thread_group_idx],
|
||||
idxs ? &internal_results[thread_group_idx] : nullptr));
|
||||
}
|
||||
if (maxResults > 0) {
|
||||
// If we are running with maxResults in a multi-threaded settings,
|
||||
// some threads may have screened more molecules than others.
|
||||
// If maxResults was close to the theoretical maximum, some threads
|
||||
// might have even run out of molecules to screen without reaching
|
||||
// maxResults so we need to make sure that all threads have screened as
|
||||
// many molecules as the most productive thread if we want multi-threaded
|
||||
// runs to yield the same results independently from the number of
|
||||
// threads.
|
||||
thread_group_idx = 0;
|
||||
for (auto &fut : thread_group) {
|
||||
fut.get();
|
||||
counter += counterVect[thread_group_idx++];
|
||||
}
|
||||
thread_group.clear();
|
||||
// Find out out the max number of molecules that was screened by the most
|
||||
// productive thread and do the same in all other threads, unless the
|
||||
// max number of molecules was reached
|
||||
unsigned int maxEndIdx =
|
||||
*std::max_element(endIdxVect.begin(), endIdxVect.end());
|
||||
for (thread_group_idx = 0; thread_group_idx < numThreads;
|
||||
++thread_group_idx) {
|
||||
if (endIdxVect[thread_group_idx] >= maxEndIdx) {
|
||||
continue;
|
||||
}
|
||||
// need to use boost::ref otherwise things are passed by value
|
||||
thread_group.emplace_back(std::async(
|
||||
std::launch::async, SubSearcher, std::ref(query), bits,
|
||||
std::ref(mols), endIdxVect[thread_group_idx] + numThreads,
|
||||
std::ref(maxEndIdx), numThreads, needs_rings,
|
||||
std::ref(counterVect[thread_group_idx]), -1,
|
||||
&internal_results[thread_group_idx]));
|
||||
}
|
||||
}
|
||||
for (auto &fut : thread_group) {
|
||||
fut.get();
|
||||
}
|
||||
for (thread_group_idx = 0; thread_group_idx < numThreads;
|
||||
++thread_group_idx) {
|
||||
if (idxs) {
|
||||
idxs->insert(idxs->end(), internal_results[thread_group_idx].begin(),
|
||||
internal_results[thread_group_idx].end());
|
||||
}
|
||||
// If there was no maxResults, we still need to count, otherwise
|
||||
// this has already been done previously
|
||||
if (maxResults < 0) {
|
||||
counter += counterVect[thread_group_idx];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// if this is running single-threaded, no need to suffer the overhead of
|
||||
// std::async
|
||||
SubSearcher(query, bits, mols, startIdx, endIdx, 1, needs_rings, counter,
|
||||
maxResults, idxs);
|
||||
}
|
||||
for (auto &thread : thread_group) {
|
||||
thread.get();
|
||||
if (idxs) {
|
||||
// the sort is necessary to ensure consistency across runs with different
|
||||
// numbers of threads
|
||||
std::sort(idxs->begin(), idxs->end());
|
||||
// we may have actually accumulated more results than maxResults due
|
||||
// to the top up above, so trim the results down if that's the case
|
||||
if (maxResults > 0 &&
|
||||
idxs->size() > static_cast<unsigned int>(maxResults)) {
|
||||
idxs->resize(maxResults);
|
||||
}
|
||||
}
|
||||
#else
|
||||
SubSearcher(query, bits, mols, startIdx, endIdx, 1, needs_rings, counter, maxResults, idxs);
|
||||
#endif
|
||||
|
||||
delete bits.queryBits;
|
||||
return (int)counter;
|
||||
|
||||
return counter;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
std::vector<unsigned int> SubstructLibrary::getMatches(
|
||||
@@ -277,9 +285,11 @@ std::vector<unsigned int> SubstructLibrary::getMatches(
|
||||
const ROMol &query, unsigned int startIdx, unsigned int endIdx,
|
||||
bool recursionPossible, bool useChirality, bool useQueryQueryMatches,
|
||||
int numThreads, int maxResults) {
|
||||
return internalGetMatches(query, *mols, fps, startIdx, endIdx,
|
||||
recursionPossible, useChirality,
|
||||
useQueryQueryMatches, numThreads, maxResults);
|
||||
std::vector<unsigned int> idxs;
|
||||
internalGetMatches(query, *mols, fps, startIdx, endIdx,
|
||||
recursionPossible, useChirality,
|
||||
useQueryQueryMatches, numThreads, maxResults, &idxs);
|
||||
return idxs;
|
||||
}
|
||||
|
||||
unsigned int SubstructLibrary::countMatches(const ROMol &query,
|
||||
@@ -295,9 +305,9 @@ unsigned int SubstructLibrary::countMatches(
|
||||
const ROMol &query, unsigned int startIdx, unsigned int endIdx,
|
||||
bool recursionPossible, bool useChirality, bool useQueryQueryMatches,
|
||||
int numThreads) {
|
||||
return internalMatchCounter(query, *mols, fps, startIdx, endIdx,
|
||||
recursionPossible, useChirality,
|
||||
useQueryQueryMatches, numThreads);
|
||||
return internalGetMatches(query, *mols, fps, startIdx, endIdx,
|
||||
recursionPossible, useChirality,
|
||||
useQueryQueryMatches, numThreads, -1);
|
||||
}
|
||||
|
||||
bool SubstructLibrary::hasMatch(const ROMol &query, bool recursionPossible,
|
||||
|
||||
Reference in New Issue
Block a user