refactor model flattening to use FasterForestConverter API with configurable target types

Generalize Model classifier from Classifier to Object to support both
trainable classifiers and flat BinaryForest models. Add rf_flatten_target
parameter for selecting forest type (FlatBinaryForest, LegacyFlatBinaryForest,
InterleavedBfsForest, etc). Deprecate rf_flatten_as_legacy in favor of the
new target type selection.
This commit is contained in:
rdk
2026-02-16 01:00:55 +01:00
parent de75ac6be1
commit b8f802b145
9 changed files with 137 additions and 79 deletions

2
.gitignore vendored
View File

@@ -37,3 +37,5 @@ distro/test_output/
distro/README.md
pastebin.txt
CLAUDE.local.md

View File

@@ -157,18 +157,19 @@ predict() {
#
# title PREDICTIONS WITH FLATTENED FOREST
#
# test ./prank.sh predict joined.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict holo4k.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict coach420.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict ah4h.holoraw.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict joined.ds -c config/test-default -rf_flatten 1 -rf_flatten_target FlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict holo4k.ds -c config/test-default -rf_flatten 1 -rf_flatten_target LegacyFlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict coach420.ds -c config/test-default -rf_flatten 1 -rf_flatten_target ShortFlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict ah4h.holoraw.ds -c config/test-default -rf_flatten 1 -rf_flatten_target SuperShortLegacyFlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
#
# test ./prank.sh predict chen11.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict fptrain.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict 'joined(mlig).ds' -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict 'holo4k(mlig).ds' -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict chen11.ds -c config/test-default -rf_flatten 1 -rf_flatten_target FlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict fptrain.ds -c config/test-default -rf_flatten 1 -rf_flatten_target FlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict 'joined(mlig).ds' -c config/test-default -rf_flatten 1 -rf_flatten_target InterleavedBfsForest -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict 'holo4k(mlig).ds' -c config/test-default -rf_flatten 1 -rf_flatten_target InterleavedBfsForest -out_subdir TEST/PREDICT_FLATTENED
#
#}
conservation() {
title PREDICTIONS USING CONSERVATION

View File

@@ -3,6 +3,7 @@ package cz.siret.prank.prediction.pockets.rescorers;
import cz.siret.prank.features.FeatureExtractor;
import cz.siret.prank.features.FeatureVector;
import cz.siret.prank.fforest.FasterForest;
import cz.siret.prank.fforest.api.BinaryForest;
import cz.siret.prank.fforest.api.FlatBinaryForest;
import cz.siret.prank.fforest2.FasterForest2;
import cz.siret.prank.program.ml.Model;
@@ -60,7 +61,7 @@ public interface InstancePredictor {
static InstancePredictor create(Model model, FeatureExtractor<?> proteinExtractor) {
Classifier classifier = model.getClassifier();
Object classifier = model.getClassifier();
InstancePredictor res = null;
@@ -119,9 +120,9 @@ public interface InstancePredictor {
return ff.distributionForAttributes(vect.getArray(), 2);
}
};
} else if (classifier instanceof FlatBinaryForest) {
} else if (classifier instanceof BinaryForest) {
res = new InstancePredictor() { // predictor using faster distributionForAttributes()
final FlatBinaryForest ff = (FlatBinaryForest) classifier;
final BinaryForest ff = (BinaryForest) classifier;
@Override
public double predictPositive(FeatureVector vect) {
@@ -143,7 +144,7 @@ public interface InstancePredictor {
if (res == null) {
log.info("Creating WekaInstancePredictor");
res = new WekaInstancePredictor(model.getClassifier(), proteinExtractor);
res = new WekaInstancePredictor(model.asWekaClassifier(), proteinExtractor);
}
return res;

View File

@@ -3,7 +3,7 @@ package cz.siret.prank.program.ml
import cz.siret.prank.features.FeatureExtractor
import cz.siret.prank.features.PrankFeatureExtractor
import cz.siret.prank.fforest.FasterForest
import cz.siret.prank.fforest.api.FlatBinaryForest
import cz.siret.prank.fforest.api.BinaryForest
import cz.siret.prank.fforest2.FasterForest2
import cz.siret.prank.program.params.Params
import cz.siret.prank.utils.Console
@@ -27,9 +27,9 @@ import javax.annotation.Nullable
class Model {
String label
Classifier classifier
Object classifier // Classifier or BinaryForest (flattened random forest)
Model(String label, Classifier classifier) {
Model(String label, Object classifier) {
this.label = label
this.classifier = Objects.requireNonNull(classifier)
}
@@ -39,9 +39,21 @@ class Model {
return this
}
boolean isTrainable() {
return classifier instanceof Classifier
}
Classifier asWekaClassifier() {
if (classifier instanceof Classifier) {
return (Classifier) classifier
} else {
throw new IllegalStateException("Model classifier is not a trainable Classifier: ${classifier.class.name}")
}
}
boolean hasFeatureImportances() {
// Use Class.isInstance() instead of instanceof to avoid Groovy 5 union type issue (GROOVY-11289)
Classifier c = classifier
Object c = classifier
return FastRandomForest.isInstance(c)
|| FasterForest.isInstance(c)
|| FasterForest2.isInstance(c)
@@ -50,7 +62,7 @@ class Model {
@Nullable
List<Double> getFeatureImportances() {
// Use local variable to avoid Groovy 5 field type narrowing with union types
Classifier c = classifier
Object c = classifier
List<Double> res = null
if (c instanceof FastRandomForest) {
res = (c as FastRandomForest).featureImportances.toList()
@@ -102,7 +114,7 @@ class Model {
}
void saveToFile(String fname) {
WekaUtils.saveClassifier((Classifier)classifier, fname)
WekaUtils.saveClassifier(classifier, fname)
Console.write "model saved to file $fname (${Futils.sizeMBFormatted(fname)} MB)"
}
@@ -132,7 +144,7 @@ class Model {
*/
static Model loadFromDirectoryV3(String dir) {
log.info "Loading model from directory (v3 format): $dir"
Classifier classifier = WekaUtils.loadClassifier(Futils.inputStream(dir + "/model.zst"))
Object classifier = WekaUtils.loadClassifier(Futils.inputStream(dir + "/model.zst"))
return new Model(Futils.shortName(dir), classifier)
}
@@ -148,7 +160,7 @@ class Model {
private static Model loadFromFileV2(String fname) {
//fname += ".zst"
Classifier classifier = WekaUtils.loadClassifier(Futils.inputStream(fname))
Object classifier = WekaUtils.loadClassifier(Futils.inputStream(fname))
return new Model(Futils.shortName(fname), classifier)
}
@@ -191,8 +203,8 @@ class Model {
info.numTrees = rf.numTrees
info.numFeatures = rf.@m_Info?.enumerateAttributes()?.toList()?.size()
info.maxDepth = rf.maxDepth
} else if (classifier instanceof FlatBinaryForest) {
FlatBinaryForest rf = (FlatBinaryForest)classifier
} else if (classifier instanceof BinaryForest) {
BinaryForest rf = (BinaryForest)classifier
info.isForest = true
info.numTrees = rf.numTrees
info.numFeatures = rf.numAttributes

View File

@@ -2,8 +2,9 @@ package cz.siret.prank.program.ml
import cz.siret.prank.fforest.FasterForest
import cz.siret.prank.fforest.FasterTree
import cz.siret.prank.fforest.api.FlatBinaryForest
import cz.siret.prank.fforest.api.FlatBinaryForestBuilder
import cz.siret.prank.fforest.api.BinaryForest
import cz.siret.prank.fforest.api.FasterForestConverter
import cz.siret.prank.fforest.api.TrainableFasterForest
import cz.siret.prank.fforest2.FasterForest2
import cz.siret.prank.program.params.Parametrized
import cz.siret.prank.utils.ATimer
@@ -14,52 +15,68 @@ import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j
import cz.siret.prank.utils.Parallel
import hr.irb.fastRandomForest.FastRandomForest
import org.apache.commons.lang3.StringUtils
import weka.classifiers.Classifier
import weka.core.Instances
import javax.annotation.Nullable
/**
*
* Utility class for converting models to different formats, e.g. flattening random forests to a more efficient format for prediction.
*/
@Slf4j
@CompileStatic
class ModelConverter implements Parametrized, Writable {
Model applyConversions(Model model) {
if (params.rf_flatten) {
model = flattenRandomForest(model)
if (!StringUtils.isBlank(params.rf_flatten_target)) {
model = flattenRandomForest(model, params.rf_flatten_target)
} else {
// useful as no-op option when running ploop for rf_flatten_target param
log.info "'rf_flatten_target' parameter is empty, no flattening is applied."
}
}
return model
}
//===========================================================================================================//
static List<Class> FLATTABLE_CLASSIFIERS = (List) [FastRandomForest, FasterForest, FasterForest2]
static List<Class> FLATTABLE_CLASSIFIERS = [FastRandomForest, FasterForest, FasterForest2] as List
static List<String> FLATTABLE_CLASSIFIER_NAMES = FLATTABLE_CLASSIFIERS*.simpleName
static boolean isFlattableClassifier(Classifier c) {
static boolean isFlattableClassifier(Object c) {
return SysUtils.isInstanceOfAny(c, FLATTABLE_CLASSIFIERS)
}
Model flattenRandomForest(Model model) {
Model flattenRandomForest(Model model, String targetType) {
def c = model.classifier
if (isFlattableClassifier(c)) {
ATimer timer = ATimer.startTimer()
write "Converting ${c.class.simpleName} to FlatBinaryForest"
write "Flattening ${c.class.simpleName} to $targetType"
FlatBinaryForest fbf
if (c instanceof FastRandomForest) {
fbf = frfToFlatForest((FastRandomForest)c)
} else if (c instanceof FasterForest) {
fbf = ((FasterForest)c).toFlatBinaryForest(params.rf_flatten_as_legacy)
} else { // FF2
fbf = ((FasterForest2)c).toFlatBinaryForest(params.rf_flatten_as_legacy)
FasterForestConverter.ForestType forestType
try {
forestType = FasterForestConverter.ForestType.valueOf(targetType)
} catch (Exception e) {
throw new IllegalArgumentException("Unknown target forest type '$targetType'. Supported types: ${FasterForestConverter.ForestType.values()*.name()}.")
}
BinaryForest flatForest
if (c instanceof TrainableFasterForest) {
flatForest = FasterForestConverter.convertFasterForest((TrainableFasterForest) c, forestType)
} else if (c instanceof FastRandomForest) {
TrainableFasterForest trainableForest = frfToTrainableBinaryForest((FastRandomForest) c)
flatForest = FasterForestConverter.convertFasterForest(trainableForest, forestType)
} else {
throw new IllegalStateException("Unexpected flattable forest type: ${c.class.simpleName}")
}
write " - flattened in: $timer.formatted"
return new Model("FlatBinaryForest_from_${model.label}", fbf)
return new Model("FlatBinaryForest_from_${model.label}", flatForest)
} else {
log.warn "Cannot flatten classifier of type ${c.class.simpleName}. Flattable classifiers: ${FLATTABLE_CLASSIFIER_NAMES}"
return model
@@ -68,19 +85,29 @@ class ModelConverter implements Parametrized, Writable {
//===========================================================================================================//
@CompileDynamic
FlatBinaryForest frfToFlatForest(FastRandomForest forest) {
ATimer timer = ATimer.startTimer()
int numAttributes = forest.@m_Info.numAttributes();
TrainableFasterForest frfToTrainableBinaryForest(FastRandomForest forest) {
int numAttributes = forest.@m_Info.numAttributes()
List<Classifier> mTrees = Arrays.asList(forest.@m_bagger.@m_Classifiers)
List<FasterTree> trees = Parallel.collectParallel(mTrees, params.threads * 2) { frfTreeToFasterTree(it) }
write " - faster trees converted in: $timer.formatted"
return new FlatBinaryForestBuilder().buildFromFasterTrees(numAttributes, trees, params.rf_flatten_as_legacy)
return new TrainableFasterForest() {
@Override
int getNumAttributes() {
return numAttributes
}
@Override
List<FasterTree> getTrees() {
return trees
}
@Override
void buildClassifier(Instances instances) throws Exception {
// NO-OP
}
}
}
/**
@@ -107,5 +134,4 @@ class ModelConverter implements Parametrized, Writable {
return new FasterTree(childLeft, childRight, attribute, splitPoint, classProbs)
}
}

View File

@@ -377,10 +377,25 @@ class Params {
@ModelParam // training
boolean rf_flatten = false
/**
* Flattening target type for random forest. Only relevant if rf_flatten=true.
*
* Available options:
* - LegacyFlatBinaryForest
* - FlatBinaryForest
* - ShortFlatBinaryForest
* - SuperShortLegacyFlatBinaryForest
* - InterleavedBfsForest
*/
@RuntimeParam
@ModelParam // training
String rf_flatten_target = "LegacyFlatBinaryForest"
/**
* Flatten random forest in a way that has exactly the same output
* by preserving weird way tree results are aggregated in FastRandomForest.
*/
@Deprecated
@RuntimeParam
@ModelParam // training
boolean rf_flatten_as_legacy = true

View File

@@ -1,10 +1,10 @@
package cz.siret.prank.program.routines.traineval
import cz.siret.prank.domain.Dataset
import cz.siret.prank.fforest.api.FlattableForest
import cz.siret.prank.prediction.metrics.ClassifierStats
import cz.siret.prank.program.ml.FeatureVectors
import cz.siret.prank.program.ml.Model
import cz.siret.prank.program.ml.ModelConverter
import cz.siret.prank.program.params.Parametrized
import cz.siret.prank.program.routines.results.EvalResults
import cz.siret.prank.program.routines.results.FeatureImportances
@@ -106,17 +106,22 @@ class TrainEvalRoutine extends EvalRoutine implements Parametrized {
Futils.delete(evalVectorFile)
}
ClassifierStats calculateTrainStats(Classifier classifier, FeatureVectors trainVectors) {
ClassifierStats calculateTrainStats(Object classifier, FeatureVectors trainVectors) {
if (params.classifier_train_stats) {
ClassifierStats trainStats = new ClassifierStats()
for (Instance inst : trainVectors.instances) {
double[] hist = classifier.distributionForInstance(inst)
double score = normalizedScore(hist)
boolean predicted = applyPointScoreThreshold(score)
boolean observed = inst.classValue() > 0
// for (Instance inst : trainVectors.instances) {
// double[] hist = classifier.distributionForInstance(inst)
// double score = normalizedScore(hist)
// boolean predicted = applyPointScoreThreshold(score)
// boolean observed = inst.classValue() > 0
//
// trainStats.addPrediction(observed, predicted, score)
// }
// TODO: implementation needs reconsidering since classifier can be of various types
// and not all of them support distributionForInstance() method (e.g. flat BinaryForest)
log.warn("Calculating training stats for classifier of type ${classifier.class.simpleName} is not implemented. Returning empty stats.")
trainStats.addPrediction(observed, predicted, score)
}
return trainStats
} else {
return null
@@ -199,18 +204,15 @@ class TrainEvalRoutine extends EvalRoutine implements Parametrized {
void trainModel(Model model, FeatureVectors data) {
WekaUtils.trainClassifier(model.classifier, data)
WekaUtils.trainClassifier(model.asWekaClassifier(), data)
if (params.rf_flatten) {
if (model.classifier instanceof FlattableForest) {
log.info "Flattening random forest"
def timer = startTimer()
model.classifier = ((FlattableForest)model.classifier).toFlatBinaryForest()
logTime "model flattened in " + timer.formatted
model.label = model.label + "_flat"
if (ModelConverter.isFlattableClassifier(model.classifier)) {
Model flattenedModel = new ModelConverter().applyConversions(model)
model.classifier = flattenedModel.classifier
model.label = flattenedModel.label
} else {
log.warn("Trying to flatten classifier that does not support it: " + model.classifier.class.simpleName)
throw new IllegalStateException("Trying to flatten classifier that does not support it: " + model.classifier.class.simpleName)
}
}
}

View File

@@ -48,7 +48,7 @@ class WekaUtils {
// == classifiers ===
static void saveClassifier(Classifier classifier, String fileName) {
static void saveClassifier(Object classifier, String fileName) {
ZipOutputStream zos = new ZipOutputStream(new BufferedOutputStream(new FileOutputStream(fileName), BUFFER_SIZE))
//zos.setLevel(9)
@@ -67,7 +67,7 @@ class WekaUtils {
oos.close()
}
static Classifier loadClassifier(String fileName) {
static Object loadClassifier(String fileName) {
InputStream zis = null
try {
zis = new ZipInputStream(new BufferedInputStream(new FileInputStream(fileName), BUFFER_SIZE))
@@ -81,9 +81,9 @@ class WekaUtils {
}
}
static Classifier loadClassifier(InputStream ins) {
static Object loadClassifier(InputStream ins) {
try {
return (Classifier) SerializationHelper.read(ins)
return SerializationHelper.read(ins)
} finally {
ins.close()
}
@@ -94,7 +94,7 @@ class WekaUtils {
* @param classifier
*/
@CompileDynamic
static void disableParallelism(Classifier classifier) {
static void disableParallelism(Object classifier) {
String[] threadPropNames = ["numThreads","numExecutionSlots"] // names used for num.threads property by different classifiers
threadPropNames.each { String name ->
if (classifier.hasProperty(name))
@@ -102,13 +102,6 @@ class WekaUtils {
}
}
/**
* load from jar
*/
static Classifier loadClassifierFromPath(String path) {
return (Classifier) SerializationHelper.read(path.class.getResourceAsStream(path));
}
static void trainClassifier(Classifier classifier, FeatureVectors data) {
validateDataset(data.instances)
classifier.buildClassifier(data.instances)

View File

@@ -1,5 +1,7 @@
package cz.siret.prank.program.ml
import cz.siret.prank.fforest.api.FasterForestConverter
import cz.siret.prank.fforest.api.TrainableFasterForest
import groovy.transform.CompileStatic
import hr.irb.fastRandomForest.FastRandomForest
import org.junit.jupiter.api.Disabled
@@ -20,7 +22,11 @@ class ModelConverterTest {
Model model = Model.load("distro/models/default.model")
assert model.classifier instanceof FastRandomForest
new ModelConverter().frfToFlatForest((FastRandomForest)model.classifier)
TrainableFasterForest trainableForest = new ModelConverter().frfToTrainableBinaryForest((FastRandomForest)model.classifier)
FasterForestConverter.convertFasterForest(trainableForest, FasterForestConverter.ForestType.FlatBinaryForest)
FasterForestConverter.convertFasterForest(trainableForest, FasterForestConverter.ForestType.LegacyFlatBinaryForest)
FasterForestConverter.convertFasterForest(trainableForest, FasterForestConverter.ForestType.InterleavedBfsForest)
}
}