Create a function to extract some specified atoms from a ROMol as a new ROMol by creating new graph (#8742) (#8743)

* Create a function to extract some specified atoms from a ROMol as a new ROMol by creating new graph (#8742)

This adds a new api, `RDKit::MolOps::ExtractMolFragment`, to allow efficient
extractions of mol fragments from large mols. Compared to the approach where
we delete "unwanted" atoms/bonds from the input mol, this api is faster for
small mols (about 2x faster) and at least 3x faster for big mols
(was 10x faster for "CCC"*1000).

* clang-format

* review comments

* cleanup

* review comments

* fix build failure

---------

Co-authored-by: Your Name <you@example.com>
This commit is contained in:
Hussein Faara
2025-09-23 21:39:52 -07:00
committed by GitHub
parent d13b002d2d
commit 040bdb61c7
3 changed files with 471 additions and 2 deletions

View File

@@ -20,7 +20,18 @@
#include <boost/dynamic_bitset.hpp>
namespace RDKit {
namespace {
// Helper api to hold atom selection data for copyMolSubset
struct [[nodiscard]] SelectedAtomInfo {
boost::dynamic_bitset<> selected_atoms;
boost::dynamic_bitset<> selected_bonds;
std::unordered_map<unsigned int, unsigned int> atom_mapping;
std::unordered_map<unsigned int, unsigned int> bond_mapping;
};
} // namespace
namespace Subgraphs {
void getNbrsList(const ROMol &mol, bool useHs, INT_INT_VECT_MAP &nbrs) {
nbrs.clear();
int nAtoms = mol.getNumAtoms();
@@ -640,4 +651,187 @@ PATH_TYPE findAtomEnvironmentOfRadiusN(
return res;
}
static void copySelectedAtomsAndBonds(RWMol &extracted_mol,
const ROMol &reference_mol,
SelectedAtomInfo &selection_info) {
auto &[selected_atoms, selected_bonds, atom_mapping, bond_mapping] =
selection_info;
for (const auto &ref_atom : reference_mol.atoms()) {
if (!selected_atoms[ref_atom->getIdx()]) {
continue;
}
std::unique_ptr<Atom> extracted_atom{ref_atom->copy()};
static constinit bool updateLabel = true;
static constinit bool takeOwnership = true;
atom_mapping[ref_atom->getIdx()] = extracted_mol.addAtom(
extracted_atom.release(), updateLabel, takeOwnership);
}
for (const auto &ref_bond : reference_mol.bonds()) {
if (!selected_bonds[ref_bond->getIdx()]) {
continue;
}
std::unique_ptr<Bond> extracted_bond{ref_bond->copy()};
extracted_bond->setBeginAtomIdx(atom_mapping[ref_bond->getBeginAtomIdx()]);
extracted_bond->setEndAtomIdx(atom_mapping[ref_bond->getEndAtomIdx()]);
static constinit bool takeOwnership = true;
auto num_bonds =
extracted_mol.addBond(extracted_bond.release(), takeOwnership);
bond_mapping[ref_bond->getIdx()] = num_bonds - 1;
}
}
[[nodiscard]] static bool is_selected_sgroup(
const SubstanceGroup &sgroup, const SelectedAtomInfo &selection_info) {
auto is_selected_component = [](auto &indices, auto &selection_test) {
return indices.empty() ||
std::all_of(indices.begin(), indices.end(), selection_test);
};
// clang-format off
auto atom_test = [&](int idx) { return selection_info.selected_atoms[idx]; };
auto bond_test = [&](int idx) { return selection_info.selected_bonds[idx]; };
return is_selected_component(sgroup.getAtoms(), atom_test) &&
is_selected_component(sgroup.getBonds(), bond_test) &&
is_selected_component(sgroup.getParentAtoms(), atom_test);
// clang-format on
}
static void copySelectedSubstanceGroups(
RWMol &extracted_mol, const ROMol &reference_mol,
const SelectedAtomInfo &selection_info) {
auto update_indices = [](auto &sgroup, auto getter, auto setter,
auto &mapping) {
auto indices = getter(sgroup);
std::for_each(indices.begin(), indices.end(),
[&](auto &idx) { idx = mapping.at(idx); });
setter(sgroup, std::move(indices));
};
const auto &[selected_atoms, selected_bonds, atom_mapping, bond_mapping] =
selection_info;
for (const auto &sgroup : getSubstanceGroups(reference_mol)) {
if (!is_selected_sgroup(sgroup, selection_info)) {
continue;
}
SubstanceGroup extracted_sgroup(sgroup);
extracted_sgroup.setOwningMol(&extracted_mol);
update_indices(extracted_sgroup, std::mem_fn(&SubstanceGroup::getAtoms),
std::mem_fn(&SubstanceGroup::setAtoms), atom_mapping);
update_indices(extracted_sgroup,
std::mem_fn(&SubstanceGroup::getParentAtoms),
std::mem_fn(&SubstanceGroup::setParentAtoms), atom_mapping);
update_indices(extracted_sgroup, std::mem_fn(&SubstanceGroup::getBonds),
std::mem_fn(&SubstanceGroup::setBonds), bond_mapping);
addSubstanceGroup(extracted_mol, std::move(extracted_sgroup));
}
}
static void copySelectedStereoGroups(RWMol &extracted_mol,
const ROMol &reference_mol,
const SelectedAtomInfo &selection_info) {
auto is_selected_component = [](auto &objects, auto &selected_indices) {
return objects.empty() ||
std::all_of(objects.begin(), objects.end(), [&](auto &object) {
return selected_indices[object->getIdx()];
});
};
auto is_selected_stereo_group = [&](const auto &stereo_group) {
return is_selected_component(stereo_group.getAtoms(),
selection_info.selected_atoms) &&
is_selected_component(stereo_group.getBonds(),
selection_info.selected_bonds);
};
std::vector<Atom *> extracted_atoms(extracted_mol.getNumAtoms());
for (const auto &atom : extracted_mol.atoms()) {
extracted_atoms[atom->getIdx()] = atom;
}
std::vector<Bond *> extracted_bonds(extracted_mol.getNumBonds());
for (const auto &bond : extracted_mol.bonds()) {
extracted_bonds[bond->getIdx()] = bond;
}
const auto &[selected_atoms, selected_bonds, atom_mapping, bond_mapping] =
selection_info;
std::vector<StereoGroup> extracted_stereo_groups;
for (const auto &stereo_group : reference_mol.getStereoGroups()) {
if (!is_selected_stereo_group(stereo_group)) {
continue;
}
std::vector<Atom *> atoms;
for (const auto &atom : stereo_group.getAtoms()) {
atoms.push_back(extracted_atoms[atom_mapping.at(atom->getIdx())]);
}
std::vector<Bond *> bonds;
for (const auto &bond : stereo_group.getBonds()) {
bonds.push_back(extracted_bonds[bond_mapping.at(bond->getIdx())]);
}
extracted_stereo_groups.push_back({stereo_group.getGroupType(),
std::move(atoms), std::move(bonds),
stereo_group.getReadId()});
extracted_stereo_groups.back().setWriteId(stereo_group.getWriteId());
}
extracted_mol.setStereoGroups(std::move(extracted_stereo_groups));
}
[[nodiscard]] static SelectedAtomInfo getSelectedAtomInfo(
const ROMol &mol, const std::vector<unsigned int> &path,
SubsetMethod method) {
const auto num_atoms = mol.getNumAtoms();
SelectedAtomInfo selection_info{
.selected_atoms = boost::dynamic_bitset<>(num_atoms),
.selected_bonds = boost::dynamic_bitset<>(mol.getNumBonds()),
.atom_mapping = {},
.bond_mapping = {}};
if (method == SubsetMethod::BONDS_BETWEEN_ATOMS) {
auto &[selected_atoms, selected_bonds, atom_mapping, bond_mapping] =
selection_info;
for (const auto &atom_idx : path) {
if (atom_idx < num_atoms) {
selected_atoms[atom_idx] = true;
}
}
for (const auto &bond : mol.bonds()) {
if (selected_atoms[bond->getBeginAtomIdx()] &&
selected_atoms[bond->getEndAtomIdx()]) {
selected_bonds[bond->getIdx()] = true;
}
}
} else {
UNDER_CONSTRUCTION("not implemented");
}
return selection_info;
}
boost::shared_ptr<RWMol> copyMolSubset(const ROMol &mol,
const std::vector<unsigned int> &path,
SubsetMethod method, bool sanitize) {
auto selection_info = getSelectedAtomInfo(mol, path, method);
auto extracted_mol = std::make_unique<RWMol>();
copySelectedAtomsAndBonds(*extracted_mol, mol, selection_info);
copySelectedSubstanceGroups(*extracted_mol, mol, selection_info);
copySelectedStereoGroups(*extracted_mol, mol, selection_info);
if (sanitize) {
MolOps::sanitizeMol(*extracted_mol);
}
// NOTE: Bookmarks and coordinates are currently not copied
return extracted_mol;
}
} // namespace RDKit

View File

@@ -29,13 +29,19 @@
#ifndef RD_SUBGRAPHS_H
#define RD_SUBGRAPHS_H
#include <vector>
#include <boost/smart_ptr.hpp>
#include <list>
#include <map>
#include <unordered_map>
#include <vector>
#include <RDGeneral/BetterEnums.h>
namespace RDKit {
class ROMol;
class RWMol;
// NOTE: before replacing the defn of PATH_TYPE: be aware that
// we do occasionally use reverse iterators on these things, so
// replacing with a slist would probably be a bad idea.
@@ -151,6 +157,32 @@ RDKIT_SUBGRAPHS_EXPORT PATH_TYPE findAtomEnvironmentOfRadiusN(
bool useHs = false, bool enforceSize = true,
std::unordered_map<unsigned int, unsigned int> *atomMap = nullptr);
BETTER_ENUM_CLASS(SubsetMethod, unsigned int,
BONDS_BETWEEN_ATOMS,
BOND_PATH
);
//!
/*
* Helper api to extract a subgraph from an ROMol. Bonds, substance groups and
* stereo groups are only extracted to the subgraph if all participant entities
* are selected by the `path` parameter.
*
* @param mol starting mol
* @param path the indices of atoms or bonds to extract. If an index falls
* outside of the acceptable indices, it is ignored.
* @param method the method by which to extract this subgraph.
* @param sanitize whether to sanitize the extracted mol.
*
* NOTE: Bookmarks and coordinates are currently not copied
*
*/
RDKIT_SUBGRAPHS_EXPORT boost::shared_ptr<RDKit::RWMol>
copyMolSubset(const RDKit::ROMol& mol,
const std::vector<unsigned int>& path,
SubsetMethod method = SubsetMethod::BONDS_BETWEEN_ATOMS,
bool sanitize = true);
} // namespace RDKit
#endif

View File

@@ -58,4 +58,247 @@ TEST_CASE("shortestPathsOnly") {
CHECK(ps.size() == 1);
CHECK(ps[3].size() == 2);
}
}
}
// helper api to get test data for copyMolSubset
struct SelectedComponents {
std::vector<bool> selected_atoms;
std::vector<bool> selected_bonds;
};
// helper api to get test mol for copyMolSubset api.
[[nodiscard]] static std::unique_ptr<RDKit::RWMol> get_test_mol()
{
std::unique_ptr<RDKit::RWMol> mol{RDKit::SmilesToMol("CCCCCCCCCCCCCCC")};
for (auto& atom : mol->atoms()) {
atom->setProp("orig_idx", atom->getIdx());
}
for (auto& bond : mol->bonds()) {
bond->setProp("orig_idx", bond->getIdx());
}
return mol;
}
// Helper api to get the included atoms and bonds from test atom indices.
[[nodiscard]] static SelectedComponents
get_selected_components(::RDKit::RWMol& mol,
const std::vector<unsigned int>& atom_ids)
{
const auto num_atoms = mol.getNumAtoms();
std::vector<bool> selected_atoms(num_atoms);
for (auto& atom_idx : atom_ids) {
if (atom_idx < num_atoms) {
selected_atoms[atom_idx] = true;
}
}
std::vector<bool> selected_bonds(mol.getNumBonds());
for (auto& bond : mol.bonds()) {
if (selected_atoms[bond->getBeginAtomIdx()] &&
selected_atoms[bond->getEndAtomIdx()]) {
selected_bonds[bond->getIdx()] = true;
}
}
return {std::move(selected_atoms), std::move(selected_bonds)};
}
// This test makes sure we correctly extract atoms
TEST_CASE("test_extract_atoms", "[copyMolSubset]") {
auto selected_atoms = GENERATE(
// unique values
std::vector<unsigned int>{0, 2, 4, 6, 8, 10, 12},
// duplicate values
std::vector<unsigned int>{0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12},
// values outside of atom indices
std::vector<unsigned int>{0, 2, 4, 6, 8, 10, 12, 100, 200, 300}
);
std::vector<unsigned int> expected_atoms{0, 2, 4, 6, 8, 10, 12};
auto mol = get_test_mol();
auto extracted_mol = copyMolSubset(*mol, selected_atoms);
REQUIRE(extracted_mol->getNumAtoms() == expected_atoms.size());
std::vector<unsigned int> extracted_atoms;
for (auto& atom : extracted_mol->atoms()) {
extracted_atoms.push_back(atom->template getProp<unsigned int>("orig_idx"));
}
CHECK(extracted_atoms == expected_atoms);
}
// This test makes sure we correctly extract atoms
TEST_CASE("test_extract_bonds", "[copyMolSubset]")
{
auto test_mol = get_test_mol();
for (auto& bond : test_mol->bonds()) {
bond->setProp("test_prop", true);
}
for (auto& bond : test_mol->bonds()) {
auto begin_idx = bond->getBeginAtomIdx();
auto end_idx = bond->getEndAtomIdx();
auto m = copyMolSubset(*test_mol, {begin_idx, end_idx});
REQUIRE(m->getNumBonds() == 1);
CHECK(m->getBondWithIdx(0)->getProp<bool>("test_prop") == true);
CHECK(m->getNumAtoms() == 2);
CHECK(m->getAtomWithIdx(0)->getProp<unsigned int>("orig_idx") == begin_idx);
CHECK(m->getAtomWithIdx(1)->getProp<unsigned int>("orig_idx") == end_idx);
}
}
// This test makes sure we correctly extract substance groups
TEST_CASE("test_extract_substance_groups", "[copyMolSubset]") {
auto mol = get_test_mol();
::RDKit::SubstanceGroup sgroup{mol.get(), "COP"};
auto test_sgroup_atoms = GENERATE(
std::vector<unsigned int>{},
std::vector<unsigned int>{0, 1, 2, 3, 4},
std::vector<unsigned int>{9, 10, 11}
);
sgroup.setAtoms(test_sgroup_atoms);
auto test_sgroup_bonds = GENERATE(
std::vector<unsigned int>{},
std::vector<unsigned int>{0, 1, 2},
std::vector<unsigned int>{3, 4, 5}
);
sgroup.setBonds(test_sgroup_bonds);
auto test_sgroup_patoms = GENERATE(
std::vector<unsigned int>{},
std::vector<unsigned int>{3, 4},
std::vector<unsigned int>{5, 6}
);
sgroup.setParentAtoms(test_sgroup_patoms);
::RDKit::addSubstanceGroup(*mol, std::move(sgroup));
auto has_selected_components = [&](auto& components, auto& ref_bitset) {
return components.empty() ||
std::ranges::all_of(components, [&](auto& idx) {
return idx < ref_bitset.size() && ref_bitset[idx];
});
};
auto test_selected_atoms = GENERATE(
std::vector<unsigned int>{0, 1, 2, 3, 4, 5},
std::vector<unsigned int>{0, 2, 4, 6, 8, 10, 12},
std::vector<unsigned int>{0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12},
std::vector<unsigned int>{0, 2, 4, 6, 8, 10, 12, 100, 200, 300}
);
auto extracted_mol = copyMolSubset(*mol, test_selected_atoms);
auto [selected_atoms, selected_bonds] = get_selected_components(*mol, test_selected_atoms);
// sgroup should only be copied if all components are selected
auto flag = ::RDKit::getSubstanceGroups(*extracted_mol).size() == 1;
REQUIRE(flag ==
(has_selected_components(test_sgroup_atoms, selected_atoms) &&
has_selected_components(test_sgroup_patoms, selected_atoms) &&
has_selected_components(test_sgroup_bonds, selected_bonds)));
// now make sure we copied the correct components
if (flag) {
auto& extracted_sgroup = ::RDKit::getSubstanceGroups(*extracted_mol)[0];
for (auto& idx : extracted_sgroup.getAtoms()) {
auto atom = extracted_mol->getAtomWithIdx(idx);
CHECK(
selected_atoms[atom->template getProp<unsigned int>("orig_idx")] ==
true);
}
for (auto& idx : extracted_sgroup.getParentAtoms()) {
auto atom = extracted_mol->getAtomWithIdx(idx);
CHECK(
selected_atoms[atom->template getProp<unsigned int>("orig_idx")] ==
true);
}
for (auto& idx : extracted_sgroup.getBonds()) {
auto bond = extracted_mol->getBondWithIdx(idx);
CHECK(
selected_bonds[bond->template getProp<unsigned int>("orig_idx")] ==
true);
}
}
}
// This test makes sure we correctly extract stereo groups
TEST_CASE("test_extract_stereo_groups", "[copyMolSubset]") {
auto mol = get_test_mol();
auto test_stereo_group_atoms = GENERATE(
std::vector<unsigned int>{},
std::vector<unsigned int>{0, 1, 2, 3, 4},
std::vector<unsigned int>{9, 10, 11}
);
std::vector<::RDKit::Atom*> sg_atoms;
for (auto& idx : test_stereo_group_atoms) {
sg_atoms.push_back(mol->getAtomWithIdx(idx));
}
auto test_stereo_group_bonds = GENERATE(
std::vector<unsigned int>{},
std::vector<unsigned int>{0, 1, 2},
std::vector<unsigned int>{3, 4, 5}
);
std::vector<::RDKit::Bond*> sg_bonds;
for (auto& idx : test_stereo_group_bonds) {
sg_bonds.push_back(mol->getBondWithIdx(idx));
}
::RDKit::StereoGroup stereo_group{::RDKit::StereoGroupType::STEREO_ABSOLUTE,
std::move(sg_atoms), std::move(sg_bonds)};
mol->setStereoGroups({std::move(stereo_group)});
auto has_selected_components = [&](auto& components, auto& ref_bitset) {
return components.empty() ||
std::ranges::all_of(components, [&](auto& idx) {
return idx < ref_bitset.size() && ref_bitset[idx];
});
};
auto test_selected_atoms = GENERATE(
std::vector<unsigned int>{0, 1, 2, 3, 4, 5},
std::vector<unsigned int>{0, 2, 4, 6, 8, 10, 12},
std::vector<unsigned int>{0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12},
std::vector<unsigned int>{0, 2, 4, 6, 8, 10, 12, 100, 200, 300}
);
auto extracted_mol = copyMolSubset(*mol, test_selected_atoms);
auto [selected_atoms, selected_bonds] =
get_selected_components(*mol, test_selected_atoms);
// stereo group should only be copied if all components are selected
auto flag = extracted_mol->getStereoGroups().size() == 1;
REQUIRE(flag ==
(has_selected_components(test_stereo_group_atoms, selected_atoms) &&
has_selected_components(test_stereo_group_bonds, selected_bonds)));
// now make sure we copied the correct components
if (flag) {
auto& extracted_stereo_group = extracted_mol->getStereoGroups()[0];
for (auto& atom : extracted_stereo_group.getAtoms()) {
CHECK(
selected_atoms[atom->template getProp<int>("orig_idx")] ==
true);
}
for (auto& bond : extracted_stereo_group.getBonds()) {
CHECK(
selected_bonds[bond->template getProp<int>("orig_idx")] ==
true);
}
}
}