Add some std::ranges support (#9218)

* initial ranges support for Atom/Bond iterators.
needs more testing

* support random access
test sort

more testing please

* compiles on windows

* fix size()
more testing
add some benchmarking

* disable benchmarking code by default

* do not allow modifying the graph through the iterators

---------

Co-authored-by: = <=>
This commit is contained in:
Greg Landrum
2026-04-13 17:13:04 +02:00
committed by GitHub
parent f150381c13
commit 2c6efb4a65
4 changed files with 392 additions and 25 deletions

View File

@@ -190,6 +190,9 @@ rdkit_catch_test(canonTestsCatch catch_canon.cpp
rdkit_catch_test(moliteratorTestsCatch catch_moliterators.cpp
LINK_LIBRARIES SubstructMatch SmilesParse GraphMol)
rdkit_catch_test(molrangesTestsCatch catch_molranges.cpp
LINK_LIBRARIES GraphMol)
rdkit_catch_test(queryTestsCatch catch_queries.cpp
LINK_LIBRARIES SmilesParse GraphMol)

View File

@@ -23,6 +23,8 @@
#include <iterator>
#include <utility>
#include <map>
#include <ranges>
#include <limits>
// boost stuff
#include <RDGeneral/BoostStartInclude.h>
@@ -117,41 +119,96 @@ struct CXXAtomIterator {
Iterator vstart, vend;
struct CXXAtomIter {
using iterator_category = std::forward_iterator_tag;
using iterator_category = std::random_access_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = Vertex;
using pointer = Vertex *;
using reference = Vertex &;
using const_reference = Vertex const &;
Graph *graph;
Graph *graph = nullptr;
Iterator pos;
Atom *current;
CXXAtomIter(Graph *graph, Iterator pos)
: graph(graph), pos(pos), current(nullptr) {}
CXXAtomIter() {};
reference operator*() {
current = (*graph)[*pos];
return current;
CXXAtomIter(Graph *graph, Iterator pos) : graph(graph), pos(pos) {}
// we only return const references since we don't want clients modifying the
// graph itself through these iterators
const_reference operator*() const { return (*graph)[*pos]; }
// we only return const references since we don't want clients modifying the
// graph itself through these iterators
const_reference operator[](difference_type n) const {
return (*graph)[*(pos + n)];
}
CXXAtomIter &operator++() {
++pos;
return *this;
}
bool operator==(const CXXAtomIter &it) const { return pos == it.pos; }
bool operator!=(const CXXAtomIter &it) const { return pos != it.pos; }
CXXAtomIter operator++(int) {
CXXAtomIter tmp = *this;
++(*this);
return tmp;
}
CXXAtomIter &operator--() {
--pos;
return *this;
}
CXXAtomIter operator+(difference_type n) const {
return CXXAtomIter(graph, pos + n);
}
CXXAtomIter operator-(difference_type n) const {
return CXXAtomIter(graph, pos - n);
}
CXXAtomIter operator--(int) {
CXXAtomIter tmp = *this;
--(*this);
return tmp;
}
CXXAtomIter &operator+=(difference_type n) {
pos += n;
return *this;
}
CXXAtomIter &operator-=(difference_type n) {
pos -= n;
return *this;
}
difference_type operator-(const CXXAtomIter &other) const {
return pos - other.pos;
}
friend CXXAtomIter operator+(difference_type n, const CXXAtomIter &it) {
return CXXAtomIter(it.graph, it.pos + n);
}
bool operator==(const CXXAtomIter &other) const {
return graph == other.graph && pos == other.pos;
}
bool operator!=(const CXXAtomIter &other) const {
return !(*this == other);
}
bool operator<(const CXXAtomIter &other) const { return pos < other.pos; }
bool operator<=(const CXXAtomIter &other) const { return pos <= other.pos; }
bool operator>(const CXXAtomIter &other) const { return pos > other.pos; }
bool operator>=(const CXXAtomIter &other) const { return pos >= other.pos; }
};
CXXAtomIterator(Graph *graph) : graph(graph) {
auto vs = boost::vertices(*graph);
vstart = vs.first;
vend = vs.second;
std::tie(vstart, vend) = boost::vertices(*graph);
}
CXXAtomIterator(Graph *graph, Iterator start, Iterator end)
: graph(graph), vstart(start), vend(end) {};
CXXAtomIter begin() { return {graph, vstart}; }
CXXAtomIter end() { return {graph, vend}; }
size_t size() const { return vend - vstart; }
};
// clang-format off
static_assert(
std::ranges::random_access_range<CXXAtomIterator<MolGraph, Atom *>>
and std::ranges::sized_range<CXXAtomIterator<MolGraph, Atom *>>
);
// clang-format on
template <class Graph, class Edge,
class Iterator = typename Graph::edge_iterator>
@@ -160,29 +217,46 @@ struct CXXBondIterator {
Iterator vstart, vend;
struct CXXBondIter {
using iterator_category = std::forward_iterator_tag;
using iterator_category = std::bidirectional_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = Edge;
using pointer = Edge *;
using reference = Edge &;
using const_reference = Edge const &;
Graph *graph;
Graph *graph = nullptr;
Iterator pos;
Bond *current;
CXXBondIter(Graph *graph, Iterator pos)
: graph(graph), pos(pos), current(nullptr) {}
CXXBondIter() {};
reference operator*() {
current = (*graph)[*pos];
return current;
}
CXXBondIter(Graph *graph, Iterator pos) : graph(graph), pos(pos) {}
// we only return const references since we don't want clients modifying the
// graph itself through these iterators
const_reference operator*() const { return (*graph)[*pos]; }
CXXBondIter &operator++() {
++pos;
return *this;
}
bool operator==(const CXXBondIter &it) const { return pos == it.pos; }
bool operator!=(const CXXBondIter &it) const { return pos != it.pos; }
CXXBondIter operator++(int) {
CXXBondIter tmp = *this;
++(*this);
return tmp;
}
CXXBondIter &operator--() {
--pos;
return *this;
}
CXXBondIter operator--(int) {
CXXBondIter tmp = *this;
--(*this);
return tmp;
}
bool operator==(const CXXBondIter &other) const {
return graph == other.graph && pos == other.pos;
}
bool operator!=(const CXXBondIter &other) const {
return !(*this == other);
}
};
CXXBondIterator(Graph *graph) : graph(graph) {
@@ -194,7 +268,19 @@ struct CXXBondIterator {
: graph(graph), vstart(start), vend(end) {};
CXXBondIter begin() { return {graph, vstart}; }
CXXBondIter end() { return {graph, vend}; }
size_t size() const {
// bond iterators aren't random access, so we can't just do vend - vstart
// here. Instead we have to iterate through;
size_t count = 0;
for (auto it = vstart; it != vend; ++it) {
++count;
}
return count;
}
};
// we don't model sized_range because size() is O(N)
static_assert(
std::ranges::bidirectional_range<CXXBondIterator<MolGraph, Bond *>>);
class RDKIT_GRAPHMOL_EXPORT ROMol : public RDProps {
public:
@@ -832,6 +918,7 @@ class RDKIT_GRAPHMOL_EXPORT ROMol : public RDProps {
protected:
unsigned int numBonds{0};
#ifndef WIN32
private:
#endif
void initMol();

View File

@@ -14,6 +14,9 @@
#include <GraphMol/SmilesParse/SmilesParse.h>
#include <GraphMol/SmilesParse/SmilesWrite.h>
#include <algorithm>
#include <ranges>
#include <filesystem>
#include <execution>
using namespace RDKit;
@@ -36,6 +39,14 @@ TEST_CASE("mol.atoms()") {
return atom->getAtomicNum() == 6;
});
CHECK(ccount == 4);
ccount =
std::count_if(std::begin(atoms), std::end(atoms),
[](const auto atom) { return atom->getAtomicNum() == 6; });
CHECK(ccount == 4);
ccount =
std::count_if(std::ranges::begin(atoms), std::ranges::end(atoms),
[](const auto atom) { return atom->getAtomicNum() == 6; });
CHECK(ccount == 4);
}
TEST_CASE("mol.bonds()") {
@@ -57,6 +68,14 @@ TEST_CASE("mol.bonds()") {
bonds.begin(), bonds.end(),
[](const auto bond) { return bond->getBondType() == Bond::DOUBLE; });
CHECK(doubleBondCount == 2);
doubleBondCount = std::count_if(
std::begin(bonds), std::end(bonds),
[](const auto bond) { return bond->getBondType() == Bond::DOUBLE; });
CHECK(doubleBondCount == 2);
doubleBondCount = std::count_if(
std::ranges::begin(bonds), std::ranges::end(bonds),
[](const auto bond) { return bond->getBondType() == Bond::DOUBLE; });
CHECK(doubleBondCount == 2);
}
TEST_CASE("mol.atomNeighbors()") {
@@ -87,4 +106,121 @@ TEST_CASE("mol.atomBonds()") {
}
MolOps::sanitizeMol(*m);
CHECK(MolToSmiles(*m) == "CC(C)CO");
}
}
TEST_CASE("ranges") {
const auto m = "CC(C)CO"_smiles;
REQUIRE(m);
auto atoms = m->atoms();
auto bonds = m->bonds();
CHECK(std::ranges::distance(atoms) == 5);
CHECK(std::ranges::distance(bonds) == 4);
{
std::vector<unsigned int> atomDegrees;
std::ranges::transform(atoms, std::back_inserter(atomDegrees),
[](const auto atom) { return atom->getDegree(); });
CHECK(atomDegrees == std::vector<unsigned int>{1, 3, 1, 2, 1});
}
{
std::vector<unsigned int> atomDegrees;
std::ranges::transform(atoms | std::views::reverse,
std::back_inserter(atomDegrees),
[](const auto atom) { return atom->getDegree(); });
CHECK(atomDegrees == std::vector<unsigned int>{1, 2, 1, 3, 1});
}
}
std::filesystem::path relative_to_rdbase(
const std::filesystem::path &relative) {
char *rdbase = std::getenv("RDBASE");
if (!rdbase) {
throw std::runtime_error("RDBASE environment variable not set");
}
std::filesystem::path path(rdbase);
path /= relative;
return path;
}
#if 0
TEST_CASE("benchmarking") {
auto *rdbase = std::getenv("RDBASE");
REQUIRE(rdbase);
auto path =
std::filesystem::path(rdbase) / "Regress/Data/zinc.leads.500.q.smi";
REQUIRE(std::filesystem::exists(path));
auto inf = std::ifstream(path);
REQUIRE(inf);
std::vector<std::unique_ptr<RWMol>> mols;
std::string line;
while (std::getline(inf, line)) {
mols.push_back(v2::SmilesParse::MolFromSmiles(line));
REQUIRE(mols.back());
}
double accum = 0;
auto start = std::chrono::high_resolution_clock::now();
for (unsigned int iter = 0; iter < 100000; ++iter) {
for (const auto &m : mols) {
for (const auto atom : m->atoms()) {
if (atom->getAtomicNum() == 6) {
accum += atom->getDegree();
}
}
}
}
auto end = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
std::cerr << "Iterating over " << mols.size() << " molecules took "
<< duration.count() << " ms" << std::endl;
CHECK(accum > 0);
accum = 0.0;
start = std::chrono::high_resolution_clock::now();
for (unsigned int iter = 0; iter < 100000; ++iter) {
for (const auto &m : mols) {
auto atoms = m->atoms();
std::ranges::for_each(
atoms | std::views::filter([](const auto atom) {
return atom->getAtomicNum() == 6;
}),
[&accum](const auto atom) { accum += atom->getDegree(); });
}
}
end = std::chrono::high_resolution_clock::now();
duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
std::cerr << "Iterating with ranges over " << mols.size()
<< " molecules took " << duration.count() << " ms" << std::endl;
CHECK(accum > 0);
accum = 0.0;
start = std::chrono::high_resolution_clock::now();
for (unsigned int iter = 0; iter < 100000; ++iter) {
for (const auto &m : mols) {
auto atoms = m->atoms();
std::ranges::for_each(
atoms | std::views::filter([](const auto atom) {
return atom->getAtomicNum() == 6;
}) | std::views::reverse,
[&accum](const auto atom) { accum += atom->getDegree(); });
}
}
end = std::chrono::high_resolution_clock::now();
duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
std::cerr << "Iterating with ranges in reverse over " << mols.size()
<< " molecules took " << duration.count() << " ms" << std::endl;
CHECK(accum > 0);
accum = 0.0;
start = std::chrono::high_resolution_clock::now();
for (unsigned int iter = 0; iter < 100000; ++iter) {
for (const auto &m : mols) {
auto atoms = m->atoms();
std::for_each(atoms.begin(), atoms.end(), [&accum](const auto atom) {
if (atom->getAtomicNum() == 6) accum += atom->getDegree();
});
}
}
end = std::chrono::high_resolution_clock::now();
duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
std::cerr << "Iterating with for_each over " << mols.size()
<< " molecules took " << duration.count() << " ms" << std::endl;
CHECK(accum > 0);
}
#endif

View File

@@ -0,0 +1,141 @@
//
// Copyright (C) 2026 Greg Landrum and other RDKit contributors
//
// @@ All Rights Reserved @@
// This file is part of the RDKit.
// The contents are covered by the terms of the BSD license
// which is included in the file license.txt, found at the root
// of the RDKit source tree.
//
#include <catch2/catch_all.hpp>
#include <GraphMol/RDKitBase.h>
#include <algorithm>
#include <ranges>
using namespace RDKit;
TEST_CASE("ranges") {
std::unique_ptr<RWMol> m{new RWMol()};
REQUIRE(m);
// = "CC(=C)C=O"_smiles;
m->addAtom(new Atom(6), true, true);
m->addAtom(new Atom(6), true, true);
m->addAtom(new Atom(6), true, true);
m->addAtom(new Atom(6), true, true);
m->addAtom(new Atom(8), true, true);
m->addBond(0, 1, Bond::SINGLE);
m->addBond(1, 2, Bond::DOUBLE);
m->addBond(1, 3, Bond::SINGLE);
m->addBond(3, 4, Bond::DOUBLE);
SECTION("atoms and bonds") {
auto atoms = m->atoms();
auto bonds = m->bonds();
CHECK(std::ranges::distance(atoms) == 5);
CHECK(std::ranges::distance(bonds) == 4);
{
std::vector<unsigned int> atomDegrees;
std::ranges::transform(atoms, std::back_inserter(atomDegrees),
[](const auto atom) { return atom->getDegree(); });
CHECK(atomDegrees == std::vector<unsigned int>{1, 3, 1, 2, 1});
}
{
std::vector<unsigned int> atomDegrees;
std::ranges::transform(atoms | std::views::reverse,
std::back_inserter(atomDegrees),
[](const auto atom) { return atom->getDegree(); });
CHECK(atomDegrees == std::vector<unsigned int>{1, 2, 1, 3, 1});
}
{
std::vector<Bond::BondType> bondOrders;
std::ranges::transform(
bonds, std::back_inserter(bondOrders),
[](const auto bond) { return bond->getBondType(); });
CHECK(bondOrders ==
std::vector<Bond::BondType>{Bond::SINGLE, Bond::DOUBLE,
Bond::SINGLE, Bond::DOUBLE});
}
{
std::vector<Bond::BondType> bondOrders;
std::ranges::transform(
bonds | std::views::reverse, std::back_inserter(bondOrders),
[](const auto bond) { return bond->getBondType(); });
CHECK(bondOrders ==
std::vector<Bond::BondType>{Bond::DOUBLE, Bond::SINGLE,
Bond::DOUBLE, Bond::SINGLE});
}
}
SECTION("Neighbors") {
auto neighbors = m->atomNeighbors(m->getAtomWithIdx(1));
CHECK(std::ranges::distance(neighbors) == 3);
std::vector<unsigned int> neighborIndices;
std::ranges::transform(neighbors, std::back_inserter(neighborIndices),
[](const auto atom) { return atom->getIdx(); });
CHECK(neighborIndices == std::vector<unsigned int>{0, 2, 3});
auto abonds = m->atomBonds(m->getAtomWithIdx(1));
CHECK(std::ranges::distance(abonds) == 3);
std::vector<unsigned int> bondIndices;
std::ranges::transform(abonds, std::back_inserter(bondIndices),
[](const auto bond) { return bond->getIdx(); });
CHECK(bondIndices == std::vector<unsigned int>{0, 1, 2});
}
}
TEST_CASE("algorithms") {
std::unique_ptr<RWMol> m{new RWMol()};
REQUIRE(m);
// = "COC(F)C=C"_smiles;
m->addAtom(new Atom(6), true, true);
m->addAtom(new Atom(8), true, true);
m->addAtom(new Atom(6), true, true);
m->addAtom(new Atom(9), true, true);
m->addAtom(new Atom(6), true, true);
m->addAtom(new Atom(6), true, true);
m->addBond(0, 1, Bond::SINGLE);
m->addBond(1, 2, Bond::SINGLE);
m->addBond(2, 3, Bond::SINGLE);
m->addBond(4, 5, Bond::DOUBLE);
m->addBond(2, 4, Bond::SINGLE);
SECTION("atom count_if, filter, and take") {
auto atoms = m->atoms();
auto numC = std::ranges::count_if(
atoms, [](const auto atom) { return atom->getAtomicNum() == 6; });
CHECK(numC == 4);
std::vector<unsigned int> atomIndices;
std::ranges::transform(atoms | std::views::filter([](const auto atom) {
return atom->getAtomicNum() == 6;
}),
std::back_inserter(atomIndices),
[](const auto atom) { return atom->getIdx(); });
CHECK(atomIndices == std::vector<unsigned int>{0, 2, 4, 5});
atomIndices.clear();
std::ranges::transform(atoms | std::views::filter([](const auto atom) {
return atom->getAtomicNum() == 6;
}) | std::views::take(2),
std::back_inserter(atomIndices),
[](const auto atom) { return atom->getIdx(); });
CHECK(atomIndices == std::vector<unsigned int>{0, 2});
}
SECTION("bond count_if, filter, and take") {
auto bonds = m->bonds();
auto numSingle = std::ranges::count_if(bonds, [](const auto bond) {
return bond->getBondType() == Bond::SINGLE;
});
CHECK(numSingle == 4);
std::vector<unsigned int> bondIndices;
std::ranges::transform(bonds | std::views::filter([](const auto bond) {
return bond->getBondType() == Bond::SINGLE;
}),
std::back_inserter(bondIndices),
[](const auto bond) { return bond->getIdx(); });
CHECK(bondIndices == std::vector<unsigned int>{0, 1, 2, 4});
bondIndices.clear();
std::ranges::transform(bonds | std::views::filter([](const auto bond) {
return bond->getBondType() == Bond::SINGLE;
}) | std::views::take(2),
std::back_inserter(bondIndices),
[](const auto bond) { return bond->getIdx(); });
CHECK(bondIndices == std::vector<unsigned int>{0, 1});
}
}