Add kdtree_implementation param and fix quickselect duplicate-key hang

Add runtime parameter to switch between KdTree3D (default) and v1
AtomKdTree. Fix O(N²) quickselect degeneration on duplicate coordinates
by adding post-partition equal-range scan.
This commit is contained in:
rdk
2026-03-02 22:20:59 +01:00
parent 24b9f5f709
commit 6d47285116
3 changed files with 84 additions and 64 deletions

View File

@@ -1,42 +1,56 @@
package cz.siret.prank.geom.kdtree.v2
import cz.siret.prank.geom.Atoms
import cz.siret.prank.geom.kdtree.AtomKdTree as AtomKdTreeV1
import cz.siret.prank.program.params.Params
import groovy.transform.CompileStatic
import org.biojava.nbio.structure.Atom
/**
* Drop-in replacement for kdtree.AtomKdTree, delegating to immutable KdTree3D.
* Unified KdTree wrapper that delegates to either v1 (Rednaxela) or v2 (KdTree3D)
* based on the kdtree_version parameter.
*
* Same class name and method signatures as the v1 AtomKdTree — only the import
* in Atoms.java needs to change. Groovy callers accessing via Atoms.getKdTree()
* are transparent (duck typing).
* Same method signatures as the original v1 AtomKdTree — Groovy callers accessing
* via Atoms.getKdTree() are transparent (duck typing).
*
* Key difference from v1: no add()/addAll() tree is immutable.
* Atoms.add() invalidates the tree; it gets rebuilt lazily via withKdTree().
*
* Uses getX()/getY()/getZ() instead of .coords — eliminates new double[3] allocation per call.
* v2 (default): immutable KdTree3D — no add()/addAll(), thread-safe, SoA layout.
* v1: original Rednaxela KdTree — mutable, generic N-dimensional.
*/
@CompileStatic
class AtomKdTree {
// Exactly one of these is non-null, depending on the selected implementation.
private final KdTree3D tree
private final AtomKdTreeV1 treeV1
private AtomKdTree(KdTree3D tree) {
this.tree = tree
this.treeV1 = null
}
private AtomKdTree(AtomKdTreeV1 treeV1) {
this.tree = null
this.treeV1 = treeV1
}
static AtomKdTree build(Atoms atoms) {
if ("AtomKdTree" == Params.INSTANCE.kdtree_implementation) {
return new AtomKdTree(AtomKdTreeV1.build(atoms))
}
return new AtomKdTree(KdTree3D.build(atoms))
}
int size() {
return tree.size()
return tree != null ? tree.size() : treeV1.size()
}
// --- Single nearest neighbor ---
Atom findNearest(Atom a) {
return tree.findNearest(a.getX(), a.getY(), a.getZ())
if (tree != null) {
return tree.findNearest(a.getX(), a.getY(), a.getZ())
}
return treeV1.findNearest(a)
}
double nearestDist(Atom a) {
@@ -44,18 +58,20 @@ class AtomKdTree {
}
double nearestSqrDist(Atom a) {
return tree.nearestSqrDist(a.getX(), a.getY(), a.getZ())
if (tree != null) {
return tree.nearestSqrDist(a.getX(), a.getY(), a.getZ())
}
return treeV1.nearestSqrDist(a)
}
// --- Nearest different (excluding identity-equal atom) ---
/**
* Find nearest atom that is not identity-equal to a.
* Uses k-NN with k=2 and filters. Returns null if no different atom exists.
*/
Atom findNearestDifferent(Atom a) {
KdTree3D.NNEntry entry = singleNearestDifferent(a)
return entry?.atom()
if (tree != null) {
KdTree3D.NNEntry entry = singleNearestDifferent(a)
return entry?.atom()
}
return treeV1.findNearestDifferent(a)
}
double nearestDifferentDist(Atom a) {
@@ -63,19 +79,17 @@ class AtomKdTree {
}
double nearestDifferentSqrDist(Atom a) {
KdTree3D.NNEntry entry = singleNearestDifferent(a)
return entry != null ? entry.sqrDist() : Double.NaN
if (tree != null) {
KdTree3D.NNEntry entry = singleNearestDifferent(a)
return entry != null ? entry.sqrDist() : Double.NaN
}
return treeV1.nearestDifferentSqrDist(a)
}
/**
* Internal: find nearest non-self entry.
* Requests k=2 neighbors and picks the one that isn't identity-equal to a.
* If both are identity-equal (duplicate points), returns null.
*/
private KdTree3D.NNEntry singleNearestDifferent(Atom a) {
List<KdTree3D.NNEntry> entries = tree.findNearestN(a.getX(), a.getY(), a.getZ(), 2, false)
for (KdTree3D.NNEntry entry : entries) {
if (!(entry.atom().is(a))) { // identity check, not equals()
if (!(entry.atom().is(a))) {
return entry
}
}
@@ -84,40 +98,33 @@ class AtomKdTree {
// --- k-NN ---
List<KdTree3D.NNEntry> findNearestN(Atom a, int count, boolean sorted) {
return tree.findNearestN(a.getX(), a.getY(), a.getZ(), count, sorted)
}
Atoms findNearestNAtoms(Atom a, int count, boolean sorted) {
return toAtoms(findNearestN(a, count, sorted))
}
List<KdTree3D.NNEntry> findNearestNDifferent(Atom a, int count, boolean sorted) {
// Request count+1 to account for the self-match, then filter
List<KdTree3D.NNEntry> entries = tree.findNearestN(a.getX(), a.getY(), a.getZ(), count + 1, sorted)
entries.removeIf { KdTree3D.NNEntry entry -> entry.atom().is(a) }
// Trim to requested count (in case self wasn't in results)
if (entries.size() > count) {
entries = entries.subList(0, count)
if (tree != null) {
return toAtoms(tree.findNearestN(a.getX(), a.getY(), a.getZ(), count, sorted))
}
return entries
return treeV1.findNearestNAtoms(a, count, sorted)
}
Atoms findNearestNDifferentAtoms(Atom a, int count, boolean sorted) {
return toAtoms(findNearestNDifferent(a, count, sorted))
if (tree != null) {
List<KdTree3D.NNEntry> entries = tree.findNearestN(a.getX(), a.getY(), a.getZ(), count + 1, sorted)
entries.removeIf { KdTree3D.NNEntry entry -> entry.atom().is(a) }
if (entries.size() > count) {
entries = entries.subList(0, count)
}
return toAtoms(entries)
}
return treeV1.findNearestNDifferentAtoms(a, count, sorted)
}
// --- Radius search ---
/**
* Find all atoms within radius of a.
* Squares the radius (KdTree3D works in squared distances throughout).
* sorted param kept for API compat but ignored — always false in production
* (only caller is Atoms.cutoutSphereKD which passes false).
*/
Atoms findAtomsWithinRadius(Atom a, double radius, boolean sorted) {
double sqrRadius = radius * radius
return tree.findWithinRadius(a.getX(), a.getY(), a.getZ(), sqrRadius)
if (tree != null) {
double sqrRadius = radius * radius
return tree.findWithinRadius(a.getX(), a.getY(), a.getZ(), sqrRadius)
}
return treeV1.findAtomsWithinRadius(a, radius, sorted)
}
// --- Helpers ---

View File

@@ -319,11 +319,12 @@ public final class KdTree3D {
* Resolves the split-axis array once (keys = xs/ys/zs based on dim) to avoid
* per-comparison branching in the inner partition loop.
*
* Implementation: Sedgewick-style partition with sentinels.
* After median-of-three: keys[lo] <= keys[mid] <= keys[hi].
* keys[lo] is left sentinel, keys[hi] is right sentinel, pivot parked at hi-1.
* IMPORTANT: must re-read values from arrays after each swap (not cache in locals)
* to maintain the sentinel property correctly.
* Implementation: Sedgewick-style partition with sentinels (fast for random data),
* plus post-partition equal-range scan to skip duplicate regions. This prevents O(N²)
* degeneration when many elements share the same coordinate value (common for surface
* points on flat protein regions where many atoms align along one axis). With many
* duplicates, the equal range around the pivot is large → the scan catches most of them
* and skips the entire region in one step.
*/
private static void quickselect(double[] xs, double[] ys, double[] zs, Atom[] atoms,
int lo, int hi, int k, int dim) {
@@ -340,14 +341,10 @@ public final class KdTree3D {
}
// Median-of-three: sort keys[lo], keys[mid], keys[hi] to select pivot.
// Re-read from keys[] after each swap to avoid stale sentinel values.
int mid = lo + (hi - lo) / 2;
if (keys[lo] > keys[mid])
swap(xs, ys, zs, atoms, lo, mid);
if (keys[lo] > keys[hi])
swap(xs, ys, zs, atoms, lo, hi);
if (keys[mid] > keys[hi])
swap(xs, ys, zs, atoms, mid, hi);
if (keys[lo] > keys[mid]) swap(xs, ys, zs, atoms, lo, mid);
if (keys[lo] > keys[hi]) swap(xs, ys, zs, atoms, lo, hi);
if (keys[mid] > keys[hi]) swap(xs, ys, zs, atoms, mid, hi);
// Now: keys[lo] <= keys[mid] <= keys[hi].
// keys[lo] is left sentinel (<=pivot), keys[hi] is right sentinel (>=pivot).
@@ -367,9 +364,19 @@ public final class KdTree3D {
}
swap(xs, ys, zs, atoms, i, hi - 1); // restore pivot to final position
// Narrow search to the half containing k
if (k <= i) hi = i - 1;
else lo = i + 1;
// After partition: keys[lo..i-1] <= pivot, keys[i] == pivot, keys[i+1..hi] >= pivot.
// Expand the equal region around position i to skip duplicates.
// For random data: region is [i, i] (2 extra comparisons, negligible).
// For all-equal data: region is [lo, hi] → return immediately, O(N) total.
int eqLo = i;
while (eqLo > lo && keys[eqLo - 1] == pivot) eqLo--;
int eqHi = i;
while (eqHi < hi && keys[eqHi + 1] == pivot) eqHi++;
// Narrow search to the region containing k
if (k >= eqLo && k <= eqHi) return; // k in equal region — done
if (k < eqLo) hi = eqLo - 1;
else lo = eqHi + 1;
}
}

View File

@@ -1375,6 +1375,12 @@ class Params {
@ModelParam // training
boolean identify_peptides_by_labeling = false
/**
* KD-tree implementation: "KdTree3D" (immutable, SoA) or "AtomKdTree" (Rednaxela, mutable, generic)
*/
@RuntimeParam
String kdtree_implementation = "KdTree3D"
/**
* Atoms size threshold for using KD-tree in cutoutSphere routine
*/