mirror of
https://github.com/rdk/p2rank.git
synced 2026-06-04 12:44:24 +08:00
Add comparative benchmark test for v1 vs v2 KdTree
Parametrized test generates random points, builds both trees, verifies identical results for all query types, and measures relative performance. Skipped during normal test runs; invoked via kdtree-benchmark.sh script.
This commit is contained in:
@@ -123,6 +123,12 @@ test {
|
||||
jvmArgs '--sun-misc-unsafe-memory-access=allow'
|
||||
}
|
||||
|
||||
// Forward kdtree benchmark properties to test JVM
|
||||
systemProperties System.properties.subMap([
|
||||
'kdtree.benchmark', 'kdtree.points', 'kdtree.queries',
|
||||
'kdtree.iterations', 'kdtree.seed', 'kdtree.radii'
|
||||
])
|
||||
|
||||
inputs.dir "$distroDir/test_data"
|
||||
inputs.files("$distroDir/models", "$distroDir/config/default.groovy")
|
||||
|
||||
|
||||
104
misc/test-scripts/kdtree-benchmark.sh
Executable file
104
misc/test-scripts/kdtree-benchmark.sh
Executable file
@@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
#
|
||||
# Runs comparative correctness and performance benchmark for v1 vs v2 KdTree.
|
||||
# Generates random points, builds both trees, verifies identical query results,
|
||||
# and measures relative performance. First iteration is warmup (excluded from averages).
|
||||
#
|
||||
# Usage: ./misc/test-scripts/kdtree-benchmark.sh [OPTIONS]
|
||||
#
|
||||
# Options:
|
||||
# -p, --points N Number of random data points (default: 5000)
|
||||
# -q, --queries N Number of random query points (default: 500)
|
||||
# -i, --iterations N Number of iterations with different seeds (default: 5)
|
||||
# -s, --seed N Base random seed (default: 42)
|
||||
# -r, --radii LIST Comma-separated radii to test (default: "2.0,6.0,10.0")
|
||||
# -h, --help Show this help
|
||||
#
|
||||
# Examples:
|
||||
#
|
||||
# # Quick smoke test — small dataset, few iterations, fast feedback
|
||||
# ./misc/test-scripts/kdtree-benchmark.sh -p 1000 -q 100 -i 3
|
||||
#
|
||||
# # Default run — typical protein size (5K atoms), good balance of speed and accuracy
|
||||
# ./misc/test-scripts/kdtree-benchmark.sh
|
||||
#
|
||||
# # Large stress test — 15K points with many queries, more iterations for stable timings
|
||||
# ./misc/test-scripts/kdtree-benchmark.sh -p 15000 -q 10000 -i 20
|
||||
#
|
||||
# # Reproducibility check — use a specific seed to reproduce exact results
|
||||
# ./misc/test-scripts/kdtree-benchmark.sh -s 12345
|
||||
#
|
||||
# # Test with tight radii typical for SAS point consolidation (~1.5 A)
|
||||
# ./misc/test-scripts/kdtree-benchmark.sh -r "1.0,1.5,3.0"
|
||||
#
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Defaults
|
||||
POINTS=5000
|
||||
QUERIES=500
|
||||
ITERATIONS=5
|
||||
SEED=42
|
||||
RADII="2.0,6.0,10.0"
|
||||
|
||||
usage() {
|
||||
sed -n '3,13p' "$0" | sed 's/^# \?//'
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
-p|--points) POINTS="$2"; shift 2 ;;
|
||||
-q|--queries) QUERIES="$2"; shift 2 ;;
|
||||
-i|--iterations) ITERATIONS="$2"; shift 2 ;;
|
||||
-s|--seed) SEED="$2"; shift 2 ;;
|
||||
-r|--radii) RADII="$2"; shift 2 ;;
|
||||
-h|--help) usage ;;
|
||||
*) echo "Unknown option: $1"; usage ;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo "KdTree Benchmark: points=$POINTS queries=$QUERIES iterations=$ITERATIONS seed=$SEED radii=$RADII"
|
||||
echo ""
|
||||
|
||||
RESULT_XML="build/test-results/test/TEST-cz.siret.prank.geom.kdtree.v2.KdTreeBenchmarkTest.xml"
|
||||
|
||||
./gradlew cleanTest test \
|
||||
--tests 'cz.siret.prank.geom.kdtree.v2.KdTreeBenchmarkTest' \
|
||||
-Dkdtree.benchmark=true \
|
||||
-Dkdtree.points="$POINTS" \
|
||||
-Dkdtree.queries="$QUERIES" \
|
||||
-Dkdtree.iterations="$ITERATIONS" \
|
||||
-Dkdtree.seed="$SEED" \
|
||||
-Dkdtree.radii="$RADII" \
|
||||
--quiet
|
||||
|
||||
# Extract and display report from test output (gradle hides stdout by default)
|
||||
if [[ -f "$RESULT_XML" ]]; then
|
||||
# Extract content between CDATA tags in system-out
|
||||
python3 -c "
|
||||
import xml.etree.ElementTree as ET
|
||||
tree = ET.parse('$RESULT_XML')
|
||||
root = tree.getroot()
|
||||
for tc in root.findall('testcase'):
|
||||
if tc.get('name') == 'compareV1vsV2()':
|
||||
errors = tc.findall('failure') + tc.findall('error')
|
||||
if errors:
|
||||
print('FAILED:', errors[0].text[:500] if errors[0].text else 'unknown error')
|
||||
so = root.find('system-out')
|
||||
if so is not None and so.text:
|
||||
print(so.text)
|
||||
" 2>/dev/null || {
|
||||
# Fallback: show report file
|
||||
if [[ -f local/kdtree-benchmark-report.txt ]]; then
|
||||
cat local/kdtree-benchmark-report.txt
|
||||
else
|
||||
echo "Test completed. See build/reports/tests/test/index.html for details."
|
||||
fi
|
||||
}
|
||||
else
|
||||
echo "ERROR: Test result not found. Check build/reports/tests/test/index.html"
|
||||
exit 1
|
||||
fi
|
||||
@@ -0,0 +1,263 @@
|
||||
package cz.siret.prank.geom.kdtree.v2
|
||||
|
||||
import cz.siret.prank.geom.Atoms
|
||||
import cz.siret.prank.geom.Point
|
||||
import cz.siret.prank.geom.kdtree.AtomKdTree as AtomKdTreeV1
|
||||
import cz.siret.prank.geom.kdtree.v2.AtomKdTree as AtomKdTreeV2
|
||||
import groovy.transform.CompileStatic
|
||||
import org.biojava.nbio.structure.Atom
|
||||
import org.junit.jupiter.api.BeforeAll
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*
|
||||
import static org.junit.jupiter.api.Assumptions.assumeTrue
|
||||
|
||||
/**
|
||||
* Comparative correctness and performance benchmark for v1 vs v2 KdTree.
|
||||
*
|
||||
* Generates random point sets, builds both trees, verifies identical results
|
||||
* for all query types, and measures relative performance.
|
||||
*
|
||||
* Skipped during normal test runs. Run via:
|
||||
* ./misc/test-scripts/kdtree-benchmark.sh [OPTIONS]
|
||||
*/
|
||||
@CompileStatic
|
||||
class KdTreeBenchmarkTest {
|
||||
|
||||
@BeforeAll
|
||||
static void checkEnabled() {
|
||||
assumeTrue("true" == System.getProperty("kdtree.benchmark"),
|
||||
"Skipped: set -Dkdtree.benchmark=true to run")
|
||||
}
|
||||
|
||||
// Configuration from system properties (with defaults)
|
||||
static final int POINT_COUNT = Integer.getInteger("kdtree.points", 5000)
|
||||
static final int QUERY_COUNT = Integer.getInteger("kdtree.queries", 500)
|
||||
static final int ITERATIONS = Integer.getInteger("kdtree.iterations", 5)
|
||||
static final long BASE_SEED = Long.getLong("kdtree.seed", 42L)
|
||||
static final double[] RADII = parseRadii(System.getProperty("kdtree.radii", "2.0,6.0,10.0"))
|
||||
static final int[] KNN_KS = [1, 5, 9, 20] as int[]
|
||||
|
||||
// Timing accumulators (indexed by iteration, excluding warmup)
|
||||
// Operations: build, findNearest, radius×N, knn
|
||||
static final int OP_BUILD = 0
|
||||
static final int OP_NEAREST = 1
|
||||
// OP_RADIUS_BASE + i for each radius
|
||||
// OP_KNN after radii
|
||||
|
||||
@Test
|
||||
void compareV1vsV2() {
|
||||
int radiusCount = RADII.length
|
||||
int opCount = 2 + radiusCount + 1 // build, nearest, radii..., knn(k=9)
|
||||
int OP_KNN = 2 + radiusCount
|
||||
|
||||
// Timing arrays: [operation][iteration] in nanoseconds
|
||||
long[][] v1Times = new long[opCount][ITERATIONS]
|
||||
long[][] v2Times = new long[opCount][ITERATIONS]
|
||||
long totalChecks = 0
|
||||
|
||||
println "=== KdTree v1 vs v2 Benchmark ==="
|
||||
println "Points: $POINT_COUNT Queries: $QUERY_COUNT Iterations: $ITERATIONS Base seed: $BASE_SEED"
|
||||
println "Radii: ${RADII.collect { String.format('%.1f', it) }.join(', ')}"
|
||||
println ""
|
||||
|
||||
for (int iter = 0; iter < ITERATIONS; iter++) {
|
||||
long seed = BASE_SEED + iter
|
||||
Random rng = new Random(seed)
|
||||
|
||||
// Generate random data points and query points
|
||||
List<Point> dataPoints = generatePoints(rng, POINT_COUNT)
|
||||
List<Point> queryPoints = generatePoints(rng, QUERY_COUNT)
|
||||
// Also include some data points as queries (tests self-match behavior)
|
||||
for (int i = 0; i < Math.min(50, dataPoints.size()); i++) {
|
||||
queryPoints.add(dataPoints.get(i))
|
||||
}
|
||||
Atoms atoms = new Atoms(dataPoints as List<Atom>)
|
||||
|
||||
// --- Build both trees ---
|
||||
long t0, t1
|
||||
|
||||
t0 = System.nanoTime()
|
||||
AtomKdTreeV1 v1Tree = AtomKdTreeV1.build(atoms)
|
||||
t1 = System.nanoTime()
|
||||
v1Times[OP_BUILD][iter] = t1 - t0
|
||||
|
||||
t0 = System.nanoTime()
|
||||
AtomKdTreeV2 v2Tree = AtomKdTreeV2.build(atoms)
|
||||
t1 = System.nanoTime()
|
||||
v2Times[OP_BUILD][iter] = t1 - t0
|
||||
|
||||
// --- findNearest correctness + timing ---
|
||||
int checks = 0
|
||||
|
||||
t0 = System.nanoTime()
|
||||
for (Point q : queryPoints) {
|
||||
v1Tree.findNearest(q)
|
||||
}
|
||||
t1 = System.nanoTime()
|
||||
v1Times[OP_NEAREST][iter] = t1 - t0
|
||||
|
||||
t0 = System.nanoTime()
|
||||
for (Point q : queryPoints) {
|
||||
v2Tree.findNearest(q)
|
||||
}
|
||||
t1 = System.nanoTime()
|
||||
v2Times[OP_NEAREST][iter] = t1 - t0
|
||||
|
||||
// Correctness check (separate pass to not pollute timing)
|
||||
for (Point q : queryPoints) {
|
||||
double v1dist = v1Tree.nearestSqrDist(q)
|
||||
double v2dist = v2Tree.nearestSqrDist(q)
|
||||
assertEquals(v1dist, v2dist, 1e-10,
|
||||
"nearestSqrDist mismatch, iter=$iter, seed=$seed, q=(${q.getX()},${q.getY()},${q.getZ()})")
|
||||
checks++
|
||||
}
|
||||
|
||||
// --- findAtomsWithinRadius correctness + timing ---
|
||||
for (int ri = 0; ri < radiusCount; ri++) {
|
||||
double radius = RADII[ri]
|
||||
int opIdx = 2 + ri
|
||||
|
||||
t0 = System.nanoTime()
|
||||
for (Point q : queryPoints) {
|
||||
v1Tree.findAtomsWithinRadius(q, radius, false)
|
||||
}
|
||||
t1 = System.nanoTime()
|
||||
v1Times[opIdx][iter] = t1 - t0
|
||||
|
||||
t0 = System.nanoTime()
|
||||
for (Point q : queryPoints) {
|
||||
v2Tree.findAtomsWithinRadius(q, radius, false)
|
||||
}
|
||||
t1 = System.nanoTime()
|
||||
v2Times[opIdx][iter] = t1 - t0
|
||||
|
||||
// Correctness
|
||||
for (Point q : queryPoints) {
|
||||
Set<Atom> v1set = v1Tree.findAtomsWithinRadius(q, radius, false).list.toSet()
|
||||
Set<Atom> v2set = v2Tree.findAtomsWithinRadius(q, radius, false).list.toSet()
|
||||
assertEquals(v1set, v2set,
|
||||
"radius=$radius mismatch, iter=$iter, seed=$seed, q=(${q.getX()},${q.getY()},${q.getZ()})")
|
||||
checks++
|
||||
}
|
||||
}
|
||||
|
||||
// --- findNearestNAtoms correctness + timing (k=9 for timing) ---
|
||||
t0 = System.nanoTime()
|
||||
for (Point q : queryPoints) {
|
||||
v1Tree.findNearestNAtoms(q, 9, false)
|
||||
}
|
||||
t1 = System.nanoTime()
|
||||
v1Times[OP_KNN][iter] = t1 - t0
|
||||
|
||||
t0 = System.nanoTime()
|
||||
for (Point q : queryPoints) {
|
||||
v2Tree.findNearestNAtoms(q, 9, false)
|
||||
}
|
||||
t1 = System.nanoTime()
|
||||
v2Times[OP_KNN][iter] = t1 - t0
|
||||
|
||||
// Correctness for all k values
|
||||
for (int k : KNN_KS) {
|
||||
for (Point q : queryPoints) {
|
||||
Set<Atom> v1set = v1Tree.findNearestNAtoms(q, k, false).list.toSet()
|
||||
Set<Atom> v2set = v2Tree.findNearestNAtoms(q, k, false).list.toSet()
|
||||
assertEquals(v1set, v2set,
|
||||
"kNN k=$k mismatch, iter=$iter, seed=$seed, q=(${q.getX()},${q.getY()},${q.getZ()})")
|
||||
checks++
|
||||
}
|
||||
}
|
||||
|
||||
totalChecks += checks
|
||||
println " Iteration ${iter + 1}/$ITERATIONS (seed=$seed): $checks checks passed"
|
||||
}
|
||||
|
||||
// --- Generate report ---
|
||||
println ""
|
||||
String report = generateReport(v1Times, v2Times, RADII, totalChecks)
|
||||
println report
|
||||
|
||||
// Write to file if local/ directory exists
|
||||
File reportFile = new File("local/kdtree-benchmark-report.txt")
|
||||
if (reportFile.parentFile.exists()) {
|
||||
reportFile.text = report
|
||||
println "(Report written to local/kdtree-benchmark-report.txt)"
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Helpers ====================
|
||||
|
||||
private static List<Point> generatePoints(Random rng, int count) {
|
||||
List<Point> points = new ArrayList<>(count)
|
||||
for (int i = 0; i < count; i++) {
|
||||
double x = rng.nextDouble() * 100.0 - 50.0 // [-50, 50]
|
||||
double y = rng.nextDouble() * 100.0 - 50.0
|
||||
double z = rng.nextDouble() * 100.0 - 50.0
|
||||
points.add(new Point(x, y, z))
|
||||
}
|
||||
return points
|
||||
}
|
||||
|
||||
private static double[] parseRadii(String s) {
|
||||
String[] parts = s.split(",")
|
||||
double[] result = new double[parts.length]
|
||||
for (int i = 0; i < parts.length; i++) {
|
||||
result[i] = Double.parseDouble(parts[i].trim())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute average time in ms, skipping the first iteration (warmup).
|
||||
*/
|
||||
private static double avgMs(long[] nanoTimes) {
|
||||
if (nanoTimes.length <= 1) {
|
||||
return nanoTimes[0] / 1_000_000.0
|
||||
}
|
||||
long sum = 0
|
||||
for (int i = 1; i < nanoTimes.length; i++) {
|
||||
sum += nanoTimes[i]
|
||||
}
|
||||
return (sum / (nanoTimes.length - 1)) / 1_000_000.0
|
||||
}
|
||||
|
||||
private static String generateReport(long[][] v1Times, long[][] v2Times, double[] radii, long totalChecks) {
|
||||
StringBuilder sb = new StringBuilder()
|
||||
sb.append("=== KdTree Benchmark Report ===\n")
|
||||
sb.append("Points: $POINT_COUNT Queries: $QUERY_COUNT Iterations: $ITERATIONS Base seed: $BASE_SEED\n")
|
||||
if (ITERATIONS > 1) {
|
||||
sb.append("(First iteration is warmup, excluded from averages)\n")
|
||||
}
|
||||
sb.append("\n")
|
||||
sb.append(String.format("%-28s %10s %10s %10s\n", "Operation", "v1 (ms)", "v2 (ms)", "speedup"))
|
||||
sb.append("-".multiply(60))
|
||||
sb.append("\n")
|
||||
|
||||
// Build
|
||||
appendRow(sb, "Build", v1Times[OP_BUILD], v2Times[OP_BUILD])
|
||||
|
||||
// findNearest
|
||||
appendRow(sb, "findNearest", v1Times[OP_NEAREST], v2Times[OP_NEAREST])
|
||||
|
||||
// Radius queries
|
||||
for (int i = 0; i < radii.length; i++) {
|
||||
appendRow(sb, "radius r=${String.format('%.1f', radii[i])}", v1Times[2 + i], v2Times[2 + i])
|
||||
}
|
||||
|
||||
// k-NN
|
||||
int opKnn = 2 + radii.length
|
||||
appendRow(sb, "findNearestN k=9", v1Times[opKnn], v2Times[opKnn])
|
||||
|
||||
sb.append("-".multiply(60))
|
||||
sb.append("\n")
|
||||
sb.append("\nAll correctness checks passed: $totalChecks total\n")
|
||||
return sb.toString()
|
||||
}
|
||||
|
||||
private static void appendRow(StringBuilder sb, String name, long[] v1Nanos, long[] v2Nanos) {
|
||||
double v1ms = avgMs(v1Nanos)
|
||||
double v2ms = avgMs(v2Nanos)
|
||||
double speedup = v1ms > 0 ? v1ms / v2ms : 0
|
||||
sb.append(String.format("%-28s %10.1f %10.1f %9.2fx\n", name, v1ms, v2ms, speedup))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user