From 6d47285116603f6fbfa910d44532dae8b2a208eb Mon Sep 17 00:00:00 2001 From: rdk Date: Mon, 2 Mar 2026 22:20:59 +0100 Subject: [PATCH] Add kdtree_implementation param and fix quickselect duplicate-key hang MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add runtime parameter to switch between KdTree3D (default) and v1 AtomKdTree. Fix O(N²) quickselect degeneration on duplicate coordinates by adding post-partition equal-range scan. --- .../prank/geom/kdtree/v2/AtomKdTree.groovy | 105 ++++++++++-------- .../siret/prank/geom/kdtree/v2/KdTree3D.java | 37 +++--- .../siret/prank/program/params/Params.groovy | 6 + 3 files changed, 84 insertions(+), 64 deletions(-) diff --git a/src/main/groovy/cz/siret/prank/geom/kdtree/v2/AtomKdTree.groovy b/src/main/groovy/cz/siret/prank/geom/kdtree/v2/AtomKdTree.groovy index 63bd0516..0b8ccbbe 100644 --- a/src/main/groovy/cz/siret/prank/geom/kdtree/v2/AtomKdTree.groovy +++ b/src/main/groovy/cz/siret/prank/geom/kdtree/v2/AtomKdTree.groovy @@ -1,42 +1,56 @@ package cz.siret.prank.geom.kdtree.v2 import cz.siret.prank.geom.Atoms +import cz.siret.prank.geom.kdtree.AtomKdTree as AtomKdTreeV1 +import cz.siret.prank.program.params.Params import groovy.transform.CompileStatic import org.biojava.nbio.structure.Atom /** - * Drop-in replacement for kdtree.AtomKdTree, delegating to immutable KdTree3D. + * Unified KdTree wrapper that delegates to either v1 (Rednaxela) or v2 (KdTree3D) + * based on the kdtree_version parameter. * - * Same class name and method signatures as the v1 AtomKdTree — only the import - * in Atoms.java needs to change. Groovy callers accessing via Atoms.getKdTree() - * are transparent (duck typing). + * Same method signatures as the original v1 AtomKdTree — Groovy callers accessing + * via Atoms.getKdTree() are transparent (duck typing). * - * Key difference from v1: no add()/addAll() — tree is immutable. - * Atoms.add() invalidates the tree; it gets rebuilt lazily via withKdTree(). - * - * Uses getX()/getY()/getZ() instead of .coords — eliminates new double[3] allocation per call. + * v2 (default): immutable KdTree3D — no add()/addAll(), thread-safe, SoA layout. + * v1: original Rednaxela KdTree — mutable, generic N-dimensional. */ @CompileStatic class AtomKdTree { + // Exactly one of these is non-null, depending on the selected implementation. private final KdTree3D tree + private final AtomKdTreeV1 treeV1 private AtomKdTree(KdTree3D tree) { this.tree = tree + this.treeV1 = null + } + + private AtomKdTree(AtomKdTreeV1 treeV1) { + this.tree = null + this.treeV1 = treeV1 } static AtomKdTree build(Atoms atoms) { + if ("AtomKdTree" == Params.INSTANCE.kdtree_implementation) { + return new AtomKdTree(AtomKdTreeV1.build(atoms)) + } return new AtomKdTree(KdTree3D.build(atoms)) } int size() { - return tree.size() + return tree != null ? tree.size() : treeV1.size() } // --- Single nearest neighbor --- Atom findNearest(Atom a) { - return tree.findNearest(a.getX(), a.getY(), a.getZ()) + if (tree != null) { + return tree.findNearest(a.getX(), a.getY(), a.getZ()) + } + return treeV1.findNearest(a) } double nearestDist(Atom a) { @@ -44,18 +58,20 @@ class AtomKdTree { } double nearestSqrDist(Atom a) { - return tree.nearestSqrDist(a.getX(), a.getY(), a.getZ()) + if (tree != null) { + return tree.nearestSqrDist(a.getX(), a.getY(), a.getZ()) + } + return treeV1.nearestSqrDist(a) } // --- Nearest different (excluding identity-equal atom) --- - /** - * Find nearest atom that is not identity-equal to a. - * Uses k-NN with k=2 and filters. Returns null if no different atom exists. - */ Atom findNearestDifferent(Atom a) { - KdTree3D.NNEntry entry = singleNearestDifferent(a) - return entry?.atom() + if (tree != null) { + KdTree3D.NNEntry entry = singleNearestDifferent(a) + return entry?.atom() + } + return treeV1.findNearestDifferent(a) } double nearestDifferentDist(Atom a) { @@ -63,19 +79,17 @@ class AtomKdTree { } double nearestDifferentSqrDist(Atom a) { - KdTree3D.NNEntry entry = singleNearestDifferent(a) - return entry != null ? entry.sqrDist() : Double.NaN + if (tree != null) { + KdTree3D.NNEntry entry = singleNearestDifferent(a) + return entry != null ? entry.sqrDist() : Double.NaN + } + return treeV1.nearestDifferentSqrDist(a) } - /** - * Internal: find nearest non-self entry. - * Requests k=2 neighbors and picks the one that isn't identity-equal to a. - * If both are identity-equal (duplicate points), returns null. - */ private KdTree3D.NNEntry singleNearestDifferent(Atom a) { List entries = tree.findNearestN(a.getX(), a.getY(), a.getZ(), 2, false) for (KdTree3D.NNEntry entry : entries) { - if (!(entry.atom().is(a))) { // identity check, not equals() + if (!(entry.atom().is(a))) { return entry } } @@ -84,40 +98,33 @@ class AtomKdTree { // --- k-NN --- - List findNearestN(Atom a, int count, boolean sorted) { - return tree.findNearestN(a.getX(), a.getY(), a.getZ(), count, sorted) - } - Atoms findNearestNAtoms(Atom a, int count, boolean sorted) { - return toAtoms(findNearestN(a, count, sorted)) - } - - List findNearestNDifferent(Atom a, int count, boolean sorted) { - // Request count+1 to account for the self-match, then filter - List entries = tree.findNearestN(a.getX(), a.getY(), a.getZ(), count + 1, sorted) - entries.removeIf { KdTree3D.NNEntry entry -> entry.atom().is(a) } - // Trim to requested count (in case self wasn't in results) - if (entries.size() > count) { - entries = entries.subList(0, count) + if (tree != null) { + return toAtoms(tree.findNearestN(a.getX(), a.getY(), a.getZ(), count, sorted)) } - return entries + return treeV1.findNearestNAtoms(a, count, sorted) } Atoms findNearestNDifferentAtoms(Atom a, int count, boolean sorted) { - return toAtoms(findNearestNDifferent(a, count, sorted)) + if (tree != null) { + List entries = tree.findNearestN(a.getX(), a.getY(), a.getZ(), count + 1, sorted) + entries.removeIf { KdTree3D.NNEntry entry -> entry.atom().is(a) } + if (entries.size() > count) { + entries = entries.subList(0, count) + } + return toAtoms(entries) + } + return treeV1.findNearestNDifferentAtoms(a, count, sorted) } // --- Radius search --- - /** - * Find all atoms within radius of a. - * Squares the radius (KdTree3D works in squared distances throughout). - * sorted param kept for API compat but ignored — always false in production - * (only caller is Atoms.cutoutSphereKD which passes false). - */ Atoms findAtomsWithinRadius(Atom a, double radius, boolean sorted) { - double sqrRadius = radius * radius - return tree.findWithinRadius(a.getX(), a.getY(), a.getZ(), sqrRadius) + if (tree != null) { + double sqrRadius = radius * radius + return tree.findWithinRadius(a.getX(), a.getY(), a.getZ(), sqrRadius) + } + return treeV1.findAtomsWithinRadius(a, radius, sorted) } // --- Helpers --- diff --git a/src/main/groovy/cz/siret/prank/geom/kdtree/v2/KdTree3D.java b/src/main/groovy/cz/siret/prank/geom/kdtree/v2/KdTree3D.java index 66e2ebe9..db651bf7 100644 --- a/src/main/groovy/cz/siret/prank/geom/kdtree/v2/KdTree3D.java +++ b/src/main/groovy/cz/siret/prank/geom/kdtree/v2/KdTree3D.java @@ -319,11 +319,12 @@ public final class KdTree3D { * Resolves the split-axis array once (keys = xs/ys/zs based on dim) to avoid * per-comparison branching in the inner partition loop. * - * Implementation: Sedgewick-style partition with sentinels. - * After median-of-three: keys[lo] <= keys[mid] <= keys[hi]. - * keys[lo] is left sentinel, keys[hi] is right sentinel, pivot parked at hi-1. - * IMPORTANT: must re-read values from arrays after each swap (not cache in locals) - * to maintain the sentinel property correctly. + * Implementation: Sedgewick-style partition with sentinels (fast for random data), + * plus post-partition equal-range scan to skip duplicate regions. This prevents O(N²) + * degeneration when many elements share the same coordinate value (common for surface + * points on flat protein regions where many atoms align along one axis). With many + * duplicates, the equal range around the pivot is large → the scan catches most of them + * and skips the entire region in one step. */ private static void quickselect(double[] xs, double[] ys, double[] zs, Atom[] atoms, int lo, int hi, int k, int dim) { @@ -340,14 +341,10 @@ public final class KdTree3D { } // Median-of-three: sort keys[lo], keys[mid], keys[hi] to select pivot. - // Re-read from keys[] after each swap to avoid stale sentinel values. int mid = lo + (hi - lo) / 2; - if (keys[lo] > keys[mid]) - swap(xs, ys, zs, atoms, lo, mid); - if (keys[lo] > keys[hi]) - swap(xs, ys, zs, atoms, lo, hi); - if (keys[mid] > keys[hi]) - swap(xs, ys, zs, atoms, mid, hi); + if (keys[lo] > keys[mid]) swap(xs, ys, zs, atoms, lo, mid); + if (keys[lo] > keys[hi]) swap(xs, ys, zs, atoms, lo, hi); + if (keys[mid] > keys[hi]) swap(xs, ys, zs, atoms, mid, hi); // Now: keys[lo] <= keys[mid] <= keys[hi]. // keys[lo] is left sentinel (<=pivot), keys[hi] is right sentinel (>=pivot). @@ -367,9 +364,19 @@ public final class KdTree3D { } swap(xs, ys, zs, atoms, i, hi - 1); // restore pivot to final position - // Narrow search to the half containing k - if (k <= i) hi = i - 1; - else lo = i + 1; + // After partition: keys[lo..i-1] <= pivot, keys[i] == pivot, keys[i+1..hi] >= pivot. + // Expand the equal region around position i to skip duplicates. + // For random data: region is [i, i] (2 extra comparisons, negligible). + // For all-equal data: region is [lo, hi] → return immediately, O(N) total. + int eqLo = i; + while (eqLo > lo && keys[eqLo - 1] == pivot) eqLo--; + int eqHi = i; + while (eqHi < hi && keys[eqHi + 1] == pivot) eqHi++; + + // Narrow search to the region containing k + if (k >= eqLo && k <= eqHi) return; // k in equal region — done + if (k < eqLo) hi = eqLo - 1; + else lo = eqHi + 1; } } diff --git a/src/main/groovy/cz/siret/prank/program/params/Params.groovy b/src/main/groovy/cz/siret/prank/program/params/Params.groovy index 78519982..ba20fb00 100644 --- a/src/main/groovy/cz/siret/prank/program/params/Params.groovy +++ b/src/main/groovy/cz/siret/prank/program/params/Params.groovy @@ -1375,6 +1375,12 @@ class Params { @ModelParam // training boolean identify_peptides_by_labeling = false + /** + * KD-tree implementation: "KdTree3D" (immutable, SoA) or "AtomKdTree" (Rednaxela, mutable, generic) + */ + @RuntimeParam + String kdtree_implementation = "KdTree3D" + /** * Atoms size threshold for using KD-tree in cutoutSphere routine */