refactor model flattening to use FasterForestConverter API with configurable target types

Generalize Model classifier from Classifier to Object to support both
trainable classifiers and flat BinaryForest models. Add rf_flatten_target
parameter for selecting forest type (FlatBinaryForest, LegacyFlatBinaryForest,
InterleavedBfsForest, etc). Deprecate rf_flatten_as_legacy in favor of the
new target type selection.
This commit is contained in:
rdk
2026-02-16 01:00:55 +01:00
parent de75ac6be1
commit b8f802b145
9 changed files with 137 additions and 79 deletions

2
.gitignore vendored
View File

@@ -37,3 +37,5 @@ distro/test_output/
distro/README.md distro/README.md
pastebin.txt pastebin.txt
CLAUDE.local.md

View File

@@ -157,18 +157,19 @@ predict() {
# #
# title PREDICTIONS WITH FLATTENED FOREST # title PREDICTIONS WITH FLATTENED FOREST
# #
# test ./prank.sh predict joined.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED # test ./prank.sh predict joined.ds -c config/test-default -rf_flatten 1 -rf_flatten_target FlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict holo4k.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED # test ./prank.sh predict holo4k.ds -c config/test-default -rf_flatten 1 -rf_flatten_target LegacyFlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict coach420.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED # test ./prank.sh predict coach420.ds -c config/test-default -rf_flatten 1 -rf_flatten_target ShortFlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict ah4h.holoraw.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED # test ./prank.sh predict ah4h.holoraw.ds -c config/test-default -rf_flatten 1 -rf_flatten_target SuperShortLegacyFlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
# #
# test ./prank.sh predict chen11.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED # test ./prank.sh predict chen11.ds -c config/test-default -rf_flatten 1 -rf_flatten_target FlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict fptrain.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED # test ./prank.sh predict fptrain.ds -c config/test-default -rf_flatten 1 -rf_flatten_target FlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict 'joined(mlig).ds' -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED # test ./prank.sh predict 'joined(mlig).ds' -c config/test-default -rf_flatten 1 -rf_flatten_target InterleavedBfsForest -out_subdir TEST/PREDICT_FLATTENED
# test ./prank.sh predict 'holo4k(mlig).ds' -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED # test ./prank.sh predict 'holo4k(mlig).ds' -c config/test-default -rf_flatten 1 -rf_flatten_target InterleavedBfsForest -out_subdir TEST/PREDICT_FLATTENED
# #
#} #}
conservation() { conservation() {
title PREDICTIONS USING CONSERVATION title PREDICTIONS USING CONSERVATION

View File

@@ -3,6 +3,7 @@ package cz.siret.prank.prediction.pockets.rescorers;
import cz.siret.prank.features.FeatureExtractor; import cz.siret.prank.features.FeatureExtractor;
import cz.siret.prank.features.FeatureVector; import cz.siret.prank.features.FeatureVector;
import cz.siret.prank.fforest.FasterForest; import cz.siret.prank.fforest.FasterForest;
import cz.siret.prank.fforest.api.BinaryForest;
import cz.siret.prank.fforest.api.FlatBinaryForest; import cz.siret.prank.fforest.api.FlatBinaryForest;
import cz.siret.prank.fforest2.FasterForest2; import cz.siret.prank.fforest2.FasterForest2;
import cz.siret.prank.program.ml.Model; import cz.siret.prank.program.ml.Model;
@@ -60,7 +61,7 @@ public interface InstancePredictor {
static InstancePredictor create(Model model, FeatureExtractor<?> proteinExtractor) { static InstancePredictor create(Model model, FeatureExtractor<?> proteinExtractor) {
Classifier classifier = model.getClassifier(); Object classifier = model.getClassifier();
InstancePredictor res = null; InstancePredictor res = null;
@@ -119,9 +120,9 @@ public interface InstancePredictor {
return ff.distributionForAttributes(vect.getArray(), 2); return ff.distributionForAttributes(vect.getArray(), 2);
} }
}; };
} else if (classifier instanceof FlatBinaryForest) { } else if (classifier instanceof BinaryForest) {
res = new InstancePredictor() { // predictor using faster distributionForAttributes() res = new InstancePredictor() { // predictor using faster distributionForAttributes()
final FlatBinaryForest ff = (FlatBinaryForest) classifier; final BinaryForest ff = (BinaryForest) classifier;
@Override @Override
public double predictPositive(FeatureVector vect) { public double predictPositive(FeatureVector vect) {
@@ -143,7 +144,7 @@ public interface InstancePredictor {
if (res == null) { if (res == null) {
log.info("Creating WekaInstancePredictor"); log.info("Creating WekaInstancePredictor");
res = new WekaInstancePredictor(model.getClassifier(), proteinExtractor); res = new WekaInstancePredictor(model.asWekaClassifier(), proteinExtractor);
} }
return res; return res;

View File

@@ -3,7 +3,7 @@ package cz.siret.prank.program.ml
import cz.siret.prank.features.FeatureExtractor import cz.siret.prank.features.FeatureExtractor
import cz.siret.prank.features.PrankFeatureExtractor import cz.siret.prank.features.PrankFeatureExtractor
import cz.siret.prank.fforest.FasterForest import cz.siret.prank.fforest.FasterForest
import cz.siret.prank.fforest.api.FlatBinaryForest import cz.siret.prank.fforest.api.BinaryForest
import cz.siret.prank.fforest2.FasterForest2 import cz.siret.prank.fforest2.FasterForest2
import cz.siret.prank.program.params.Params import cz.siret.prank.program.params.Params
import cz.siret.prank.utils.Console import cz.siret.prank.utils.Console
@@ -27,9 +27,9 @@ import javax.annotation.Nullable
class Model { class Model {
String label String label
Classifier classifier Object classifier // Classifier or BinaryForest (flattened random forest)
Model(String label, Classifier classifier) { Model(String label, Object classifier) {
this.label = label this.label = label
this.classifier = Objects.requireNonNull(classifier) this.classifier = Objects.requireNonNull(classifier)
} }
@@ -39,9 +39,21 @@ class Model {
return this return this
} }
boolean isTrainable() {
return classifier instanceof Classifier
}
Classifier asWekaClassifier() {
if (classifier instanceof Classifier) {
return (Classifier) classifier
} else {
throw new IllegalStateException("Model classifier is not a trainable Classifier: ${classifier.class.name}")
}
}
boolean hasFeatureImportances() { boolean hasFeatureImportances() {
// Use Class.isInstance() instead of instanceof to avoid Groovy 5 union type issue (GROOVY-11289) // Use Class.isInstance() instead of instanceof to avoid Groovy 5 union type issue (GROOVY-11289)
Classifier c = classifier Object c = classifier
return FastRandomForest.isInstance(c) return FastRandomForest.isInstance(c)
|| FasterForest.isInstance(c) || FasterForest.isInstance(c)
|| FasterForest2.isInstance(c) || FasterForest2.isInstance(c)
@@ -50,7 +62,7 @@ class Model {
@Nullable @Nullable
List<Double> getFeatureImportances() { List<Double> getFeatureImportances() {
// Use local variable to avoid Groovy 5 field type narrowing with union types // Use local variable to avoid Groovy 5 field type narrowing with union types
Classifier c = classifier Object c = classifier
List<Double> res = null List<Double> res = null
if (c instanceof FastRandomForest) { if (c instanceof FastRandomForest) {
res = (c as FastRandomForest).featureImportances.toList() res = (c as FastRandomForest).featureImportances.toList()
@@ -102,7 +114,7 @@ class Model {
} }
void saveToFile(String fname) { void saveToFile(String fname) {
WekaUtils.saveClassifier((Classifier)classifier, fname) WekaUtils.saveClassifier(classifier, fname)
Console.write "model saved to file $fname (${Futils.sizeMBFormatted(fname)} MB)" Console.write "model saved to file $fname (${Futils.sizeMBFormatted(fname)} MB)"
} }
@@ -132,7 +144,7 @@ class Model {
*/ */
static Model loadFromDirectoryV3(String dir) { static Model loadFromDirectoryV3(String dir) {
log.info "Loading model from directory (v3 format): $dir" log.info "Loading model from directory (v3 format): $dir"
Classifier classifier = WekaUtils.loadClassifier(Futils.inputStream(dir + "/model.zst")) Object classifier = WekaUtils.loadClassifier(Futils.inputStream(dir + "/model.zst"))
return new Model(Futils.shortName(dir), classifier) return new Model(Futils.shortName(dir), classifier)
} }
@@ -148,7 +160,7 @@ class Model {
private static Model loadFromFileV2(String fname) { private static Model loadFromFileV2(String fname) {
//fname += ".zst" //fname += ".zst"
Classifier classifier = WekaUtils.loadClassifier(Futils.inputStream(fname)) Object classifier = WekaUtils.loadClassifier(Futils.inputStream(fname))
return new Model(Futils.shortName(fname), classifier) return new Model(Futils.shortName(fname), classifier)
} }
@@ -191,8 +203,8 @@ class Model {
info.numTrees = rf.numTrees info.numTrees = rf.numTrees
info.numFeatures = rf.@m_Info?.enumerateAttributes()?.toList()?.size() info.numFeatures = rf.@m_Info?.enumerateAttributes()?.toList()?.size()
info.maxDepth = rf.maxDepth info.maxDepth = rf.maxDepth
} else if (classifier instanceof FlatBinaryForest) { } else if (classifier instanceof BinaryForest) {
FlatBinaryForest rf = (FlatBinaryForest)classifier BinaryForest rf = (BinaryForest)classifier
info.isForest = true info.isForest = true
info.numTrees = rf.numTrees info.numTrees = rf.numTrees
info.numFeatures = rf.numAttributes info.numFeatures = rf.numAttributes

View File

@@ -2,8 +2,9 @@ package cz.siret.prank.program.ml
import cz.siret.prank.fforest.FasterForest import cz.siret.prank.fforest.FasterForest
import cz.siret.prank.fforest.FasterTree import cz.siret.prank.fforest.FasterTree
import cz.siret.prank.fforest.api.FlatBinaryForest import cz.siret.prank.fforest.api.BinaryForest
import cz.siret.prank.fforest.api.FlatBinaryForestBuilder import cz.siret.prank.fforest.api.FasterForestConverter
import cz.siret.prank.fforest.api.TrainableFasterForest
import cz.siret.prank.fforest2.FasterForest2 import cz.siret.prank.fforest2.FasterForest2
import cz.siret.prank.program.params.Parametrized import cz.siret.prank.program.params.Parametrized
import cz.siret.prank.utils.ATimer import cz.siret.prank.utils.ATimer
@@ -14,52 +15,68 @@ import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j import groovy.util.logging.Slf4j
import cz.siret.prank.utils.Parallel import cz.siret.prank.utils.Parallel
import hr.irb.fastRandomForest.FastRandomForest import hr.irb.fastRandomForest.FastRandomForest
import org.apache.commons.lang3.StringUtils
import weka.classifiers.Classifier import weka.classifiers.Classifier
import weka.core.Instances
import javax.annotation.Nullable import javax.annotation.Nullable
/** /**
* * Utility class for converting models to different formats, e.g. flattening random forests to a more efficient format for prediction.
*/ */
@Slf4j @Slf4j
@CompileStatic @CompileStatic
class ModelConverter implements Parametrized, Writable { class ModelConverter implements Parametrized, Writable {
Model applyConversions(Model model) { Model applyConversions(Model model) {
if (params.rf_flatten) { if (params.rf_flatten) {
model = flattenRandomForest(model) if (!StringUtils.isBlank(params.rf_flatten_target)) {
model = flattenRandomForest(model, params.rf_flatten_target)
} else {
// useful as no-op option when running ploop for rf_flatten_target param
log.info "'rf_flatten_target' parameter is empty, no flattening is applied."
}
} }
return model return model
} }
//===========================================================================================================// //===========================================================================================================//
static List<Class> FLATTABLE_CLASSIFIERS = (List) [FastRandomForest, FasterForest, FasterForest2] static List<Class> FLATTABLE_CLASSIFIERS = [FastRandomForest, FasterForest, FasterForest2] as List
static List<String> FLATTABLE_CLASSIFIER_NAMES = FLATTABLE_CLASSIFIERS*.simpleName static List<String> FLATTABLE_CLASSIFIER_NAMES = FLATTABLE_CLASSIFIERS*.simpleName
static boolean isFlattableClassifier(Classifier c) { static boolean isFlattableClassifier(Object c) {
return SysUtils.isInstanceOfAny(c, FLATTABLE_CLASSIFIERS) return SysUtils.isInstanceOfAny(c, FLATTABLE_CLASSIFIERS)
} }
Model flattenRandomForest(Model model) { Model flattenRandomForest(Model model, String targetType) {
def c = model.classifier def c = model.classifier
if (isFlattableClassifier(c)) { if (isFlattableClassifier(c)) {
ATimer timer = ATimer.startTimer() ATimer timer = ATimer.startTimer()
write "Converting ${c.class.simpleName} to FlatBinaryForest" write "Flattening ${c.class.simpleName} to $targetType"
FlatBinaryForest fbf FasterForestConverter.ForestType forestType
if (c instanceof FastRandomForest) { try {
fbf = frfToFlatForest((FastRandomForest)c) forestType = FasterForestConverter.ForestType.valueOf(targetType)
} else if (c instanceof FasterForest) { } catch (Exception e) {
fbf = ((FasterForest)c).toFlatBinaryForest(params.rf_flatten_as_legacy) throw new IllegalArgumentException("Unknown target forest type '$targetType'. Supported types: ${FasterForestConverter.ForestType.values()*.name()}.")
} else { // FF2 }
fbf = ((FasterForest2)c).toFlatBinaryForest(params.rf_flatten_as_legacy)
BinaryForest flatForest
if (c instanceof TrainableFasterForest) {
flatForest = FasterForestConverter.convertFasterForest((TrainableFasterForest) c, forestType)
} else if (c instanceof FastRandomForest) {
TrainableFasterForest trainableForest = frfToTrainableBinaryForest((FastRandomForest) c)
flatForest = FasterForestConverter.convertFasterForest(trainableForest, forestType)
} else {
throw new IllegalStateException("Unexpected flattable forest type: ${c.class.simpleName}")
} }
write " - flattened in: $timer.formatted" write " - flattened in: $timer.formatted"
return new Model("FlatBinaryForest_from_${model.label}", fbf) return new Model("FlatBinaryForest_from_${model.label}", flatForest)
} else { } else {
log.warn "Cannot flatten classifier of type ${c.class.simpleName}. Flattable classifiers: ${FLATTABLE_CLASSIFIER_NAMES}" log.warn "Cannot flatten classifier of type ${c.class.simpleName}. Flattable classifiers: ${FLATTABLE_CLASSIFIER_NAMES}"
return model return model
@@ -68,19 +85,29 @@ class ModelConverter implements Parametrized, Writable {
//===========================================================================================================// //===========================================================================================================//
@CompileDynamic @CompileDynamic
FlatBinaryForest frfToFlatForest(FastRandomForest forest) { TrainableFasterForest frfToTrainableBinaryForest(FastRandomForest forest) {
ATimer timer = ATimer.startTimer() int numAttributes = forest.@m_Info.numAttributes()
int numAttributes = forest.@m_Info.numAttributes();
List<Classifier> mTrees = Arrays.asList(forest.@m_bagger.@m_Classifiers) List<Classifier> mTrees = Arrays.asList(forest.@m_bagger.@m_Classifiers)
List<FasterTree> trees = Parallel.collectParallel(mTrees, params.threads * 2) { frfTreeToFasterTree(it) } List<FasterTree> trees = Parallel.collectParallel(mTrees, params.threads * 2) { frfTreeToFasterTree(it) }
write " - faster trees converted in: $timer.formatted"
return new FlatBinaryForestBuilder().buildFromFasterTrees(numAttributes, trees, params.rf_flatten_as_legacy) return new TrainableFasterForest() {
@Override
int getNumAttributes() {
return numAttributes
}
@Override
List<FasterTree> getTrees() {
return trees
}
@Override
void buildClassifier(Instances instances) throws Exception {
// NO-OP
}
}
} }
/** /**
@@ -107,5 +134,4 @@ class ModelConverter implements Parametrized, Writable {
return new FasterTree(childLeft, childRight, attribute, splitPoint, classProbs) return new FasterTree(childLeft, childRight, attribute, splitPoint, classProbs)
} }
} }

View File

@@ -377,10 +377,25 @@ class Params {
@ModelParam // training @ModelParam // training
boolean rf_flatten = false boolean rf_flatten = false
/**
* Flattening target type for random forest. Only relevant if rf_flatten=true.
*
* Available options:
* - LegacyFlatBinaryForest
* - FlatBinaryForest
* - ShortFlatBinaryForest
* - SuperShortLegacyFlatBinaryForest
* - InterleavedBfsForest
*/
@RuntimeParam
@ModelParam // training
String rf_flatten_target = "LegacyFlatBinaryForest"
/** /**
* Flatten random forest in a way that has exactly the same output * Flatten random forest in a way that has exactly the same output
* by preserving weird way tree results are aggregated in FastRandomForest. * by preserving weird way tree results are aggregated in FastRandomForest.
*/ */
@Deprecated
@RuntimeParam @RuntimeParam
@ModelParam // training @ModelParam // training
boolean rf_flatten_as_legacy = true boolean rf_flatten_as_legacy = true

View File

@@ -1,10 +1,10 @@
package cz.siret.prank.program.routines.traineval package cz.siret.prank.program.routines.traineval
import cz.siret.prank.domain.Dataset import cz.siret.prank.domain.Dataset
import cz.siret.prank.fforest.api.FlattableForest
import cz.siret.prank.prediction.metrics.ClassifierStats import cz.siret.prank.prediction.metrics.ClassifierStats
import cz.siret.prank.program.ml.FeatureVectors import cz.siret.prank.program.ml.FeatureVectors
import cz.siret.prank.program.ml.Model import cz.siret.prank.program.ml.Model
import cz.siret.prank.program.ml.ModelConverter
import cz.siret.prank.program.params.Parametrized import cz.siret.prank.program.params.Parametrized
import cz.siret.prank.program.routines.results.EvalResults import cz.siret.prank.program.routines.results.EvalResults
import cz.siret.prank.program.routines.results.FeatureImportances import cz.siret.prank.program.routines.results.FeatureImportances
@@ -106,17 +106,22 @@ class TrainEvalRoutine extends EvalRoutine implements Parametrized {
Futils.delete(evalVectorFile) Futils.delete(evalVectorFile)
} }
ClassifierStats calculateTrainStats(Classifier classifier, FeatureVectors trainVectors) { ClassifierStats calculateTrainStats(Object classifier, FeatureVectors trainVectors) {
if (params.classifier_train_stats) { if (params.classifier_train_stats) {
ClassifierStats trainStats = new ClassifierStats() ClassifierStats trainStats = new ClassifierStats()
for (Instance inst : trainVectors.instances) { // for (Instance inst : trainVectors.instances) {
double[] hist = classifier.distributionForInstance(inst) // double[] hist = classifier.distributionForInstance(inst)
double score = normalizedScore(hist) // double score = normalizedScore(hist)
boolean predicted = applyPointScoreThreshold(score) // boolean predicted = applyPointScoreThreshold(score)
boolean observed = inst.classValue() > 0 // boolean observed = inst.classValue() > 0
//
// trainStats.addPrediction(observed, predicted, score)
// }
// TODO: implementation needs reconsidering since classifier can be of various types
// and not all of them support distributionForInstance() method (e.g. flat BinaryForest)
log.warn("Calculating training stats for classifier of type ${classifier.class.simpleName} is not implemented. Returning empty stats.")
trainStats.addPrediction(observed, predicted, score)
}
return trainStats return trainStats
} else { } else {
return null return null
@@ -199,18 +204,15 @@ class TrainEvalRoutine extends EvalRoutine implements Parametrized {
void trainModel(Model model, FeatureVectors data) { void trainModel(Model model, FeatureVectors data) {
WekaUtils.trainClassifier(model.classifier, data) WekaUtils.trainClassifier(model.asWekaClassifier(), data)
if (params.rf_flatten) { if (params.rf_flatten) {
if (model.classifier instanceof FlattableForest) { if (ModelConverter.isFlattableClassifier(model.classifier)) {
log.info "Flattening random forest" Model flattenedModel = new ModelConverter().applyConversions(model)
def timer = startTimer() model.classifier = flattenedModel.classifier
model.classifier = ((FlattableForest)model.classifier).toFlatBinaryForest() model.label = flattenedModel.label
logTime "model flattened in " + timer.formatted
model.label = model.label + "_flat"
} else { } else {
log.warn("Trying to flatten classifier that does not support it: " + model.classifier.class.simpleName) throw new IllegalStateException("Trying to flatten classifier that does not support it: " + model.classifier.class.simpleName)
} }
} }
} }

View File

@@ -48,7 +48,7 @@ class WekaUtils {
// == classifiers === // == classifiers ===
static void saveClassifier(Classifier classifier, String fileName) { static void saveClassifier(Object classifier, String fileName) {
ZipOutputStream zos = new ZipOutputStream(new BufferedOutputStream(new FileOutputStream(fileName), BUFFER_SIZE)) ZipOutputStream zos = new ZipOutputStream(new BufferedOutputStream(new FileOutputStream(fileName), BUFFER_SIZE))
//zos.setLevel(9) //zos.setLevel(9)
@@ -67,7 +67,7 @@ class WekaUtils {
oos.close() oos.close()
} }
static Classifier loadClassifier(String fileName) { static Object loadClassifier(String fileName) {
InputStream zis = null InputStream zis = null
try { try {
zis = new ZipInputStream(new BufferedInputStream(new FileInputStream(fileName), BUFFER_SIZE)) zis = new ZipInputStream(new BufferedInputStream(new FileInputStream(fileName), BUFFER_SIZE))
@@ -81,9 +81,9 @@ class WekaUtils {
} }
} }
static Classifier loadClassifier(InputStream ins) { static Object loadClassifier(InputStream ins) {
try { try {
return (Classifier) SerializationHelper.read(ins) return SerializationHelper.read(ins)
} finally { } finally {
ins.close() ins.close()
} }
@@ -94,7 +94,7 @@ class WekaUtils {
* @param classifier * @param classifier
*/ */
@CompileDynamic @CompileDynamic
static void disableParallelism(Classifier classifier) { static void disableParallelism(Object classifier) {
String[] threadPropNames = ["numThreads","numExecutionSlots"] // names used for num.threads property by different classifiers String[] threadPropNames = ["numThreads","numExecutionSlots"] // names used for num.threads property by different classifiers
threadPropNames.each { String name -> threadPropNames.each { String name ->
if (classifier.hasProperty(name)) if (classifier.hasProperty(name))
@@ -102,13 +102,6 @@ class WekaUtils {
} }
} }
/**
* load from jar
*/
static Classifier loadClassifierFromPath(String path) {
return (Classifier) SerializationHelper.read(path.class.getResourceAsStream(path));
}
static void trainClassifier(Classifier classifier, FeatureVectors data) { static void trainClassifier(Classifier classifier, FeatureVectors data) {
validateDataset(data.instances) validateDataset(data.instances)
classifier.buildClassifier(data.instances) classifier.buildClassifier(data.instances)

View File

@@ -1,5 +1,7 @@
package cz.siret.prank.program.ml package cz.siret.prank.program.ml
import cz.siret.prank.fforest.api.FasterForestConverter
import cz.siret.prank.fforest.api.TrainableFasterForest
import groovy.transform.CompileStatic import groovy.transform.CompileStatic
import hr.irb.fastRandomForest.FastRandomForest import hr.irb.fastRandomForest.FastRandomForest
import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Disabled
@@ -20,7 +22,11 @@ class ModelConverterTest {
Model model = Model.load("distro/models/default.model") Model model = Model.load("distro/models/default.model")
assert model.classifier instanceof FastRandomForest assert model.classifier instanceof FastRandomForest
new ModelConverter().frfToFlatForest((FastRandomForest)model.classifier) TrainableFasterForest trainableForest = new ModelConverter().frfToTrainableBinaryForest((FastRandomForest)model.classifier)
FasterForestConverter.convertFasterForest(trainableForest, FasterForestConverter.ForestType.FlatBinaryForest)
FasterForestConverter.convertFasterForest(trainableForest, FasterForestConverter.ForestType.LegacyFlatBinaryForest)
FasterForestConverter.convertFasterForest(trainableForest, FasterForestConverter.ForestType.InterleavedBfsForest)
} }
} }