mirror of
https://github.com/rdkit/rdkit.git
synced 2026-06-03 21:44:30 +08:00
Synthon substructure search 2x performance (#9307)
* synthon perf: replace O(N) haveEnoughHits scan with O(1) atomic counter processPartHitsFromDetails called haveEnoughHits after each verified hit, which scanned every slot of the pre-sized results vector (up to toTryChunkSize = 2.5M entries) to count non-null entries via std::accumulate. With ~3000 verified hits per search that is ~7.5B pointer reads per query. Replace with a std::atomic<int64_t> numHitsFound counter in makeHitsFromToTry, incremented via fetch_add on each verified hit. The early-exit condition becomes a single atomic read, O(1) per hit regardless of vector size. The atomic is local to makeHitsFromToTry so it resets correctly per chunk and is safe for the multi-threaded path without added synchronization. Measured on synthon_perf branch (42-rxn / 140B-product Freedom space, maxHits=3000, hitStart=1000, before boost::unordered_flat_set change): search-several (9 queries): ~30s → ~16.5s (~1.8x) search-one (benzene): ~3.5s → ~1.8s (~1.9x) All 4 synthon ctest cases pass. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * style ++ * Update Code/GraphMol/SynthonSpaceSearch/SynthonSpaceSearcher.cpp Co-authored-by: Greg Landrum <greg.landrum@gmail.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: Greg Landrum <greg.landrum@gmail.com>
This commit is contained in:
committed by
GitHub
parent
bf711414a3
commit
226427e0bc
@@ -66,6 +66,7 @@ void SynthonSpaceSearcher::search(const SearchResultCallback &cb) {
|
||||
|
||||
// from buildAllhits
|
||||
std::vector<std::pair<const SynthonSpaceHitSet *, std::vector<size_t>>> toTry;
|
||||
std::atomic<std::int64_t> numHitsFound = 0;
|
||||
std::int64_t hitCount = 0;
|
||||
bool stop = false; // set by callback
|
||||
|
||||
@@ -86,7 +87,7 @@ void SynthonSpaceSearcher::search(const SearchResultCallback &cb) {
|
||||
toTry.emplace_back(hitset.get(), stepper.d_currState);
|
||||
if (toTry.size() == static_cast<size_t>(d_params.toTryChunkSize)) {
|
||||
std::vector<std::unique_ptr<ROMol>> partResults;
|
||||
processToTrySet(toTry, endTime, partResults);
|
||||
processToTrySet(toTry, endTime, partResults, numHitsFound);
|
||||
hitCount += partResults.size();
|
||||
stop = cb(partResults);
|
||||
toTry.clear();
|
||||
@@ -105,7 +106,7 @@ void SynthonSpaceSearcher::search(const SearchResultCallback &cb) {
|
||||
if ((d_params.maxHits == -1 || hitCount < d_params.maxHits) && !stop &&
|
||||
!toTry.empty()) {
|
||||
std::vector<std::unique_ptr<ROMol>> partResults;
|
||||
processToTrySet(toTry, endTime, partResults);
|
||||
processToTrySet(toTry, endTime, partResults, numHitsFound);
|
||||
cb(partResults);
|
||||
}
|
||||
}
|
||||
@@ -396,25 +397,6 @@ void sortAndUniquifyToTry(
|
||||
toTry.end());
|
||||
}
|
||||
|
||||
bool haveEnoughHits(const std::vector<std::unique_ptr<ROMol>> &results,
|
||||
const std::int64_t maxHits, const std::int64_t hitStart) {
|
||||
const std::int64_t numHits = std::accumulate(
|
||||
results.begin(), results.end(), 0,
|
||||
[](const size_t prevVal, const std::unique_ptr<ROMol> &m) -> size_t {
|
||||
if (m) {
|
||||
return prevVal + 1;
|
||||
}
|
||||
return prevVal;
|
||||
});
|
||||
// If there's a limit on the number of hits, we still need to keep the
|
||||
// first hitStart hits and remove them later. They had to be built
|
||||
// to see if they passed verifyHit.
|
||||
if (maxHits != -1 && numHits >= maxHits + hitStart) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void SynthonSpaceSearcher::buildHits(
|
||||
@@ -449,6 +431,7 @@ void SynthonSpaceSearcher::buildAllHits(
|
||||
// they will be built into molecules, verified and accepted or
|
||||
// rejected as hits.
|
||||
std::vector<std::pair<const SynthonSpaceHitSet *, std::vector<size_t>>> toTry;
|
||||
std::atomic<std::int64_t> numHitsFound = 0;
|
||||
bool enoughHits = false;
|
||||
|
||||
// Each hitset contains possible hits from a single SynthonSet.
|
||||
@@ -468,13 +451,13 @@ void SynthonSpaceSearcher::buildAllHits(
|
||||
toTry.emplace_back(hitset.get(), stepper.d_currState);
|
||||
if (toTry.size() == static_cast<size_t>(d_params.toTryChunkSize)) {
|
||||
std::vector<std::unique_ptr<ROMol>> partResults;
|
||||
processToTrySet(toTry, endTime, partResults);
|
||||
processToTrySet(toTry, endTime, partResults, numHitsFound);
|
||||
results.insert(results.end(),
|
||||
std::make_move_iterator(partResults.begin()),
|
||||
std::make_move_iterator(partResults.end()));
|
||||
partResults.clear();
|
||||
enoughHits =
|
||||
haveEnoughHits(results, d_params.maxHits, d_params.hitStart);
|
||||
enoughHits = d_params.maxHits != -1 &&
|
||||
numHitsFound.load() >= d_params.maxHits + d_params.hitStart;
|
||||
timedOut = details::checkTimeOut(endTime);
|
||||
toTry.clear();
|
||||
if (enoughHits || timedOut || ControlCHandler::getGotSignal()) {
|
||||
@@ -490,7 +473,7 @@ void SynthonSpaceSearcher::buildAllHits(
|
||||
|
||||
// Do any remaining.
|
||||
if (!enoughHits && !timedOut && !toTry.empty()) {
|
||||
processToTrySet(toTry, endTime, results);
|
||||
processToTrySet(toTry, endTime, results, numHitsFound);
|
||||
}
|
||||
|
||||
sortHits(results);
|
||||
@@ -526,8 +509,11 @@ void processPartHitsFromDetails(
|
||||
std::pair<const SynthonSpaceHitSet *, std::vector<size_t>>> &toTry,
|
||||
const TimePoint *endTime, std::vector<std::unique_ptr<ROMol>> &results,
|
||||
const SynthonSpaceSearcher *searcher,
|
||||
std::atomic<std::int64_t> &mostRecentTry, std::int64_t lastTry) {
|
||||
std::atomic<std::int64_t> &mostRecentTry, std::int64_t lastTry,
|
||||
std::atomic<std::int64_t> &numHitsFound) {
|
||||
std::uint64_t numTries = 100;
|
||||
const std::int64_t maxHits = searcher->getParams().maxHits;
|
||||
const std::int64_t hitStart = searcher->getParams().hitStart;
|
||||
while (true) {
|
||||
std::int64_t thisTry = ++mostRecentTry;
|
||||
if (thisTry > lastTry) {
|
||||
@@ -536,8 +522,8 @@ void processPartHitsFromDetails(
|
||||
if (auto prod = searcher->buildAndVerifyHit(toTry[thisTry].first,
|
||||
toTry[thisTry].second)) {
|
||||
results[thisTry] = std::move(prod);
|
||||
if (haveEnoughHits(results, searcher->getParams().maxHits,
|
||||
searcher->getParams().hitStart)) {
|
||||
++numHitsFound;
|
||||
if (maxHits != -1 && numHitsFound >= maxHits + hitStart) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -559,8 +545,8 @@ void processPartHitsFromDetails(
|
||||
void SynthonSpaceSearcher::makeHitsFromToTry(
|
||||
const std::vector<
|
||||
std::pair<const SynthonSpaceHitSet *, std::vector<size_t>>> &toTry,
|
||||
const TimePoint *endTime,
|
||||
std::vector<std::unique_ptr<ROMol>> &results) const {
|
||||
const TimePoint *endTime, std::vector<std::unique_ptr<ROMol>> &results,
|
||||
std::atomic<std::int64_t> &numHitsFound) const {
|
||||
results.resize(toTry.size());
|
||||
std::int64_t lastTry = toTry.size() - 1;
|
||||
std::atomic<std::int64_t> mostRecentTry = -1;
|
||||
@@ -581,18 +567,19 @@ void SynthonSpaceSearcher::makeHitsFromToTry(
|
||||
for (unsigned int i = 0U; i < numThreads; ++i, start += eachThread) {
|
||||
threads.push_back(std::thread(processPartHitsFromDetails, std::ref(toTry),
|
||||
endTime, std::ref(results), this,
|
||||
std::ref(mostRecentTry), lastTry));
|
||||
std::ref(mostRecentTry), lastTry,
|
||||
std::ref(numHitsFound)));
|
||||
}
|
||||
for (auto &t : threads) {
|
||||
t.join();
|
||||
}
|
||||
} else {
|
||||
processPartHitsFromDetails(toTry, endTime, results, this, mostRecentTry,
|
||||
lastTry);
|
||||
lastTry, numHitsFound);
|
||||
}
|
||||
#else
|
||||
processPartHitsFromDetails(toTry, endTime, results, this, mostRecentTry,
|
||||
lastTry);
|
||||
lastTry, numHitsFound);
|
||||
#endif
|
||||
|
||||
// Take out any gaps in the results set, where products didn't make the grade.
|
||||
@@ -606,8 +593,8 @@ void SynthonSpaceSearcher::makeHitsFromToTry(
|
||||
void SynthonSpaceSearcher::processToTrySet(
|
||||
std::vector<std::pair<const SynthonSpaceHitSet *, std::vector<size_t>>>
|
||||
&toTry,
|
||||
const TimePoint *endTime,
|
||||
std::vector<std::unique_ptr<ROMol>> &results) const {
|
||||
const TimePoint *endTime, std::vector<std::unique_ptr<ROMol>> &results,
|
||||
std::atomic<std::int64_t> &numHitsFound) const {
|
||||
// There are possibly duplicate entries in toTry, because 2
|
||||
// different fragmentations might produce overlapping synthon lists in
|
||||
// the same reaction. The duplicates need to be removed. Although
|
||||
@@ -617,6 +604,6 @@ void SynthonSpaceSearcher::processToTrySet(
|
||||
if (d_params.randomSample) {
|
||||
std::shuffle(toTry.begin(), toTry.end(), *d_randGen);
|
||||
}
|
||||
makeHitsFromToTry(toTry, endTime, results);
|
||||
makeHitsFromToTry(toTry, endTime, results, numHitsFound);
|
||||
}
|
||||
} // namespace RDKit::SynthonSpaceSearch
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#ifndef SYNTHONSPACESEARCHER_H
|
||||
#define SYNTHONSPACESEARCHER_H
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <random>
|
||||
|
||||
@@ -124,13 +125,13 @@ class SynthonSpaceSearcher {
|
||||
void makeHitsFromToTry(
|
||||
const std::vector<
|
||||
std::pair<const SynthonSpaceHitSet *, std::vector<size_t>>> &toTry,
|
||||
const TimePoint *endTime,
|
||||
std::vector<std::unique_ptr<ROMol>> &results) const;
|
||||
const TimePoint *endTime, std::vector<std::unique_ptr<ROMol>> &results,
|
||||
std::atomic<std::int64_t> &numHitsFound) const;
|
||||
void processToTrySet(
|
||||
std::vector<std::pair<const SynthonSpaceHitSet *, std::vector<size_t>>>
|
||||
&toTry,
|
||||
const TimePoint *endTime,
|
||||
std::vector<std::unique_ptr<ROMol>> &results) const;
|
||||
const TimePoint *endTime, std::vector<std::unique_ptr<ROMol>> &results,
|
||||
std::atomic<std::int64_t> &numHitsFound) const;
|
||||
|
||||
// get the subset of synthons for the given reaction to use for this
|
||||
// enumeration.
|
||||
|
||||
Reference in New Issue
Block a user