DistViolationContribs optimization (#8208)

* DistViolationContribs optimization.

* Update preFactor calculation to avoid numerical instability: use square of distance instead of squared distance directly, to prevent accumulated errors leading to CI test failures on Mac when d2 > c.ub2.

* Optimize distance calculations by moving sqrt under condition branches to avoid unnecessary computations when squared distance bounds are not met.
This commit is contained in:
EvaSnow
2025-01-29 19:58:53 -08:00
committed by greg landrum
parent dbaa50f240
commit b0a1d6fe50
2 changed files with 31 additions and 27 deletions

View File

@@ -1,5 +1,5 @@
//
// Copyright (C) 2024 Greg Landrum and other RDKit contributors
// Copyright (C) 2024-2025 Greg Landrum and other RDKit contributors
//
// @@ All Rights Reserved @@
// This file is part of the RDKit.
@@ -19,8 +19,8 @@ DistViolationContribs::DistViolationContribs(ForceFields::ForceField *owner) {
dp_forceField = owner;
}
auto distance2 = [](const unsigned int idx1, const unsigned int idx2,
const double *pos, const unsigned int dim) {
inline double distance2(const unsigned int idx1, const unsigned int idx2,
const double *pos, const unsigned int dim) {
const auto *end1Coords = &(pos[dim * idx1]);
const auto *end2Coords = &(pos[dim * idx2]);
double d2 = 0.0;
@@ -29,11 +29,13 @@ auto distance2 = [](const unsigned int idx1, const unsigned int idx2,
d2 += d * d;
}
return d2;
};
auto distance = [](const unsigned int idx1, const unsigned int idx2,
const double *pos, const unsigned int dim) {
}
inline double distance(const unsigned int idx1, const unsigned int idx2,
const double *pos, const unsigned int dim) {
return sqrt(distance2(idx1, idx2, pos, dim));
};
}
double DistViolationContribs::getEnergy(double *pos) const {
PRECONDITION(dp_forceField, "no owner");
PRECONDITION(pos, "bad vector");
@@ -42,10 +44,10 @@ double DistViolationContribs::getEnergy(double *pos) const {
auto contrib = [&](const auto &c) {
double d2 = distance2(c.idx1, c.idx2, pos, dp_forceField->dimension());
double val = 0.0;
if (d2 > c.ub * c.ub) {
val = (d2 / (c.ub * c.ub)) - 1.0;
} else if (d2 < c.lb * c.lb) {
val = ((2 * c.lb * c.lb) / (c.lb * c.lb + d2)) - 1.0;
if (d2 > c.ub2) {
val = (d2 / (c.ub2)) - 1.0;
} else if (d2 < c.lb2) {
val = ((2 * c.lb2) / (c.lb2 + d2)) - 1.0;
}
if (val > 0.0) {
accum += c.weight * val * val;
@@ -64,16 +66,16 @@ void DistViolationContribs::getGrad(double *pos, double *grad) const {
const unsigned int dim = this->dp_forceField->dimension();
auto contrib = [&](const auto &c) {
double d = distance(c.idx1, c.idx2, pos, dp_forceField->dimension());
double d2 = distance2(c.idx1, c.idx2, pos, dp_forceField->dimension());
double d;
double preFactor = 0.0;
if (d > c.ub) {
double u2 = c.ub * c.ub;
preFactor = 4. * (((d * d) / u2) - 1.0) * (d / u2);
} else if (d < c.lb) {
double d2 = d * d;
double l2 = c.lb * c.lb;
double l2d2 = d2 + l2;
preFactor = 8. * l2 * d * (1. - 2 * l2 / l2d2) / (l2d2 * l2d2);
if (d2 > c.ub2) {
d = sqrt(d2);
preFactor = 4. * (((d * d) / c.ub2) - 1.0) * (d / c.ub2);
} else if (d2 < c.lb2) {
d = sqrt(d2);
double l2d2 = d2 + c.lb2;
preFactor = 8. * c.lb2 * d * (1. - 2 * c.lb2 / l2d2) / (l2d2 * l2d2);
} else {
return;
}

View File

@@ -1,5 +1,5 @@
//
// Copyright (C) 2024 Greg Landrum and other RDKit contributors
// Copyright (C) 2024-2025 Greg Landrum and other RDKit contributors
//
// @@ All Rights Reserved @@
// This file is part of the RDKit.
@@ -17,14 +17,16 @@
namespace DistGeom {
struct DistViolationContribsParams {
unsigned int idx1{0}; //!< index of end1 in the ForceField's positions
unsigned int idx2{0}; //!< index of end2 in the ForceField's positions
double ub{1000.0}; //!< upper bound on the distance
double lb{0.0}; //!< lower bound on the distance
double weight{1.0}; //!< used to adjust relative contribution weights
unsigned int idx1{0}; //!< index of end1 in the ForceField's positions
unsigned int idx2{0}; //!< index of end2 in the ForceField's positions
double ub{1000.0}; //!< upper bound on the distance
double lb{0.0}; //!< lower bound on the distance
double ub2{1000000.0}; //!< squared upper bound on the distance
double lb2{0.0}; //!< squared lower bound on the distance
double weight{1.0}; //!< used to adjust relative contribution weights
DistViolationContribsParams(unsigned int i1, unsigned int i2, double u,
double l, double w = 1.0)
: idx1(i1), idx2(i2), ub(u), lb(l), weight(w) {};
: idx1(i1), idx2(i2), ub(u), lb(l), ub2(u * u), lb2(l * l), weight(w) {};
};
//! A term to capture all violations of the upper and lower bounds by
//! distance between two points