mirror of
https://github.com/rdkit/rdkit.git
synced 2026-06-04 21:54:27 +08:00
* iterators for random-access MolSuppliers add optional caching to SDMolSupplier * add support to SmilesMolSupplier too There is a lot of duplicate code between the random-access suppliers that would be worth trying to remove but at the moment it looks like it would require multiple inheritance, and I think we want to avoid that * add input iterators for ForwardSDMolSupplier() * throw when calling begin() on a used supplier * switch to use the spaceship operator * init() should reset the mol cache * Make SDMolSupplier and SmilesMolSupplier safe for multi-threaded reads * add benchmarking * add TDTMolSupplier support improved testing add benchmarks for parallel iteration optional TBB support * better const handling, add reverse iterators doesn't look like const_iterator is possible since getting data from the underlyng supplier object is non-const * improve docs more usings add reverse iterator to TDTMolSupplier * tests only try execution::par when it is there * fix typo * more testing/demo * remove accidentally added files * review changes * add default ctors * disable a false-positive compiler warning it is stupid to have to do this --------- Co-authored-by: = <=>
507 lines
17 KiB
C++
507 lines
17 KiB
C++
//
|
|
// Copyright (C) 2026 Greg Landrum and other RDKit contributors
|
|
//
|
|
// @@ All Rights Reserved @@
|
|
// This file is part of the RDKit.
|
|
// The contents are covered by the terms of the BSD license
|
|
// which is included in the file license.txt, found at the root
|
|
// of the RDKit source tree.
|
|
//
|
|
#include <catch2/catch_all.hpp>
|
|
|
|
#include <vector>
|
|
#include <algorithm>
|
|
#include <execution>
|
|
#include <ranges>
|
|
#include <iostream>
|
|
#include <filesystem>
|
|
|
|
#include <GraphMol/RDKitBase.h>
|
|
#include "MolSupplier.h"
|
|
#include <GraphMol/SmilesParse/SmilesWrite.h>
|
|
#include <GraphMol/SmilesParse/SmilesParse.h>
|
|
#include <GraphMol/Substruct/SubstructMatch.h>
|
|
|
|
static const std::string rdbase = getenv("RDBASE");
|
|
|
|
using namespace RDKit;
|
|
|
|
namespace {
|
|
unsigned int countAtoms(const std::shared_ptr<RWMol> &mol) {
|
|
REQUIRE(mol);
|
|
return mol->getNumAtoms();
|
|
}
|
|
size_t molPtr(const std::shared_ptr<RWMol> &mol) { return (size_t)mol.get(); }
|
|
} // namespace
|
|
|
|
template <typename Supplier>
|
|
void iterTest(Supplier &reader, size_t len) {
|
|
CHECK(reader.length() == len);
|
|
std::vector<unsigned int> expected;
|
|
expected.reserve(reader.length());
|
|
for (unsigned int i = 0; i < reader.length(); ++i) {
|
|
expected.push_back(reader[i]->getNumAtoms());
|
|
}
|
|
std::vector<unsigned int> actual;
|
|
std::ranges::transform(reader, std::back_inserter(actual), countAtoms);
|
|
CHECK(actual == expected);
|
|
}
|
|
|
|
template <typename Supplier>
|
|
void forwardIterTest(Supplier &reader, size_t len) {
|
|
std::vector<unsigned int> actual;
|
|
std::ranges::transform(reader, std::back_inserter(actual), countAtoms);
|
|
CHECK(actual.size() == len);
|
|
}
|
|
|
|
template <typename Supplier>
|
|
void cacheTest(Supplier &reader, size_t len) {
|
|
reader.setCaching(true);
|
|
CHECK(reader.length() == len);
|
|
std::vector<std::shared_ptr<RWMol>> mols(reader.length());
|
|
std::copy(reader.begin(), reader.end(), mols.begin());
|
|
REQUIRE(mols.size() == reader.length());
|
|
std::vector<size_t> expected;
|
|
std::ranges::transform(mols, std::back_inserter(expected), molPtr);
|
|
std::vector<size_t> actual;
|
|
std::ranges::transform(reader, std::back_inserter(actual), molPtr);
|
|
CHECK(actual == expected);
|
|
}
|
|
|
|
TEST_CASE("basic SDMolSupplier iteration") {
|
|
std::string infile =
|
|
rdbase + "/Code/GraphMol/FileParsers/test_data/NCI_aids_few.sdf";
|
|
v2::FileParsers::SDMolSupplier reader(infile);
|
|
SECTION("basics") { iterTest(reader, 16); }
|
|
SECTION("with caching") { cacheTest(reader, 16); }
|
|
SECTION("reverse iteration") {
|
|
std::vector<unsigned int> expected;
|
|
std::ranges::transform(std::begin(reader), std::end(reader),
|
|
std::back_inserter(expected), countAtoms);
|
|
std::vector<unsigned int> actual;
|
|
std::ranges::transform(std::rbegin(reader), std::rend(reader),
|
|
std::back_inserter(actual), countAtoms);
|
|
std::ranges::reverse(actual);
|
|
CHECK(actual == expected);
|
|
}
|
|
}
|
|
|
|
TEST_CASE("ForwardSDMolSupplier iteration") {
|
|
std::string infile =
|
|
rdbase + "/Code/GraphMol/FileParsers/test_data/NCI_aids_few.sdf";
|
|
std::ifstream strm(infile);
|
|
bool takeOwnership = false;
|
|
v2::FileParsers::ForwardSDMolSupplier reader(&strm, takeOwnership);
|
|
SECTION("basics") { forwardIterTest(reader, 16); }
|
|
SECTION("error handling") {
|
|
reader.next();
|
|
CHECK_THROWS_AS(reader.begin(), ValueErrorException);
|
|
}
|
|
SECTION("pre-increment") {
|
|
unsigned int i = 0;
|
|
auto it = reader.begin();
|
|
auto end = reader.end();
|
|
while (it != end) {
|
|
auto mol = *it;
|
|
REQUIRE(mol);
|
|
++it;
|
|
++i;
|
|
}
|
|
CHECK(i == 16);
|
|
}
|
|
SECTION("post-increment") {
|
|
unsigned int i = 0;
|
|
auto it = reader.begin();
|
|
auto end = reader.end();
|
|
while (it != end) {
|
|
auto mol = *it;
|
|
REQUIRE(mol);
|
|
it++;
|
|
++i;
|
|
}
|
|
CHECK(i == 16);
|
|
}
|
|
}
|
|
TEST_CASE("ForwardSDMolSupplier iteration with failing molecules") {
|
|
std::string infile =
|
|
rdbase + "/Code/GraphMol/FileParsers/test_data/good_bad_good_good.sdf";
|
|
std::ifstream strm(infile);
|
|
bool takeOwnership = false;
|
|
v2::FileParsers::ForwardSDMolSupplier reader(&strm, takeOwnership);
|
|
SECTION("basics") {
|
|
std::vector<unsigned int> expected{6, 0, 6, 6};
|
|
std::vector<unsigned int> actual;
|
|
std::ranges::transform(
|
|
reader, std::back_inserter(actual),
|
|
[](const auto &mol) { return mol ? mol->getNumAtoms() : 0; });
|
|
CHECK(actual == expected);
|
|
}
|
|
SECTION("pre-increment") {
|
|
unsigned int i = 0;
|
|
auto it = reader.begin();
|
|
auto end = reader.end();
|
|
while (it != end) {
|
|
auto mol = *it;
|
|
if (i == 1) {
|
|
CHECK(!mol);
|
|
} else {
|
|
CHECK(mol);
|
|
}
|
|
++it;
|
|
++i;
|
|
}
|
|
CHECK(i == 4);
|
|
}
|
|
SECTION("post-increment") {
|
|
unsigned int i = 0;
|
|
auto it = reader.begin();
|
|
auto end = reader.end();
|
|
while (it != end) {
|
|
auto mol = *it;
|
|
if (i == 1) {
|
|
CHECK(!mol);
|
|
} else {
|
|
CHECK(mol);
|
|
}
|
|
it++;
|
|
++i;
|
|
}
|
|
CHECK(i == 4);
|
|
}
|
|
}
|
|
|
|
TEST_CASE("ForwardSDMolSupplier iteration with failing molecule at end") {
|
|
std::string infile =
|
|
rdbase + "/Code/GraphMol/FileParsers/test_data/good_bad_good_bad.sdf";
|
|
std::ifstream strm(infile);
|
|
bool takeOwnership = false;
|
|
v2::FileParsers::ForwardSDMolSupplier reader(&strm, takeOwnership);
|
|
SECTION("basics") {
|
|
std::vector<unsigned int> expected{6, 0, 6, 0};
|
|
std::vector<unsigned int> actual;
|
|
std::ranges::transform(
|
|
reader, std::back_inserter(actual),
|
|
[](const auto &mol) { return mol ? mol->getNumAtoms() : 0; });
|
|
CHECK(actual == expected);
|
|
}
|
|
SECTION("pre-increment") {
|
|
unsigned int i = 0;
|
|
auto it = reader.begin();
|
|
auto end = reader.end();
|
|
while (it != end) {
|
|
auto mol = *it;
|
|
if (i % 2) {
|
|
CHECK(!mol);
|
|
} else {
|
|
CHECK(mol);
|
|
}
|
|
++it;
|
|
++i;
|
|
}
|
|
CHECK(i == 4);
|
|
}
|
|
SECTION("post-increment") {
|
|
unsigned int i = 0;
|
|
auto it = reader.begin();
|
|
auto end = reader.end();
|
|
while (it != end) {
|
|
auto mol = *it;
|
|
if (i % 2) {
|
|
CHECK(!mol);
|
|
} else {
|
|
CHECK(mol);
|
|
}
|
|
it++;
|
|
++i;
|
|
}
|
|
CHECK(i == 4);
|
|
}
|
|
}
|
|
|
|
TEST_CASE("cached SDMolSupplier error handling") {
|
|
std::string infile =
|
|
rdbase + "/Code/GraphMol/FileParsers/test_data/good_bad_good_bad.sdf";
|
|
SECTION("basics") {
|
|
v2::FileParsers::SDMolSupplier reader(infile);
|
|
reader.setCaching(true);
|
|
CHECK(reader.length() == 4);
|
|
std::vector<std::shared_ptr<RWMol>> mols(reader.length());
|
|
std::copy(reader.begin(), reader.end(), mols.begin());
|
|
REQUIRE(mols.size() == reader.length());
|
|
CHECK(mols[0]);
|
|
CHECK(!mols[1]);
|
|
CHECK(mols[2]);
|
|
CHECK(!mols[3]);
|
|
std::vector<size_t> expected;
|
|
std::ranges::transform(mols, std::back_inserter(expected), molPtr);
|
|
|
|
// now use the cached versions:
|
|
std::vector<std::shared_ptr<RWMol>> nmols(reader.length());
|
|
std::copy(reader.begin(), reader.end(), nmols.begin());
|
|
REQUIRE(nmols.size() == reader.length());
|
|
CHECK(nmols[0]);
|
|
CHECK(!nmols[1]);
|
|
CHECK(nmols[2]);
|
|
CHECK(!nmols[3]);
|
|
// confirm the caching
|
|
std::vector<size_t> actual;
|
|
std::ranges::transform(nmols, std::back_inserter(actual), molPtr);
|
|
CHECK(actual == expected);
|
|
}
|
|
}
|
|
|
|
TEST_CASE("basic SmilesMolSupplier iteration") {
|
|
std::string infile =
|
|
rdbase + "/Code/GraphMol/FileParsers/test_data/first_200.tpsa.csv";
|
|
v2::FileParsers::SmilesMolSupplierParams params;
|
|
params.delimiter = ',';
|
|
params.smilesColumn = 0;
|
|
params.nameColumn = -1;
|
|
v2::FileParsers::SmilesMolSupplier reader(infile, params);
|
|
SECTION("basics") { iterTest(reader, 200); }
|
|
SECTION("with caching") { cacheTest(reader, 200); }
|
|
SECTION("reverse iteration") {
|
|
std::vector<unsigned int> expected;
|
|
std::ranges::transform(std::begin(reader), std::end(reader),
|
|
std::back_inserter(expected), countAtoms);
|
|
std::vector<unsigned int> actual;
|
|
std::ranges::transform(std::rbegin(reader), std::rend(reader),
|
|
std::back_inserter(actual), countAtoms);
|
|
std::ranges::reverse(actual);
|
|
CHECK(actual == expected);
|
|
}
|
|
}
|
|
|
|
TEST_CASE("cached SmilesMolSupplier error handling") {
|
|
v2::FileParsers::SmilesMolSupplierParams params;
|
|
params.delimiter = ',';
|
|
params.smilesColumn = 0;
|
|
params.nameColumn = -1;
|
|
std::string data = R"SMI(smiles,is_valid
|
|
CCO,1
|
|
CFC,0
|
|
CCN,1
|
|
c1cc1,0
|
|
c1cc,0
|
|
)SMI";
|
|
SECTION("basics") {
|
|
v2::FileParsers::SmilesMolSupplier reader;
|
|
reader.setData(data, params);
|
|
reader.setCaching(true);
|
|
CHECK(reader.length() == 5);
|
|
std::vector<std::shared_ptr<RWMol>> mols(reader.length());
|
|
std::copy(reader.begin(), reader.end(), mols.begin());
|
|
REQUIRE(mols.size() == reader.length());
|
|
CHECK(mols[0]);
|
|
CHECK(!mols[1]);
|
|
CHECK(mols[2]);
|
|
CHECK(!mols[3]);
|
|
CHECK(!mols[4]);
|
|
std::vector<size_t> expected;
|
|
std::ranges::transform(mols, std::back_inserter(expected), molPtr);
|
|
|
|
// now use the cached versions:
|
|
std::vector<std::shared_ptr<RWMol>> nmols(reader.length());
|
|
std::copy(reader.begin(), reader.end(), nmols.begin());
|
|
REQUIRE(nmols.size() == reader.length());
|
|
CHECK(nmols[0]);
|
|
CHECK(!nmols[1]);
|
|
CHECK(nmols[2]);
|
|
CHECK(!nmols[3]);
|
|
CHECK(!nmols[4]);
|
|
// confirm the caching
|
|
std::vector<size_t> actual;
|
|
std::ranges::transform(nmols, std::back_inserter(actual), molPtr);
|
|
CHECK(actual == expected);
|
|
}
|
|
}
|
|
|
|
#if defined(RDK_BUILD_THREADSAFE_SSS) && defined(__cpp_lib_execution)
|
|
// NOTE: will only run in parallel on linux if TBB is installed
|
|
TEST_CASE("parallel reads") {
|
|
// there's likely no benefit from the parallelization here,
|
|
// but we want to be sure that the mutexes are working correctly
|
|
auto *rdbase = std::getenv("RDBASE");
|
|
REQUIRE(rdbase);
|
|
SECTION("sdf") {
|
|
auto path = std::filesystem::path(rdbase) /
|
|
"Code/GraphMol/FileParsers/test_data/zinc.leads.500.q.sdf";
|
|
REQUIRE(std::filesystem::exists(path));
|
|
v2::FileParsers::SDMolSupplier reader1(path.string());
|
|
std::vector<unsigned int> nAts1(reader1.length());
|
|
std::transform(reader1.begin(), reader1.end(), nAts1.begin(), countAtoms);
|
|
// std::sort(nAts1.begin(), nAts1.end());
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
constexpr unsigned int numIters = 20;
|
|
for (unsigned int iter = 0; iter < numIters; ++iter) {
|
|
v2::FileParsers::SDMolSupplier reader2(path.string());
|
|
reader2.setCaching(true);
|
|
std::vector<unsigned int> nAts2(reader1.length());
|
|
std::transform(std::execution::par, reader2.begin(), reader2.end(),
|
|
nAts2.begin(), countAtoms);
|
|
REQUIRE(nAts1.size() == nAts2.size());
|
|
REQUIRE(nAts1 == nAts2);
|
|
}
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
auto duration =
|
|
std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
|
|
std::cerr << "Read of " << reader1.length() << "x" << numIters
|
|
<< " molecules took " << duration.count() << " ms" << std::endl;
|
|
}
|
|
SECTION("smiles") {
|
|
auto path = std::filesystem::path(rdbase) /
|
|
"Code/GraphMol/FileParsers/test_data/zinc.leads.500.q.smi";
|
|
REQUIRE(std::filesystem::exists(path));
|
|
v2::FileParsers::SmilesMolSupplierParams params;
|
|
params.delimiter = '\t';
|
|
params.smilesColumn = 0;
|
|
params.nameColumn = 1;
|
|
params.titleLine = false;
|
|
|
|
v2::FileParsers::SmilesMolSupplier reader1(path.string(), params);
|
|
std::vector<unsigned int> nAts1(reader1.length());
|
|
std::transform(reader1.begin(), reader1.end(), nAts1.begin(), countAtoms);
|
|
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
constexpr unsigned int numIters = 20;
|
|
for (unsigned int iter = 0; iter < numIters; ++iter) {
|
|
v2::FileParsers::SmilesMolSupplier reader2(path.string(), params);
|
|
reader2.setCaching(true);
|
|
std::vector<unsigned int> nAts2(reader1.length());
|
|
std::transform(std::execution::par, reader2.begin(), reader2.end(),
|
|
nAts2.begin(), countAtoms);
|
|
REQUIRE(nAts1 == nAts2);
|
|
}
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
auto duration =
|
|
std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
|
|
std::cerr << "Read of " << reader1.length() << "x" << numIters
|
|
<< " molecules took " << duration.count() << " ms" << std::endl;
|
|
}
|
|
SECTION("TDT") {
|
|
auto path = std::filesystem::path(rdbase) /
|
|
"Code/GraphMol/FileParsers/test_data/zinc.leads.500.q.tdt";
|
|
REQUIRE(std::filesystem::exists(path));
|
|
|
|
v2::FileParsers::TDTMolSupplier reader1(path.string());
|
|
std::vector<unsigned int> nAts1(reader1.length());
|
|
std::transform(reader1.begin(), reader1.end(), nAts1.begin(), countAtoms);
|
|
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
constexpr unsigned int numIters = 20;
|
|
for (unsigned int iter = 0; iter < numIters; ++iter) {
|
|
v2::FileParsers::TDTMolSupplier reader2(path.string());
|
|
reader2.setCaching(true);
|
|
std::vector<unsigned int> nAts2(reader1.length());
|
|
std::transform(std::execution::par, reader2.begin(), reader2.end(),
|
|
nAts2.begin(), countAtoms);
|
|
REQUIRE(nAts1 == nAts2);
|
|
}
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
auto duration =
|
|
std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
|
|
std::cerr << "Read of " << reader1.length() << "x" << numIters
|
|
<< " molecules took " << duration.count() << " ms" << std::endl;
|
|
}
|
|
}
|
|
TEST_CASE("benchmarking") {
|
|
auto *rdbase = std::getenv("RDBASE");
|
|
REQUIRE(rdbase);
|
|
auto path = std::filesystem::path(rdbase) /
|
|
"Code/GraphMol/FileParsers/test_data/zinc.leads.500.q.smi";
|
|
REQUIRE(std::filesystem::exists(path));
|
|
v2::FileParsers::SmilesMolSupplierParams params;
|
|
params.delimiter = '\t';
|
|
params.smilesColumn = 0;
|
|
params.nameColumn = 1;
|
|
params.titleLine = false;
|
|
v2::FileParsers::SmilesMolSupplier reader(path.string(), params);
|
|
reader.setCaching(true);
|
|
|
|
SECTION("transform") {
|
|
auto start = std::chrono::high_resolution_clock::now();
|
|
// prime the cache:
|
|
std::vector<unsigned int> nAts1;
|
|
std::transform(reader.begin(), reader.end(), std::back_inserter(nAts1),
|
|
countAtoms);
|
|
|
|
double accum = 0.0;
|
|
constexpr unsigned int numIters = 200;
|
|
for (unsigned int iter = 0; iter < numIters; ++iter) {
|
|
std::vector<unsigned int> nAts(reader.length());
|
|
std::transform(std::execution::seq, reader.begin(), reader.end(),
|
|
nAts.begin(),
|
|
[](const auto mol) { return MolToSmiles(*mol).size(); });
|
|
accum += nAts.size();
|
|
}
|
|
auto end = std::chrono::high_resolution_clock::now();
|
|
auto duration =
|
|
std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
|
|
std::cerr << "Base transform of " << reader.length() << "x" << numIters
|
|
<< " molecules took " << duration.count() << " ms" << std::endl;
|
|
CHECK(accum > 0);
|
|
accum = 0.0;
|
|
start = std::chrono::high_resolution_clock::now();
|
|
for (unsigned int iter = 0; iter < numIters; ++iter) {
|
|
std::vector<unsigned int> nAts(reader.length());
|
|
std::transform(std::execution::par, reader.begin(), reader.end(),
|
|
nAts.begin(),
|
|
[](const auto mol) { return MolToSmiles(*mol).size(); });
|
|
accum += nAts.size();
|
|
}
|
|
end = std::chrono::high_resolution_clock::now();
|
|
duration =
|
|
std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
|
|
std::cerr << "Parallel transform of " << reader.length() << "x" << numIters
|
|
<< " molecules took " << duration.count() << " ms" << std::endl;
|
|
CHECK(accum > 0);
|
|
}
|
|
}
|
|
#endif
|
|
|
|
TEST_CASE("views and filtered reads") {
|
|
auto *rdbase = std::getenv("RDBASE");
|
|
REQUIRE(rdbase);
|
|
auto path = std::filesystem::path(rdbase) /
|
|
"Code/GraphMol/FileParsers/test_data/zinc.leads.500.q.sdf";
|
|
REQUIRE(std::filesystem::exists(path));
|
|
v2::FileParsers::SDMolSupplier reader1(path.string());
|
|
SECTION("filters") {
|
|
std::vector<unsigned int> nAts1(reader1.length());
|
|
std::transform(reader1.begin(), reader1.end(), nAts1.begin(), countAtoms);
|
|
constexpr unsigned int tgtSize = 15;
|
|
auto nSmall =
|
|
std::ranges::count_if(nAts1, [](const auto &v) { return v < tgtSize; });
|
|
REQUIRE(nSmall > 0);
|
|
auto filtered = reader1 | std::views::filter([](const auto &mol) {
|
|
return mol->getNumAtoms() < tgtSize;
|
|
});
|
|
CHECK(std::distance(std::begin(filtered), std::end(filtered)) == nSmall);
|
|
|
|
// only read until we have a set number of molecules matching our filter:
|
|
const unsigned int subsetSz = nSmall / 4;
|
|
auto firstN = reader1 | std::views::filter([](const auto &mol) {
|
|
return mol->getNumAtoms() < tgtSize;
|
|
}) |
|
|
std::views::take(subsetSz);
|
|
// this doen't work and I don't understand why:
|
|
// CHECK(std::distance(std::begin(firstN), std::end(firstN)) == subsetSz);
|
|
std::vector<std::shared_ptr<RWMol>> mols;
|
|
std::ranges::copy(firstN, std::back_inserter(mols));
|
|
CHECK(mols.size() == subsetSz);
|
|
}
|
|
SECTION("substructure filter") {
|
|
auto query = "c1ncncn1"_smiles;
|
|
SubstructMatchParameters params;
|
|
params.maxMatches = 1;
|
|
auto firstN = reader1 |
|
|
std::views::filter([&query, ¶ms](const auto &mol) {
|
|
return SubstructMatch(*mol, *query, params).empty();
|
|
}) |
|
|
std::views::take(5);
|
|
std::vector<std::shared_ptr<RWMol>> mols;
|
|
std::ranges::copy(firstN, std::back_inserter(mols));
|
|
CHECK(mols.size() == 5);
|
|
}
|
|
}
|