Refactor KdTree into AtomKdTree interface with V1/V2 implementations

Rewrite AtomKdTreeV1 from Groovy to Java to eliminate Groovy IndyInterface
monitor contention that serialized 16 parallel threads down to ~2.
Move V1 KdTree into v1/ subpackage, extract AtomKdTree as a Java interface
with factory method dispatching by kdtree_implementation param, and rename
the old v2 wrapper to AtomKdTreeV2 implementing the same interface.
This commit is contained in:
rdk
2026-03-03 00:17:38 +01:00
parent 6d47285116
commit a66a973e1c
10 changed files with 258 additions and 275 deletions

View File

@@ -1,8 +1,7 @@
package cz.siret.prank.geom;
import com.google.common.collect.Lists;
import cz.siret.prank.geom.kdtree.v2.AtomKdTree;
import cz.siret.prank.geom.kdtree.v2.KdTree3D;
import cz.siret.prank.geom.kdtree.AtomKdTree;
import cz.siret.prank.program.params.Params;
import cz.siret.prank.utils.ATimer;
import cz.siret.prank.utils.CutoffAtomsCallLog;
@@ -303,13 +302,13 @@ public final class Atoms implements Iterable<Atom> {
public void add(Atom a) {
list.add(a);
// KdTree3D is immutable — invalidate, rebuilt lazily via withKdTree()/buildKdTree()
// invalidate, rebuilt lazily via withKdTree()/buildKdTree()
kdTree = null;
}
public Atoms addAll(Atoms atoms) {
list.addAll(atoms.list);
kdTree = null; // invalidate immutable tree
kdTree = null; // invalidate
return this;
}
@@ -528,7 +527,7 @@ public final class Atoms implements Iterable<Atom> {
public static Atoms consolidate(Atoms atoms, double dist) {
List<Atom> result = new ArrayList<>();
KdTree3D tree = null;
AtomKdTree tree = null;
int lastBuild = 0;
double sqrDist = dist * dist;
@@ -537,7 +536,7 @@ public final class Atoms implements Iterable<Atom> {
// Check against tree (covers result[0..lastBuild))
if (tree != null) {
if (tree.nearestSqrDist(a.getX(), a.getY(), a.getZ()) <= sqrDist) {
if (tree.nearestSqrDist(a) <= sqrDist) {
tooClose = true;
}
}
@@ -556,7 +555,7 @@ public final class Atoms implements Iterable<Atom> {
result.add(a);
// Rebuild tree when gap grows to batch size
if (result.size() - lastBuild >= CONSOLIDATE_BATCH) {
tree = KdTree3D.build(result);
tree = AtomKdTree.build(new Atoms(result));
lastBuild = result.size();
}
}

View File

@@ -1,110 +0,0 @@
package cz.siret.prank.geom.kdtree
import cz.siret.prank.geom.Atoms
import groovy.transform.CompileStatic
import org.biojava.nbio.structure.Atom
@CompileStatic
class AtomKdTree extends KdTree.SqrEuclid3D<Atom> {
AtomKdTree(Integer sizeLimit) {
super(sizeLimit)
}
public static AtomKdTree build(Atoms atoms) {
AtomKdTree res = new AtomKdTree(Integer.MAX_VALUE)
res.addAll(atoms)
return res
}
//===========================================================================================================//
public add(Atom a) {
addPoint(a.coords, a)
}
public addAll(Atoms atoms) {
for (Atom a : atoms.list) {
add(a)
}
}
public Atom findNearest(Atom a) {
return singleNearestNeighbor(a.coords)?.value
}
public double nearestDist(Atom a) {
double dist = nearestSqrDist(a)
return Math.sqrt(dist)
}
public double nearestSqrDist(Atom a) {
return singleNearestNeighbor(a.coords).distance
}
public Atom findNearestDifferent(Atom a) {
return singleNearestDifferent(a)?.value
}
public double nearestDifferentDist(Atom a) {
return Math.sqrt(nearestDifferentSqrDist(a))
}
public double nearestDifferentSqrDist(Atom a) {
Entry<Atom> ent = singleNearestDifferent(a)
return ent!=null ? ent.distance : Double.NaN
}
public Entry<Atom> singleNearestDifferent(Atom a) {
List<Entry<Atom>> resList = nearestNeighbor(a.coords, 2, false)
for (Entry<Atom> ent in resList) {
if (!(ent.value == a)) {
return ent
}
}
return null
}
public List<Entry<Atom>> findNearestN(Atom a, int count, boolean sorted) {
return nearestNeighbor(a.coords, count, sorted)
}
public List<Entry<Atom>> findNearestNDifferent(Atom a, int count, boolean sorted) {
List<Entry<Atom>> aaa = nearestNeighbor(a.coords, count, sorted)
Iterator<Entry<Atom>> it = aaa.iterator();
while (it.hasNext()) {
Atom ai = it.next().value
if (a.equals(ai)) {
it.remove();
}
}
return aaa
}
public Atoms findNearestNAtoms(Atom a, int count, boolean sorted) {
return atoms(findNearestN(a, count, sorted))
}
public Atoms findNearestNDifferentAtoms(Atom a, int count, boolean sorted) {
return atoms( findNearestNDifferent(a, count, sorted))
}
public final Atoms findAtomsWithinRadius(Atom a, double radius, boolean sorted) {
radius = radius*radius // since we inherit from SqrEuclid
return atoms(neighboursWithinRadius(a.coords, radius, sorted))
}
private Atoms atoms(List<Entry<Atom>> entries) {
List<Atom> list = new ArrayList<>(entries.size());
for (Entry<Atom> e : entries) {
list.add(e.value)
}
return new Atoms(list)
}
}

View File

@@ -0,0 +1,48 @@
package cz.siret.prank.geom.kdtree;
import cz.siret.prank.geom.Atoms;
import cz.siret.prank.geom.kdtree.v1.AtomKdTreeV1;
import cz.siret.prank.geom.kdtree.v2.AtomKdTreeV2;
import cz.siret.prank.program.params.Params;
import org.biojava.nbio.structure.Atom;
/**
* Interface for spatial atom queries backed by a KD-tree.
*
* Implementations:
* AtomKdTreeV1 — Rednaxela (mutable, generic N-dimensional)
* AtomKdTreeV2 — KdTree3D wrapper (immutable, SoA, hardcoded 3D)
*
* Use the static {@link #build(Atoms)} factory to obtain an instance
* based on the {@code kdtree_implementation} runtime parameter.
*/
public interface AtomKdTree {
static AtomKdTree build(Atoms atoms) {
if ("AtomKdTreeV1".equals(Params.INSTANCE.getKdtree_implementation())) {
return AtomKdTreeV1.build(atoms);
}
return AtomKdTreeV2.build(atoms);
}
int size();
Atom findNearest(Atom a);
double nearestSqrDist(Atom a);
Atom findNearestDifferent(Atom a);
double nearestDifferentSqrDist(Atom a);
Atoms findNearestNAtoms(Atom a, int count, boolean sorted);
Atoms findNearestNDifferentAtoms(Atom a, int count, boolean sorted);
Atoms findAtomsWithinRadius(Atom a, double radius, boolean sorted);
default double nearestDist(Atom a) {
return Math.sqrt(nearestSqrDist(a));
}
default double nearestDifferentDist(Atom a) {
return Math.sqrt(nearestDifferentSqrDist(a));
}
}

View File

@@ -0,0 +1,90 @@
package cz.siret.prank.geom.kdtree.v1;
import cz.siret.prank.geom.Atoms;
import cz.siret.prank.geom.kdtree.AtomKdTree;
import org.biojava.nbio.structure.Atom;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
public class AtomKdTreeV1 extends KdTree.SqrEuclid3D<Atom> implements AtomKdTree {
AtomKdTreeV1(int sizeLimit) {
super(sizeLimit);
}
public static AtomKdTreeV1 build(Atoms atoms) {
AtomKdTreeV1 res = new AtomKdTreeV1(Integer.MAX_VALUE);
List<Atom> list = atoms.list;
for (int i = 0, n = list.size(); i < n; i++) {
Atom a = list.get(i);
res.addPoint(a.getCoords(), a);
}
return res;
}
@Override
public Atom findNearest(Atom a) {
Entry<Atom> entry = singleNearestNeighbor(a.getCoords());
return entry != null ? entry.value : null;
}
@Override
public double nearestSqrDist(Atom a) {
return singleNearestNeighbor(a.getCoords()).distance;
}
@Override
public Atom findNearestDifferent(Atom a) {
Entry<Atom> entry = singleNearestDifferent(a);
return entry != null ? entry.value : null;
}
@Override
public double nearestDifferentSqrDist(Atom a) {
Entry<Atom> entry = singleNearestDifferent(a);
return entry != null ? entry.distance : Double.NaN;
}
private Entry<Atom> singleNearestDifferent(Atom a) {
List<Entry<Atom>> resList = nearestNeighbor(a.getCoords(), 2, false);
for (Entry<Atom> ent : resList) {
if (ent.value != a) {
return ent;
}
}
return null;
}
@Override
public Atoms findNearestNAtoms(Atom a, int count, boolean sorted) {
return toAtoms(nearestNeighbor(a.getCoords(), count, sorted));
}
@Override
public Atoms findNearestNDifferentAtoms(Atom a, int count, boolean sorted) {
List<Entry<Atom>> entries = nearestNeighbor(a.getCoords(), count, sorted);
Iterator<Entry<Atom>> it = entries.iterator();
while (it.hasNext()) {
if (it.next().value == a) {
it.remove();
}
}
return toAtoms(entries);
}
@Override
public Atoms findAtomsWithinRadius(Atom a, double radius, boolean sorted) {
double sqrRadius = radius * radius; // SqrEuclid uses squared distances
return toAtoms(neighboursWithinRadius(a.getCoords(), sqrRadius, sorted));
}
private static Atoms toAtoms(List<Entry<Atom>> entries) {
List<Atom> list = new ArrayList<>(entries.size());
for (Entry<Atom> e : entries) {
list.add(e.value);
}
return new Atoms(list);
}
}

View File

@@ -1,4 +1,4 @@
package cz.siret.prank.geom.kdtree;
package cz.siret.prank.geom.kdtree.v1;
/**
* Copyright 2009 Rednaxela
@@ -1000,4 +1000,4 @@ public abstract class KdTree<T> {
return distance[0];
}
}
}
}

View File

@@ -1,139 +0,0 @@
package cz.siret.prank.geom.kdtree.v2
import cz.siret.prank.geom.Atoms
import cz.siret.prank.geom.kdtree.AtomKdTree as AtomKdTreeV1
import cz.siret.prank.program.params.Params
import groovy.transform.CompileStatic
import org.biojava.nbio.structure.Atom
/**
* Unified KdTree wrapper that delegates to either v1 (Rednaxela) or v2 (KdTree3D)
* based on the kdtree_version parameter.
*
* Same method signatures as the original v1 AtomKdTree — Groovy callers accessing
* via Atoms.getKdTree() are transparent (duck typing).
*
* v2 (default): immutable KdTree3D — no add()/addAll(), thread-safe, SoA layout.
* v1: original Rednaxela KdTree — mutable, generic N-dimensional.
*/
@CompileStatic
class AtomKdTree {
// Exactly one of these is non-null, depending on the selected implementation.
private final KdTree3D tree
private final AtomKdTreeV1 treeV1
private AtomKdTree(KdTree3D tree) {
this.tree = tree
this.treeV1 = null
}
private AtomKdTree(AtomKdTreeV1 treeV1) {
this.tree = null
this.treeV1 = treeV1
}
static AtomKdTree build(Atoms atoms) {
if ("AtomKdTree" == Params.INSTANCE.kdtree_implementation) {
return new AtomKdTree(AtomKdTreeV1.build(atoms))
}
return new AtomKdTree(KdTree3D.build(atoms))
}
int size() {
return tree != null ? tree.size() : treeV1.size()
}
// --- Single nearest neighbor ---
Atom findNearest(Atom a) {
if (tree != null) {
return tree.findNearest(a.getX(), a.getY(), a.getZ())
}
return treeV1.findNearest(a)
}
double nearestDist(Atom a) {
return Math.sqrt(nearestSqrDist(a))
}
double nearestSqrDist(Atom a) {
if (tree != null) {
return tree.nearestSqrDist(a.getX(), a.getY(), a.getZ())
}
return treeV1.nearestSqrDist(a)
}
// --- Nearest different (excluding identity-equal atom) ---
Atom findNearestDifferent(Atom a) {
if (tree != null) {
KdTree3D.NNEntry entry = singleNearestDifferent(a)
return entry?.atom()
}
return treeV1.findNearestDifferent(a)
}
double nearestDifferentDist(Atom a) {
return Math.sqrt(nearestDifferentSqrDist(a))
}
double nearestDifferentSqrDist(Atom a) {
if (tree != null) {
KdTree3D.NNEntry entry = singleNearestDifferent(a)
return entry != null ? entry.sqrDist() : Double.NaN
}
return treeV1.nearestDifferentSqrDist(a)
}
private KdTree3D.NNEntry singleNearestDifferent(Atom a) {
List<KdTree3D.NNEntry> entries = tree.findNearestN(a.getX(), a.getY(), a.getZ(), 2, false)
for (KdTree3D.NNEntry entry : entries) {
if (!(entry.atom().is(a))) {
return entry
}
}
return null
}
// --- k-NN ---
Atoms findNearestNAtoms(Atom a, int count, boolean sorted) {
if (tree != null) {
return toAtoms(tree.findNearestN(a.getX(), a.getY(), a.getZ(), count, sorted))
}
return treeV1.findNearestNAtoms(a, count, sorted)
}
Atoms findNearestNDifferentAtoms(Atom a, int count, boolean sorted) {
if (tree != null) {
List<KdTree3D.NNEntry> entries = tree.findNearestN(a.getX(), a.getY(), a.getZ(), count + 1, sorted)
entries.removeIf { KdTree3D.NNEntry entry -> entry.atom().is(a) }
if (entries.size() > count) {
entries = entries.subList(0, count)
}
return toAtoms(entries)
}
return treeV1.findNearestNDifferentAtoms(a, count, sorted)
}
// --- Radius search ---
Atoms findAtomsWithinRadius(Atom a, double radius, boolean sorted) {
if (tree != null) {
double sqrRadius = radius * radius
return tree.findWithinRadius(a.getX(), a.getY(), a.getZ(), sqrRadius)
}
return treeV1.findAtomsWithinRadius(a, radius, sorted)
}
// --- Helpers ---
private static Atoms toAtoms(List<KdTree3D.NNEntry> entries) {
List<Atom> list = new ArrayList<>(entries.size())
for (KdTree3D.NNEntry e : entries) {
list.add(e.atom())
}
return new Atoms(list)
}
}

View File

@@ -0,0 +1,97 @@
package cz.siret.prank.geom.kdtree.v2
import cz.siret.prank.geom.Atoms
import cz.siret.prank.geom.kdtree.AtomKdTree
import groovy.transform.CompileStatic
import org.biojava.nbio.structure.Atom
/**
* AtomKdTree implementation backed by immutable KdTree3D (SoA layout, hardcoded 3D).
* Thread-safe for concurrent queries after construction.
*/
@CompileStatic
class AtomKdTreeV2 implements AtomKdTree {
private final KdTree3D tree
private AtomKdTreeV2(KdTree3D tree) {
this.tree = tree
}
static AtomKdTreeV2 build(Atoms atoms) {
return new AtomKdTreeV2(KdTree3D.build(atoms))
}
int size() {
return tree.size()
}
// --- Single nearest neighbor ---
Atom findNearest(Atom a) {
return tree.findNearest(a.getX(), a.getY(), a.getZ())
}
double nearestSqrDist(Atom a) {
return tree.nearestSqrDist(a.getX(), a.getY(), a.getZ())
}
/** Raw-coordinate variant for hot paths (e.g. consolidate). */
double nearestSqrDist(double x, double y, double z) {
return tree.nearestSqrDist(x, y, z)
}
// --- Nearest different (excluding identity-equal atom) ---
Atom findNearestDifferent(Atom a) {
KdTree3D.NNEntry entry = singleNearestDifferent(a)
return entry?.atom()
}
double nearestDifferentSqrDist(Atom a) {
KdTree3D.NNEntry entry = singleNearestDifferent(a)
return entry != null ? entry.sqrDist() : Double.NaN
}
private KdTree3D.NNEntry singleNearestDifferent(Atom a) {
List<KdTree3D.NNEntry> entries = tree.findNearestN(a.getX(), a.getY(), a.getZ(), 2, false)
for (KdTree3D.NNEntry entry : entries) {
if (!(entry.atom().is(a))) {
return entry
}
}
return null
}
// --- k-NN ---
Atoms findNearestNAtoms(Atom a, int count, boolean sorted) {
return toAtoms(tree.findNearestN(a.getX(), a.getY(), a.getZ(), count, sorted))
}
Atoms findNearestNDifferentAtoms(Atom a, int count, boolean sorted) {
List<KdTree3D.NNEntry> entries = tree.findNearestN(a.getX(), a.getY(), a.getZ(), count + 1, sorted)
entries.removeIf { KdTree3D.NNEntry entry -> entry.atom().is(a) }
if (entries.size() > count) {
entries = entries.subList(0, count)
}
return toAtoms(entries)
}
// --- Radius search ---
Atoms findAtomsWithinRadius(Atom a, double radius, boolean sorted) {
double sqrRadius = radius * radius
return tree.findWithinRadius(a.getX(), a.getY(), a.getZ(), sqrRadius)
}
// --- Helpers ---
private static Atoms toAtoms(List<KdTree3D.NNEntry> entries) {
List<Atom> list = new ArrayList<>(entries.size())
for (KdTree3D.NNEntry e : entries) {
list.add(e.atom())
}
return new Atoms(list)
}
}

View File

@@ -1376,10 +1376,10 @@ class Params {
boolean identify_peptides_by_labeling = false
/**
* KD-tree implementation: "KdTree3D" (immutable, SoA) or "AtomKdTree" (Rednaxela, mutable, generic)
* KD-tree implementation: "AtomKdTreeV2" (immutable, SoA) or "AtomKdTreeV1" (Rednaxela, mutable, generic)
*/
@RuntimeParam
String kdtree_implementation = "KdTree3D"
String kdtree_implementation = "AtomKdTreeV2"
/**
* Atoms size threshold for using KD-tree in cutoutSphere routine

View File

@@ -3,7 +3,6 @@ package cz.siret.prank.geom.kdtree.v2
import cz.siret.prank.domain.Protein
import cz.siret.prank.geom.Atoms
import cz.siret.prank.geom.Point
import cz.siret.prank.geom.Struct
import cz.siret.prank.utils.PerfUtils
import groovy.transform.CompileStatic
import org.biojava.nbio.structure.Atom
@@ -15,7 +14,7 @@ import java.util.concurrent.atomic.AtomicInteger
import static org.junit.jupiter.api.Assertions.*
/**
* Tests for KdTree3D and v2.AtomKdTree.
* Tests for KdTree3D and v2.AtomKdTreeV2.
*
* Parity tests compare v2 results against brute-force serial computation
* on real protein data to catch any algorithmic bugs.
@@ -37,7 +36,7 @@ class KdTree3DTest {
double RADIUS = 6d
Atoms atoms = p.proteinAtoms
AtomKdTree kdTree = AtomKdTree.build(atoms)
AtomKdTreeV2 kdTree = AtomKdTreeV2.build(atoms)
for (Atom a : atoms) {
// Brute-force serial scan
@@ -54,7 +53,7 @@ class KdTree3DTest {
void findWithinRadius_multipleRadii() {
Protein p = Protein.load('distro/test_data/2W83.pdb')
Atoms atoms = p.proteinAtoms
AtomKdTree kdTree = AtomKdTree.build(atoms)
AtomKdTreeV2 kdTree = AtomKdTreeV2.build(atoms)
// Test with different radii to exercise different pruning paths
for (double radius : [2d, 6d, 10d, 15d]) {
@@ -71,7 +70,7 @@ class KdTree3DTest {
void findNearest_parity() {
Protein p = Protein.load('distro/test_data/2W83.pdb')
Atoms atoms = p.proteinAtoms
AtomKdTree kdTree = AtomKdTree.build(atoms)
AtomKdTreeV2 kdTree = AtomKdTreeV2.build(atoms)
for (Atom a : atoms) {
// Brute-force: find nearest by scanning all
@@ -97,7 +96,7 @@ class KdTree3DTest {
void findNearestNAtoms_parity() {
Protein p = Protein.load('distro/test_data/2W83.pdb')
Atoms atoms = p.proteinAtoms
AtomKdTree kdTree = AtomKdTree.build(atoms)
AtomKdTreeV2 kdTree = AtomKdTreeV2.build(atoms)
int k = 9 // same as PyramidFeature usage
@@ -129,7 +128,7 @@ class KdTree3DTest {
@Test
void emptyTree() {
Atoms empty = new Atoms(0)
AtomKdTree tree = AtomKdTree.build(empty)
AtomKdTreeV2 tree = AtomKdTreeV2.build(empty)
assertEquals(0, tree.size())
assertNull(tree.findNearest(new Point(0, 0, 0)))
@@ -139,7 +138,7 @@ class KdTree3DTest {
@Test
void singlePoint() {
Point p = new Point(1, 2, 3)
AtomKdTree tree = AtomKdTree.build(new Atoms(p))
AtomKdTreeV2 tree = AtomKdTreeV2.build(new Atoms(p))
assertEquals(1, tree.size())
assertSame(p, tree.findNearest(new Point(0, 0, 0)))
@@ -155,7 +154,7 @@ class KdTree3DTest {
void twoPoints() {
Point p1 = new Point(0, 0, 0)
Point p2 = new Point(10, 0, 0)
AtomKdTree tree = AtomKdTree.build(new Atoms([p1, p2] as List<Atom>))
AtomKdTreeV2 tree = AtomKdTreeV2.build(new Atoms([p1, p2] as List<Atom>))
assertEquals(2, tree.size())
assertSame(p1, tree.findNearest(new Point(1, 0, 0)))
@@ -169,7 +168,7 @@ class KdTree3DTest {
for (int i = 0; i < 100; i++) {
points.add(new Point(i as double, 0, 0))
}
AtomKdTree tree = AtomKdTree.build(new Atoms(points))
AtomKdTreeV2 tree = AtomKdTreeV2.build(new Atoms(points))
assertEquals(100, tree.size())
// Nearest to (50.4, 0, 0) should be the point at x=50
@@ -184,7 +183,7 @@ class KdTree3DTest {
for (int i = 0; i < 50; i++) {
points.add(new Point(5, 5, 5))
}
AtomKdTree tree = AtomKdTree.build(new Atoms(points))
AtomKdTreeV2 tree = AtomKdTreeV2.build(new Atoms(points))
assertEquals(50, tree.size())
// All should be within any radius
@@ -197,7 +196,7 @@ class KdTree3DTest {
Point p1 = new Point(0, 0, 0)
Point p2 = new Point(1, 0, 0)
Point p3 = new Point(10, 0, 0)
AtomKdTree tree = AtomKdTree.build(new Atoms([p1, p2, p3] as List<Atom>))
AtomKdTreeV2 tree = AtomKdTreeV2.build(new Atoms([p1, p2, p3] as List<Atom>))
// findNearestDifferent(p1) should return p2, not p1 itself
Atom different = tree.findNearestDifferent(p1)
@@ -246,7 +245,7 @@ class KdTree3DTest {
// v1 KdTree would fail here due to mutable node.status field.
Protein p = Protein.load('distro/test_data/2W83.pdb')
Atoms atoms = p.proteinAtoms
AtomKdTree kdTree = AtomKdTree.build(atoms)
AtomKdTreeV2 kdTree = AtomKdTreeV2.build(atoms)
double RADIUS = 6d
int THREADS = 8

View File

@@ -2,8 +2,7 @@ package cz.siret.prank.geom.kdtree.v2
import cz.siret.prank.geom.Atoms
import cz.siret.prank.geom.Point
import cz.siret.prank.geom.kdtree.AtomKdTree as AtomKdTreeV1
import cz.siret.prank.geom.kdtree.v2.AtomKdTree as AtomKdTreeV2
import cz.siret.prank.geom.kdtree.v1.AtomKdTreeV1
import groovy.transform.CompileStatic
import org.biojava.nbio.structure.Atom
import org.junit.jupiter.api.BeforeAll