Files
rdkit/Code/GraphMol/FileParsers/fileParsersIterTest.cpp
Greg Landrum 6d75052459 Support using iterators with MolSuppliers (#9230)
* iterators for random-access MolSuppliers
add optional caching to SDMolSupplier

* add support to SmilesMolSupplier too
There is a lot of duplicate code between the random-access suppliers that would be worth trying to remove
but at the moment it looks like it would require multiple inheritance, and I think we want to avoid that

* add input iterators for ForwardSDMolSupplier()

* throw when calling begin() on a used supplier

* switch to use the spaceship operator

* init() should reset the mol cache

* Make SDMolSupplier and SmilesMolSupplier safe for multi-threaded reads

* add benchmarking

* add TDTMolSupplier support
improved testing
add benchmarks for parallel iteration
optional TBB support

* better const handling, add reverse iterators

doesn't look like const_iterator is possible since getting data from the underlyng supplier object is non-const

* improve docs
more usings
add reverse iterator to TDTMolSupplier

* tests only try execution::par when it is there

* fix typo

* more testing/demo

* remove accidentally added files

* review changes

* add default ctors

* disable a false-positive compiler warning
it is stupid to have to do this

---------

Co-authored-by: = <=>
2026-05-05 13:36:15 +02:00

507 lines
17 KiB
C++

//
// Copyright (C) 2026 Greg Landrum and other RDKit contributors
//
// @@ All Rights Reserved @@
// This file is part of the RDKit.
// The contents are covered by the terms of the BSD license
// which is included in the file license.txt, found at the root
// of the RDKit source tree.
//
#include <catch2/catch_all.hpp>
#include <vector>
#include <algorithm>
#include <execution>
#include <ranges>
#include <iostream>
#include <filesystem>
#include <GraphMol/RDKitBase.h>
#include "MolSupplier.h"
#include <GraphMol/SmilesParse/SmilesWrite.h>
#include <GraphMol/SmilesParse/SmilesParse.h>
#include <GraphMol/Substruct/SubstructMatch.h>
static const std::string rdbase = getenv("RDBASE");
using namespace RDKit;
namespace {
unsigned int countAtoms(const std::shared_ptr<RWMol> &mol) {
REQUIRE(mol);
return mol->getNumAtoms();
}
size_t molPtr(const std::shared_ptr<RWMol> &mol) { return (size_t)mol.get(); }
} // namespace
template <typename Supplier>
void iterTest(Supplier &reader, size_t len) {
CHECK(reader.length() == len);
std::vector<unsigned int> expected;
expected.reserve(reader.length());
for (unsigned int i = 0; i < reader.length(); ++i) {
expected.push_back(reader[i]->getNumAtoms());
}
std::vector<unsigned int> actual;
std::ranges::transform(reader, std::back_inserter(actual), countAtoms);
CHECK(actual == expected);
}
template <typename Supplier>
void forwardIterTest(Supplier &reader, size_t len) {
std::vector<unsigned int> actual;
std::ranges::transform(reader, std::back_inserter(actual), countAtoms);
CHECK(actual.size() == len);
}
template <typename Supplier>
void cacheTest(Supplier &reader, size_t len) {
reader.setCaching(true);
CHECK(reader.length() == len);
std::vector<std::shared_ptr<RWMol>> mols(reader.length());
std::copy(reader.begin(), reader.end(), mols.begin());
REQUIRE(mols.size() == reader.length());
std::vector<size_t> expected;
std::ranges::transform(mols, std::back_inserter(expected), molPtr);
std::vector<size_t> actual;
std::ranges::transform(reader, std::back_inserter(actual), molPtr);
CHECK(actual == expected);
}
TEST_CASE("basic SDMolSupplier iteration") {
std::string infile =
rdbase + "/Code/GraphMol/FileParsers/test_data/NCI_aids_few.sdf";
v2::FileParsers::SDMolSupplier reader(infile);
SECTION("basics") { iterTest(reader, 16); }
SECTION("with caching") { cacheTest(reader, 16); }
SECTION("reverse iteration") {
std::vector<unsigned int> expected;
std::ranges::transform(std::begin(reader), std::end(reader),
std::back_inserter(expected), countAtoms);
std::vector<unsigned int> actual;
std::ranges::transform(std::rbegin(reader), std::rend(reader),
std::back_inserter(actual), countAtoms);
std::ranges::reverse(actual);
CHECK(actual == expected);
}
}
TEST_CASE("ForwardSDMolSupplier iteration") {
std::string infile =
rdbase + "/Code/GraphMol/FileParsers/test_data/NCI_aids_few.sdf";
std::ifstream strm(infile);
bool takeOwnership = false;
v2::FileParsers::ForwardSDMolSupplier reader(&strm, takeOwnership);
SECTION("basics") { forwardIterTest(reader, 16); }
SECTION("error handling") {
reader.next();
CHECK_THROWS_AS(reader.begin(), ValueErrorException);
}
SECTION("pre-increment") {
unsigned int i = 0;
auto it = reader.begin();
auto end = reader.end();
while (it != end) {
auto mol = *it;
REQUIRE(mol);
++it;
++i;
}
CHECK(i == 16);
}
SECTION("post-increment") {
unsigned int i = 0;
auto it = reader.begin();
auto end = reader.end();
while (it != end) {
auto mol = *it;
REQUIRE(mol);
it++;
++i;
}
CHECK(i == 16);
}
}
TEST_CASE("ForwardSDMolSupplier iteration with failing molecules") {
std::string infile =
rdbase + "/Code/GraphMol/FileParsers/test_data/good_bad_good_good.sdf";
std::ifstream strm(infile);
bool takeOwnership = false;
v2::FileParsers::ForwardSDMolSupplier reader(&strm, takeOwnership);
SECTION("basics") {
std::vector<unsigned int> expected{6, 0, 6, 6};
std::vector<unsigned int> actual;
std::ranges::transform(
reader, std::back_inserter(actual),
[](const auto &mol) { return mol ? mol->getNumAtoms() : 0; });
CHECK(actual == expected);
}
SECTION("pre-increment") {
unsigned int i = 0;
auto it = reader.begin();
auto end = reader.end();
while (it != end) {
auto mol = *it;
if (i == 1) {
CHECK(!mol);
} else {
CHECK(mol);
}
++it;
++i;
}
CHECK(i == 4);
}
SECTION("post-increment") {
unsigned int i = 0;
auto it = reader.begin();
auto end = reader.end();
while (it != end) {
auto mol = *it;
if (i == 1) {
CHECK(!mol);
} else {
CHECK(mol);
}
it++;
++i;
}
CHECK(i == 4);
}
}
TEST_CASE("ForwardSDMolSupplier iteration with failing molecule at end") {
std::string infile =
rdbase + "/Code/GraphMol/FileParsers/test_data/good_bad_good_bad.sdf";
std::ifstream strm(infile);
bool takeOwnership = false;
v2::FileParsers::ForwardSDMolSupplier reader(&strm, takeOwnership);
SECTION("basics") {
std::vector<unsigned int> expected{6, 0, 6, 0};
std::vector<unsigned int> actual;
std::ranges::transform(
reader, std::back_inserter(actual),
[](const auto &mol) { return mol ? mol->getNumAtoms() : 0; });
CHECK(actual == expected);
}
SECTION("pre-increment") {
unsigned int i = 0;
auto it = reader.begin();
auto end = reader.end();
while (it != end) {
auto mol = *it;
if (i % 2) {
CHECK(!mol);
} else {
CHECK(mol);
}
++it;
++i;
}
CHECK(i == 4);
}
SECTION("post-increment") {
unsigned int i = 0;
auto it = reader.begin();
auto end = reader.end();
while (it != end) {
auto mol = *it;
if (i % 2) {
CHECK(!mol);
} else {
CHECK(mol);
}
it++;
++i;
}
CHECK(i == 4);
}
}
TEST_CASE("cached SDMolSupplier error handling") {
std::string infile =
rdbase + "/Code/GraphMol/FileParsers/test_data/good_bad_good_bad.sdf";
SECTION("basics") {
v2::FileParsers::SDMolSupplier reader(infile);
reader.setCaching(true);
CHECK(reader.length() == 4);
std::vector<std::shared_ptr<RWMol>> mols(reader.length());
std::copy(reader.begin(), reader.end(), mols.begin());
REQUIRE(mols.size() == reader.length());
CHECK(mols[0]);
CHECK(!mols[1]);
CHECK(mols[2]);
CHECK(!mols[3]);
std::vector<size_t> expected;
std::ranges::transform(mols, std::back_inserter(expected), molPtr);
// now use the cached versions:
std::vector<std::shared_ptr<RWMol>> nmols(reader.length());
std::copy(reader.begin(), reader.end(), nmols.begin());
REQUIRE(nmols.size() == reader.length());
CHECK(nmols[0]);
CHECK(!nmols[1]);
CHECK(nmols[2]);
CHECK(!nmols[3]);
// confirm the caching
std::vector<size_t> actual;
std::ranges::transform(nmols, std::back_inserter(actual), molPtr);
CHECK(actual == expected);
}
}
TEST_CASE("basic SmilesMolSupplier iteration") {
std::string infile =
rdbase + "/Code/GraphMol/FileParsers/test_data/first_200.tpsa.csv";
v2::FileParsers::SmilesMolSupplierParams params;
params.delimiter = ',';
params.smilesColumn = 0;
params.nameColumn = -1;
v2::FileParsers::SmilesMolSupplier reader(infile, params);
SECTION("basics") { iterTest(reader, 200); }
SECTION("with caching") { cacheTest(reader, 200); }
SECTION("reverse iteration") {
std::vector<unsigned int> expected;
std::ranges::transform(std::begin(reader), std::end(reader),
std::back_inserter(expected), countAtoms);
std::vector<unsigned int> actual;
std::ranges::transform(std::rbegin(reader), std::rend(reader),
std::back_inserter(actual), countAtoms);
std::ranges::reverse(actual);
CHECK(actual == expected);
}
}
TEST_CASE("cached SmilesMolSupplier error handling") {
v2::FileParsers::SmilesMolSupplierParams params;
params.delimiter = ',';
params.smilesColumn = 0;
params.nameColumn = -1;
std::string data = R"SMI(smiles,is_valid
CCO,1
CFC,0
CCN,1
c1cc1,0
c1cc,0
)SMI";
SECTION("basics") {
v2::FileParsers::SmilesMolSupplier reader;
reader.setData(data, params);
reader.setCaching(true);
CHECK(reader.length() == 5);
std::vector<std::shared_ptr<RWMol>> mols(reader.length());
std::copy(reader.begin(), reader.end(), mols.begin());
REQUIRE(mols.size() == reader.length());
CHECK(mols[0]);
CHECK(!mols[1]);
CHECK(mols[2]);
CHECK(!mols[3]);
CHECK(!mols[4]);
std::vector<size_t> expected;
std::ranges::transform(mols, std::back_inserter(expected), molPtr);
// now use the cached versions:
std::vector<std::shared_ptr<RWMol>> nmols(reader.length());
std::copy(reader.begin(), reader.end(), nmols.begin());
REQUIRE(nmols.size() == reader.length());
CHECK(nmols[0]);
CHECK(!nmols[1]);
CHECK(nmols[2]);
CHECK(!nmols[3]);
CHECK(!nmols[4]);
// confirm the caching
std::vector<size_t> actual;
std::ranges::transform(nmols, std::back_inserter(actual), molPtr);
CHECK(actual == expected);
}
}
#if defined(RDK_BUILD_THREADSAFE_SSS) && defined(__cpp_lib_execution)
// NOTE: will only run in parallel on linux if TBB is installed
TEST_CASE("parallel reads") {
// there's likely no benefit from the parallelization here,
// but we want to be sure that the mutexes are working correctly
auto *rdbase = std::getenv("RDBASE");
REQUIRE(rdbase);
SECTION("sdf") {
auto path = std::filesystem::path(rdbase) /
"Code/GraphMol/FileParsers/test_data/zinc.leads.500.q.sdf";
REQUIRE(std::filesystem::exists(path));
v2::FileParsers::SDMolSupplier reader1(path.string());
std::vector<unsigned int> nAts1(reader1.length());
std::transform(reader1.begin(), reader1.end(), nAts1.begin(), countAtoms);
// std::sort(nAts1.begin(), nAts1.end());
auto start = std::chrono::high_resolution_clock::now();
constexpr unsigned int numIters = 20;
for (unsigned int iter = 0; iter < numIters; ++iter) {
v2::FileParsers::SDMolSupplier reader2(path.string());
reader2.setCaching(true);
std::vector<unsigned int> nAts2(reader1.length());
std::transform(std::execution::par, reader2.begin(), reader2.end(),
nAts2.begin(), countAtoms);
REQUIRE(nAts1.size() == nAts2.size());
REQUIRE(nAts1 == nAts2);
}
auto end = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
std::cerr << "Read of " << reader1.length() << "x" << numIters
<< " molecules took " << duration.count() << " ms" << std::endl;
}
SECTION("smiles") {
auto path = std::filesystem::path(rdbase) /
"Code/GraphMol/FileParsers/test_data/zinc.leads.500.q.smi";
REQUIRE(std::filesystem::exists(path));
v2::FileParsers::SmilesMolSupplierParams params;
params.delimiter = '\t';
params.smilesColumn = 0;
params.nameColumn = 1;
params.titleLine = false;
v2::FileParsers::SmilesMolSupplier reader1(path.string(), params);
std::vector<unsigned int> nAts1(reader1.length());
std::transform(reader1.begin(), reader1.end(), nAts1.begin(), countAtoms);
auto start = std::chrono::high_resolution_clock::now();
constexpr unsigned int numIters = 20;
for (unsigned int iter = 0; iter < numIters; ++iter) {
v2::FileParsers::SmilesMolSupplier reader2(path.string(), params);
reader2.setCaching(true);
std::vector<unsigned int> nAts2(reader1.length());
std::transform(std::execution::par, reader2.begin(), reader2.end(),
nAts2.begin(), countAtoms);
REQUIRE(nAts1 == nAts2);
}
auto end = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
std::cerr << "Read of " << reader1.length() << "x" << numIters
<< " molecules took " << duration.count() << " ms" << std::endl;
}
SECTION("TDT") {
auto path = std::filesystem::path(rdbase) /
"Code/GraphMol/FileParsers/test_data/zinc.leads.500.q.tdt";
REQUIRE(std::filesystem::exists(path));
v2::FileParsers::TDTMolSupplier reader1(path.string());
std::vector<unsigned int> nAts1(reader1.length());
std::transform(reader1.begin(), reader1.end(), nAts1.begin(), countAtoms);
auto start = std::chrono::high_resolution_clock::now();
constexpr unsigned int numIters = 20;
for (unsigned int iter = 0; iter < numIters; ++iter) {
v2::FileParsers::TDTMolSupplier reader2(path.string());
reader2.setCaching(true);
std::vector<unsigned int> nAts2(reader1.length());
std::transform(std::execution::par, reader2.begin(), reader2.end(),
nAts2.begin(), countAtoms);
REQUIRE(nAts1 == nAts2);
}
auto end = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
std::cerr << "Read of " << reader1.length() << "x" << numIters
<< " molecules took " << duration.count() << " ms" << std::endl;
}
}
TEST_CASE("benchmarking") {
auto *rdbase = std::getenv("RDBASE");
REQUIRE(rdbase);
auto path = std::filesystem::path(rdbase) /
"Code/GraphMol/FileParsers/test_data/zinc.leads.500.q.smi";
REQUIRE(std::filesystem::exists(path));
v2::FileParsers::SmilesMolSupplierParams params;
params.delimiter = '\t';
params.smilesColumn = 0;
params.nameColumn = 1;
params.titleLine = false;
v2::FileParsers::SmilesMolSupplier reader(path.string(), params);
reader.setCaching(true);
SECTION("transform") {
auto start = std::chrono::high_resolution_clock::now();
// prime the cache:
std::vector<unsigned int> nAts1;
std::transform(reader.begin(), reader.end(), std::back_inserter(nAts1),
countAtoms);
double accum = 0.0;
constexpr unsigned int numIters = 200;
for (unsigned int iter = 0; iter < numIters; ++iter) {
std::vector<unsigned int> nAts(reader.length());
std::transform(std::execution::seq, reader.begin(), reader.end(),
nAts.begin(),
[](const auto mol) { return MolToSmiles(*mol).size(); });
accum += nAts.size();
}
auto end = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
std::cerr << "Base transform of " << reader.length() << "x" << numIters
<< " molecules took " << duration.count() << " ms" << std::endl;
CHECK(accum > 0);
accum = 0.0;
start = std::chrono::high_resolution_clock::now();
for (unsigned int iter = 0; iter < numIters; ++iter) {
std::vector<unsigned int> nAts(reader.length());
std::transform(std::execution::par, reader.begin(), reader.end(),
nAts.begin(),
[](const auto mol) { return MolToSmiles(*mol).size(); });
accum += nAts.size();
}
end = std::chrono::high_resolution_clock::now();
duration =
std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
std::cerr << "Parallel transform of " << reader.length() << "x" << numIters
<< " molecules took " << duration.count() << " ms" << std::endl;
CHECK(accum > 0);
}
}
#endif
TEST_CASE("views and filtered reads") {
auto *rdbase = std::getenv("RDBASE");
REQUIRE(rdbase);
auto path = std::filesystem::path(rdbase) /
"Code/GraphMol/FileParsers/test_data/zinc.leads.500.q.sdf";
REQUIRE(std::filesystem::exists(path));
v2::FileParsers::SDMolSupplier reader1(path.string());
SECTION("filters") {
std::vector<unsigned int> nAts1(reader1.length());
std::transform(reader1.begin(), reader1.end(), nAts1.begin(), countAtoms);
constexpr unsigned int tgtSize = 15;
auto nSmall =
std::ranges::count_if(nAts1, [](const auto &v) { return v < tgtSize; });
REQUIRE(nSmall > 0);
auto filtered = reader1 | std::views::filter([](const auto &mol) {
return mol->getNumAtoms() < tgtSize;
});
CHECK(std::distance(std::begin(filtered), std::end(filtered)) == nSmall);
// only read until we have a set number of molecules matching our filter:
const unsigned int subsetSz = nSmall / 4;
auto firstN = reader1 | std::views::filter([](const auto &mol) {
return mol->getNumAtoms() < tgtSize;
}) |
std::views::take(subsetSz);
// this doen't work and I don't understand why:
// CHECK(std::distance(std::begin(firstN), std::end(firstN)) == subsetSz);
std::vector<std::shared_ptr<RWMol>> mols;
std::ranges::copy(firstN, std::back_inserter(mols));
CHECK(mols.size() == subsetSz);
}
SECTION("substructure filter") {
auto query = "c1ncncn1"_smiles;
SubstructMatchParameters params;
params.maxMatches = 1;
auto firstN = reader1 |
std::views::filter([&query, &params](const auto &mol) {
return SubstructMatch(*mol, *query, params).empty();
}) |
std::views::take(5);
std::vector<std::shared_ptr<RWMol>> mols;
std::ranges::copy(firstN, std::back_inserter(mols));
CHECK(mols.size() == 5);
}
}