Synthon substructure search 2x performance (#9307)

* synthon perf: replace O(N) haveEnoughHits scan with O(1) atomic counter

processPartHitsFromDetails called haveEnoughHits after each verified hit,
which scanned every slot of the pre-sized results vector (up to toTryChunkSize
= 2.5M entries) to count non-null entries via std::accumulate. With ~3000
verified hits per search that is ~7.5B pointer reads per query.

Replace with a std::atomic<int64_t> numHitsFound counter in makeHitsFromToTry,
incremented via fetch_add on each verified hit. The early-exit condition becomes
a single atomic read, O(1) per hit regardless of vector size. The atomic is
local to makeHitsFromToTry so it resets correctly per chunk and is safe for
the multi-threaded path without added synchronization.

Measured on synthon_perf branch (42-rxn / 140B-product Freedom space,
maxHits=3000, hitStart=1000, before boost::unordered_flat_set change):
  search-several (9 queries): ~30s → ~16.5s (~1.8x)
  search-one (benzene):       ~3.5s → ~1.8s  (~1.9x)

All 4 synthon ctest cases pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* style ++

* Update Code/GraphMol/SynthonSpaceSearch/SynthonSpaceSearcher.cpp

Co-authored-by: Greg Landrum <greg.landrum@gmail.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Greg Landrum <greg.landrum@gmail.com>
This commit is contained in:
Dan Nealschneider
2026-06-02 05:23:49 -07:00
committed by GitHub
parent bf711414a3
commit 226427e0bc
2 changed files with 28 additions and 40 deletions

View File

@@ -66,6 +66,7 @@ void SynthonSpaceSearcher::search(const SearchResultCallback &cb) {
// from buildAllhits
std::vector<std::pair<const SynthonSpaceHitSet *, std::vector<size_t>>> toTry;
std::atomic<std::int64_t> numHitsFound = 0;
std::int64_t hitCount = 0;
bool stop = false; // set by callback
@@ -86,7 +87,7 @@ void SynthonSpaceSearcher::search(const SearchResultCallback &cb) {
toTry.emplace_back(hitset.get(), stepper.d_currState);
if (toTry.size() == static_cast<size_t>(d_params.toTryChunkSize)) {
std::vector<std::unique_ptr<ROMol>> partResults;
processToTrySet(toTry, endTime, partResults);
processToTrySet(toTry, endTime, partResults, numHitsFound);
hitCount += partResults.size();
stop = cb(partResults);
toTry.clear();
@@ -105,7 +106,7 @@ void SynthonSpaceSearcher::search(const SearchResultCallback &cb) {
if ((d_params.maxHits == -1 || hitCount < d_params.maxHits) && !stop &&
!toTry.empty()) {
std::vector<std::unique_ptr<ROMol>> partResults;
processToTrySet(toTry, endTime, partResults);
processToTrySet(toTry, endTime, partResults, numHitsFound);
cb(partResults);
}
}
@@ -396,25 +397,6 @@ void sortAndUniquifyToTry(
toTry.end());
}
bool haveEnoughHits(const std::vector<std::unique_ptr<ROMol>> &results,
const std::int64_t maxHits, const std::int64_t hitStart) {
const std::int64_t numHits = std::accumulate(
results.begin(), results.end(), 0,
[](const size_t prevVal, const std::unique_ptr<ROMol> &m) -> size_t {
if (m) {
return prevVal + 1;
}
return prevVal;
});
// If there's a limit on the number of hits, we still need to keep the
// first hitStart hits and remove them later. They had to be built
// to see if they passed verifyHit.
if (maxHits != -1 && numHits >= maxHits + hitStart) {
return true;
}
return false;
}
} // namespace
void SynthonSpaceSearcher::buildHits(
@@ -449,6 +431,7 @@ void SynthonSpaceSearcher::buildAllHits(
// they will be built into molecules, verified and accepted or
// rejected as hits.
std::vector<std::pair<const SynthonSpaceHitSet *, std::vector<size_t>>> toTry;
std::atomic<std::int64_t> numHitsFound = 0;
bool enoughHits = false;
// Each hitset contains possible hits from a single SynthonSet.
@@ -468,13 +451,13 @@ void SynthonSpaceSearcher::buildAllHits(
toTry.emplace_back(hitset.get(), stepper.d_currState);
if (toTry.size() == static_cast<size_t>(d_params.toTryChunkSize)) {
std::vector<std::unique_ptr<ROMol>> partResults;
processToTrySet(toTry, endTime, partResults);
processToTrySet(toTry, endTime, partResults, numHitsFound);
results.insert(results.end(),
std::make_move_iterator(partResults.begin()),
std::make_move_iterator(partResults.end()));
partResults.clear();
enoughHits =
haveEnoughHits(results, d_params.maxHits, d_params.hitStart);
enoughHits = d_params.maxHits != -1 &&
numHitsFound.load() >= d_params.maxHits + d_params.hitStart;
timedOut = details::checkTimeOut(endTime);
toTry.clear();
if (enoughHits || timedOut || ControlCHandler::getGotSignal()) {
@@ -490,7 +473,7 @@ void SynthonSpaceSearcher::buildAllHits(
// Do any remaining.
if (!enoughHits && !timedOut && !toTry.empty()) {
processToTrySet(toTry, endTime, results);
processToTrySet(toTry, endTime, results, numHitsFound);
}
sortHits(results);
@@ -526,8 +509,11 @@ void processPartHitsFromDetails(
std::pair<const SynthonSpaceHitSet *, std::vector<size_t>>> &toTry,
const TimePoint *endTime, std::vector<std::unique_ptr<ROMol>> &results,
const SynthonSpaceSearcher *searcher,
std::atomic<std::int64_t> &mostRecentTry, std::int64_t lastTry) {
std::atomic<std::int64_t> &mostRecentTry, std::int64_t lastTry,
std::atomic<std::int64_t> &numHitsFound) {
std::uint64_t numTries = 100;
const std::int64_t maxHits = searcher->getParams().maxHits;
const std::int64_t hitStart = searcher->getParams().hitStart;
while (true) {
std::int64_t thisTry = ++mostRecentTry;
if (thisTry > lastTry) {
@@ -536,8 +522,8 @@ void processPartHitsFromDetails(
if (auto prod = searcher->buildAndVerifyHit(toTry[thisTry].first,
toTry[thisTry].second)) {
results[thisTry] = std::move(prod);
if (haveEnoughHits(results, searcher->getParams().maxHits,
searcher->getParams().hitStart)) {
++numHitsFound;
if (maxHits != -1 && numHitsFound >= maxHits + hitStart) {
break;
}
}
@@ -559,8 +545,8 @@ void processPartHitsFromDetails(
void SynthonSpaceSearcher::makeHitsFromToTry(
const std::vector<
std::pair<const SynthonSpaceHitSet *, std::vector<size_t>>> &toTry,
const TimePoint *endTime,
std::vector<std::unique_ptr<ROMol>> &results) const {
const TimePoint *endTime, std::vector<std::unique_ptr<ROMol>> &results,
std::atomic<std::int64_t> &numHitsFound) const {
results.resize(toTry.size());
std::int64_t lastTry = toTry.size() - 1;
std::atomic<std::int64_t> mostRecentTry = -1;
@@ -581,18 +567,19 @@ void SynthonSpaceSearcher::makeHitsFromToTry(
for (unsigned int i = 0U; i < numThreads; ++i, start += eachThread) {
threads.push_back(std::thread(processPartHitsFromDetails, std::ref(toTry),
endTime, std::ref(results), this,
std::ref(mostRecentTry), lastTry));
std::ref(mostRecentTry), lastTry,
std::ref(numHitsFound)));
}
for (auto &t : threads) {
t.join();
}
} else {
processPartHitsFromDetails(toTry, endTime, results, this, mostRecentTry,
lastTry);
lastTry, numHitsFound);
}
#else
processPartHitsFromDetails(toTry, endTime, results, this, mostRecentTry,
lastTry);
lastTry, numHitsFound);
#endif
// Take out any gaps in the results set, where products didn't make the grade.
@@ -606,8 +593,8 @@ void SynthonSpaceSearcher::makeHitsFromToTry(
void SynthonSpaceSearcher::processToTrySet(
std::vector<std::pair<const SynthonSpaceHitSet *, std::vector<size_t>>>
&toTry,
const TimePoint *endTime,
std::vector<std::unique_ptr<ROMol>> &results) const {
const TimePoint *endTime, std::vector<std::unique_ptr<ROMol>> &results,
std::atomic<std::int64_t> &numHitsFound) const {
// There are possibly duplicate entries in toTry, because 2
// different fragmentations might produce overlapping synthon lists in
// the same reaction. The duplicates need to be removed. Although
@@ -617,6 +604,6 @@ void SynthonSpaceSearcher::processToTrySet(
if (d_params.randomSample) {
std::shuffle(toTry.begin(), toTry.end(), *d_randGen);
}
makeHitsFromToTry(toTry, endTime, results);
makeHitsFromToTry(toTry, endTime, results, numHitsFound);
}
} // namespace RDKit::SynthonSpaceSearch

View File

@@ -15,6 +15,7 @@
#ifndef SYNTHONSPACESEARCHER_H
#define SYNTHONSPACESEARCHER_H
#include <atomic>
#include <chrono>
#include <random>
@@ -124,13 +125,13 @@ class SynthonSpaceSearcher {
void makeHitsFromToTry(
const std::vector<
std::pair<const SynthonSpaceHitSet *, std::vector<size_t>>> &toTry,
const TimePoint *endTime,
std::vector<std::unique_ptr<ROMol>> &results) const;
const TimePoint *endTime, std::vector<std::unique_ptr<ROMol>> &results,
std::atomic<std::int64_t> &numHitsFound) const;
void processToTrySet(
std::vector<std::pair<const SynthonSpaceHitSet *, std::vector<size_t>>>
&toTry,
const TimePoint *endTime,
std::vector<std::unique_ptr<ROMol>> &results) const;
const TimePoint *endTime, std::vector<std::unique_ptr<ROMol>> &results,
std::atomic<std::int64_t> &numHitsFound) const;
// get the subset of synthons for the given reaction to use for this
// enumeration.