mirror of
https://github.com/rdk/p2rank.git
synced 2026-06-04 12:44:24 +08:00
Add kdtree_implementation param and fix quickselect duplicate-key hang
Add runtime parameter to switch between KdTree3D (default) and v1 AtomKdTree. Fix O(N²) quickselect degeneration on duplicate coordinates by adding post-partition equal-range scan.
This commit is contained in:
@@ -1,42 +1,56 @@
|
||||
package cz.siret.prank.geom.kdtree.v2
|
||||
|
||||
import cz.siret.prank.geom.Atoms
|
||||
import cz.siret.prank.geom.kdtree.AtomKdTree as AtomKdTreeV1
|
||||
import cz.siret.prank.program.params.Params
|
||||
import groovy.transform.CompileStatic
|
||||
import org.biojava.nbio.structure.Atom
|
||||
|
||||
/**
|
||||
* Drop-in replacement for kdtree.AtomKdTree, delegating to immutable KdTree3D.
|
||||
* Unified KdTree wrapper that delegates to either v1 (Rednaxela) or v2 (KdTree3D)
|
||||
* based on the kdtree_version parameter.
|
||||
*
|
||||
* Same class name and method signatures as the v1 AtomKdTree — only the import
|
||||
* in Atoms.java needs to change. Groovy callers accessing via Atoms.getKdTree()
|
||||
* are transparent (duck typing).
|
||||
* Same method signatures as the original v1 AtomKdTree — Groovy callers accessing
|
||||
* via Atoms.getKdTree() are transparent (duck typing).
|
||||
*
|
||||
* Key difference from v1: no add()/addAll() — tree is immutable.
|
||||
* Atoms.add() invalidates the tree; it gets rebuilt lazily via withKdTree().
|
||||
*
|
||||
* Uses getX()/getY()/getZ() instead of .coords — eliminates new double[3] allocation per call.
|
||||
* v2 (default): immutable KdTree3D — no add()/addAll(), thread-safe, SoA layout.
|
||||
* v1: original Rednaxela KdTree — mutable, generic N-dimensional.
|
||||
*/
|
||||
@CompileStatic
|
||||
class AtomKdTree {
|
||||
|
||||
// Exactly one of these is non-null, depending on the selected implementation.
|
||||
private final KdTree3D tree
|
||||
private final AtomKdTreeV1 treeV1
|
||||
|
||||
private AtomKdTree(KdTree3D tree) {
|
||||
this.tree = tree
|
||||
this.treeV1 = null
|
||||
}
|
||||
|
||||
private AtomKdTree(AtomKdTreeV1 treeV1) {
|
||||
this.tree = null
|
||||
this.treeV1 = treeV1
|
||||
}
|
||||
|
||||
static AtomKdTree build(Atoms atoms) {
|
||||
if ("AtomKdTree" == Params.INSTANCE.kdtree_implementation) {
|
||||
return new AtomKdTree(AtomKdTreeV1.build(atoms))
|
||||
}
|
||||
return new AtomKdTree(KdTree3D.build(atoms))
|
||||
}
|
||||
|
||||
int size() {
|
||||
return tree.size()
|
||||
return tree != null ? tree.size() : treeV1.size()
|
||||
}
|
||||
|
||||
// --- Single nearest neighbor ---
|
||||
|
||||
Atom findNearest(Atom a) {
|
||||
return tree.findNearest(a.getX(), a.getY(), a.getZ())
|
||||
if (tree != null) {
|
||||
return tree.findNearest(a.getX(), a.getY(), a.getZ())
|
||||
}
|
||||
return treeV1.findNearest(a)
|
||||
}
|
||||
|
||||
double nearestDist(Atom a) {
|
||||
@@ -44,18 +58,20 @@ class AtomKdTree {
|
||||
}
|
||||
|
||||
double nearestSqrDist(Atom a) {
|
||||
return tree.nearestSqrDist(a.getX(), a.getY(), a.getZ())
|
||||
if (tree != null) {
|
||||
return tree.nearestSqrDist(a.getX(), a.getY(), a.getZ())
|
||||
}
|
||||
return treeV1.nearestSqrDist(a)
|
||||
}
|
||||
|
||||
// --- Nearest different (excluding identity-equal atom) ---
|
||||
|
||||
/**
|
||||
* Find nearest atom that is not identity-equal to a.
|
||||
* Uses k-NN with k=2 and filters. Returns null if no different atom exists.
|
||||
*/
|
||||
Atom findNearestDifferent(Atom a) {
|
||||
KdTree3D.NNEntry entry = singleNearestDifferent(a)
|
||||
return entry?.atom()
|
||||
if (tree != null) {
|
||||
KdTree3D.NNEntry entry = singleNearestDifferent(a)
|
||||
return entry?.atom()
|
||||
}
|
||||
return treeV1.findNearestDifferent(a)
|
||||
}
|
||||
|
||||
double nearestDifferentDist(Atom a) {
|
||||
@@ -63,19 +79,17 @@ class AtomKdTree {
|
||||
}
|
||||
|
||||
double nearestDifferentSqrDist(Atom a) {
|
||||
KdTree3D.NNEntry entry = singleNearestDifferent(a)
|
||||
return entry != null ? entry.sqrDist() : Double.NaN
|
||||
if (tree != null) {
|
||||
KdTree3D.NNEntry entry = singleNearestDifferent(a)
|
||||
return entry != null ? entry.sqrDist() : Double.NaN
|
||||
}
|
||||
return treeV1.nearestDifferentSqrDist(a)
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal: find nearest non-self entry.
|
||||
* Requests k=2 neighbors and picks the one that isn't identity-equal to a.
|
||||
* If both are identity-equal (duplicate points), returns null.
|
||||
*/
|
||||
private KdTree3D.NNEntry singleNearestDifferent(Atom a) {
|
||||
List<KdTree3D.NNEntry> entries = tree.findNearestN(a.getX(), a.getY(), a.getZ(), 2, false)
|
||||
for (KdTree3D.NNEntry entry : entries) {
|
||||
if (!(entry.atom().is(a))) { // identity check, not equals()
|
||||
if (!(entry.atom().is(a))) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
@@ -84,40 +98,33 @@ class AtomKdTree {
|
||||
|
||||
// --- k-NN ---
|
||||
|
||||
List<KdTree3D.NNEntry> findNearestN(Atom a, int count, boolean sorted) {
|
||||
return tree.findNearestN(a.getX(), a.getY(), a.getZ(), count, sorted)
|
||||
}
|
||||
|
||||
Atoms findNearestNAtoms(Atom a, int count, boolean sorted) {
|
||||
return toAtoms(findNearestN(a, count, sorted))
|
||||
}
|
||||
|
||||
List<KdTree3D.NNEntry> findNearestNDifferent(Atom a, int count, boolean sorted) {
|
||||
// Request count+1 to account for the self-match, then filter
|
||||
List<KdTree3D.NNEntry> entries = tree.findNearestN(a.getX(), a.getY(), a.getZ(), count + 1, sorted)
|
||||
entries.removeIf { KdTree3D.NNEntry entry -> entry.atom().is(a) }
|
||||
// Trim to requested count (in case self wasn't in results)
|
||||
if (entries.size() > count) {
|
||||
entries = entries.subList(0, count)
|
||||
if (tree != null) {
|
||||
return toAtoms(tree.findNearestN(a.getX(), a.getY(), a.getZ(), count, sorted))
|
||||
}
|
||||
return entries
|
||||
return treeV1.findNearestNAtoms(a, count, sorted)
|
||||
}
|
||||
|
||||
Atoms findNearestNDifferentAtoms(Atom a, int count, boolean sorted) {
|
||||
return toAtoms(findNearestNDifferent(a, count, sorted))
|
||||
if (tree != null) {
|
||||
List<KdTree3D.NNEntry> entries = tree.findNearestN(a.getX(), a.getY(), a.getZ(), count + 1, sorted)
|
||||
entries.removeIf { KdTree3D.NNEntry entry -> entry.atom().is(a) }
|
||||
if (entries.size() > count) {
|
||||
entries = entries.subList(0, count)
|
||||
}
|
||||
return toAtoms(entries)
|
||||
}
|
||||
return treeV1.findNearestNDifferentAtoms(a, count, sorted)
|
||||
}
|
||||
|
||||
// --- Radius search ---
|
||||
|
||||
/**
|
||||
* Find all atoms within radius of a.
|
||||
* Squares the radius (KdTree3D works in squared distances throughout).
|
||||
* sorted param kept for API compat but ignored — always false in production
|
||||
* (only caller is Atoms.cutoutSphereKD which passes false).
|
||||
*/
|
||||
Atoms findAtomsWithinRadius(Atom a, double radius, boolean sorted) {
|
||||
double sqrRadius = radius * radius
|
||||
return tree.findWithinRadius(a.getX(), a.getY(), a.getZ(), sqrRadius)
|
||||
if (tree != null) {
|
||||
double sqrRadius = radius * radius
|
||||
return tree.findWithinRadius(a.getX(), a.getY(), a.getZ(), sqrRadius)
|
||||
}
|
||||
return treeV1.findAtomsWithinRadius(a, radius, sorted)
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
@@ -319,11 +319,12 @@ public final class KdTree3D {
|
||||
* Resolves the split-axis array once (keys = xs/ys/zs based on dim) to avoid
|
||||
* per-comparison branching in the inner partition loop.
|
||||
*
|
||||
* Implementation: Sedgewick-style partition with sentinels.
|
||||
* After median-of-three: keys[lo] <= keys[mid] <= keys[hi].
|
||||
* keys[lo] is left sentinel, keys[hi] is right sentinel, pivot parked at hi-1.
|
||||
* IMPORTANT: must re-read values from arrays after each swap (not cache in locals)
|
||||
* to maintain the sentinel property correctly.
|
||||
* Implementation: Sedgewick-style partition with sentinels (fast for random data),
|
||||
* plus post-partition equal-range scan to skip duplicate regions. This prevents O(N²)
|
||||
* degeneration when many elements share the same coordinate value (common for surface
|
||||
* points on flat protein regions where many atoms align along one axis). With many
|
||||
* duplicates, the equal range around the pivot is large → the scan catches most of them
|
||||
* and skips the entire region in one step.
|
||||
*/
|
||||
private static void quickselect(double[] xs, double[] ys, double[] zs, Atom[] atoms,
|
||||
int lo, int hi, int k, int dim) {
|
||||
@@ -340,14 +341,10 @@ public final class KdTree3D {
|
||||
}
|
||||
|
||||
// Median-of-three: sort keys[lo], keys[mid], keys[hi] to select pivot.
|
||||
// Re-read from keys[] after each swap to avoid stale sentinel values.
|
||||
int mid = lo + (hi - lo) / 2;
|
||||
if (keys[lo] > keys[mid])
|
||||
swap(xs, ys, zs, atoms, lo, mid);
|
||||
if (keys[lo] > keys[hi])
|
||||
swap(xs, ys, zs, atoms, lo, hi);
|
||||
if (keys[mid] > keys[hi])
|
||||
swap(xs, ys, zs, atoms, mid, hi);
|
||||
if (keys[lo] > keys[mid]) swap(xs, ys, zs, atoms, lo, mid);
|
||||
if (keys[lo] > keys[hi]) swap(xs, ys, zs, atoms, lo, hi);
|
||||
if (keys[mid] > keys[hi]) swap(xs, ys, zs, atoms, mid, hi);
|
||||
// Now: keys[lo] <= keys[mid] <= keys[hi].
|
||||
// keys[lo] is left sentinel (<=pivot), keys[hi] is right sentinel (>=pivot).
|
||||
|
||||
@@ -367,9 +364,19 @@ public final class KdTree3D {
|
||||
}
|
||||
swap(xs, ys, zs, atoms, i, hi - 1); // restore pivot to final position
|
||||
|
||||
// Narrow search to the half containing k
|
||||
if (k <= i) hi = i - 1;
|
||||
else lo = i + 1;
|
||||
// After partition: keys[lo..i-1] <= pivot, keys[i] == pivot, keys[i+1..hi] >= pivot.
|
||||
// Expand the equal region around position i to skip duplicates.
|
||||
// For random data: region is [i, i] (2 extra comparisons, negligible).
|
||||
// For all-equal data: region is [lo, hi] → return immediately, O(N) total.
|
||||
int eqLo = i;
|
||||
while (eqLo > lo && keys[eqLo - 1] == pivot) eqLo--;
|
||||
int eqHi = i;
|
||||
while (eqHi < hi && keys[eqHi + 1] == pivot) eqHi++;
|
||||
|
||||
// Narrow search to the region containing k
|
||||
if (k >= eqLo && k <= eqHi) return; // k in equal region — done
|
||||
if (k < eqLo) hi = eqLo - 1;
|
||||
else lo = eqHi + 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1375,6 +1375,12 @@ class Params {
|
||||
@ModelParam // training
|
||||
boolean identify_peptides_by_labeling = false
|
||||
|
||||
/**
|
||||
* KD-tree implementation: "KdTree3D" (immutable, SoA) or "AtomKdTree" (Rednaxela, mutable, generic)
|
||||
*/
|
||||
@RuntimeParam
|
||||
String kdtree_implementation = "KdTree3D"
|
||||
|
||||
/**
|
||||
* Atoms size threshold for using KD-tree in cutoutSphere routine
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user