SubstructLibrary improvements (#4403)

* support MolBundles in the SubstructLibrary

* support SubstructMatchParameters in SubstructLibrary

* expose those methods to the python wrapper

* fix non-threadsafe build problem

* be explicit about sorting

* add the option to set the search order

* add example functions for setting the search order
This commit is contained in:
Greg Landrum
2021-08-24 05:02:41 +02:00
committed by GitHub
parent fc26b309d6
commit e21c39dcb3
7 changed files with 799 additions and 354 deletions

View File

@@ -1,4 +1,6 @@
// Copyright (c) 2017-2019, Novartis Institutes for BioMedical Research Inc.
// Copyright (c) 2017-2021, Novartis Institutes for BioMedical Research Inc.
// and other RDKit contributors
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
@@ -36,6 +38,7 @@
#endif
#include <GraphMol/Substruct/SubstructMatch.h>
#include <boost/dynamic_bitset.hpp>
namespace RDKit {
@@ -50,16 +53,11 @@ bool SubstructLibraryCanSerialize() {
struct Bits {
const ExplicitBitVect *queryBits;
const FPHolderBase *fps;
bool recursionPossible;
bool useChirality;
bool useQueryQueryMatches;
SubstructMatchParameters params;
Bits(const FPHolderBase *fps, const ROMol &m, bool recursionPossible,
bool useChirality, bool useQueryQueryMatches)
: fps(fps),
recursionPossible(recursionPossible),
useChirality(useChirality),
useQueryQueryMatches(useQueryQueryMatches) {
Bits(const FPHolderBase *fps, const ROMol &m,
const SubstructMatchParameters &ssparams)
: fps(fps), params(ssparams) {
if (fps) {
queryBits = fps->makeFingerprint(m);
} else {
@@ -68,11 +66,8 @@ struct Bits {
}
Bits(const FPHolderBase *fingerprints, const TautomerQuery &m,
bool recursionPossible, bool useChirality, bool useQueryQueryMatches)
: fps(nullptr),
recursionPossible(recursionPossible),
useChirality(useChirality),
useQueryQueryMatches(useQueryQueryMatches) {
const SubstructMatchParameters &ssparams)
: fps(nullptr), params(ssparams) {
if (fingerprints) {
const TautomerPatternHolder *tp =
dynamic_cast<const TautomerPatternHolder *>(fingerprints);
@@ -143,16 +138,23 @@ void SubSearcher(const Query &in_query, const Bits &bits,
const MolHolderBase &mols, unsigned int start,
unsigned int &end, unsigned int numThreads,
const bool needs_rings, int &counter, const int maxResults,
boost::dynamic_bitset<> &found,
const std::vector<unsigned int> &searchOrder,
std::vector<unsigned int> *idxs) {
PRECONDITION(searchOrder.empty() || searchOrder.size() >= end,
"bad searchOrder data");
Query query(in_query);
MatchVectType matchVect;
for (unsigned int idx = start; idx < end; idx += numThreads) {
if (!bits.check(idx)) {
unsigned int sidx = idx;
if (!searchOrder.empty()) {
sidx = searchOrder[idx];
}
if (!bits.check(sidx) || found[sidx]) {
continue;
}
// need shared_ptr as it (may) control the lifespan of the
// returned molecule!
const boost::shared_ptr<ROMol> &m = mols.getMol(idx);
const boost::shared_ptr<ROMol> &m = mols.getMol(sidx);
ROMol *mol = m.get();
if (!mol) {
continue;
@@ -162,11 +164,11 @@ void SubSearcher(const Query &in_query, const Bits &bits,
MolOps::symmetrizeSSSR(*mol);
}
if (SubstructMatch(*mol, query, matchVect, bits.recursionPossible,
bits.useChirality, bits.useQueryQueryMatches)) {
if (!SubstructMatch(*mol, query, bits.params).empty()) {
++counter;
found.set(sidx);
if (idxs) {
idxs->push_back(idx);
idxs->push_back(sidx);
if (maxResults > 0 && counter == maxResults) {
// if we reached maxResults, record the last idx we processed and bail
// out
@@ -181,11 +183,14 @@ void SubSearcher(const Query &in_query, const Bits &bits,
template <class Query>
int internalGetMatches(const Query &query, MolHolderBase &mols,
const FPHolderBase *fps, unsigned int startIdx,
unsigned int endIdx, bool recursionPossible,
bool useChirality, bool useQueryQueryMatches,
int numThreads = -1, int maxResults = 1000,
std::vector<unsigned int> *idxs = nullptr) {
unsigned int endIdx,
const SubstructMatchParameters &params, int numThreads,
int maxResults, boost::dynamic_bitset<> &found,
const std::vector<unsigned int> &searchOrder,
std::vector<unsigned int> *idxs) {
PRECONDITION(startIdx < mols.size(), "startIdx out of bounds");
PRECONDITION(searchOrder.empty() || startIdx < searchOrder.size(),
"startIdx out of bounds");
PRECONDITION(endIdx > startIdx, "endIdx > startIdx");
// do not do any work if no results were requested
@@ -194,12 +199,15 @@ int internalGetMatches(const Query &query, MolHolderBase &mols,
}
endIdx = std::min(mols.size(), endIdx);
if (!searchOrder.empty()) {
endIdx = std::min(static_cast<unsigned int>(searchOrder.size()), endIdx);
}
numThreads = static_cast<int>(getNumThreadsToUse(numThreads));
numThreads = std::min(numThreads, static_cast<int>(endIdx));
bool needs_rings = query_needs_rings(query);
Bits bits(fps, query, recursionPossible, useChirality, useQueryQueryMatches);
Bits bits(fps, query, params);
int counter = 0;
#ifdef RDK_THREADSAFE_SSS
@@ -219,20 +227,23 @@ int internalGetMatches(const Query &query, MolHolderBase &mols,
}
std::vector<std::future<void>> thread_group;
std::vector<std::vector<unsigned int>> internal_results;
std::vector<boost::dynamic_bitset<>> internal_found(numThreads);
if (idxs) {
internal_results.resize(numThreads);
}
int thread_group_idx;
for (thread_group_idx = 0; thread_group_idx < numThreads;
++thread_group_idx) {
internal_found[thread_group_idx] = found;
// need to use boost::ref otherwise things are passed by value
thread_group.emplace_back(
std::async(std::launch::async, SubSearcher<Query>, std::ref(query),
bits, std::ref(mols), startIdx + thread_group_idx,
std::ref(endIdxVect[thread_group_idx]), numThreads,
needs_rings, std::ref(counterVect[thread_group_idx]),
maxResultsVect[thread_group_idx],
idxs ? &internal_results[thread_group_idx] : nullptr));
thread_group.emplace_back(std::async(
std::launch::async, SubSearcher<Query>, std::ref(query), bits,
std::ref(mols), startIdx + thread_group_idx,
std::ref(endIdxVect[thread_group_idx]), numThreads, needs_rings,
std::ref(counterVect[thread_group_idx]),
maxResultsVect[thread_group_idx],
std::ref(internal_found[thread_group_idx]), std::ref(searchOrder),
idxs ? &internal_results[thread_group_idx] : nullptr));
}
unsigned int maxEndIdx;
if (maxResults > 0) {
@@ -259,12 +270,14 @@ int internalGetMatches(const Query &query, MolHolderBase &mols,
if (endIdxVect[thread_group_idx] >= maxEndIdx) {
continue;
}
internal_found[thread_group_idx] = found;
// need to use boost::ref otherwise things are passed by value
thread_group.emplace_back(std::async(
std::launch::async, SubSearcher<Query>, std::ref(query), bits,
std::ref(mols), endIdxVect[thread_group_idx] + numThreads,
std::ref(maxEndIdx), numThreads, needs_rings,
std::ref(counterVect[thread_group_idx]), -1,
std::ref(internal_found[thread_group_idx]), std::ref(searchOrder),
&internal_results[thread_group_idx]));
}
}
@@ -277,6 +290,7 @@ int internalGetMatches(const Query &query, MolHolderBase &mols,
idxs->insert(idxs->end(), internal_results[thread_group_idx].begin(),
internal_results[thread_group_idx].end());
}
found |= internal_found[thread_group_idx];
// If there was no maxResults, we still need to count, otherwise
// this has already been done previously
if (maxResults < 0) {
@@ -287,11 +301,19 @@ int internalGetMatches(const Query &query, MolHolderBase &mols,
// if this is running single-threaded, no need to suffer the overhead of
// std::async
SubSearcher(query, bits, mols, startIdx, endIdx, 1, needs_rings, counter,
maxResults, idxs);
maxResults, found, searchOrder, idxs);
}
if (idxs) {
if (idxs && numThreads > 1) {
// the sort is necessary to ensure consistency across runs with different
// numbers of threads
if (!searchOrder.empty()) {
std::transform(idxs->begin(), idxs->end(), idxs->begin(),
[searchOrder](unsigned int v) -> unsigned int {
return std::find(searchOrder.begin(), searchOrder.end(),
v) -
searchOrder.begin();
});
}
std::sort(idxs->begin(), idxs->end());
// we may have actually accumulated more results than maxResults due
// to the top up above, so trim the results down if that's the case
@@ -299,10 +321,16 @@ int internalGetMatches(const Query &query, MolHolderBase &mols,
idxs->size() > static_cast<unsigned int>(maxResults)) {
idxs->resize(maxResults);
}
if (!searchOrder.empty()) {
std::transform(idxs->begin(), idxs->end(), idxs->begin(),
[searchOrder](unsigned int v) -> unsigned int {
return searchOrder[v];
});
}
}
#else
SubSearcher(query, bits, mols, startIdx, endIdx, 1, needs_rings, counter,
maxResults, idxs);
maxResults, found, searchOrder, idxs);
#endif
delete bits.queryBits;
@@ -310,66 +338,102 @@ int internalGetMatches(const Query &query, MolHolderBase &mols,
return counter;
}
int molbundleGetMatches(const MolBundle &query, MolHolderBase &mols,
const FPHolderBase *fps, unsigned int startIdx,
unsigned int endIdx,
const SubstructMatchParameters &params, int numThreads,
int maxResults,
const std::vector<unsigned int> &searchOrder,
std::vector<unsigned int> *idxs) {
int res = 0;
boost::dynamic_bitset<> found(mols.size());
for (const auto qmol : query.getMols()) {
maxResults -= res;
res += internalGetMatches(*qmol, mols, fps, startIdx, endIdx, params,
numThreads, maxResults, found, searchOrder, idxs);
}
return res;
}
} // namespace
std::vector<unsigned int> SubstructLibrary::getMatches(
const ROMol &query, unsigned int startIdx, unsigned int endIdx,
bool recursionPossible, bool useChirality, bool useQueryQueryMatches,
int numThreads, int maxResults) const {
const SubstructMatchParameters &params, int numThreads,
int maxResults) const {
std::vector<unsigned int> idxs;
internalGetMatches(query, *mols, fps, startIdx, endIdx, recursionPossible,
useChirality, useQueryQueryMatches, numThreads, maxResults,
&idxs);
boost::dynamic_bitset<> found(mols->size());
internalGetMatches(query, *mols, fps, startIdx, endIdx, params, numThreads,
maxResults, found, searchOrder, &idxs);
return idxs;
}
std::vector<unsigned int> SubstructLibrary::getMatches(
const TautomerQuery &query, unsigned int startIdx, unsigned int endIdx,
bool recursionPossible, bool useChirality, bool useQueryQueryMatches,
int numThreads, int maxResults) const {
const SubstructMatchParameters &params, int numThreads,
int maxResults) const {
std::vector<unsigned int> idxs;
internalGetMatches(query, *mols, fps, startIdx, endIdx, recursionPossible,
useChirality, useQueryQueryMatches, numThreads, maxResults,
&idxs);
boost::dynamic_bitset<> found(mols->size());
internalGetMatches(query, *mols, fps, startIdx, endIdx, params, numThreads,
maxResults, found, searchOrder, &idxs);
return idxs;
}
std::vector<unsigned int> SubstructLibrary::getMatches(
const MolBundle &query, unsigned int startIdx, unsigned int endIdx,
const SubstructMatchParameters &params, int numThreads,
int maxResults) const {
std::vector<unsigned int> idxs;
molbundleGetMatches(query, *mols, fps, startIdx, endIdx, params, numThreads,
maxResults, searchOrder, &idxs);
return idxs;
}
unsigned int SubstructLibrary::countMatches(
const ROMol &query, unsigned int startIdx, unsigned int endIdx,
bool recursionPossible, bool useChirality, bool useQueryQueryMatches,
int numThreads) const {
return internalGetMatches(query, *mols, fps, startIdx, endIdx,
recursionPossible, useChirality,
useQueryQueryMatches, numThreads, -1);
const SubstructMatchParameters &params, int numThreads) const {
boost::dynamic_bitset<> found(mols->size());
return internalGetMatches(query, *mols, fps, startIdx, endIdx, params,
numThreads, -1, found, searchOrder, nullptr);
}
unsigned int SubstructLibrary::countMatches(
const TautomerQuery &query, unsigned int startIdx, unsigned int endIdx,
bool recursionPossible, bool useChirality, bool useQueryQueryMatches,
int numThreads) const {
return internalGetMatches(query, *mols, fps, startIdx, endIdx,
recursionPossible, useChirality,
useQueryQueryMatches, numThreads, -1);
const SubstructMatchParameters &params, int numThreads) const {
boost::dynamic_bitset<> found(mols->size());
return internalGetMatches(query, *mols, fps, startIdx, endIdx, params,
numThreads, -1, found, searchOrder, nullptr);
}
unsigned int SubstructLibrary::countMatches(
const MolBundle &query, unsigned int startIdx, unsigned int endIdx,
const SubstructMatchParameters &params, int numThreads) const {
return molbundleGetMatches(query, *mols, fps, startIdx, endIdx, params,
numThreads, -1, searchOrder, nullptr);
}
bool SubstructLibrary::hasMatch(const ROMol &query, unsigned int startIdx,
unsigned int endIdx, bool recursionPossible,
bool useChirality, bool useQueryQueryMatches,
unsigned int endIdx,
const SubstructMatchParameters &params,
int numThreads) const {
const int maxResults = 1;
return getMatches(query, startIdx, endIdx, recursionPossible, useChirality,
useQueryQueryMatches, numThreads, maxResults)
return getMatches(query, startIdx, endIdx, params, numThreads, maxResults)
.size() > 0;
}
bool SubstructLibrary::hasMatch(const TautomerQuery &query,
unsigned int startIdx, unsigned int endIdx,
bool recursionPossible, bool useChirality,
bool useQueryQueryMatches,
const SubstructMatchParameters &params,
int numThreads) const {
const int maxResults = 1;
return getMatches(query, startIdx, endIdx, recursionPossible, useChirality,
useQueryQueryMatches, numThreads, maxResults)
return getMatches(query, startIdx, endIdx, params, numThreads, maxResults)
.size() > 0;
}
bool SubstructLibrary::hasMatch(const MolBundle &query, unsigned int startIdx,
unsigned int endIdx,
const SubstructMatchParameters &params,
int numThreads) const {
const int maxResults = 1;
return getMatches(query, startIdx, endIdx, params, numThreads, maxResults)
.size() > 0;
}