Small refactor before PR #8968 follow up (#9135)

* done

* remove unused arg

* restore parenthesis

* more refactoring

* make checks even stricter

* don't use a string_view
This commit is contained in:
Ricardo Rodriguez
2026-03-01 01:24:59 -05:00
committed by GitHub
parent 5f1bfa1f78
commit 46a4f59048
2 changed files with 156 additions and 121 deletions

View File

@@ -924,101 +924,72 @@ void canonicalDFSTraversal(ROMol &mol, int atomIdx, int inBondIdx,
}
void clearBondDirs(ROMol &mol, Bond *refBond, const Atom *fromAtom,
UINT_VECT &bondDirCounts, UINT_VECT &atomDirCounts,
const UINT_VECT &) {
UINT_VECT &bondDirCounts, UINT_VECT &atomDirCounts) {
PRECONDITION(bondDirCounts.size() >= mol.getNumBonds(), "bad dirCount size");
PRECONDITION(refBond, "bad bond");
PRECONDITION(&refBond->getOwningMol() == &mol, "bad bond");
PRECONDITION(fromAtom, "bad atom");
PRECONDITION(&fromAtom->getOwningMol() == &mol, "bad bond");
ROMol::OEDGE_ITER beg, end;
boost::tie(beg, end) = mol.getAtomBonds(fromAtom);
bool nbrPossible = false, adjusted = false;
while (beg != end) {
Bond *oBond = mol[*beg];
auto clearDirection = [&atomDirCounts, &bondDirCounts](Bond *bond) {
--bondDirCounts[bond->getIdx()];
if (!bondDirCounts[bond->getIdx()]) {
bond->setBondDir(Bond::NONE);
--atomDirCounts[bond->getBeginAtomIdx()];
--atomDirCounts[bond->getEndAtomIdx()];
}
};
for (auto oBond : mol.atomBonds(fromAtom)) {
if (oBond != refBond && canHaveDirection(*oBond)) {
nbrPossible = true;
if ((bondDirCounts[oBond->getIdx()] >=
bondDirCounts[refBond->getIdx()]) &&
atomDirCounts[oBond->getBeginAtomIdx()] != 1 &&
atomDirCounts[oBond->getEndAtomIdx()] != 1) {
adjusted = true;
bondDirCounts[oBond->getIdx()] -= 1;
if (!bondDirCounts[oBond->getIdx()]) {
// no one is setting the direction here:
oBond->setBondDir(Bond::NONE);
atomDirCounts[oBond->getBeginAtomIdx()] -= 1;
atomDirCounts[oBond->getEndAtomIdx()] -= 1;
// std::cerr<<"ob:"<<oBond->getIdx()<<" ";
}
clearDirection(oBond);
} else if (atomDirCounts[refBond->getBeginAtomIdx()] != 1 &&
atomDirCounts[refBond->getEndAtomIdx()] != 1) {
// we found a neighbor that could have directionality set,
// but it had a lower bondDirCount than us, so we must
// need to be adjusted:
clearDirection(refBond);
}
}
beg++;
}
if (nbrPossible && !adjusted &&
atomDirCounts[refBond->getBeginAtomIdx()] != 1 &&
atomDirCounts[refBond->getEndAtomIdx()] != 1) {
// we found a neighbor that could have directionality set,
// but it had a lower bondDirCount than us, so we must
// need to be adjusted:
bondDirCounts[refBond->getIdx()] -= 1;
if (!bondDirCounts[refBond->getIdx()]) {
refBond->setBondDir(Bond::NONE);
atomDirCounts[refBond->getBeginAtomIdx()] -= 1;
atomDirCounts[refBond->getEndAtomIdx()] -= 1;
break;
}
}
}
void removeRedundantBondDirSpecs(ROMol &mol, MolStack &molStack,
UINT_VECT &bondDirCounts,
UINT_VECT &atomDirCounts,
const UINT_VECT &bondVisitOrders) {
UINT_VECT &atomDirCounts) {
PRECONDITION(bondDirCounts.size() >= mol.getNumBonds(), "bad dirCount size");
auto clearBondDirsFromAtom = [&mol, &bondDirCounts, &atomDirCounts](
Bond *tBond, const Atom *atom) {
for (auto bond : mol.atomBonds(atom)) {
if (bond != tBond && bond->getBondType() == Bond::DOUBLE &&
bond->getStereo() > Bond::STEREOANY) {
clearBondDirs(mol, tBond, atom, bondDirCounts, atomDirCounts);
return;
}
}
};
// find bonds that have directions indicated that are redundant:
for (auto &msI : molStack) {
if (msI.type == MOL_STACK_BOND) {
Bond *tBond = msI.obj.bond;
const Atom *canonBeginAtom = mol.getAtomWithIdx(msI.number);
const Atom *canonEndAtom =
mol.getAtomWithIdx(tBond->getOtherAtomIdx(msI.number));
if (canHaveDirection(*tBond) && bondDirCounts[tBond->getIdx()] >= 1) {
// start by finding the double bond that sets tBond's direction:
const Atom *dblBondAtom = nullptr;
ROMol::OEDGE_ITER beg, end;
boost::tie(beg, end) = mol.getAtomBonds(canonBeginAtom);
while (beg != end) {
if (mol[*beg] != tBond && mol[*beg]->getBondType() == Bond::DOUBLE &&
mol[*beg]->getStereo() > Bond::STEREOANY) {
dblBondAtom =
canonBeginAtom; // tBond->getOtherAtom(canonBeginAtom);
break;
}
beg++;
}
if (dblBondAtom != nullptr) {
clearBondDirs(mol, tBond, dblBondAtom, bondDirCounts, atomDirCounts,
bondVisitOrders);
}
dblBondAtom = nullptr;
boost::tie(beg, end) = mol.getAtomBonds(canonEndAtom);
while (beg != end) {
if (mol[*beg] != tBond && mol[*beg]->getBondType() == Bond::DOUBLE &&
mol[*beg]->getStereo() > Bond::STEREOANY) {
dblBondAtom = canonEndAtom; // tBond->getOtherAtom(canonEndAtom);
break;
}
beg++;
}
if (dblBondAtom != nullptr) {
clearBondDirs(mol, tBond, dblBondAtom, bondDirCounts, atomDirCounts,
bondVisitOrders);
}
} else if (tBond->getBondDir() != Bond::NONE) {
// we aren't supposed to have a direction set, but we do:
tBond->setBondDir(Bond::NONE);
}
if (msI.type != MOL_STACK_BOND) {
continue;
}
Bond *tBond = msI.obj.bond;
const Atom *canonBeginAtom = mol.getAtomWithIdx(msI.number);
const Atom *canonEndAtom =
mol.getAtomWithIdx(tBond->getOtherAtomIdx(msI.number));
if (canHaveDirection(*tBond) && bondDirCounts[tBond->getIdx()]) {
clearBondDirsFromAtom(tBond, canonBeginAtom);
clearBondDirsFromAtom(tBond, canonEndAtom);
} else if (tBond->getBondDir() != Bond::NONE) {
// we aren't supposed to have a direction set, but we do:
tBond->setBondDir(Bond::NONE);
}
}
}
@@ -1271,7 +1242,7 @@ void canonicalizeFragment(ROMol &mol, int atomIdx,
}
}
Canon::removeRedundantBondDirSpecs(mol, molStack, bondDirCounts,
atomDirCounts, bondVisitOrders);
atomDirCounts);
}
void canonicalizeEnhancedStereo(ROMol &mol,

View File

@@ -1160,6 +1160,114 @@ TEST_CASE("allow disabling ring stereo in ranking") {
CHECK(res1[6] == res1[7]);
}
static void checkSmilesRoundtrip(const std::string &smiles,
bool shouldMatch = true) {
auto getFeatures = [](ROMol &m) {
// enable this for development only (it's SLOW): make the test stricter
// by comparing CIP codes instead of just counting features
#if 0
CIPLabeler::assignCIPLabels(m);
std::vector<std::string> labels;
for (const auto atom : m.atoms()) {
std::string tag;
std::string cip;
if (atom->getPropIfPresent<std::string>(common_properties::_CIPCode,
cip)) {
atom->getPropIfPresent(common_properties::atomLabel, tag);
labels.push_back(cip + "_" + tag);
}
}
for (const auto bond : m.bonds()) {
std::string cip;
if (bond->getPropIfPresent<std::string>(common_properties::_CIPCode,
cip)) {
auto atom1 = bond->getBeginAtom();
unsigned int idx1 = std::numeric_limits<unsigned int>::max();
atom1->getPropIfPresent(common_properties::atomLabel, idx1);
auto atom2 = bond->getEndAtom();
unsigned int idx2 = std::numeric_limits<unsigned int>::max();
atom2->getPropIfPresent(common_properties::atomLabel, idx2);
if (idx1 > idx2) {
std::swap(idx1, idx2);
}
labels.push_back(cip + "_" + std::to_string(idx1) + "_" +
std::to_string(idx2));
}
}
std::sort(labels.begin(), labels.end());
return labels;
#else
unsigned int nChiralCenters = 0;
for (const auto atom : m.atoms()) {
if (atom->getChiralTag() != Atom::ChiralType::CHI_UNSPECIFIED) {
++nChiralCenters;
}
}
unsigned int nDoubleBondStereo = 0;
for (const auto bond : m.bonds()) {
if (bond->getStereo() > Bond::STEREOANY) {
++nDoubleBondStereo;
}
}
return std::make_pair(nChiralCenters, nDoubleBondStereo);
#endif
};
SmilesWriteParams ps;
auto fields = SmilesWrite::CXSmilesFields::CX_ATOM_LABELS;
auto getStrings = [&ps, &fields](const ROMol &m) {
const auto cxsmiles = MolToCXSmiles(m, ps, fields);
auto pos = cxsmiles.find(" ");
const std::string smiles(cxsmiles.data(), pos);
return std::make_pair(smiles, cxsmiles);
};
// pre-canonicalize SMILES: the inputs get outdated when
// we make changes to the canonicalization algorithm
auto m1 = v2::SmilesParse::MolFromSmiles(smiles);
REQUIRE(m1);
for (auto atom : m1->atoms()) {
atom->setProp(common_properties::atomLabel, atom->getIdx());
}
const auto [firstSmiles, firstCxsmiles] = getStrings(*m1);
// Get the stereo features after the SMILES roundtrip,
// so that assigning labels can't have any influence
// on the SMILES
const auto refFeatures = getFeatures(*m1);
auto m2 = v2::SmilesParse::MolFromSmiles(firstCxsmiles);
REQUIRE(m2);
const auto [secondSmiles, secondCxsmiles] = getStrings(*m2);
if (shouldMatch) {
CHECK(firstSmiles == secondSmiles);
// If the stereo labels don't match after round-tripping, something is wrong
CHECK(refFeatures == getFeatures(*m2));
// Check the second roundtrip too
auto m3 = v2::SmilesParse::MolFromSmiles(secondCxsmiles);
REQUIRE(m3);
CHECK(refFeatures == getFeatures(*m3));
} else {
CHECK(firstSmiles != secondSmiles);
}
}
TEST_CASE("Canonicalization issues watch (see GitHub Issue #8775)") {
// This is a check about the state of things with canonicalization.
// The "samples" below initially come from the list compiled in GitHub
@@ -1344,23 +1452,6 @@ TEST_CASE("Canonicalization issues watch (see GitHub Issue #8775)") {
false}, // #8965
};
auto count_features = [](RWMol m) {
unsigned int nChiralCenters = 0;
for (const auto atom : m.atoms()) {
if (atom->getChiralTag() != Atom::ChiralType::CHI_UNSPECIFIED) {
++nChiralCenters;
}
}
unsigned int nDoubleBondStereo = 0;
for (const auto bond : m.bonds()) {
if (bond->getStereo() > Bond::STEREOANY) {
++nDoubleBondStereo;
}
}
return std::make_pair(nChiralCenters, nDoubleBondStereo);
};
const auto &[smiles, legacyState, modernState] =
GENERATE_REF(values(samples));
auto usingLegacyStereo = GENERATE(false, true);
@@ -1368,34 +1459,7 @@ TEST_CASE("Canonicalization issues watch (see GitHub Issue #8775)") {
UseLegacyStereoPerceptionFixture useLegacy(usingLegacyStereo);
// pre-canonicalize SMILES: the inputs get outdated when
// we make changes to the canonicalization algorithm
auto m1 = v2::SmilesParse::MolFromSmiles(smiles);
REQUIRE(m1);
const auto firstRoundtrip = MolToSmiles(*m1);
// Get the stereo features after the SMILES roundtrip,
// so that assigning labels can't have any influence
// on the SMILES
const auto refFeatures = count_features(*m1);
auto m2 = v2::SmilesParse::MolFromSmiles(firstRoundtrip);
REQUIRE(m2);
const auto secondRoundtrip = MolToSmiles(*m2);
auto shouldMatch = usingLegacyStereo ? legacyState : modernState;
if (shouldMatch) {
CHECK(firstRoundtrip == secondRoundtrip);
// If the stereo labels don't match after round-tripping, something is wrong
CHECK(refFeatures == count_features(*m2));
// Check the second roundtrip too
auto m3 = v2::SmilesParse::MolFromSmiles(secondRoundtrip);
REQUIRE(m3);
CHECK(refFeatures == count_features(*m3));
} else {
CHECK(firstRoundtrip != secondRoundtrip);
}
}
checkSmilesRoundtrip(smiles, shouldMatch);
}