mirror of
https://github.com/rdkit/rdkit.git
synced 2026-06-03 21:44:30 +08:00
DistViolationContribs optimization (#8208)
* DistViolationContribs optimization. * Update preFactor calculation to avoid numerical instability: use square of distance instead of squared distance directly, to prevent accumulated errors leading to CI test failures on Mac when d2 > c.ub2. * Optimize distance calculations by moving sqrt under condition branches to avoid unnecessary computations when squared distance bounds are not met.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
//
|
||||
// Copyright (C) 2024 Greg Landrum and other RDKit contributors
|
||||
// Copyright (C) 2024-2025 Greg Landrum and other RDKit contributors
|
||||
//
|
||||
// @@ All Rights Reserved @@
|
||||
// This file is part of the RDKit.
|
||||
@@ -19,8 +19,8 @@ DistViolationContribs::DistViolationContribs(ForceFields::ForceField *owner) {
|
||||
dp_forceField = owner;
|
||||
}
|
||||
|
||||
auto distance2 = [](const unsigned int idx1, const unsigned int idx2,
|
||||
const double *pos, const unsigned int dim) {
|
||||
inline double distance2(const unsigned int idx1, const unsigned int idx2,
|
||||
const double *pos, const unsigned int dim) {
|
||||
const auto *end1Coords = &(pos[dim * idx1]);
|
||||
const auto *end2Coords = &(pos[dim * idx2]);
|
||||
double d2 = 0.0;
|
||||
@@ -29,11 +29,13 @@ auto distance2 = [](const unsigned int idx1, const unsigned int idx2,
|
||||
d2 += d * d;
|
||||
}
|
||||
return d2;
|
||||
};
|
||||
auto distance = [](const unsigned int idx1, const unsigned int idx2,
|
||||
const double *pos, const unsigned int dim) {
|
||||
}
|
||||
|
||||
inline double distance(const unsigned int idx1, const unsigned int idx2,
|
||||
const double *pos, const unsigned int dim) {
|
||||
return sqrt(distance2(idx1, idx2, pos, dim));
|
||||
};
|
||||
}
|
||||
|
||||
double DistViolationContribs::getEnergy(double *pos) const {
|
||||
PRECONDITION(dp_forceField, "no owner");
|
||||
PRECONDITION(pos, "bad vector");
|
||||
@@ -42,10 +44,10 @@ double DistViolationContribs::getEnergy(double *pos) const {
|
||||
auto contrib = [&](const auto &c) {
|
||||
double d2 = distance2(c.idx1, c.idx2, pos, dp_forceField->dimension());
|
||||
double val = 0.0;
|
||||
if (d2 > c.ub * c.ub) {
|
||||
val = (d2 / (c.ub * c.ub)) - 1.0;
|
||||
} else if (d2 < c.lb * c.lb) {
|
||||
val = ((2 * c.lb * c.lb) / (c.lb * c.lb + d2)) - 1.0;
|
||||
if (d2 > c.ub2) {
|
||||
val = (d2 / (c.ub2)) - 1.0;
|
||||
} else if (d2 < c.lb2) {
|
||||
val = ((2 * c.lb2) / (c.lb2 + d2)) - 1.0;
|
||||
}
|
||||
if (val > 0.0) {
|
||||
accum += c.weight * val * val;
|
||||
@@ -64,16 +66,16 @@ void DistViolationContribs::getGrad(double *pos, double *grad) const {
|
||||
const unsigned int dim = this->dp_forceField->dimension();
|
||||
|
||||
auto contrib = [&](const auto &c) {
|
||||
double d = distance(c.idx1, c.idx2, pos, dp_forceField->dimension());
|
||||
double d2 = distance2(c.idx1, c.idx2, pos, dp_forceField->dimension());
|
||||
double d;
|
||||
double preFactor = 0.0;
|
||||
if (d > c.ub) {
|
||||
double u2 = c.ub * c.ub;
|
||||
preFactor = 4. * (((d * d) / u2) - 1.0) * (d / u2);
|
||||
} else if (d < c.lb) {
|
||||
double d2 = d * d;
|
||||
double l2 = c.lb * c.lb;
|
||||
double l2d2 = d2 + l2;
|
||||
preFactor = 8. * l2 * d * (1. - 2 * l2 / l2d2) / (l2d2 * l2d2);
|
||||
if (d2 > c.ub2) {
|
||||
d = sqrt(d2);
|
||||
preFactor = 4. * (((d * d) / c.ub2) - 1.0) * (d / c.ub2);
|
||||
} else if (d2 < c.lb2) {
|
||||
d = sqrt(d2);
|
||||
double l2d2 = d2 + c.lb2;
|
||||
preFactor = 8. * c.lb2 * d * (1. - 2 * c.lb2 / l2d2) / (l2d2 * l2d2);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
//
|
||||
// Copyright (C) 2024 Greg Landrum and other RDKit contributors
|
||||
// Copyright (C) 2024-2025 Greg Landrum and other RDKit contributors
|
||||
//
|
||||
// @@ All Rights Reserved @@
|
||||
// This file is part of the RDKit.
|
||||
@@ -17,14 +17,16 @@
|
||||
namespace DistGeom {
|
||||
|
||||
struct DistViolationContribsParams {
|
||||
unsigned int idx1{0}; //!< index of end1 in the ForceField's positions
|
||||
unsigned int idx2{0}; //!< index of end2 in the ForceField's positions
|
||||
double ub{1000.0}; //!< upper bound on the distance
|
||||
double lb{0.0}; //!< lower bound on the distance
|
||||
double weight{1.0}; //!< used to adjust relative contribution weights
|
||||
unsigned int idx1{0}; //!< index of end1 in the ForceField's positions
|
||||
unsigned int idx2{0}; //!< index of end2 in the ForceField's positions
|
||||
double ub{1000.0}; //!< upper bound on the distance
|
||||
double lb{0.0}; //!< lower bound on the distance
|
||||
double ub2{1000000.0}; //!< squared upper bound on the distance
|
||||
double lb2{0.0}; //!< squared lower bound on the distance
|
||||
double weight{1.0}; //!< used to adjust relative contribution weights
|
||||
DistViolationContribsParams(unsigned int i1, unsigned int i2, double u,
|
||||
double l, double w = 1.0)
|
||||
: idx1(i1), idx2(i2), ub(u), lb(l), weight(w) {};
|
||||
: idx1(i1), idx2(i2), ub(u), lb(l), ub2(u * u), lb2(l * l), weight(w) {};
|
||||
};
|
||||
//! A term to capture all violations of the upper and lower bounds by
|
||||
//! distance between two points
|
||||
|
||||
Reference in New Issue
Block a user