From b0a1d6fe50ee6b7cf3038cadf366d605b96551c3 Mon Sep 17 00:00:00 2001 From: EvaSnow <10635420+evasnow1992@users.noreply.github.com> Date: Wed, 29 Jan 2025 19:58:53 -0800 Subject: [PATCH] DistViolationContribs optimization (#8208) * DistViolationContribs optimization. * Update preFactor calculation to avoid numerical instability: use square of distance instead of squared distance directly, to prevent accumulated errors leading to CI test failures on Mac when d2 > c.ub2. * Optimize distance calculations by moving sqrt under condition branches to avoid unnecessary computations when squared distance bounds are not met. --- Code/DistGeom/DistViolationContribs.cpp | 42 +++++++++++++------------ Code/DistGeom/DistViolationContribs.h | 16 +++++----- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/Code/DistGeom/DistViolationContribs.cpp b/Code/DistGeom/DistViolationContribs.cpp index 513efa9e3..5ccdefcdc 100644 --- a/Code/DistGeom/DistViolationContribs.cpp +++ b/Code/DistGeom/DistViolationContribs.cpp @@ -1,5 +1,5 @@ // -// Copyright (C) 2024 Greg Landrum and other RDKit contributors +// Copyright (C) 2024-2025 Greg Landrum and other RDKit contributors // // @@ All Rights Reserved @@ // This file is part of the RDKit. @@ -19,8 +19,8 @@ DistViolationContribs::DistViolationContribs(ForceFields::ForceField *owner) { dp_forceField = owner; } -auto distance2 = [](const unsigned int idx1, const unsigned int idx2, - const double *pos, const unsigned int dim) { +inline double distance2(const unsigned int idx1, const unsigned int idx2, + const double *pos, const unsigned int dim) { const auto *end1Coords = &(pos[dim * idx1]); const auto *end2Coords = &(pos[dim * idx2]); double d2 = 0.0; @@ -29,11 +29,13 @@ auto distance2 = [](const unsigned int idx1, const unsigned int idx2, d2 += d * d; } return d2; -}; -auto distance = [](const unsigned int idx1, const unsigned int idx2, - const double *pos, const unsigned int dim) { +} + +inline double distance(const unsigned int idx1, const unsigned int idx2, + const double *pos, const unsigned int dim) { return sqrt(distance2(idx1, idx2, pos, dim)); -}; +} + double DistViolationContribs::getEnergy(double *pos) const { PRECONDITION(dp_forceField, "no owner"); PRECONDITION(pos, "bad vector"); @@ -42,10 +44,10 @@ double DistViolationContribs::getEnergy(double *pos) const { auto contrib = [&](const auto &c) { double d2 = distance2(c.idx1, c.idx2, pos, dp_forceField->dimension()); double val = 0.0; - if (d2 > c.ub * c.ub) { - val = (d2 / (c.ub * c.ub)) - 1.0; - } else if (d2 < c.lb * c.lb) { - val = ((2 * c.lb * c.lb) / (c.lb * c.lb + d2)) - 1.0; + if (d2 > c.ub2) { + val = (d2 / (c.ub2)) - 1.0; + } else if (d2 < c.lb2) { + val = ((2 * c.lb2) / (c.lb2 + d2)) - 1.0; } if (val > 0.0) { accum += c.weight * val * val; @@ -64,16 +66,16 @@ void DistViolationContribs::getGrad(double *pos, double *grad) const { const unsigned int dim = this->dp_forceField->dimension(); auto contrib = [&](const auto &c) { - double d = distance(c.idx1, c.idx2, pos, dp_forceField->dimension()); + double d2 = distance2(c.idx1, c.idx2, pos, dp_forceField->dimension()); + double d; double preFactor = 0.0; - if (d > c.ub) { - double u2 = c.ub * c.ub; - preFactor = 4. * (((d * d) / u2) - 1.0) * (d / u2); - } else if (d < c.lb) { - double d2 = d * d; - double l2 = c.lb * c.lb; - double l2d2 = d2 + l2; - preFactor = 8. * l2 * d * (1. - 2 * l2 / l2d2) / (l2d2 * l2d2); + if (d2 > c.ub2) { + d = sqrt(d2); + preFactor = 4. * (((d * d) / c.ub2) - 1.0) * (d / c.ub2); + } else if (d2 < c.lb2) { + d = sqrt(d2); + double l2d2 = d2 + c.lb2; + preFactor = 8. * c.lb2 * d * (1. - 2 * c.lb2 / l2d2) / (l2d2 * l2d2); } else { return; } diff --git a/Code/DistGeom/DistViolationContribs.h b/Code/DistGeom/DistViolationContribs.h index 9c045fc76..4295221f6 100644 --- a/Code/DistGeom/DistViolationContribs.h +++ b/Code/DistGeom/DistViolationContribs.h @@ -1,5 +1,5 @@ // -// Copyright (C) 2024 Greg Landrum and other RDKit contributors +// Copyright (C) 2024-2025 Greg Landrum and other RDKit contributors // // @@ All Rights Reserved @@ // This file is part of the RDKit. @@ -17,14 +17,16 @@ namespace DistGeom { struct DistViolationContribsParams { - unsigned int idx1{0}; //!< index of end1 in the ForceField's positions - unsigned int idx2{0}; //!< index of end2 in the ForceField's positions - double ub{1000.0}; //!< upper bound on the distance - double lb{0.0}; //!< lower bound on the distance - double weight{1.0}; //!< used to adjust relative contribution weights + unsigned int idx1{0}; //!< index of end1 in the ForceField's positions + unsigned int idx2{0}; //!< index of end2 in the ForceField's positions + double ub{1000.0}; //!< upper bound on the distance + double lb{0.0}; //!< lower bound on the distance + double ub2{1000000.0}; //!< squared upper bound on the distance + double lb2{0.0}; //!< squared lower bound on the distance + double weight{1.0}; //!< used to adjust relative contribution weights DistViolationContribsParams(unsigned int i1, unsigned int i2, double u, double l, double w = 1.0) - : idx1(i1), idx2(i2), ub(u), lb(l), weight(w) {}; + : idx1(i1), idx2(i2), ub(u), lb(l), ub2(u * u), lb2(l * l), weight(w) {}; }; //! A term to capture all violations of the upper and lower bounds by //! distance between two points