Tests for pocket-grid-point descriptors + extract DescriptorListValidator

Adds focused regression tests for the new framework: 11 tests in three
new files plus 4 added to PocketGridRowsTest.

  PocketGridRowsTest +4
    - descriptor schema uses "{name}.{col}" prefix for multi-col
    - getRow appends descriptor values after the base 4 columns
    - unknown descriptor name throws at construction
    - scalar descriptor emits bare name() with no prefix (uses an
      inline ScalarTestDescriptor registered via the now-public
      registry hook — none of the shipped descriptors are scalar so
      the branch was untested)
  VolsiteGridPointDescriptorTest (new, 4 tests)
    - covers indicator aggregation + radius cutoff
  VolsiteSmoothGridPointDescriptorTest (new, 4 tests)
    - covers Gaussian kernel arithmetic + 4σ cutoff
  PocketGridPointDescriptorRegistryTest (new, 2 tests)
    - shipped names resolve, unknown name throws helpful error
  DescriptorListValidatorTest (new, 8 tests)
    - null/empty/valid/unknown/duplicate/null-entry/blank/dash-prefix

Refactors Main.validateDescriptorList out to a self-contained Java
utility (DescriptorListValidator) under predict/output/. The two call
sites in Main.validatePocketGridParams now invoke the static helper;
the private helper in Main is removed (-37 lines).

PocketGridPointDescriptorRegistry.register is promoted from private to
public so tests (and future external descriptor plugins) can add
descriptors without touching the registry's static initializer. The
shipped registrations still happen at class-load.
This commit is contained in:
rdk
2026-05-19 10:29:02 +02:00
parent 1931ef1f93
commit 6888716aa0
8 changed files with 456 additions and 34 deletions

View File

@@ -235,13 +235,8 @@ class Main implements Parametrized, Writable {
"Known: ${knownAssigners}")
}
// Every name in -pocket_descriptors must be registered. null/blank entries are
// rejected (rather than skipped) so a malformed config file fails fast instead
// of slipping through to the consumers (which would otherwise hit
// PocketDescriptorRegistry.get('') with a less useful error). Duplicates are
// rejected too — accepting them would produce a CSV with duplicate header cells
// and break Parquet's schema builder.
validateDescriptorList(params.pocket_descriptors,
cz.siret.prank.program.routines.predict.output.DescriptorListValidator.validate(
params.pocket_descriptors,
cz.siret.prank.program.routines.predict.output.descriptors.PocketDescriptorRegistry.knownNames(),
"pocket_descriptors")
@@ -285,7 +280,8 @@ class Main implements Parametrized, Writable {
"-vis_pocket_grid_gaussian_iso must be > 0 (got ${params.vis_pocket_grid_gaussian_iso}).")
}
validateDescriptorList(params.pocket_grid_point_descriptors,
cz.siret.prank.program.routines.predict.output.DescriptorListValidator.validate(
params.pocket_grid_point_descriptors,
cz.siret.prank.program.routines.predict.output.grid.descriptors.PocketGridPointDescriptorRegistry.knownNames(),
"pocket_grid_point_descriptors")
if (params.pocket_grid_volsite_radius <= 0d) {
@@ -305,31 +301,6 @@ class Main implements Parametrized, Writable {
}
}
/**
* Shared shape for validating a name-list param against a registry: rejects
* null/blank entries, unknown names, and duplicates. {@code paramName} is the
* Params property name (without the {@code -} prefix) used in error messages.
*/
private static void validateDescriptorList(List<String> names, Set<String> known, String paramName) {
if (names == null) return
Set<String> seen = new HashSet<>()
for (String name : names) {
if (name == null || name.trim().isEmpty()) {
throw new PrankException(
"-${paramName} contains an empty/null entry. Known: ${known}")
}
if (!known.contains(name)) {
throw new PrankException(
"Unknown name in -${paramName}: '${name}'. Known: ${known}")
}
if (!seen.add(name)) {
throw new PrankException(
"-${paramName} contains duplicate name '${name}'. " +
"Each descriptor may be listed at most once.")
}
}
}
String evalDirParam(String dirParam, String relativePrefixDir) {
if (dirParam == null) {
dirParam = "."

View File

@@ -0,0 +1,58 @@
package cz.siret.prank.program.routines.predict.output;
import cz.siret.prank.program.PrankException;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/**
* Validates a name-list param (e.g. {@code -pocket_descriptors},
* {@code -pocket_grid_point_descriptors}) against the names registered in a
* descriptor registry. Rejects:
*
* <ul>
* <li>null or blank entries — a malformed config file should fail fast rather than
* slip through and hit {@code Registry.get('')} with a less useful error.</li>
* <li>unknown names — names not registered in the registry.</li>
* <li>duplicates — would produce duplicate output columns and break Parquet's
* schema builder.</li>
* </ul>
*
* <p>A {@code null} list is treated as "no entries selected" and accepted without
* change — matches the existing Params convention where a missing list means
* "no extra columns".
*/
public final class DescriptorListValidator {
private DescriptorListValidator() {}
/**
* @param names values passed on the CLI / from a config file. A {@code null}
* list is treated as empty and accepted.
* @param known the registry's known-names set (from {@code Registry.knownNames()}).
* @param paramName the Params property name, used to format error messages
* (the leading {@code -} is added by the validator).
*/
public static void validate(List<String> names, Set<String> known, String paramName)
throws PrankException {
if (names == null) return;
Set<String> seen = new HashSet<>();
for (String name : names) {
if (name == null || name.trim().isEmpty()) {
throw new PrankException(
"-" + paramName + " contains an empty/null entry. Known: " + known);
}
if (!known.contains(name)) {
throw new PrankException(
"Unknown name in -" + paramName + ": '" + name + "'. Known: " + known);
}
if (!seen.add(name)) {
throw new PrankException(
"-" + paramName + " contains duplicate name '" + name + "'. " +
"Each descriptor may be listed at most once.");
}
}
}
}

View File

@@ -29,7 +29,14 @@ public final class PocketGridPointDescriptorRegistry {
private PocketGridPointDescriptorRegistry() {}
private static void register(PocketGridPointDescriptor d) {
/**
* Add a descriptor to the registry. Called from the static initializer for
* the shipped descriptors; also exposed for tests that need to register a
* fixture descriptor and for future external descriptor plugins. The
* registry has no remove/clear — a registered descriptor lives for the JVM's
* lifetime, which is intentional (CLI selection by name must be deterministic).
*/
public static void register(PocketGridPointDescriptor d) {
List<String> cols = d.columnNames();
if (cols.size() > 1 && new HashSet<>(cols).size() != cols.size()) {
throw new IllegalStateException(

View File

@@ -0,0 +1,81 @@
package cz.siret.prank.program.routines.predict.output
import cz.siret.prank.program.PrankException
import groovy.transform.CompileStatic
import org.junit.jupiter.api.Test
import static org.junit.jupiter.api.Assertions.*
@CompileStatic
class DescriptorListValidatorTest {
private static final Set<String> KNOWN = ['volume', 'sphericity', 'num_residues'] as Set
private static final String PARAM = 'pocket_descriptors'
@Test
void nullListIsAcceptedAsEmpty() {
// A missing list should be a no-op so a default-empty Params field doesn't
// require the caller to special-case it. Failure mode = any thrown exception.
DescriptorListValidator.validate(null, KNOWN, PARAM)
}
@Test
void emptyListIsAccepted() {
DescriptorListValidator.validate([], KNOWN, PARAM)
}
@Test
void validNamesAreAccepted() {
DescriptorListValidator.validate(['volume', 'sphericity'], KNOWN, PARAM)
}
@Test
void unknownNameThrowsAndNamesTheOffender() {
// The error message MUST include the typo so the user can locate it in their
// config, and MUST include the known list so they can correct it.
PrankException e = assertThrows(PrankException.class) {
DescriptorListValidator.validate(['volume', 'sphericty'], KNOWN, PARAM)
} as PrankException
assertTrue(e.message.contains("'sphericty'"), "missing typo: ${e.message}")
assertTrue(e.message.contains('-pocket_descriptors'), "missing param: ${e.message}")
assertTrue(e.message.contains('volume'), "missing known list: ${e.message}")
}
@Test
void duplicateNameThrows() {
PrankException e = assertThrows(PrankException.class) {
DescriptorListValidator.validate(['volume', 'volume'], KNOWN, PARAM)
} as PrankException
assertTrue(e.message.contains("'volume'"), "missing dup name: ${e.message}")
assertTrue(e.message.toLowerCase().contains('duplicate'), "missing 'duplicate': ${e.message}")
}
@Test
void nullEntryThrows() {
// Distinguishes "list of one valid name" from "list with a null inside" —
// catches malformed Groovy config files (e.g. trailing comma in a list literal).
PrankException e = assertThrows(PrankException.class) {
DescriptorListValidator.validate([null] as List<String>, KNOWN, PARAM)
} as PrankException
assertTrue(e.message.toLowerCase().contains('empty/null'), e.message)
}
@Test
void blankEntryThrows() {
PrankException e = assertThrows(PrankException.class) {
DescriptorListValidator.validate(['volume', ' '], KNOWN, PARAM)
} as PrankException
assertTrue(e.message.toLowerCase().contains('empty/null'), e.message)
}
@Test
void paramNameIsRenderedWithDashPrefix() {
// We don't want the validator to be inconsistent about CLI-flag formatting.
PrankException e = assertThrows(PrankException.class) {
DescriptorListValidator.validate(['xx'], KNOWN, 'some_param')
} as PrankException
assertTrue(e.message.contains('-some_param'),
"expected leading dash on param name, got: ${e.message}")
}
}

View File

@@ -1,11 +1,18 @@
package cz.siret.prank.program.routines.predict.output
import com.carrotsearch.hppc.LongIntHashMap
import cz.siret.prank.domain.Pocket
import cz.siret.prank.domain.Protein
import cz.siret.prank.geom.Atoms
import cz.siret.prank.geom.Point
import cz.siret.prank.program.PrankException
import cz.siret.prank.program.routines.predict.output.grid.PocketGrid
import cz.siret.prank.program.routines.predict.output.grid.descriptors.PocketGridPointContext
import cz.siret.prank.program.routines.predict.output.grid.descriptors.PocketGridPointDescriptor
import cz.siret.prank.program.routines.predict.output.grid.descriptors.PocketGridPointDescriptorRegistry
import groovy.transform.CompileStatic
import org.biojava.nbio.structure.Atom
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import static org.junit.jupiter.api.Assertions.*
@@ -80,4 +87,83 @@ class PocketGridRowsTest {
assertEquals(TableData.ColumnType.INT, data.getColumnType(3))
}
private static Protein emptyProtein() {
Protein p = new Protein()
p.proteinAtoms = new Atoms()
return p
}
@Test
void descriptorColumnsPrefixedWithDescriptorName() {
// Multi-column descriptor (volsite, 6 cols) must produce 6 prefixed
// headers; the prefix rule is documented contract for the export.
PocketGridRows data = new PocketGridRows(buildTwoPocketGrid(), false,
emptyProtein(), [] as List<Pocket>, ['volsite'])
assertEquals(['x', 'y', 'z', 'pocket',
'volsite.vsAromatic', 'volsite.vsCation', 'volsite.vsAnion',
'volsite.vsHydrophobic', 'volsite.vsAcceptor', 'volsite.vsDonor'],
data.header)
}
@Test
void getRowAppendsDescriptorValuesAfterBaseColumns() {
// Empty protein → cutoutSphere is empty → all 6 indicator columns are 0.
// The point of the test is the row LAYOUT (base 4 then 6 descriptor cols),
// not the descriptor's numeric semantics — that's covered in
// VolsiteGridPointDescriptorTest.
PocketGridRows data = new PocketGridRows(buildTwoPocketGrid(), false,
emptyProtein(), [] as List<Pocket>, ['volsite'])
double[] row = data.getRow(0)
assertEquals(10, row.length)
// base columns intact
assertEquals(1.0d, row[0], 0d); assertEquals(0d, row[1], 0d); assertEquals(0d, row[2], 0d)
assertEquals(1, (int) row[3])
// descriptor columns all zero (no atoms to classify)
for (int i = 4; i < row.length; i++) assertEquals(0d, row[i], 0d)
}
@Test
void unknownDescriptorNameThrowsAtConstruction() {
PocketGrid grid = buildTwoPocketGrid()
PrankException e = assertThrows(PrankException.class) {
new PocketGridRows(grid, false, emptyProtein(), [] as List<Pocket>, ['no_such_descriptor'])
} as PrankException
// The message must name the typo so the user can fix it.
assertTrue(e.message.contains('no_such_descriptor'),
"expected message to mention typo, got: ${e.message}")
}
/** Fixture: a 1-column descriptor that exercises the scalar branch of the header rule. */
@CompileStatic
private static final class ScalarTestDescriptor implements PocketGridPointDescriptor {
@Override String name() { return TEST_SCALAR_NAME }
@Override List<String> columnNames() { return ['ignored'] }
@Override List<TableData.ColumnType> columnTypes() { return [TableData.ColumnType.DOUBLE] }
@Override double[] compute(PocketGridPointContext ctx) { return [42.0d] as double[] }
}
private static final String TEST_SCALAR_NAME = '__test_scalar_descriptor__'
@BeforeAll
static void registerScalarFixture() {
// Idempotent: register() overwrites by name, so re-running tests in the same JVM
// is safe. Name is namespaced with underscores so it can't collide with any
// user-facing CLI name.
PocketGridPointDescriptorRegistry.register(new ScalarTestDescriptor())
}
@Test
void scalarDescriptorEmitsBareNameWithNoPrefix() {
// The "{name}.{col}" prefix rule applies ONLY when a descriptor has more than
// one column. A single-column descriptor's header is exactly name() — sub-name
// is ignored. None of the shipped descriptors are scalar, so this branch
// exists for future descriptors and the registered fixture exercises it.
PocketGridRows data = new PocketGridRows(buildTwoPocketGrid(), false,
emptyProtein(), [] as List<Pocket>, [TEST_SCALAR_NAME])
assertEquals(['x', 'y', 'z', 'pocket', TEST_SCALAR_NAME], data.header)
// The value 42 from compute() must land in the trailing descriptor column.
double[] row = data.getRow(0)
assertEquals(5, row.length)
assertEquals(42.0d, row[4], 0d)
}
}

View File

@@ -0,0 +1,32 @@
package cz.siret.prank.program.routines.predict.output.grid.descriptors
import cz.siret.prank.program.PrankException
import groovy.transform.CompileStatic
import org.junit.jupiter.api.Test
import static org.junit.jupiter.api.Assertions.*
@CompileStatic
class PocketGridPointDescriptorRegistryTest {
@Test
void shippedDescriptorsAreRegistered() {
// If someone removes one of the shipped descriptors, the CLI -pocket_grid_point_descriptors
// default no longer resolves and existing user config files start to fail. This test
// pins both shipped names.
assertNotNull(PocketGridPointDescriptorRegistry.get('volsite'))
assertNotNull(PocketGridPointDescriptorRegistry.get('volsite_smooth'))
}
@Test
void unknownNameThrowsWithKnownList() {
PrankException e = assertThrows(PrankException.class) {
PocketGridPointDescriptorRegistry.get('does_not_exist')
} as PrankException
// The error message must name the typo AND list the known names so the
// user can see the correct spelling.
assertTrue(e.message.contains('does_not_exist'), "missing typo in: ${e.message}")
assertTrue(e.message.contains('volsite'), "missing 'volsite' in known list: ${e.message}")
}
}

View File

@@ -0,0 +1,97 @@
package cz.siret.prank.program.routines.predict.output.grid.descriptors
import cz.siret.prank.domain.Protein
import cz.siret.prank.geom.Atoms
import cz.siret.prank.geom.Point
import groovy.transform.CompileStatic
import org.biojava.nbio.structure.Atom
import org.biojava.nbio.structure.AtomImpl
import org.biojava.nbio.structure.Element
import org.biojava.nbio.structure.Group
import org.biojava.nbio.structure.AminoAcidImpl
import org.junit.jupiter.api.Test
import static org.junit.jupiter.api.Assertions.*
/**
* Behavioural tests for the indicator volsite descriptor. We don't test every
* pharmacophore branch (that's {@code VolSitePharmacophore}'s domain) — we
* verify that the descriptor correctly aggregates per-atom flags into the
* 6-column row and honors the cutoff radius.
*
* <p>Column order (matches {@code VolSitePharmacophore.COLUMN_NAMES}):
* aromatic, cation, anion, hydrophobic, acceptor, donor.
*/
@CompileStatic
class VolsiteGridPointDescriptorTest {
private static final int AROMATIC = 0, CATION = 1, ANION = 2,
HYDROPHOBIC = 3, ACCEPTOR = 4, DONOR = 5
private static Atom atomAt(String element, String resName, String atomName,
double x, double y, double z) {
AtomImpl a = new AtomImpl()
a.element = Element.valueOfIgnoreCase(element)
a.name = atomName
a.x = x; a.y = y; a.z = z
Group g = new AminoAcidImpl()
g.setPDBName(resName)
a.setGroup(g)
return a
}
private static PocketGridPointContext ctxAt(double x, double y, double z, Atoms proteinAtoms) {
Protein p = new Protein()
p.proteinAtoms = proteinAtoms
return new PocketGridPointContext(0, new Point(x, y, z), 0, null, p, null)
}
@Test
void singleHydrophobicAtomNearbySetsOnlyHydrophobic() {
// Atom name "C" + any residue → hydrophobic (first branch in VolSitePharmacophore).
Atom c = atomAt("C", "ALA", "C", 1d, 0d, 0d) // 1 Å from grid point — well inside default 4 Å
double[] out = new VolsiteGridPointDescriptor().compute(
ctxAt(0d, 0d, 0d, new Atoms([c])))
assertEquals(1d, out[HYDROPHOBIC])
for (int i = 0; i < 6; i++) if (i != HYDROPHOBIC) assertEquals(0d, out[i], 0d, "col $i")
}
@Test
void atomOutsideRadiusContributesNothing() {
// Default -pocket_grid_volsite_radius is 4.0; place a hydrophobic atom at 5 Å.
Atom c = atomAt("C", "ALA", "C", 5d, 0d, 0d)
double[] out = new VolsiteGridPointDescriptor().compute(
ctxAt(0d, 0d, 0d, new Atoms([c])))
for (int i = 0; i < 6; i++) assertEquals(0d, out[i], 0d, "col $i")
}
@Test
void mixedNearbyAtomsSetMultipleIndependentFlags() {
// One donor (N backbone) + one anion (OD1 in ASP) + one cation (ZN — name-only rule).
// All within 4 Å of origin. Three indicators should fire; the other three should not.
Atoms protein = new Atoms([
atomAt("N", "ALA", "N", 1d, 0d, 0d),
atomAt("O", "ASP", "OD1", 0d, 1d, 0d),
atomAt("ZN", "ZN", "ZN", 0d, 0d, 1d),
])
double[] out = new VolsiteGridPointDescriptor().compute(
ctxAt(0d, 0d, 0d, protein))
assertEquals(1d, out[DONOR])
assertEquals(1d, out[ANION])
assertEquals(1d, out[CATION])
assertEquals(0d, out[AROMATIC])
assertEquals(0d, out[HYDROPHOBIC])
assertEquals(0d, out[ACCEPTOR])
}
@Test
void emptyNeighborhoodAllZeros() {
double[] out = new VolsiteGridPointDescriptor().compute(
ctxAt(0d, 0d, 0d, new Atoms()))
for (int i = 0; i < 6; i++) assertEquals(0d, out[i], 0d)
}
}

View File

@@ -0,0 +1,90 @@
package cz.siret.prank.program.routines.predict.output.grid.descriptors
import cz.siret.prank.domain.Protein
import cz.siret.prank.geom.Atoms
import cz.siret.prank.geom.Point
import cz.siret.prank.program.params.Params
import groovy.transform.CompileStatic
import org.biojava.nbio.structure.Atom
import org.biojava.nbio.structure.AtomImpl
import org.biojava.nbio.structure.Element
import org.biojava.nbio.structure.Group
import org.biojava.nbio.structure.AminoAcidImpl
import org.junit.jupiter.api.Test
import static org.junit.jupiter.api.Assertions.*
/**
* Tests the Gaussian weighting math: kernel value at known distances, summing
* across atoms, and the 4σ cutoff. These are the numeric facts that, if broken,
* silently produce wrong scores — exactly what unit tests should catch.
*/
@CompileStatic
class VolsiteSmoothGridPointDescriptorTest {
private static final double DELTA = 1e-9
private static final int AROMATIC = 0, CATION = 1, ANION = 2,
HYDROPHOBIC = 3, ACCEPTOR = 4, DONOR = 5
private static Atom atomAt(String atomName, String resName, double x, double y, double z) {
AtomImpl a = new AtomImpl()
a.element = Element.C
a.name = atomName
a.x = x; a.y = y; a.z = z
Group g = new AminoAcidImpl()
g.setPDBName(resName)
a.setGroup(g)
return a
}
private static PocketGridPointContext ctxAt(double x, double y, double z, Atoms proteinAtoms) {
Protein p = new Protein()
p.proteinAtoms = proteinAtoms
return new PocketGridPointContext(0, new Point(x, y, z), 0, null, p, null)
}
@Test
void weightAtZeroDistanceIsOne() {
// exp(0) = 1.0 exactly. Atom name "C" is hydrophobic.
Atom c = atomAt("C", "ALA", 0d, 0d, 0d)
double[] out = new VolsiteSmoothGridPointDescriptor().compute(
ctxAt(0d, 0d, 0d, new Atoms([c])))
assertEquals(1.0d, out[HYDROPHOBIC], DELTA)
}
@Test
void weightAtSigmaMatchesGaussianFormula() {
// At distance r = σ, weight = exp(-r²/(2σ²)) = exp(-1/2) ≈ 0.6065.
double sigma = Params.inst.pocket_grid_volsite_sigma
Atom c = atomAt("C", "ALA", sigma, 0d, 0d)
double[] out = new VolsiteSmoothGridPointDescriptor().compute(
ctxAt(0d, 0d, 0d, new Atoms([c])))
assertEquals(Math.exp(-0.5d), out[HYDROPHOBIC], DELTA)
}
@Test
void weightsFromMultipleAtomsOfSameTypeSum() {
// Two hydrophobic atoms at distance σ each → sum = 2 × exp(-0.5).
double sigma = Params.inst.pocket_grid_volsite_sigma
Atoms protein = new Atoms([
atomAt("C", "ALA", sigma, 0d, 0d),
atomAt("C", "ALA", 0d, sigma, 0d),
])
double[] out = new VolsiteSmoothGridPointDescriptor().compute(
ctxAt(0d, 0d, 0d, protein))
assertEquals(2d * Math.exp(-0.5d), out[HYDROPHOBIC], DELTA)
}
@Test
void atomBeyondCutoffContributesZero() {
// 4σ is the hard cutoff (cutoutSphere is the gate). At 5σ the atom isn't even
// in the kdtree result. Zero contribution.
double sigma = Params.inst.pocket_grid_volsite_sigma
Atom c = atomAt("C", "ALA", 5d * sigma, 0d, 0d)
double[] out = new VolsiteSmoothGridPointDescriptor().compute(
ctxAt(0d, 0d, 0d, new Atoms([c])))
assertEquals(0d, out[HYDROPHOBIC], DELTA)
}
}