A few improvements to SubstructLibrary (#3557)

* - enable SubstructLibrary to build without threading support
- enforce consistency across single- and multi-threaded runs
- improve performance on single-threaded runs avoiding overhead of spawning threads
- consolidate internalCountMatches and internalGetMatches into one function to reduce code duplication
- fix a bug in Python CountMatches whjich would run on 1000 threads

* reverted Code/GraphMol/SubstructLibrary/Wrap/SubstructLibraryWrap.cpp as it is now in its own PR (#3558)

* - added missing cast
- removed unnecessary std::make_move_iterator
- added Brian's test

Co-authored-by: Greg Landrum <greg.landrum@gmail.com>
This commit is contained in:
Paolo Tosco
2020-11-18 22:07:34 +01:00
committed by GitHub
parent 874fd5c1bd
commit c5dc6cc258
4 changed files with 253 additions and 140 deletions

View File

@@ -35,7 +35,6 @@
#include <future>
#endif
#include <atomic>
#include <GraphMol/Substruct/SubstructMatch.h>
@@ -94,30 +93,28 @@ bool query_needs_rings(const ROMol &in_query) {
for (auto &atom: in_query.atoms()) {
if(atom->hasQuery()) {
if (describeQuery(atom).find("Ring") != std::string::npos) {
return true;
return true;
}
}
}
for (auto &bond: in_query.bonds()) {
if(bond->hasQuery()) {
if (describeQuery(bond).find("Ring") != std::string::npos) {
return true;
return true;
}
}
}
return false;
}
// end is exclusive here
void SubSearcher(const ROMol &in_query, const Bits &bits,
const MolHolderBase &mols, std::vector<unsigned int> &idxs,
unsigned int start, unsigned int end, unsigned int numThreads,
std::atomic<int> &counter, const int maxResults, const bool needs_rings) {
const MolHolderBase &mols, unsigned int start,
unsigned int &end, unsigned int numThreads,
const bool needs_rings, int &counter, const int maxResults,
std::vector<unsigned int> *idxs) {
ROMol query(in_query);
MatchVectType matchVect;
for (unsigned int idx = start;
idx < end && (maxResults == -1 || counter < maxResults);
idx += numThreads) {
for (unsigned int idx = start; idx < end; idx += numThreads) {
if (!bits.check(idx)) {
continue;
}
@@ -126,144 +123,155 @@ void SubSearcher(const ROMol &in_query, const Bits &bits,
const boost::shared_ptr<ROMol> &m = mols.getMol(idx);
ROMol *mol = m.get();
if (needs_rings && (!mol->getRingInfo() || !mol->getRingInfo()->isInitialized())) {
// I have no idea what happens when symmetrizeSSSR gets called
// on the same molecule twice in two threads.
// This most likely WILL NOT HAPPEN since only one molholder
// likely needs ring info.
MolOps::symmetrizeSSSR(*mol);
}
if (SubstructMatch(*mol, query, matchVect, bits.recursionPossible,
bits.useChirality, bits.useQueryQueryMatches)) {
// this is squishy when updating the counter. While incrementing is
// atomic
// several substructure runs can update the counter beyond the maxResults
// This okay: if we get one or two extra, we can fix it on the way out
if (maxResults != -1 && counter >= maxResults) {
break;
}
idxs.push_back(idx);
if (maxResults != -1) {
counter++;
++counter;
if (idxs) {
idxs->push_back(idx);
if (maxResults > 0 && counter == maxResults) {
// if we reached maxResults, record the last idx we processed and bail
// out
end = idx;
break;
}
}
}
}
}
// end is inclusive here
void SubSearchMatchCounter(const ROMol &in_query, const Bits &bits,
const MolHolderBase &mols, unsigned int start,
unsigned int end, int numThreads,
std::atomic<int> &counter, bool needs_rings) {
ROMol query(in_query);
MatchVectType matchVect;
for (unsigned int idx = start; idx < end; idx += numThreads) {
if (!bits.check(idx)) {
continue;
}
// need shared_ptr as it (may) controls the lifespan of the
// returned molecule!
const boost::shared_ptr<ROMol> &m = mols.getMol(idx);
ROMol *mol = m.get();
if (needs_rings && (!mol->getRingInfo() || !mol->getRingInfo()->isInitialized())) {
// I have no idea what happens when symmetrizeSSSR gets called
// on the same molecule twice in two threads.
// This most likely WILL NOT HAPPEN since only one molholder
// likely needs ring info.
MolOps::symmetrizeSSSR(*mol);
}
if (SubstructMatch(*mol, query, matchVect, bits.recursionPossible,
bits.useChirality, bits.useQueryQueryMatches)) {
counter++;
}
}
}
std::vector<unsigned int> internalGetMatches(
int internalGetMatches(
const ROMol &query, MolHolderBase &mols, const FPHolderBase *fps,
unsigned int startIdx, unsigned int endIdx, bool recursionPossible,
bool useChirality, bool useQueryQueryMatches, int numThreads = -1,
int maxResults = 1000) {
PRECONDITION(startIdx < mols.size(), "startIdx out of bounds");
PRECONDITION(endIdx > startIdx, "endIdx > startIdx");
numThreads = (int)getNumThreadsToUse(numThreads);
endIdx = std::min(mols.size(), endIdx);
if (endIdx < static_cast<unsigned int>(numThreads)) {
numThreads = endIdx;
}
std::vector<std::future<void>> thread_group;
std::atomic<int> counter(0);
std::vector<std::vector<unsigned int>> internal_results(numThreads);
bool needs_rings = query_needs_rings(query);
Bits bits(fps, query, recursionPossible, useChirality, useQueryQueryMatches);
for (int thread_group_idx = 0; thread_group_idx < numThreads;
++thread_group_idx) {
// need to use boost::ref otherwise things are passed by value
thread_group.emplace_back(
std::async(std::launch::async, SubSearcher, std::ref(query), bits,
std::ref(mols), std::ref(internal_results[thread_group_idx]),
startIdx + thread_group_idx, endIdx, numThreads,
std::ref(counter), maxResults, needs_rings));
}
for (auto &fut : thread_group) {
fut.get();
}
delete bits.queryBits;
std::vector<unsigned int> results;
for (int thread_group_idx = 0; thread_group_idx < numThreads;
++thread_group_idx) {
results.insert(results.end(), internal_results[thread_group_idx].begin(),
internal_results[thread_group_idx].end());
}
// this is so we don't really have to do locking on the atomic counter...
if (maxResults != -1 && rdcast<int>(results.size()) > maxResults) {
results.resize(maxResults);
}
return results;
}
int internalMatchCounter(const ROMol &query, MolHolderBase &mols,
const FPHolderBase *fps, unsigned int startIdx,
unsigned int endIdx, bool recursionPossible,
bool useChirality, bool useQueryQueryMatches,
int numThreads = -1) {
bool useChirality, bool useQueryQueryMatches,
int numThreads = -1, int maxResults = 1000,
std::vector<unsigned int> *idxs = nullptr) {
PRECONDITION(startIdx < mols.size(), "startIdx out of bounds");
PRECONDITION(endIdx > startIdx, "endIdx > startIdx");
// do not do any work if no results were requested
if (maxResults == 0) {
return 0;
}
endIdx = std::min(mols.size(), endIdx);
numThreads = (int)getNumThreadsToUse(numThreads);
numThreads = static_cast<int>(getNumThreadsToUse(numThreads));
numThreads = std::min(numThreads, static_cast<int>(endIdx));
if (endIdx < static_cast<unsigned int>(numThreads)) {
numThreads = endIdx;
}
std::vector<std::future<void>> thread_group;
std::atomic<int> counter(0);
bool needs_rings = query_needs_rings(query);
Bits bits(fps, query, recursionPossible, useChirality, useQueryQueryMatches);
for (int thread_group_idx = 0; thread_group_idx < numThreads;
++thread_group_idx) {
// need to use boost::ref otherwise things are passed by value
thread_group.emplace_back(
std::async(std::launch::async, SubSearchMatchCounter, std::ref(query),
bits, std::ref(mols), startIdx + thread_group_idx, endIdx,
numThreads, std::ref(counter), needs_rings));
int counter = 0;
#ifdef RDK_THREADSAFE_SSS
if (numThreads > 1) {
std::vector<int> counterVect(numThreads, 0);
int maxResultsPerThread = maxResults;
if (maxResults > 0) {
maxResultsPerThread /= numThreads;
}
std::vector<int> maxResultsVect(numThreads, maxResultsPerThread);
std::vector<unsigned int> endIdxVect(numThreads, endIdx);
if (maxResults > 0) {
int excess = maxResults % numThreads;
for (int i = 0; i < excess; ++i) {
++maxResultsVect[i];
}
}
std::vector<std::future<void>> thread_group;
std::vector<std::vector<unsigned int>> internal_results;
if (idxs) {
internal_results.resize(numThreads);
}
int thread_group_idx;
for (thread_group_idx = 0; thread_group_idx < numThreads;
++thread_group_idx) {
// need to use boost::ref otherwise things are passed by value
thread_group.emplace_back(
std::async(std::launch::async, SubSearcher, std::ref(query), bits,
std::ref(mols), startIdx + thread_group_idx,
std::ref(endIdxVect[thread_group_idx]), numThreads,
needs_rings, std::ref(counterVect[thread_group_idx]),
maxResultsVect[thread_group_idx],
idxs ? &internal_results[thread_group_idx] : nullptr));
}
if (maxResults > 0) {
// If we are running with maxResults in a multi-threaded settings,
// some threads may have screened more molecules than others.
// If maxResults was close to the theoretical maximum, some threads
// might have even run out of molecules to screen without reaching
// maxResults so we need to make sure that all threads have screened as
// many molecules as the most productive thread if we want multi-threaded
// runs to yield the same results independently from the number of
// threads.
thread_group_idx = 0;
for (auto &fut : thread_group) {
fut.get();
counter += counterVect[thread_group_idx++];
}
thread_group.clear();
// Find out out the max number of molecules that was screened by the most
// productive thread and do the same in all other threads, unless the
// max number of molecules was reached
unsigned int maxEndIdx =
*std::max_element(endIdxVect.begin(), endIdxVect.end());
for (thread_group_idx = 0; thread_group_idx < numThreads;
++thread_group_idx) {
if (endIdxVect[thread_group_idx] >= maxEndIdx) {
continue;
}
// need to use boost::ref otherwise things are passed by value
thread_group.emplace_back(std::async(
std::launch::async, SubSearcher, std::ref(query), bits,
std::ref(mols), endIdxVect[thread_group_idx] + numThreads,
std::ref(maxEndIdx), numThreads, needs_rings,
std::ref(counterVect[thread_group_idx]), -1,
&internal_results[thread_group_idx]));
}
}
for (auto &fut : thread_group) {
fut.get();
}
for (thread_group_idx = 0; thread_group_idx < numThreads;
++thread_group_idx) {
if (idxs) {
idxs->insert(idxs->end(), internal_results[thread_group_idx].begin(),
internal_results[thread_group_idx].end());
}
// If there was no maxResults, we still need to count, otherwise
// this has already been done previously
if (maxResults < 0) {
counter += counterVect[thread_group_idx];
}
}
} else {
// if this is running single-threaded, no need to suffer the overhead of
// std::async
SubSearcher(query, bits, mols, startIdx, endIdx, 1, needs_rings, counter,
maxResults, idxs);
}
for (auto &thread : thread_group) {
thread.get();
if (idxs) {
// the sort is necessary to ensure consistency across runs with different
// numbers of threads
std::sort(idxs->begin(), idxs->end());
// we may have actually accumulated more results than maxResults due
// to the top up above, so trim the results down if that's the case
if (maxResults > 0 &&
idxs->size() > static_cast<unsigned int>(maxResults)) {
idxs->resize(maxResults);
}
}
#else
SubSearcher(query, bits, mols, startIdx, endIdx, 1, needs_rings, counter, maxResults, idxs);
#endif
delete bits.queryBits;
return (int)counter;
return counter;
}
}
std::vector<unsigned int> SubstructLibrary::getMatches(
@@ -277,9 +285,11 @@ std::vector<unsigned int> SubstructLibrary::getMatches(
const ROMol &query, unsigned int startIdx, unsigned int endIdx,
bool recursionPossible, bool useChirality, bool useQueryQueryMatches,
int numThreads, int maxResults) {
return internalGetMatches(query, *mols, fps, startIdx, endIdx,
recursionPossible, useChirality,
useQueryQueryMatches, numThreads, maxResults);
std::vector<unsigned int> idxs;
internalGetMatches(query, *mols, fps, startIdx, endIdx,
recursionPossible, useChirality,
useQueryQueryMatches, numThreads, maxResults, &idxs);
return idxs;
}
unsigned int SubstructLibrary::countMatches(const ROMol &query,
@@ -295,9 +305,9 @@ unsigned int SubstructLibrary::countMatches(
const ROMol &query, unsigned int startIdx, unsigned int endIdx,
bool recursionPossible, bool useChirality, bool useQueryQueryMatches,
int numThreads) {
return internalMatchCounter(query, *mols, fps, startIdx, endIdx,
recursionPossible, useChirality,
useQueryQueryMatches, numThreads);
return internalGetMatches(query, *mols, fps, startIdx, endIdx,
recursionPossible, useChirality,
useQueryQueryMatches, numThreads, -1);
}
bool SubstructLibrary::hasMatch(const ROMol &query, bool recursionPossible,