mirror of
https://github.com/rdk/p2rank.git
synced 2026-06-04 12:44:24 +08:00
export-points command that works with custom feature setup
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -27,6 +27,7 @@ out/*
|
||||
|
||||
venv/
|
||||
tmp
|
||||
pastebin.txt
|
||||
|
||||
# P2Rank distro
|
||||
distro/bin/
|
||||
|
||||
@@ -110,6 +110,9 @@ quick() {
|
||||
# test export_points feature
|
||||
test ./prank.sh predict -f distro/test_data/1fbl.pdb -export_points 1 -export_points_format csv.gz -out_subdir TEST/TESTS
|
||||
|
||||
# test export-points command (no model)
|
||||
test ./prank.sh export-points -f distro/test_data/1fbl.pdb -export_points_format csv.gz -out_subdir TEST/TESTS
|
||||
|
||||
}
|
||||
|
||||
quick_train() {
|
||||
@@ -399,21 +402,33 @@ export_points() {
|
||||
|
||||
title EXPORT POINTS FEATURE
|
||||
|
||||
test ./prank.sh predict -f distro/test_data/1fbl.pdb -export_points 1 -export_points_format csv -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh predict -f distro/test_data/1fbl.pdb -export_points 1 -export_points_format csv.gz -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh predict -f distro/test_data/1fbl.pdb -export_points 1 -export_points_format csv.zst -out_subdir TEST/EXPORT_POINTS
|
||||
|
||||
test ./prank.sh predict -f distro/test_data/1fbl.pdb -export_points 1 -export_points_format csv -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh predict -f distro/test_data/1fbl.pdb -export_points 1 -export_points_format csv.gz -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh predict -f distro/test_data/1fbl.pdb -export_points 1 -export_points_format csv.zst -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh predict -f distro/test_data/1fbl.pdb -export_points 1 -export_points_format arrow -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh predict -f distro/test_data/1fbl.pdb -export_points 1 -export_points_format arrow.gz -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh predict -f distro/test_data/1fbl.pdb -export_points 1 -export_points_format arrow.zst -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh predict -f distro/test_data/1fbl.pdb -export_points 1 -export_points_format parquet -out_subdir TEST/EXPORT_POINTS
|
||||
|
||||
test ./prank.sh predict -f distro/test_data/1fbl.pdb -export_points 1 -export_points_format parquet -out_subdir TEST/EXPORT_POINTS
|
||||
# export-points command (no model needed)
|
||||
test ./prank.sh export-points -f distro/test_data/2W83.cif -extra_features 'surface_protrusion' -export_points_format arrow -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh export-points -f distro/test_data/1fbl.pdb -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh export-points -f distro/test_data/1fbl.pdb -export_points_format csv.gz -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh export-points -f distro/test_data/1fbl.pdb -export_points_format csv.zst -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh export-points -f distro/test_data/1fbl.pdb -export_points_format arrow -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh export-points -f distro/test_data/1fbl.pdb -export_points_format arrow.gz -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh export-points -f distro/test_data/1fbl.pdb -export_points_format arrow.zst -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh export-points -f distro/test_data/1fbl.pdb -export_points_format parquet -out_subdir TEST/EXPORT_POINTS
|
||||
|
||||
# predict/rescore on datasets
|
||||
test ./prank.sh predict chen11.ds -c config/test-default -export_points 1 -export_points_format arrow.zst -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh rescore coach420-fpocket.ds -c config/test-default -export_points 1 -export_points_format csv.zst -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh rescore chen11-fpocket.ds -c config/test-default -export_points 1 -export_points_format parquet -out_subdir TEST/EXPORT_POINTS
|
||||
|
||||
# export-points command on datasets
|
||||
test ./prank.sh export-points chen11.ds -c config/test-default -export_points_format arrow.zst -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh export-points coach420-fpocket.ds -c config/test-default -export_points_format csv.zst -out_subdir TEST/EXPORT_POINTS
|
||||
test ./prank.sh export-points chen11-fpocket.ds -c config/test-default -export_points_format parquet -out_subdir TEST/EXPORT_POINTS
|
||||
}
|
||||
|
||||
classifiers() {
|
||||
|
||||
@@ -11,6 +11,7 @@ import cz.siret.prank.program.routines.analyze.AnalyzeRoutine
|
||||
import cz.siret.prank.program.routines.analyze.PrintRoutine
|
||||
import cz.siret.prank.program.routines.analyze.TransformRoutine
|
||||
import cz.siret.prank.program.routines.benchmark.Benchmarks
|
||||
import cz.siret.prank.program.routines.predict.ExportPointsRoutine
|
||||
import cz.siret.prank.program.routines.predict.PredictPocketsRoutine
|
||||
import cz.siret.prank.program.routines.predict.PredictResiduesRoutine
|
||||
import cz.siret.prank.program.routines.predict.RescorePocketsRoutine
|
||||
@@ -298,6 +299,15 @@ class Main implements Parametrized, Writable {
|
||||
doRunPredict("predict", false)
|
||||
}
|
||||
|
||||
void runExportPoints() {
|
||||
Dataset dataset = loadDatasetOrFile()
|
||||
String outdir = findOutdir("export_points_$dataset.label")
|
||||
configureLoggers(outdir)
|
||||
|
||||
Dataset.Result result = new ExportPointsRoutine(dataset, outdir).execute()
|
||||
finalizeDatasetResult(result)
|
||||
}
|
||||
|
||||
void runEvalPredict() {
|
||||
doRunPredict("eval_predict", true)
|
||||
}
|
||||
@@ -424,6 +434,8 @@ class Main implements Parametrized, Writable {
|
||||
switch (command) {
|
||||
case 'predict': runPredict()
|
||||
break
|
||||
case 'export-points': runExportPoints()
|
||||
break
|
||||
case 'eval-predict': runEvalPredict()
|
||||
break
|
||||
case 'rescore': runRescore()
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
package cz.siret.prank.program.routines.predict
|
||||
|
||||
import cz.siret.prank.domain.Dataset
|
||||
import cz.siret.prank.domain.PredictionPair
|
||||
import cz.siret.prank.domain.labeling.LabeledPoint
|
||||
import cz.siret.prank.domain.loaders.LoaderParams
|
||||
import cz.siret.prank.features.FeatureExtractor
|
||||
import cz.siret.prank.features.FeatureVector
|
||||
import cz.siret.prank.features.PrankFeatureExtractor
|
||||
import cz.siret.prank.program.routines.Routine
|
||||
import cz.siret.prank.program.routines.predict.output.PointExportData
|
||||
import cz.siret.prank.program.routines.predict.output.PointsExporter
|
||||
import groovy.transform.CompileStatic
|
||||
import groovy.util.logging.Slf4j
|
||||
import org.biojava.nbio.structure.Atom
|
||||
|
||||
import static cz.siret.prank.utils.ATimer.startTimer
|
||||
import static cz.siret.prank.utils.Futils.mkdirs
|
||||
|
||||
/**
|
||||
* Routine for exporting SAS points with feature vectors — no model, no prediction.
|
||||
*
|
||||
* Generates SAS surface for each protein, calculates configured features
|
||||
* (including extra_features), and exports point coordinates + feature values.
|
||||
* No model is loaded, no predictions are made, no score column in output.
|
||||
*
|
||||
* Backs prank command 'export-points'.
|
||||
*/
|
||||
@Slf4j
|
||||
@CompileStatic
|
||||
class ExportPointsRoutine extends Routine {
|
||||
|
||||
Dataset dataset
|
||||
|
||||
ExportPointsRoutine(Dataset dataset, String outdir) {
|
||||
super(outdir)
|
||||
this.dataset = dataset
|
||||
}
|
||||
|
||||
Dataset.Result execute() {
|
||||
def timer = startTimer()
|
||||
|
||||
write "exporting SAS points with features for proteins from dataset [$dataset.name]"
|
||||
|
||||
mkdirs(outdir)
|
||||
writeParams(outdir)
|
||||
log.info "outdir: $outdir"
|
||||
|
||||
FeatureExtractor extractorFactory = FeatureExtractor.createFactory()
|
||||
|
||||
LoaderParams.ignoreLigandsSwitch = true
|
||||
|
||||
String format = params.export_points_format
|
||||
|
||||
Dataset.Result result = dataset.processItems { Dataset.Item item ->
|
||||
|
||||
PredictionPair pair = item.predictionPair
|
||||
|
||||
PointExportData exportData = calculateExportData(extractorFactory, pair, item)
|
||||
|
||||
PointsExporter.exportPoints(exportData, outdir, item.label, format)
|
||||
|
||||
if (!dataset.cached) {
|
||||
item.cachedPair = null
|
||||
}
|
||||
}
|
||||
|
||||
write "exporting points finished in $timer.formatted"
|
||||
write "results saved to directory [$outdir]"
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate feature vectors for all SAS points on the protein surface.
|
||||
*/
|
||||
private static PointExportData calculateExportData(FeatureExtractor extractorFactory,
|
||||
PredictionPair pair,
|
||||
Dataset.Item item) {
|
||||
FeatureExtractor proteinExtractor = extractorFactory.createPrototypeForProtein(pair.protein, item.context)
|
||||
|
||||
try {
|
||||
PrankFeatureExtractor extractor = (PrankFeatureExtractor) proteinExtractor
|
||||
extractor = (PrankFeatureExtractor) extractor.createInstanceForWholeProtein()
|
||||
|
||||
int nPoints = extractor.sampledPoints.points.count
|
||||
List<LabeledPoint> labeledPoints = new ArrayList<>(nPoints)
|
||||
List<FeatureVector> vectors = new ArrayList<>(nPoints)
|
||||
|
||||
for (Atom point : extractor.sampledPoints.points) {
|
||||
labeledPoints.add(new LabeledPoint(point))
|
||||
vectors.add(extractor.calcFeatureVector(point))
|
||||
}
|
||||
|
||||
return PointExportData.createWithoutScores(labeledPoints, vectors, extractor.vectorHeader)
|
||||
} finally {
|
||||
proteinExtractor.finalizeProteinPrototype()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import cz.siret.prank.features.FeatureVector
|
||||
import groovy.transform.CompileStatic
|
||||
|
||||
/**
|
||||
* Encapsulates data needed for exporting SAS points with their feature vectors and scores.
|
||||
* Encapsulates data needed for exporting SAS points with their feature vectors and optionally scores.
|
||||
* Implements TableData for generic export via TableExporter.
|
||||
*/
|
||||
@CompileStatic
|
||||
@@ -15,12 +15,19 @@ class PointExportData implements TableData {
|
||||
final List<FeatureVector> featureVectors
|
||||
final List<String> featureHeader
|
||||
|
||||
/** Cached full header: [x, y, z, score, ...featureHeader] */
|
||||
/** Whether to include score column in export (false for export-points command) */
|
||||
final boolean includeScore
|
||||
|
||||
/** Number of fixed columns before features (3 without score, 4 with score) */
|
||||
private final int fixedColumns
|
||||
|
||||
/** Cached full header */
|
||||
private List<String> cachedHeader
|
||||
|
||||
private PointExportData(List<LabeledPoint> labeledPoints,
|
||||
List<FeatureVector> featureVectors,
|
||||
List<String> featureHeader) {
|
||||
List<String> featureHeader,
|
||||
boolean includeScore) {
|
||||
if (labeledPoints.size() != featureVectors.size()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Size mismatch: ${labeledPoints.size()} points but ${featureVectors.size()} feature vectors")
|
||||
@@ -28,6 +35,8 @@ class PointExportData implements TableData {
|
||||
this.labeledPoints = labeledPoints
|
||||
this.featureVectors = featureVectors
|
||||
this.featureHeader = featureHeader
|
||||
this.includeScore = includeScore
|
||||
this.fixedColumns = includeScore ? 4 : 3
|
||||
}
|
||||
|
||||
// --- TableData Implementation ---
|
||||
@@ -35,7 +44,8 @@ class PointExportData implements TableData {
|
||||
@Override
|
||||
List<String> getHeader() {
|
||||
if (cachedHeader == null) {
|
||||
cachedHeader = ["x", "y", "z", "score"] + featureHeader
|
||||
List<String> prefix = includeScore ? ["x", "y", "z", "score"] : ["x", "y", "z"]
|
||||
cachedHeader = prefix + featureHeader
|
||||
}
|
||||
return cachedHeader
|
||||
}
|
||||
@@ -51,12 +61,14 @@ class PointExportData implements TableData {
|
||||
double[] coords = lp.getCoords()
|
||||
double[] features = featureVectors.get(index).getArray()
|
||||
|
||||
double[] row = new double[4 + features.length]
|
||||
double[] row = new double[fixedColumns + features.length]
|
||||
row[0] = coords[0]
|
||||
row[1] = coords[1]
|
||||
row[2] = coords[2]
|
||||
row[3] = lp.score
|
||||
System.arraycopy(features, 0, row, 4, features.length)
|
||||
if (includeScore) {
|
||||
row[3] = lp.score
|
||||
}
|
||||
System.arraycopy(features, 0, row, fixedColumns, features.length)
|
||||
return row
|
||||
}
|
||||
|
||||
@@ -74,14 +86,14 @@ class PointExportData implements TableData {
|
||||
for (int i = 0; i < n; i++) {
|
||||
column[i] = labeledPoints.get(i).getCoords()[colIndex]
|
||||
}
|
||||
} else if (colIndex == 3) {
|
||||
// Score column
|
||||
} else if (includeScore && colIndex == 3) {
|
||||
// Score column (only when included)
|
||||
for (int i = 0; i < n; i++) {
|
||||
column[i] = labeledPoints.get(i).score
|
||||
}
|
||||
} else {
|
||||
// Feature columns
|
||||
int featureIndex = colIndex - 4
|
||||
int featureIndex = colIndex - fixedColumns
|
||||
for (int i = 0; i < n; i++) {
|
||||
column[i] = featureVectors.get(i).getArray()[featureIndex]
|
||||
}
|
||||
@@ -100,13 +112,21 @@ class PointExportData implements TableData {
|
||||
// --- Factory Methods ---
|
||||
|
||||
/**
|
||||
* Creates export data from pre-collected lists.
|
||||
* Used when vectors are computed in batch (predict mode).
|
||||
* Creates export data with score column (for predict/rescore).
|
||||
*/
|
||||
static PointExportData create(List<LabeledPoint> labeledPoints,
|
||||
List<FeatureVector> featureVectors,
|
||||
List<String> featureHeader) {
|
||||
return new PointExportData(labeledPoints, featureVectors, featureHeader)
|
||||
return new PointExportData(labeledPoints, featureVectors, featureHeader, true)
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates export data without score column (for export-points command).
|
||||
*/
|
||||
static PointExportData createWithoutScores(List<LabeledPoint> labeledPoints,
|
||||
List<FeatureVector> featureVectors,
|
||||
List<String> featureHeader) {
|
||||
return new PointExportData(labeledPoints, featureVectors, featureHeader, false)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -133,7 +153,7 @@ class PointExportData implements TableData {
|
||||
}
|
||||
|
||||
PointExportData build() {
|
||||
return new PointExportData(labeledPoints, featureVectors, featureHeader)
|
||||
return new PointExportData(labeledPoints, featureVectors, featureHeader, true)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import groovy.transform.CompileStatic
|
||||
import groovy.util.logging.Slf4j
|
||||
|
||||
/**
|
||||
* Exports SAS points with their feature vectors and predicted scores.
|
||||
* Exports SAS points with their feature vectors and optionally predicted scores.
|
||||
* Delegates to TableExporter for format-specific logic.
|
||||
*
|
||||
* Supported formats: csv, csv.gz, csv.zst, arrow, arrow.gz, arrow.zst, parquet
|
||||
@@ -31,10 +31,17 @@ class PointsExporter {
|
||||
}
|
||||
|
||||
/**
|
||||
* Export points to file.
|
||||
* Export points to file. Uses format from Params.
|
||||
*/
|
||||
static void exportPoints(PointExportData data, String outdir, String label) {
|
||||
String format = Params.inst.export_points_format
|
||||
exportPoints(data, outdir, label, Params.inst.export_points_format)
|
||||
}
|
||||
|
||||
/**
|
||||
* Export points to file with explicit format.
|
||||
* Used by export-points command (always exports, format passed directly).
|
||||
*/
|
||||
static void exportPoints(PointExportData data, String outdir, String label, String format) {
|
||||
String filepath = "${outdir}/${label}_points.${format}"
|
||||
|
||||
long start = System.currentTimeMillis()
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
|
||||
commands:
|
||||
|
||||
predict ... predict pockets (P2RANK)
|
||||
eval-predict ... evaluate model on a dataset with known ligands
|
||||
rescore ... rescore previously detected pockets (PRANK)
|
||||
eval-rescore ... evaluate rescoring model on a dataset with known ligands
|
||||
predict ... predict pockets (P2RANK)
|
||||
eval-predict ... evaluate model on a dataset with known ligands
|
||||
rescore ... rescore previously detected pockets (PRANK)
|
||||
eval-rescore ... evaluate rescoring model on a dataset with known ligands
|
||||
export-points ... export SAS points with feature vectors (no model needed)
|
||||
|
||||
datasets:
|
||||
|
||||
@@ -38,4 +39,4 @@
|
||||
-visualizations <0/1> produce PyMOL visualizations
|
||||
default: true
|
||||
|
||||
-<param> <value> for full list of parameters see config/default.groovy
|
||||
-<param> <value> for full list of parameters see config/default.groovy
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
package cz.siret.prank.program.routines.predict
|
||||
|
||||
import cz.siret.prank.domain.Dataset
|
||||
import cz.siret.prank.domain.loaders.LoaderParams
|
||||
import cz.siret.prank.program.params.Params
|
||||
import cz.siret.prank.utils.Futils
|
||||
import groovy.transform.CompileStatic
|
||||
import org.junit.jupiter.api.*
|
||||
import org.junit.jupiter.api.parallel.Isolated
|
||||
import org.junit.jupiter.api.parallel.ResourceLock
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*
|
||||
|
||||
/**
|
||||
* Integration tests for export-points command.
|
||||
* Uses existing test data — no additional downloads required.
|
||||
*/
|
||||
@Isolated
|
||||
@ResourceLock("Params")
|
||||
@CompileStatic
|
||||
class ExportPointsRoutineTest {
|
||||
|
||||
static final String TEST_DATA = "distro/test_data"
|
||||
static final String PDB_1FBL = "$TEST_DATA/1fbl.pdb"
|
||||
static final String OUT_DIR = "$TEST_DATA/../test_output/export_points_test"
|
||||
|
||||
static Params originalParams
|
||||
static boolean origIgnoreLigandsSwitch
|
||||
|
||||
@BeforeAll
|
||||
static void setup() {
|
||||
originalParams = (Params) Params.inst.clone()
|
||||
origIgnoreLigandsSwitch = LoaderParams.ignoreLigandsSwitch
|
||||
Params.INSTANCE = new Params()
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
static void tearDown() {
|
||||
Params.INSTANCE = originalParams
|
||||
LoaderParams.ignoreLigandsSwitch = origIgnoreLigandsSwitch
|
||||
try { Futils.delete(OUT_DIR) } catch (Exception ignored) {}
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void resetParams() {
|
||||
Params.INSTANCE = new Params()
|
||||
}
|
||||
|
||||
@Test
|
||||
void exportsCsvWithDefaultFeatures() {
|
||||
Params.inst.export_points_format = "csv"
|
||||
|
||||
Dataset dataset = Dataset.createSingleFileDataset(PDB_1FBL)
|
||||
String outdir = "$OUT_DIR/csv_default"
|
||||
|
||||
Dataset.Result result = new ExportPointsRoutine(dataset, outdir).execute()
|
||||
|
||||
assertFalse(result.hasErrors(), "Should not have errors")
|
||||
|
||||
String csvFile = "$outdir/1fbl.pdb_points.csv"
|
||||
assertTrue(Futils.exists(csvFile), "CSV file should exist")
|
||||
assertTrue(Futils.size(csvFile) > 0, "CSV file should not be empty")
|
||||
|
||||
// Verify header has no score column
|
||||
String firstLine = new File(csvFile).readLines().first()
|
||||
assertTrue(firstLine.startsWith("x,y,z,"), "Should start with x,y,z")
|
||||
assertFalse(firstLine.contains("score"), "Should not contain score column")
|
||||
|
||||
// Verify has feature columns
|
||||
assertTrue(firstLine.contains("chem."), "Should contain chem features")
|
||||
|
||||
// Verify has data rows (default tessellation=2 gives ~5000 points for ~200 residue protein)
|
||||
int lineCount = new File(csvFile).readLines().size()
|
||||
assertTrue(lineCount > 1000, "Should have >1000 data rows (SAS points), got $lineCount")
|
||||
}
|
||||
|
||||
@Test
|
||||
void exportsParquet() {
|
||||
Params.inst.export_points_format = "parquet"
|
||||
|
||||
Dataset dataset = Dataset.createSingleFileDataset(PDB_1FBL)
|
||||
String outdir = "$OUT_DIR/parquet"
|
||||
|
||||
Dataset.Result result = new ExportPointsRoutine(dataset, outdir).execute()
|
||||
|
||||
assertFalse(result.hasErrors())
|
||||
|
||||
String pqFile = "$outdir/1fbl.pdb_points.parquet"
|
||||
assertTrue(Futils.exists(pqFile), "Parquet file should exist")
|
||||
assertTrue(Futils.size(pqFile) > 0, "Parquet file should not be empty")
|
||||
}
|
||||
|
||||
@Test
|
||||
void writesParamsFile() {
|
||||
Params.inst.export_points_format = "csv"
|
||||
|
||||
Dataset dataset = Dataset.createSingleFileDataset(PDB_1FBL)
|
||||
String outdir = "$OUT_DIR/params_check"
|
||||
|
||||
new ExportPointsRoutine(dataset, outdir).execute()
|
||||
|
||||
assertTrue(Futils.exists("$outdir/params.txt"), "params.txt should exist")
|
||||
}
|
||||
|
||||
@Test
|
||||
void featureCountMatchesHeader() {
|
||||
Params.inst.export_points_format = "csv"
|
||||
|
||||
Dataset dataset = Dataset.createSingleFileDataset(PDB_1FBL)
|
||||
String outdir = "$OUT_DIR/feature_count"
|
||||
|
||||
new ExportPointsRoutine(dataset, outdir).execute()
|
||||
|
||||
String csvFile = "$outdir/1fbl.pdb_points.csv"
|
||||
List<String> lines = new File(csvFile).readLines()
|
||||
String header = lines.first()
|
||||
String dataLine = lines.get(1)
|
||||
|
||||
int headerCols = header.split(",").length
|
||||
int dataCols = dataLine.split(",").length
|
||||
assertEquals(headerCols, dataCols,
|
||||
"Header columns ($headerCols) should match data columns ($dataCols)")
|
||||
}
|
||||
|
||||
@Test
|
||||
void noScoreColumnInOutput() {
|
||||
Params.inst.export_points_format = "csv"
|
||||
|
||||
Dataset dataset = Dataset.createSingleFileDataset(PDB_1FBL)
|
||||
String outdir = "$OUT_DIR/no_score"
|
||||
|
||||
new ExportPointsRoutine(dataset, outdir).execute()
|
||||
|
||||
String csvFile = "$outdir/1fbl.pdb_points.csv"
|
||||
String header = new File(csvFile).readLines().first()
|
||||
List<String> columns = header.split(",").toList()
|
||||
|
||||
// First 3 columns are coordinates
|
||||
assertEquals("x", columns[0])
|
||||
assertEquals("y", columns[1])
|
||||
assertEquals("z", columns[2])
|
||||
|
||||
// 4th column should be a feature, not "score"
|
||||
assertNotEquals("score", columns[3], "4th column should be a feature, not score")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,177 @@
|
||||
package cz.siret.prank.program.routines.predict.output
|
||||
|
||||
import cz.siret.prank.collectors.DoubleVector
|
||||
import cz.siret.prank.domain.labeling.LabeledPoint
|
||||
import cz.siret.prank.features.FeatureVector
|
||||
import groovy.transform.CompileStatic
|
||||
import org.biojava.nbio.structure.AtomImpl
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*
|
||||
|
||||
@CompileStatic
|
||||
class PointExportDataTest {
|
||||
|
||||
// --- With scores (existing behavior) ---
|
||||
|
||||
@Test
|
||||
void withScores_headerIncludesScore() {
|
||||
def data = PointExportData.create(
|
||||
[point(1, 2, 3, 0.8)],
|
||||
[vector(0.1, 0.2)],
|
||||
["feat1", "feat2"]
|
||||
)
|
||||
|
||||
assertEquals(["x", "y", "z", "score", "feat1", "feat2"], data.header)
|
||||
}
|
||||
|
||||
@Test
|
||||
void withScores_rowIncludesScore() {
|
||||
def data = PointExportData.create(
|
||||
[point(1, 2, 3, 0.8)],
|
||||
[vector(0.1, 0.2)],
|
||||
["feat1", "feat2"]
|
||||
)
|
||||
|
||||
double[] row = data.getRow(0)
|
||||
assertEquals(6, row.length)
|
||||
assertEquals(1.0d, row[0], 1e-9) // x
|
||||
assertEquals(2.0d, row[1], 1e-9) // y
|
||||
assertEquals(3.0d, row[2], 1e-9) // z
|
||||
assertEquals(0.8d, row[3], 1e-9) // score
|
||||
assertEquals(0.1d, row[4], 1e-9) // feat1
|
||||
assertEquals(0.2d, row[5], 1e-9) // feat2
|
||||
}
|
||||
|
||||
@Test
|
||||
void withScores_columnAccess() {
|
||||
def data = PointExportData.create(
|
||||
[point(1, 2, 3, 0.8), point(4, 5, 6, 0.9)],
|
||||
[vector(0.1, 0.2), vector(0.3, 0.4)],
|
||||
["feat1", "feat2"]
|
||||
)
|
||||
|
||||
// Score column
|
||||
double[] scoreCol = data.getColumn(3)
|
||||
assertArrayEquals([0.8d, 0.9d] as double[], scoreCol, 1e-9)
|
||||
|
||||
// First feature column (index 4)
|
||||
double[] feat1Col = data.getColumn(4)
|
||||
assertArrayEquals([0.1d, 0.3d] as double[], feat1Col, 1e-9)
|
||||
}
|
||||
|
||||
// --- Without scores (export-points behavior) ---
|
||||
|
||||
@Test
|
||||
void withoutScores_headerExcludesScore() {
|
||||
def data = PointExportData.createWithoutScores(
|
||||
[point(1, 2, 3, 0.8)],
|
||||
[vector(0.1, 0.2)],
|
||||
["feat1", "feat2"]
|
||||
)
|
||||
|
||||
assertEquals(["x", "y", "z", "feat1", "feat2"], data.header)
|
||||
}
|
||||
|
||||
@Test
|
||||
void withoutScores_rowExcludesScore() {
|
||||
def data = PointExportData.createWithoutScores(
|
||||
[point(1, 2, 3, 0.8)],
|
||||
[vector(0.1, 0.2)],
|
||||
["feat1", "feat2"]
|
||||
)
|
||||
|
||||
double[] row = data.getRow(0)
|
||||
assertEquals(5, row.length)
|
||||
assertEquals(1.0d, row[0], 1e-9) // x
|
||||
assertEquals(2.0d, row[1], 1e-9) // y
|
||||
assertEquals(3.0d, row[2], 1e-9) // z
|
||||
assertEquals(0.1d, row[3], 1e-9) // feat1 (no score gap)
|
||||
assertEquals(0.2d, row[4], 1e-9) // feat2
|
||||
}
|
||||
|
||||
@Test
|
||||
void withoutScores_columnAccess() {
|
||||
def data = PointExportData.createWithoutScores(
|
||||
[point(1, 2, 3, 0.8), point(4, 5, 6, 0.9)],
|
||||
[vector(0.1, 0.2), vector(0.3, 0.4)],
|
||||
["feat1", "feat2"]
|
||||
)
|
||||
|
||||
// Column 3 is now feat1 (not score)
|
||||
double[] feat1Col = data.getColumn(3)
|
||||
assertArrayEquals([0.1d, 0.3d] as double[], feat1Col, 1e-9)
|
||||
|
||||
// Column 4 is feat2
|
||||
double[] feat2Col = data.getColumn(4)
|
||||
assertArrayEquals([0.2d, 0.4d] as double[], feat2Col, 1e-9)
|
||||
}
|
||||
|
||||
@Test
|
||||
void withoutScores_coordinateColumns() {
|
||||
def data = PointExportData.createWithoutScores(
|
||||
[point(1, 2, 3, 0), point(4, 5, 6, 0)],
|
||||
[vector(0.1), vector(0.2)],
|
||||
["feat"]
|
||||
)
|
||||
|
||||
assertArrayEquals([1.0d, 4.0d] as double[], data.getColumn(0), 1e-9) // x
|
||||
assertArrayEquals([2.0d, 5.0d] as double[], data.getColumn(1), 1e-9) // y
|
||||
assertArrayEquals([3.0d, 6.0d] as double[], data.getColumn(2), 1e-9) // z
|
||||
}
|
||||
|
||||
@Test
|
||||
void withoutScores_rowCount() {
|
||||
def data = PointExportData.createWithoutScores(
|
||||
[point(1, 2, 3, 0), point(4, 5, 6, 0)],
|
||||
[vector(0.1), vector(0.2)],
|
||||
["feat"]
|
||||
)
|
||||
|
||||
assertEquals(2, data.rowCount)
|
||||
}
|
||||
|
||||
@Test
|
||||
void withoutScores_includeScoreIsFalse() {
|
||||
def data = PointExportData.createWithoutScores(
|
||||
[point(1, 2, 3, 0)],
|
||||
[vector(0.1)],
|
||||
["feat"]
|
||||
)
|
||||
assertFalse(data.includeScore)
|
||||
}
|
||||
|
||||
@Test
|
||||
void withScores_includeScoreIsTrue() {
|
||||
def data = PointExportData.create(
|
||||
[point(1, 2, 3, 0)],
|
||||
[vector(0.1)],
|
||||
["feat"]
|
||||
)
|
||||
assertTrue(data.includeScore)
|
||||
}
|
||||
|
||||
@Test
|
||||
void builderProducesDataWithScores() {
|
||||
def builder = PointExportData.builder(["feat1"])
|
||||
builder.add(point(1, 2, 3, 0.5), vector(0.1))
|
||||
def data = builder.build()
|
||||
|
||||
assertTrue(data.includeScore)
|
||||
assertEquals(["x", "y", "z", "score", "feat1"], data.header)
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
private static LabeledPoint point(double x, double y, double z, double score) {
|
||||
def atom = new AtomImpl()
|
||||
atom.coords = [x, y, z] as double[]
|
||||
def lp = new LabeledPoint(atom)
|
||||
lp.score = score
|
||||
return lp
|
||||
}
|
||||
|
||||
private static FeatureVector vector(double... values) {
|
||||
new DoubleVector(values)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user