diff --git a/build.gradle b/build.gradle index 7c1a6d37..d9d5d8b7 100644 --- a/build.gradle +++ b/build.gradle @@ -123,6 +123,12 @@ test { jvmArgs '--sun-misc-unsafe-memory-access=allow' } + // Forward kdtree benchmark properties to test JVM + systemProperties System.properties.subMap([ + 'kdtree.benchmark', 'kdtree.points', 'kdtree.queries', + 'kdtree.iterations', 'kdtree.seed', 'kdtree.radii' + ]) + inputs.dir "$distroDir/test_data" inputs.files("$distroDir/models", "$distroDir/config/default.groovy") diff --git a/misc/test-scripts/kdtree-benchmark.sh b/misc/test-scripts/kdtree-benchmark.sh new file mode 100755 index 00000000..d4b0a673 --- /dev/null +++ b/misc/test-scripts/kdtree-benchmark.sh @@ -0,0 +1,104 @@ +#!/usr/bin/env bash + +# +# Runs comparative correctness and performance benchmark for v1 vs v2 KdTree. +# Generates random points, builds both trees, verifies identical query results, +# and measures relative performance. First iteration is warmup (excluded from averages). +# +# Usage: ./misc/test-scripts/kdtree-benchmark.sh [OPTIONS] +# +# Options: +# -p, --points N Number of random data points (default: 5000) +# -q, --queries N Number of random query points (default: 500) +# -i, --iterations N Number of iterations with different seeds (default: 5) +# -s, --seed N Base random seed (default: 42) +# -r, --radii LIST Comma-separated radii to test (default: "2.0,6.0,10.0") +# -h, --help Show this help +# +# Examples: +# +# # Quick smoke test — small dataset, few iterations, fast feedback +# ./misc/test-scripts/kdtree-benchmark.sh -p 1000 -q 100 -i 3 +# +# # Default run — typical protein size (5K atoms), good balance of speed and accuracy +# ./misc/test-scripts/kdtree-benchmark.sh +# +# # Large stress test — 15K points with many queries, more iterations for stable timings +# ./misc/test-scripts/kdtree-benchmark.sh -p 15000 -q 10000 -i 20 +# +# # Reproducibility check — use a specific seed to reproduce exact results +# ./misc/test-scripts/kdtree-benchmark.sh -s 12345 +# +# # Test with tight radii typical for SAS point consolidation (~1.5 A) +# ./misc/test-scripts/kdtree-benchmark.sh -r "1.0,1.5,3.0" +# + +set -euo pipefail + +# Defaults +POINTS=5000 +QUERIES=500 +ITERATIONS=5 +SEED=42 +RADII="2.0,6.0,10.0" + +usage() { + sed -n '3,13p' "$0" | sed 's/^# \?//' + exit 0 +} + +# Parse arguments +while [[ $# -gt 0 ]]; do + case "$1" in + -p|--points) POINTS="$2"; shift 2 ;; + -q|--queries) QUERIES="$2"; shift 2 ;; + -i|--iterations) ITERATIONS="$2"; shift 2 ;; + -s|--seed) SEED="$2"; shift 2 ;; + -r|--radii) RADII="$2"; shift 2 ;; + -h|--help) usage ;; + *) echo "Unknown option: $1"; usage ;; + esac +done + +echo "KdTree Benchmark: points=$POINTS queries=$QUERIES iterations=$ITERATIONS seed=$SEED radii=$RADII" +echo "" + +RESULT_XML="build/test-results/test/TEST-cz.siret.prank.geom.kdtree.v2.KdTreeBenchmarkTest.xml" + +./gradlew cleanTest test \ + --tests 'cz.siret.prank.geom.kdtree.v2.KdTreeBenchmarkTest' \ + -Dkdtree.benchmark=true \ + -Dkdtree.points="$POINTS" \ + -Dkdtree.queries="$QUERIES" \ + -Dkdtree.iterations="$ITERATIONS" \ + -Dkdtree.seed="$SEED" \ + -Dkdtree.radii="$RADII" \ + --quiet + +# Extract and display report from test output (gradle hides stdout by default) +if [[ -f "$RESULT_XML" ]]; then + # Extract content between CDATA tags in system-out + python3 -c " +import xml.etree.ElementTree as ET +tree = ET.parse('$RESULT_XML') +root = tree.getroot() +for tc in root.findall('testcase'): + if tc.get('name') == 'compareV1vsV2()': + errors = tc.findall('failure') + tc.findall('error') + if errors: + print('FAILED:', errors[0].text[:500] if errors[0].text else 'unknown error') +so = root.find('system-out') +if so is not None and so.text: + print(so.text) +" 2>/dev/null || { + # Fallback: show report file + if [[ -f local/kdtree-benchmark-report.txt ]]; then + cat local/kdtree-benchmark-report.txt + else + echo "Test completed. See build/reports/tests/test/index.html for details." + fi +} +else + echo "ERROR: Test result not found. Check build/reports/tests/test/index.html" + exit 1 +fi diff --git a/src/test/groovy/cz/siret/prank/geom/kdtree/v2/KdTreeBenchmarkTest.groovy b/src/test/groovy/cz/siret/prank/geom/kdtree/v2/KdTreeBenchmarkTest.groovy new file mode 100644 index 00000000..463197da --- /dev/null +++ b/src/test/groovy/cz/siret/prank/geom/kdtree/v2/KdTreeBenchmarkTest.groovy @@ -0,0 +1,263 @@ +package cz.siret.prank.geom.kdtree.v2 + +import cz.siret.prank.geom.Atoms +import cz.siret.prank.geom.Point +import cz.siret.prank.geom.kdtree.AtomKdTree as AtomKdTreeV1 +import cz.siret.prank.geom.kdtree.v2.AtomKdTree as AtomKdTreeV2 +import groovy.transform.CompileStatic +import org.biojava.nbio.structure.Atom +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.Test + +import static org.junit.jupiter.api.Assertions.* +import static org.junit.jupiter.api.Assumptions.assumeTrue + +/** + * Comparative correctness and performance benchmark for v1 vs v2 KdTree. + * + * Generates random point sets, builds both trees, verifies identical results + * for all query types, and measures relative performance. + * + * Skipped during normal test runs. Run via: + * ./misc/test-scripts/kdtree-benchmark.sh [OPTIONS] + */ +@CompileStatic +class KdTreeBenchmarkTest { + + @BeforeAll + static void checkEnabled() { + assumeTrue("true" == System.getProperty("kdtree.benchmark"), + "Skipped: set -Dkdtree.benchmark=true to run") + } + + // Configuration from system properties (with defaults) + static final int POINT_COUNT = Integer.getInteger("kdtree.points", 5000) + static final int QUERY_COUNT = Integer.getInteger("kdtree.queries", 500) + static final int ITERATIONS = Integer.getInteger("kdtree.iterations", 5) + static final long BASE_SEED = Long.getLong("kdtree.seed", 42L) + static final double[] RADII = parseRadii(System.getProperty("kdtree.radii", "2.0,6.0,10.0")) + static final int[] KNN_KS = [1, 5, 9, 20] as int[] + + // Timing accumulators (indexed by iteration, excluding warmup) + // Operations: build, findNearest, radius×N, knn + static final int OP_BUILD = 0 + static final int OP_NEAREST = 1 + // OP_RADIUS_BASE + i for each radius + // OP_KNN after radii + + @Test + void compareV1vsV2() { + int radiusCount = RADII.length + int opCount = 2 + radiusCount + 1 // build, nearest, radii..., knn(k=9) + int OP_KNN = 2 + radiusCount + + // Timing arrays: [operation][iteration] in nanoseconds + long[][] v1Times = new long[opCount][ITERATIONS] + long[][] v2Times = new long[opCount][ITERATIONS] + long totalChecks = 0 + + println "=== KdTree v1 vs v2 Benchmark ===" + println "Points: $POINT_COUNT Queries: $QUERY_COUNT Iterations: $ITERATIONS Base seed: $BASE_SEED" + println "Radii: ${RADII.collect { String.format('%.1f', it) }.join(', ')}" + println "" + + for (int iter = 0; iter < ITERATIONS; iter++) { + long seed = BASE_SEED + iter + Random rng = new Random(seed) + + // Generate random data points and query points + List dataPoints = generatePoints(rng, POINT_COUNT) + List queryPoints = generatePoints(rng, QUERY_COUNT) + // Also include some data points as queries (tests self-match behavior) + for (int i = 0; i < Math.min(50, dataPoints.size()); i++) { + queryPoints.add(dataPoints.get(i)) + } + Atoms atoms = new Atoms(dataPoints as List) + + // --- Build both trees --- + long t0, t1 + + t0 = System.nanoTime() + AtomKdTreeV1 v1Tree = AtomKdTreeV1.build(atoms) + t1 = System.nanoTime() + v1Times[OP_BUILD][iter] = t1 - t0 + + t0 = System.nanoTime() + AtomKdTreeV2 v2Tree = AtomKdTreeV2.build(atoms) + t1 = System.nanoTime() + v2Times[OP_BUILD][iter] = t1 - t0 + + // --- findNearest correctness + timing --- + int checks = 0 + + t0 = System.nanoTime() + for (Point q : queryPoints) { + v1Tree.findNearest(q) + } + t1 = System.nanoTime() + v1Times[OP_NEAREST][iter] = t1 - t0 + + t0 = System.nanoTime() + for (Point q : queryPoints) { + v2Tree.findNearest(q) + } + t1 = System.nanoTime() + v2Times[OP_NEAREST][iter] = t1 - t0 + + // Correctness check (separate pass to not pollute timing) + for (Point q : queryPoints) { + double v1dist = v1Tree.nearestSqrDist(q) + double v2dist = v2Tree.nearestSqrDist(q) + assertEquals(v1dist, v2dist, 1e-10, + "nearestSqrDist mismatch, iter=$iter, seed=$seed, q=(${q.getX()},${q.getY()},${q.getZ()})") + checks++ + } + + // --- findAtomsWithinRadius correctness + timing --- + for (int ri = 0; ri < radiusCount; ri++) { + double radius = RADII[ri] + int opIdx = 2 + ri + + t0 = System.nanoTime() + for (Point q : queryPoints) { + v1Tree.findAtomsWithinRadius(q, radius, false) + } + t1 = System.nanoTime() + v1Times[opIdx][iter] = t1 - t0 + + t0 = System.nanoTime() + for (Point q : queryPoints) { + v2Tree.findAtomsWithinRadius(q, radius, false) + } + t1 = System.nanoTime() + v2Times[opIdx][iter] = t1 - t0 + + // Correctness + for (Point q : queryPoints) { + Set v1set = v1Tree.findAtomsWithinRadius(q, radius, false).list.toSet() + Set v2set = v2Tree.findAtomsWithinRadius(q, radius, false).list.toSet() + assertEquals(v1set, v2set, + "radius=$radius mismatch, iter=$iter, seed=$seed, q=(${q.getX()},${q.getY()},${q.getZ()})") + checks++ + } + } + + // --- findNearestNAtoms correctness + timing (k=9 for timing) --- + t0 = System.nanoTime() + for (Point q : queryPoints) { + v1Tree.findNearestNAtoms(q, 9, false) + } + t1 = System.nanoTime() + v1Times[OP_KNN][iter] = t1 - t0 + + t0 = System.nanoTime() + for (Point q : queryPoints) { + v2Tree.findNearestNAtoms(q, 9, false) + } + t1 = System.nanoTime() + v2Times[OP_KNN][iter] = t1 - t0 + + // Correctness for all k values + for (int k : KNN_KS) { + for (Point q : queryPoints) { + Set v1set = v1Tree.findNearestNAtoms(q, k, false).list.toSet() + Set v2set = v2Tree.findNearestNAtoms(q, k, false).list.toSet() + assertEquals(v1set, v2set, + "kNN k=$k mismatch, iter=$iter, seed=$seed, q=(${q.getX()},${q.getY()},${q.getZ()})") + checks++ + } + } + + totalChecks += checks + println " Iteration ${iter + 1}/$ITERATIONS (seed=$seed): $checks checks passed" + } + + // --- Generate report --- + println "" + String report = generateReport(v1Times, v2Times, RADII, totalChecks) + println report + + // Write to file if local/ directory exists + File reportFile = new File("local/kdtree-benchmark-report.txt") + if (reportFile.parentFile.exists()) { + reportFile.text = report + println "(Report written to local/kdtree-benchmark-report.txt)" + } + } + + // ==================== Helpers ==================== + + private static List generatePoints(Random rng, int count) { + List points = new ArrayList<>(count) + for (int i = 0; i < count; i++) { + double x = rng.nextDouble() * 100.0 - 50.0 // [-50, 50] + double y = rng.nextDouble() * 100.0 - 50.0 + double z = rng.nextDouble() * 100.0 - 50.0 + points.add(new Point(x, y, z)) + } + return points + } + + private static double[] parseRadii(String s) { + String[] parts = s.split(",") + double[] result = new double[parts.length] + for (int i = 0; i < parts.length; i++) { + result[i] = Double.parseDouble(parts[i].trim()) + } + return result + } + + /** + * Compute average time in ms, skipping the first iteration (warmup). + */ + private static double avgMs(long[] nanoTimes) { + if (nanoTimes.length <= 1) { + return nanoTimes[0] / 1_000_000.0 + } + long sum = 0 + for (int i = 1; i < nanoTimes.length; i++) { + sum += nanoTimes[i] + } + return (sum / (nanoTimes.length - 1)) / 1_000_000.0 + } + + private static String generateReport(long[][] v1Times, long[][] v2Times, double[] radii, long totalChecks) { + StringBuilder sb = new StringBuilder() + sb.append("=== KdTree Benchmark Report ===\n") + sb.append("Points: $POINT_COUNT Queries: $QUERY_COUNT Iterations: $ITERATIONS Base seed: $BASE_SEED\n") + if (ITERATIONS > 1) { + sb.append("(First iteration is warmup, excluded from averages)\n") + } + sb.append("\n") + sb.append(String.format("%-28s %10s %10s %10s\n", "Operation", "v1 (ms)", "v2 (ms)", "speedup")) + sb.append("-".multiply(60)) + sb.append("\n") + + // Build + appendRow(sb, "Build", v1Times[OP_BUILD], v2Times[OP_BUILD]) + + // findNearest + appendRow(sb, "findNearest", v1Times[OP_NEAREST], v2Times[OP_NEAREST]) + + // Radius queries + for (int i = 0; i < radii.length; i++) { + appendRow(sb, "radius r=${String.format('%.1f', radii[i])}", v1Times[2 + i], v2Times[2 + i]) + } + + // k-NN + int opKnn = 2 + radii.length + appendRow(sb, "findNearestN k=9", v1Times[opKnn], v2Times[opKnn]) + + sb.append("-".multiply(60)) + sb.append("\n") + sb.append("\nAll correctness checks passed: $totalChecks total\n") + return sb.toString() + } + + private static void appendRow(StringBuilder sb, String name, long[] v1Nanos, long[] v2Nanos) { + double v1ms = avgMs(v1Nanos) + double v2ms = avgMs(v2Nanos) + double speedup = v1ms > 0 ? v1ms / v2ms : 0 + sb.append(String.format("%-28s %10.1f %10.1f %9.2fx\n", name, v1ms, v2ms, speedup)) + } +}