Add fused addWeighted and O(N²) single-linkage clusterer

Add GenericVector.addWeighted() for fused multiply-add, eliminating per-
neighbor array allocation in feature vector aggregation. Add SLinkClustererV2
using union-find with path compression, reducing single-linkage clustering
from O(N³) to O(N²). Wire V2 via factory methods on AtomClusterer and
AtomGroupClusterer.
This commit is contained in:
rdk
2026-03-03 05:13:05 +01:00
parent 261dae09c9
commit 8f5da9fdcd
8 changed files with 355 additions and 4 deletions

View File

@@ -303,7 +303,7 @@ class PrankFeatureExtractor extends FeatureExtractor<PrankFeatureVector> impleme
double weight = calcWeight(dist)
weightSum += weight
res.add( props.copy().multiply(weight) )
res.addWeighted(props, weight)
}
if (AVERAGE_FEAT_VECTORS) {

View File

@@ -113,4 +113,12 @@ public class PrankFeatureVector extends FeatureVector implements Cloneable {
return this;
}
/**
* modifies this instance: this += p * weight (no allocation)
*/
public PrankFeatureVector addWeighted(PrankFeatureVector p, double weight) {
valueVector.addWeighted(p.valueVector, weight);
return this;
}
}

View File

@@ -93,6 +93,18 @@ public class GenericVector {
return this;
}
/**
* modifies instance: data[i] += gv.data[i] * weight
*/
public GenericVector addWeighted(final GenericVector gv, double weight) {
final double[] gvData = gv.data;
final int n = data.length;
for (int i = 0; i < n; ++i) {
data[i] += gvData[i] * weight;
}
return this;
}
/**
* modifies instance
*/

View File

@@ -7,7 +7,6 @@ import cz.siret.prank.domain.Residue
import cz.siret.prank.domain.ResidueChain
import cz.siret.prank.geom.clustering.AtomClusterer
import cz.siret.prank.geom.clustering.AtomGroupClusterer
import cz.siret.prank.geom.clustering.SLinkClusterer
import cz.siret.prank.utils.Cutils
import cz.siret.prank.utils.PdbUtils
import cz.siret.prank.utils.PerfUtils
@@ -178,11 +177,11 @@ class Struct {
* @return
*/
static List<Atoms> clusterAtoms(Atoms atoms, double clusterDist) {
return new AtomClusterer(new SLinkClusterer<Atom>()).clusterAtoms(atoms, clusterDist)
return AtomClusterer.singleLinkage().clusterAtoms(atoms, clusterDist)
}
static List<Atoms> clusterAtomGroups(List<Atoms> atomGroups, double clusterDist ) {
return new AtomGroupClusterer(new SLinkClusterer()).clusterGroups(atomGroups, clusterDist)
return AtomGroupClusterer.singleLinkage().clusterGroups(atomGroups, clusterDist)
}
static final Ordering<Group> GROUP_ORDERING = new Ordering<Group>() {

View File

@@ -19,6 +19,10 @@ class AtomClusterer implements Clusterer<Atom> {
this.clusteringAlgorithm = clusteringAlgorithm
}
static AtomClusterer singleLinkage() {
return new AtomClusterer(new SLinkClustererV2<Atom>())
}
@Override
List<List<Atom>> cluster(List<Atom> elements, double minDist, Clusterer.Distance<Atom> distDef) {
return clusteringAlgorithm.cluster(elements, minDist, distDef)

View File

@@ -17,6 +17,10 @@ class AtomGroupClusterer implements Clusterer<Atoms> {
this.clusteringAlgorithm = clusteringAlgorithm
}
static AtomGroupClusterer singleLinkage() {
return new AtomGroupClusterer(new SLinkClustererV2<Atoms>())
}
@Override
List<List<Atoms>> cluster(List<Atoms> elements, double minDist, Clusterer.Distance<Atoms> distDef) {
return clusteringAlgorithm.cluster(elements, minDist, distDef)

View File

@@ -0,0 +1,82 @@
package cz.siret.prank.geom.clustering;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
/**
* Single-linkage clusterer using Union-Find with path compression and union by rank.
* O(N² α(N)) ≈ O(N²) — eliminates the O(N) relabeling loop of SLinkClusterer (O(N³)).
*/
public class SLinkClustererV2<E> implements Clusterer<E> {
private static final Logger log = LoggerFactory.getLogger(SLinkClustererV2.class);
private int[] parent;
private int[] rank;
private int find(int x) {
while (parent[x] != x) {
parent[x] = parent[parent[x]]; // path halving
x = parent[x];
}
return x;
}
private void union(int a, int b) {
int ra = find(a);
int rb = find(b);
if (ra == rb) return;
if (rank[ra] < rank[rb]) {
parent[ra] = rb;
} else if (rank[ra] > rank[rb]) {
parent[rb] = ra;
} else {
parent[rb] = ra;
rank[ra]++;
}
}
@Override
@SuppressWarnings("unchecked")
public List<List<E>> cluster(List<E> elements, double minDist, Clusterer.Distance<E> distDef) {
if (elements.isEmpty()) return Collections.emptyList();
if (elements.size() == 1) return new ArrayList<>(Collections.singletonList(elements));
Object[] els = elements.toArray();
int N = els.length;
parent = new int[N];
rank = new int[N];
for (int i = 0; i < N; i++) {
parent[i] = i;
}
// Check all pairs; merge if within minDist (same iteration order as V1)
for (int j = N - 1; j >= 1; j--) {
for (int i = j - 1; i >= 0; i--) {
if (find(i) != find(j)) {
double dist = distDef.dist((E) els[i], (E) els[j]);
if (dist <= minDist) {
union(i, j);
}
}
}
}
// Collect clusters
Map<Integer, List<E>> clusterMap = new LinkedHashMap<>();
for (int i = 0; i < N; i++) {
int root = find(i);
clusterMap.computeIfAbsent(root, k -> new ArrayList<>()).add((E) els[i]);
}
List<List<E>> result = new ArrayList<>(clusterMap.values());
log.info("clusters: {}", result);
log.info("clusters together: {} / {}", result.stream().mapToInt(List::size).sum(), elements.size());
return result;
}
}

View File

@@ -0,0 +1,242 @@
package cz.siret.prank.geom.clustering
import cz.siret.prank.domain.Protein
import cz.siret.prank.geom.Atoms
import groovy.transform.CompileStatic
import org.biojava.nbio.structure.Atom
import org.junit.jupiter.api.Test
import static org.junit.jupiter.api.Assertions.*
@CompileStatic
class SLinkClustererTest {
static final Clusterer.Distance<Integer> INT_DIST = new Clusterer.Distance<Integer>() {
@Override
double dist(Integer a, Integer b) {
return Math.abs(a - b) as double
}
}
private static Set<Set<Integer>> toSetOfSets(List<List<Integer>> clusters) {
return clusters.collect { it.toSet() }.toSet()
}
// ==================== Basic correctness ====================
@Test
void emptyInput() {
def v1 = new SLinkClusterer<Integer>().cluster([], 1.0, INT_DIST)
def v2 = new SLinkClustererV2<Integer>().cluster([], 1.0, INT_DIST)
assertTrue(v1.isEmpty())
assertTrue(v2.isEmpty())
}
@Test
void singleElement() {
def v1 = new SLinkClusterer<Integer>().cluster([42], 1.0, INT_DIST)
def v2 = new SLinkClustererV2<Integer>().cluster([42], 1.0, INT_DIST)
assertEquals(toSetOfSets(v1), toSetOfSets(v2))
assertEquals(1, v2.size())
assertEquals([42], v2[0])
}
@Test
void twoClusters() {
// 1,2,3 should cluster together; 10,11 should cluster together (dist <= 2)
def elements = [1, 2, 3, 10, 11]
def v1 = new SLinkClusterer<Integer>().cluster(elements, 2.0, INT_DIST)
def v2 = new SLinkClustererV2<Integer>().cluster(elements, 2.0, INT_DIST)
def expected = [[1, 2, 3].toSet(), [10, 11].toSet()].toSet()
assertEquals(expected, toSetOfSets(v1))
assertEquals(expected, toSetOfSets(v2))
}
@Test
void allInOneCluster() {
def elements = [1, 2, 3, 4, 5]
def v1 = new SLinkClusterer<Integer>().cluster(elements, 5.0, INT_DIST)
def v2 = new SLinkClustererV2<Integer>().cluster(elements, 5.0, INT_DIST)
assertEquals(1, v1.size())
assertEquals(1, v2.size())
assertEquals(elements.toSet(), v2[0].toSet())
}
@Test
void allSingletons() {
def elements = [1, 10, 20, 30]
def v1 = new SLinkClusterer<Integer>().cluster(elements, 0.5, INT_DIST)
def v2 = new SLinkClustererV2<Integer>().cluster(elements, 0.5, INT_DIST)
assertEquals(4, v1.size())
assertEquals(4, v2.size())
assertEquals(toSetOfSets(v1), toSetOfSets(v2))
}
@Test
void transitiveChaining() {
// 1-3-5: each adjacent pair within dist 2, so all should merge via single-linkage
def elements = [1, 3, 5]
def v1 = new SLinkClusterer<Integer>().cluster(elements, 2.0, INT_DIST)
def v2 = new SLinkClustererV2<Integer>().cluster(elements, 2.0, INT_DIST)
assertEquals(1, v1.size())
assertEquals(1, v2.size())
assertEquals(toSetOfSets(v1), toSetOfSets(v2))
}
@Test
void boundaryDistance() {
// Elements exactly at minDist should be merged (dist <= minDist)
def elements = [0, 5]
def v1 = new SLinkClusterer<Integer>().cluster(elements, 5.0, INT_DIST)
def v2 = new SLinkClustererV2<Integer>().cluster(elements, 5.0, INT_DIST)
assertEquals(1, v1.size())
assertEquals(1, v2.size())
// Just below: should not merge
def v1b = new SLinkClusterer<Integer>().cluster(elements, 4.9, INT_DIST)
def v2b = new SLinkClustererV2<Integer>().cluster(elements, 4.9, INT_DIST)
assertEquals(2, v1b.size())
assertEquals(2, v2b.size())
}
// ==================== Random 3D data parity ====================
static final Clusterer.Distance<double[]> SQR_EUCLID_3D = new Clusterer.Distance<double[]>() {
@Override
double dist(double[] a, double[] b) {
double dx = a[0] - b[0]
double dy = a[1] - b[1]
double dz = a[2] - b[2]
return dx * dx + dy * dy + dz * dz
}
}
private static Set<Set<Integer>> indexSets(List<double[]> elements, List<List<double[]>> clusters) {
Map<double[], Integer> indexMap = new IdentityHashMap<>()
for (int i = 0; i < elements.size(); i++) {
indexMap.put(elements[i], i)
}
return clusters.collect { List<double[]> c -> c.collect { indexMap.get(it) }.toSet() }.toSet()
}
@Test
void parity_random3d_sparse() {
parity_random3d(50, 0.0, 100.0, 5.0, 42L)
}
@Test
void parity_random3d_dense() {
parity_random3d(80, 0.0, 10.0, 25.0, 123L)
}
@Test
void parity_random3d_medium() {
parity_random3d(100, 0.0, 50.0, 15.0, 7L)
}
@Test
void parity_random3d_varyingDistances() {
Random rng = new Random(999L)
List<double[]> points = generateRandom3dPoints(60, 0.0, 30.0, rng)
for (double minDist : [1.0, 5.0, 10.0, 20.0, 50.0]) {
double sqrDist = minDist * minDist
def v1 = new SLinkClusterer<double[]>().cluster(points, sqrDist, SQR_EUCLID_3D)
def v2 = new SLinkClustererV2<double[]>().cluster(points, sqrDist, SQR_EUCLID_3D)
assertEquals(indexSets(points, v1), indexSets(points, v2),
"Mismatch for minDist=$minDist")
}
}
@Test
void parity_random3d_duplicateCoordinates() {
List<double[]> points = []
Random rng = new Random(55L)
for (int i = 0; i < 5; i++) {
points.add([5.0, 5.0, 5.0] as double[])
points.add([20.0, 20.0, 20.0] as double[])
}
for (int i = 0; i < 5; i++) {
points.add([rng.nextDouble() * 100, rng.nextDouble() * 100, rng.nextDouble() * 100] as double[])
}
Collections.shuffle(points, rng)
double sqrDist = 1.0
def v1 = new SLinkClusterer<double[]>().cluster(points, sqrDist, SQR_EUCLID_3D)
def v2 = new SLinkClustererV2<double[]>().cluster(points, sqrDist, SQR_EUCLID_3D)
assertEquals(indexSets(points, v1), indexSets(points, v2))
}
@Test
void parity_random3d_largerSet() {
parity_random3d(300, -50.0, 50.0, 8.0, 2026L)
}
private void parity_random3d(int n, double min, double max, double minDist, long seed) {
Random rng = new Random(seed)
List<double[]> points = generateRandom3dPoints(n, min, max, rng)
double sqrDist = minDist * minDist
def v1 = new SLinkClusterer<double[]>().cluster(points, sqrDist, SQR_EUCLID_3D)
def v2 = new SLinkClustererV2<double[]>().cluster(points, sqrDist, SQR_EUCLID_3D)
assertEquals(v1.size(), v2.size(),
"Cluster count mismatch: n=$n minDist=$minDist seed=$seed")
assertEquals(indexSets(points, v1), indexSets(points, v2),
"Cluster memberships differ: n=$n minDist=$minDist seed=$seed")
}
private static List<double[]> generateRandom3dPoints(int n, double min, double max, Random rng) {
double range = max - min
List<double[]> points = new ArrayList<>(n)
for (int i = 0; i < n; i++) {
points.add([min + rng.nextDouble() * range, min + rng.nextDouble() * range, min + rng.nextDouble() * range] as double[])
}
return points
}
// ==================== Parity on real protein atoms ====================
@Test
void parity_proteinAtoms_2W83() {
parity_proteinAtomClusters('distro/test_data/2W83.pdb', 3.0)
parity_proteinAtomClusters('distro/test_data/2W83.pdb', 5.0)
}
@Test
void parity_proteinAtoms_1fbl() {
parity_proteinAtomClusters('distro/test_data/1fbl.pdb.gz', 3.0)
parity_proteinAtomClusters('distro/test_data/1fbl.pdb.gz', 5.0)
}
/**
* Compare V1 and V2 on a subset of real protein atoms.
* Uses first N atoms to keep test fast (O(N²) is fine for small N).
*/
void parity_proteinAtomClusters(String pdbFile, double minDist) {
Protein p = Protein.load(pdbFile)
// Take a subset to keep test duration reasonable
int n = Math.min(200, p.proteinAtoms.count)
List<Atom> atoms = p.proteinAtoms.list.subList(0, n)
double sqrDist = minDist * minDist
Clusterer.Distance<Atom> sqrEuclid = AtomClusterer.SQR_EUCLID
def v1 = new SLinkClusterer<Atom>().cluster(atoms, sqrDist, sqrEuclid)
def v2 = new SLinkClustererV2<Atom>().cluster(atoms, sqrDist, sqrEuclid)
// Same number of clusters
assertEquals(v1.size(), v2.size(),
"Cluster count mismatch for $pdbFile minDist=$minDist")
// Same cluster memberships (as sets of PDB serials, order-independent)
Set<Set<Integer>> v1sets = v1.collect { List<Atom> c -> c.collect { it.PDBserial }.toSet() }.toSet()
Set<Set<Integer>> v2sets = v2.collect { List<Atom> c -> c.collect { it.PDBserial }.toSet() }.toSet()
assertEquals(v1sets, v2sets,
"Cluster memberships differ for $pdbFile minDist=$minDist")
}
}