mirror of
https://github.com/rdk/p2rank.git
synced 2026-06-04 12:44:24 +08:00
refactor model flattening to use FasterForestConverter API with configurable target types
Generalize Model classifier from Classifier to Object to support both trainable classifiers and flat BinaryForest models. Add rf_flatten_target parameter for selecting forest type (FlatBinaryForest, LegacyFlatBinaryForest, InterleavedBfsForest, etc). Deprecate rf_flatten_as_legacy in favor of the new target type selection.
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -37,3 +37,5 @@ distro/test_output/
|
|||||||
distro/README.md
|
distro/README.md
|
||||||
|
|
||||||
pastebin.txt
|
pastebin.txt
|
||||||
|
|
||||||
|
CLAUDE.local.md
|
||||||
|
|||||||
@@ -157,18 +157,19 @@ predict() {
|
|||||||
#
|
#
|
||||||
# title PREDICTIONS WITH FLATTENED FOREST
|
# title PREDICTIONS WITH FLATTENED FOREST
|
||||||
#
|
#
|
||||||
# test ./prank.sh predict joined.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED
|
# test ./prank.sh predict joined.ds -c config/test-default -rf_flatten 1 -rf_flatten_target FlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
|
||||||
# test ./prank.sh predict holo4k.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED
|
# test ./prank.sh predict holo4k.ds -c config/test-default -rf_flatten 1 -rf_flatten_target LegacyFlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
|
||||||
# test ./prank.sh predict coach420.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED
|
# test ./prank.sh predict coach420.ds -c config/test-default -rf_flatten 1 -rf_flatten_target ShortFlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
|
||||||
# test ./prank.sh predict ah4h.holoraw.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED
|
# test ./prank.sh predict ah4h.holoraw.ds -c config/test-default -rf_flatten 1 -rf_flatten_target SuperShortLegacyFlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
|
||||||
#
|
#
|
||||||
# test ./prank.sh predict chen11.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED
|
# test ./prank.sh predict chen11.ds -c config/test-default -rf_flatten 1 -rf_flatten_target FlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
|
||||||
# test ./prank.sh predict fptrain.ds -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED
|
# test ./prank.sh predict fptrain.ds -c config/test-default -rf_flatten 1 -rf_flatten_target FlatBinaryForest -out_subdir TEST/PREDICT_FLATTENED
|
||||||
# test ./prank.sh predict 'joined(mlig).ds' -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED
|
# test ./prank.sh predict 'joined(mlig).ds' -c config/test-default -rf_flatten 1 -rf_flatten_target InterleavedBfsForest -out_subdir TEST/PREDICT_FLATTENED
|
||||||
# test ./prank.sh predict 'holo4k(mlig).ds' -c config/test-default -rf_flatten 1 -out_subdir TEST/PREDICT_FLATTENED
|
# test ./prank.sh predict 'holo4k(mlig).ds' -c config/test-default -rf_flatten 1 -rf_flatten_target InterleavedBfsForest -out_subdir TEST/PREDICT_FLATTENED
|
||||||
#
|
#
|
||||||
#}
|
#}
|
||||||
|
|
||||||
|
|
||||||
conservation() {
|
conservation() {
|
||||||
|
|
||||||
title PREDICTIONS USING CONSERVATION
|
title PREDICTIONS USING CONSERVATION
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package cz.siret.prank.prediction.pockets.rescorers;
|
|||||||
import cz.siret.prank.features.FeatureExtractor;
|
import cz.siret.prank.features.FeatureExtractor;
|
||||||
import cz.siret.prank.features.FeatureVector;
|
import cz.siret.prank.features.FeatureVector;
|
||||||
import cz.siret.prank.fforest.FasterForest;
|
import cz.siret.prank.fforest.FasterForest;
|
||||||
|
import cz.siret.prank.fforest.api.BinaryForest;
|
||||||
import cz.siret.prank.fforest.api.FlatBinaryForest;
|
import cz.siret.prank.fforest.api.FlatBinaryForest;
|
||||||
import cz.siret.prank.fforest2.FasterForest2;
|
import cz.siret.prank.fforest2.FasterForest2;
|
||||||
import cz.siret.prank.program.ml.Model;
|
import cz.siret.prank.program.ml.Model;
|
||||||
@@ -60,7 +61,7 @@ public interface InstancePredictor {
|
|||||||
|
|
||||||
|
|
||||||
static InstancePredictor create(Model model, FeatureExtractor<?> proteinExtractor) {
|
static InstancePredictor create(Model model, FeatureExtractor<?> proteinExtractor) {
|
||||||
Classifier classifier = model.getClassifier();
|
Object classifier = model.getClassifier();
|
||||||
|
|
||||||
InstancePredictor res = null;
|
InstancePredictor res = null;
|
||||||
|
|
||||||
@@ -119,9 +120,9 @@ public interface InstancePredictor {
|
|||||||
return ff.distributionForAttributes(vect.getArray(), 2);
|
return ff.distributionForAttributes(vect.getArray(), 2);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} else if (classifier instanceof FlatBinaryForest) {
|
} else if (classifier instanceof BinaryForest) {
|
||||||
res = new InstancePredictor() { // predictor using faster distributionForAttributes()
|
res = new InstancePredictor() { // predictor using faster distributionForAttributes()
|
||||||
final FlatBinaryForest ff = (FlatBinaryForest) classifier;
|
final BinaryForest ff = (BinaryForest) classifier;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double predictPositive(FeatureVector vect) {
|
public double predictPositive(FeatureVector vect) {
|
||||||
@@ -143,7 +144,7 @@ public interface InstancePredictor {
|
|||||||
|
|
||||||
if (res == null) {
|
if (res == null) {
|
||||||
log.info("Creating WekaInstancePredictor");
|
log.info("Creating WekaInstancePredictor");
|
||||||
res = new WekaInstancePredictor(model.getClassifier(), proteinExtractor);
|
res = new WekaInstancePredictor(model.asWekaClassifier(), proteinExtractor);
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package cz.siret.prank.program.ml
|
|||||||
import cz.siret.prank.features.FeatureExtractor
|
import cz.siret.prank.features.FeatureExtractor
|
||||||
import cz.siret.prank.features.PrankFeatureExtractor
|
import cz.siret.prank.features.PrankFeatureExtractor
|
||||||
import cz.siret.prank.fforest.FasterForest
|
import cz.siret.prank.fforest.FasterForest
|
||||||
import cz.siret.prank.fforest.api.FlatBinaryForest
|
import cz.siret.prank.fforest.api.BinaryForest
|
||||||
import cz.siret.prank.fforest2.FasterForest2
|
import cz.siret.prank.fforest2.FasterForest2
|
||||||
import cz.siret.prank.program.params.Params
|
import cz.siret.prank.program.params.Params
|
||||||
import cz.siret.prank.utils.Console
|
import cz.siret.prank.utils.Console
|
||||||
@@ -27,9 +27,9 @@ import javax.annotation.Nullable
|
|||||||
class Model {
|
class Model {
|
||||||
|
|
||||||
String label
|
String label
|
||||||
Classifier classifier
|
Object classifier // Classifier or BinaryForest (flattened random forest)
|
||||||
|
|
||||||
Model(String label, Classifier classifier) {
|
Model(String label, Object classifier) {
|
||||||
this.label = label
|
this.label = label
|
||||||
this.classifier = Objects.requireNonNull(classifier)
|
this.classifier = Objects.requireNonNull(classifier)
|
||||||
}
|
}
|
||||||
@@ -39,9 +39,21 @@ class Model {
|
|||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
boolean isTrainable() {
|
||||||
|
return classifier instanceof Classifier
|
||||||
|
}
|
||||||
|
|
||||||
|
Classifier asWekaClassifier() {
|
||||||
|
if (classifier instanceof Classifier) {
|
||||||
|
return (Classifier) classifier
|
||||||
|
} else {
|
||||||
|
throw new IllegalStateException("Model classifier is not a trainable Classifier: ${classifier.class.name}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
boolean hasFeatureImportances() {
|
boolean hasFeatureImportances() {
|
||||||
// Use Class.isInstance() instead of instanceof to avoid Groovy 5 union type issue (GROOVY-11289)
|
// Use Class.isInstance() instead of instanceof to avoid Groovy 5 union type issue (GROOVY-11289)
|
||||||
Classifier c = classifier
|
Object c = classifier
|
||||||
return FastRandomForest.isInstance(c)
|
return FastRandomForest.isInstance(c)
|
||||||
|| FasterForest.isInstance(c)
|
|| FasterForest.isInstance(c)
|
||||||
|| FasterForest2.isInstance(c)
|
|| FasterForest2.isInstance(c)
|
||||||
@@ -50,7 +62,7 @@ class Model {
|
|||||||
@Nullable
|
@Nullable
|
||||||
List<Double> getFeatureImportances() {
|
List<Double> getFeatureImportances() {
|
||||||
// Use local variable to avoid Groovy 5 field type narrowing with union types
|
// Use local variable to avoid Groovy 5 field type narrowing with union types
|
||||||
Classifier c = classifier
|
Object c = classifier
|
||||||
List<Double> res = null
|
List<Double> res = null
|
||||||
if (c instanceof FastRandomForest) {
|
if (c instanceof FastRandomForest) {
|
||||||
res = (c as FastRandomForest).featureImportances.toList()
|
res = (c as FastRandomForest).featureImportances.toList()
|
||||||
@@ -102,7 +114,7 @@ class Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void saveToFile(String fname) {
|
void saveToFile(String fname) {
|
||||||
WekaUtils.saveClassifier((Classifier)classifier, fname)
|
WekaUtils.saveClassifier(classifier, fname)
|
||||||
Console.write "model saved to file $fname (${Futils.sizeMBFormatted(fname)} MB)"
|
Console.write "model saved to file $fname (${Futils.sizeMBFormatted(fname)} MB)"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,7 +144,7 @@ class Model {
|
|||||||
*/
|
*/
|
||||||
static Model loadFromDirectoryV3(String dir) {
|
static Model loadFromDirectoryV3(String dir) {
|
||||||
log.info "Loading model from directory (v3 format): $dir"
|
log.info "Loading model from directory (v3 format): $dir"
|
||||||
Classifier classifier = WekaUtils.loadClassifier(Futils.inputStream(dir + "/model.zst"))
|
Object classifier = WekaUtils.loadClassifier(Futils.inputStream(dir + "/model.zst"))
|
||||||
return new Model(Futils.shortName(dir), classifier)
|
return new Model(Futils.shortName(dir), classifier)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,7 +160,7 @@ class Model {
|
|||||||
|
|
||||||
private static Model loadFromFileV2(String fname) {
|
private static Model loadFromFileV2(String fname) {
|
||||||
//fname += ".zst"
|
//fname += ".zst"
|
||||||
Classifier classifier = WekaUtils.loadClassifier(Futils.inputStream(fname))
|
Object classifier = WekaUtils.loadClassifier(Futils.inputStream(fname))
|
||||||
return new Model(Futils.shortName(fname), classifier)
|
return new Model(Futils.shortName(fname), classifier)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -191,8 +203,8 @@ class Model {
|
|||||||
info.numTrees = rf.numTrees
|
info.numTrees = rf.numTrees
|
||||||
info.numFeatures = rf.@m_Info?.enumerateAttributes()?.toList()?.size()
|
info.numFeatures = rf.@m_Info?.enumerateAttributes()?.toList()?.size()
|
||||||
info.maxDepth = rf.maxDepth
|
info.maxDepth = rf.maxDepth
|
||||||
} else if (classifier instanceof FlatBinaryForest) {
|
} else if (classifier instanceof BinaryForest) {
|
||||||
FlatBinaryForest rf = (FlatBinaryForest)classifier
|
BinaryForest rf = (BinaryForest)classifier
|
||||||
info.isForest = true
|
info.isForest = true
|
||||||
info.numTrees = rf.numTrees
|
info.numTrees = rf.numTrees
|
||||||
info.numFeatures = rf.numAttributes
|
info.numFeatures = rf.numAttributes
|
||||||
|
|||||||
@@ -2,8 +2,9 @@ package cz.siret.prank.program.ml
|
|||||||
|
|
||||||
import cz.siret.prank.fforest.FasterForest
|
import cz.siret.prank.fforest.FasterForest
|
||||||
import cz.siret.prank.fforest.FasterTree
|
import cz.siret.prank.fforest.FasterTree
|
||||||
import cz.siret.prank.fforest.api.FlatBinaryForest
|
import cz.siret.prank.fforest.api.BinaryForest
|
||||||
import cz.siret.prank.fforest.api.FlatBinaryForestBuilder
|
import cz.siret.prank.fforest.api.FasterForestConverter
|
||||||
|
import cz.siret.prank.fforest.api.TrainableFasterForest
|
||||||
import cz.siret.prank.fforest2.FasterForest2
|
import cz.siret.prank.fforest2.FasterForest2
|
||||||
import cz.siret.prank.program.params.Parametrized
|
import cz.siret.prank.program.params.Parametrized
|
||||||
import cz.siret.prank.utils.ATimer
|
import cz.siret.prank.utils.ATimer
|
||||||
@@ -14,52 +15,68 @@ import groovy.transform.CompileStatic
|
|||||||
import groovy.util.logging.Slf4j
|
import groovy.util.logging.Slf4j
|
||||||
import cz.siret.prank.utils.Parallel
|
import cz.siret.prank.utils.Parallel
|
||||||
import hr.irb.fastRandomForest.FastRandomForest
|
import hr.irb.fastRandomForest.FastRandomForest
|
||||||
|
import org.apache.commons.lang3.StringUtils
|
||||||
import weka.classifiers.Classifier
|
import weka.classifiers.Classifier
|
||||||
|
import weka.core.Instances
|
||||||
|
|
||||||
import javax.annotation.Nullable
|
import javax.annotation.Nullable
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
* Utility class for converting models to different formats, e.g. flattening random forests to a more efficient format for prediction.
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@CompileStatic
|
@CompileStatic
|
||||||
class ModelConverter implements Parametrized, Writable {
|
class ModelConverter implements Parametrized, Writable {
|
||||||
|
|
||||||
|
|
||||||
Model applyConversions(Model model) {
|
Model applyConversions(Model model) {
|
||||||
|
|
||||||
if (params.rf_flatten) {
|
if (params.rf_flatten) {
|
||||||
model = flattenRandomForest(model)
|
if (!StringUtils.isBlank(params.rf_flatten_target)) {
|
||||||
|
model = flattenRandomForest(model, params.rf_flatten_target)
|
||||||
|
} else {
|
||||||
|
// useful as no-op option when running ploop for rf_flatten_target param
|
||||||
|
log.info "'rf_flatten_target' parameter is empty, no flattening is applied."
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
return model
|
return model
|
||||||
}
|
}
|
||||||
|
|
||||||
//===========================================================================================================//
|
//===========================================================================================================//
|
||||||
|
|
||||||
static List<Class> FLATTABLE_CLASSIFIERS = (List) [FastRandomForest, FasterForest, FasterForest2]
|
static List<Class> FLATTABLE_CLASSIFIERS = [FastRandomForest, FasterForest, FasterForest2] as List
|
||||||
static List<String> FLATTABLE_CLASSIFIER_NAMES = FLATTABLE_CLASSIFIERS*.simpleName
|
static List<String> FLATTABLE_CLASSIFIER_NAMES = FLATTABLE_CLASSIFIERS*.simpleName
|
||||||
|
|
||||||
static boolean isFlattableClassifier(Classifier c) {
|
static boolean isFlattableClassifier(Object c) {
|
||||||
return SysUtils.isInstanceOfAny(c, FLATTABLE_CLASSIFIERS)
|
return SysUtils.isInstanceOfAny(c, FLATTABLE_CLASSIFIERS)
|
||||||
}
|
}
|
||||||
|
|
||||||
Model flattenRandomForest(Model model) {
|
Model flattenRandomForest(Model model, String targetType) {
|
||||||
def c = model.classifier
|
def c = model.classifier
|
||||||
if (isFlattableClassifier(c)) {
|
if (isFlattableClassifier(c)) {
|
||||||
ATimer timer = ATimer.startTimer()
|
ATimer timer = ATimer.startTimer()
|
||||||
|
|
||||||
write "Converting ${c.class.simpleName} to FlatBinaryForest"
|
write "Flattening ${c.class.simpleName} to $targetType"
|
||||||
|
|
||||||
FlatBinaryForest fbf
|
FasterForestConverter.ForestType forestType
|
||||||
if (c instanceof FastRandomForest) {
|
try {
|
||||||
fbf = frfToFlatForest((FastRandomForest)c)
|
forestType = FasterForestConverter.ForestType.valueOf(targetType)
|
||||||
} else if (c instanceof FasterForest) {
|
} catch (Exception e) {
|
||||||
fbf = ((FasterForest)c).toFlatBinaryForest(params.rf_flatten_as_legacy)
|
throw new IllegalArgumentException("Unknown target forest type '$targetType'. Supported types: ${FasterForestConverter.ForestType.values()*.name()}.")
|
||||||
} else { // FF2
|
}
|
||||||
fbf = ((FasterForest2)c).toFlatBinaryForest(params.rf_flatten_as_legacy)
|
|
||||||
|
BinaryForest flatForest
|
||||||
|
if (c instanceof TrainableFasterForest) {
|
||||||
|
flatForest = FasterForestConverter.convertFasterForest((TrainableFasterForest) c, forestType)
|
||||||
|
} else if (c instanceof FastRandomForest) {
|
||||||
|
TrainableFasterForest trainableForest = frfToTrainableBinaryForest((FastRandomForest) c)
|
||||||
|
flatForest = FasterForestConverter.convertFasterForest(trainableForest, forestType)
|
||||||
|
} else {
|
||||||
|
throw new IllegalStateException("Unexpected flattable forest type: ${c.class.simpleName}")
|
||||||
}
|
}
|
||||||
write " - flattened in: $timer.formatted"
|
write " - flattened in: $timer.formatted"
|
||||||
|
|
||||||
return new Model("FlatBinaryForest_from_${model.label}", fbf)
|
return new Model("FlatBinaryForest_from_${model.label}", flatForest)
|
||||||
} else {
|
} else {
|
||||||
log.warn "Cannot flatten classifier of type ${c.class.simpleName}. Flattable classifiers: ${FLATTABLE_CLASSIFIER_NAMES}"
|
log.warn "Cannot flatten classifier of type ${c.class.simpleName}. Flattable classifiers: ${FLATTABLE_CLASSIFIER_NAMES}"
|
||||||
return model
|
return model
|
||||||
@@ -68,19 +85,29 @@ class ModelConverter implements Parametrized, Writable {
|
|||||||
|
|
||||||
//===========================================================================================================//
|
//===========================================================================================================//
|
||||||
|
|
||||||
|
|
||||||
@CompileDynamic
|
@CompileDynamic
|
||||||
FlatBinaryForest frfToFlatForest(FastRandomForest forest) {
|
TrainableFasterForest frfToTrainableBinaryForest(FastRandomForest forest) {
|
||||||
ATimer timer = ATimer.startTimer()
|
int numAttributes = forest.@m_Info.numAttributes()
|
||||||
|
|
||||||
int numAttributes = forest.@m_Info.numAttributes();
|
|
||||||
List<Classifier> mTrees = Arrays.asList(forest.@m_bagger.@m_Classifiers)
|
List<Classifier> mTrees = Arrays.asList(forest.@m_bagger.@m_Classifiers)
|
||||||
|
|
||||||
List<FasterTree> trees = Parallel.collectParallel(mTrees, params.threads * 2) { frfTreeToFasterTree(it) }
|
List<FasterTree> trees = Parallel.collectParallel(mTrees, params.threads * 2) { frfTreeToFasterTree(it) }
|
||||||
|
|
||||||
write " - faster trees converted in: $timer.formatted"
|
|
||||||
|
|
||||||
return new FlatBinaryForestBuilder().buildFromFasterTrees(numAttributes, trees, params.rf_flatten_as_legacy)
|
return new TrainableFasterForest() {
|
||||||
|
@Override
|
||||||
|
int getNumAttributes() {
|
||||||
|
return numAttributes
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
List<FasterTree> getTrees() {
|
||||||
|
return trees
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
void buildClassifier(Instances instances) throws Exception {
|
||||||
|
// NO-OP
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -107,5 +134,4 @@ class ModelConverter implements Parametrized, Writable {
|
|||||||
return new FasterTree(childLeft, childRight, attribute, splitPoint, classProbs)
|
return new FasterTree(childLeft, childRight, attribute, splitPoint, classProbs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -377,10 +377,25 @@ class Params {
|
|||||||
@ModelParam // training
|
@ModelParam // training
|
||||||
boolean rf_flatten = false
|
boolean rf_flatten = false
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Flattening target type for random forest. Only relevant if rf_flatten=true.
|
||||||
|
*
|
||||||
|
* Available options:
|
||||||
|
* - LegacyFlatBinaryForest
|
||||||
|
* - FlatBinaryForest
|
||||||
|
* - ShortFlatBinaryForest
|
||||||
|
* - SuperShortLegacyFlatBinaryForest
|
||||||
|
* - InterleavedBfsForest
|
||||||
|
*/
|
||||||
|
@RuntimeParam
|
||||||
|
@ModelParam // training
|
||||||
|
String rf_flatten_target = "LegacyFlatBinaryForest"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Flatten random forest in a way that has exactly the same output
|
* Flatten random forest in a way that has exactly the same output
|
||||||
* by preserving weird way tree results are aggregated in FastRandomForest.
|
* by preserving weird way tree results are aggregated in FastRandomForest.
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
@RuntimeParam
|
@RuntimeParam
|
||||||
@ModelParam // training
|
@ModelParam // training
|
||||||
boolean rf_flatten_as_legacy = true
|
boolean rf_flatten_as_legacy = true
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
package cz.siret.prank.program.routines.traineval
|
package cz.siret.prank.program.routines.traineval
|
||||||
|
|
||||||
import cz.siret.prank.domain.Dataset
|
import cz.siret.prank.domain.Dataset
|
||||||
import cz.siret.prank.fforest.api.FlattableForest
|
|
||||||
import cz.siret.prank.prediction.metrics.ClassifierStats
|
import cz.siret.prank.prediction.metrics.ClassifierStats
|
||||||
import cz.siret.prank.program.ml.FeatureVectors
|
import cz.siret.prank.program.ml.FeatureVectors
|
||||||
import cz.siret.prank.program.ml.Model
|
import cz.siret.prank.program.ml.Model
|
||||||
|
import cz.siret.prank.program.ml.ModelConverter
|
||||||
import cz.siret.prank.program.params.Parametrized
|
import cz.siret.prank.program.params.Parametrized
|
||||||
import cz.siret.prank.program.routines.results.EvalResults
|
import cz.siret.prank.program.routines.results.EvalResults
|
||||||
import cz.siret.prank.program.routines.results.FeatureImportances
|
import cz.siret.prank.program.routines.results.FeatureImportances
|
||||||
@@ -106,17 +106,22 @@ class TrainEvalRoutine extends EvalRoutine implements Parametrized {
|
|||||||
Futils.delete(evalVectorFile)
|
Futils.delete(evalVectorFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
ClassifierStats calculateTrainStats(Classifier classifier, FeatureVectors trainVectors) {
|
ClassifierStats calculateTrainStats(Object classifier, FeatureVectors trainVectors) {
|
||||||
if (params.classifier_train_stats) {
|
if (params.classifier_train_stats) {
|
||||||
ClassifierStats trainStats = new ClassifierStats()
|
ClassifierStats trainStats = new ClassifierStats()
|
||||||
for (Instance inst : trainVectors.instances) {
|
// for (Instance inst : trainVectors.instances) {
|
||||||
double[] hist = classifier.distributionForInstance(inst)
|
// double[] hist = classifier.distributionForInstance(inst)
|
||||||
double score = normalizedScore(hist)
|
// double score = normalizedScore(hist)
|
||||||
boolean predicted = applyPointScoreThreshold(score)
|
// boolean predicted = applyPointScoreThreshold(score)
|
||||||
boolean observed = inst.classValue() > 0
|
// boolean observed = inst.classValue() > 0
|
||||||
|
//
|
||||||
|
// trainStats.addPrediction(observed, predicted, score)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// TODO: implementation needs reconsidering since classifier can be of various types
|
||||||
|
// and not all of them support distributionForInstance() method (e.g. flat BinaryForest)
|
||||||
|
log.warn("Calculating training stats for classifier of type ${classifier.class.simpleName} is not implemented. Returning empty stats.")
|
||||||
|
|
||||||
trainStats.addPrediction(observed, predicted, score)
|
|
||||||
}
|
|
||||||
return trainStats
|
return trainStats
|
||||||
} else {
|
} else {
|
||||||
return null
|
return null
|
||||||
@@ -199,18 +204,15 @@ class TrainEvalRoutine extends EvalRoutine implements Parametrized {
|
|||||||
|
|
||||||
|
|
||||||
void trainModel(Model model, FeatureVectors data) {
|
void trainModel(Model model, FeatureVectors data) {
|
||||||
WekaUtils.trainClassifier(model.classifier, data)
|
WekaUtils.trainClassifier(model.asWekaClassifier(), data)
|
||||||
|
|
||||||
if (params.rf_flatten) {
|
if (params.rf_flatten) {
|
||||||
if (model.classifier instanceof FlattableForest) {
|
if (ModelConverter.isFlattableClassifier(model.classifier)) {
|
||||||
log.info "Flattening random forest"
|
Model flattenedModel = new ModelConverter().applyConversions(model)
|
||||||
def timer = startTimer()
|
model.classifier = flattenedModel.classifier
|
||||||
model.classifier = ((FlattableForest)model.classifier).toFlatBinaryForest()
|
model.label = flattenedModel.label
|
||||||
logTime "model flattened in " + timer.formatted
|
|
||||||
|
|
||||||
model.label = model.label + "_flat"
|
|
||||||
} else {
|
} else {
|
||||||
log.warn("Trying to flatten classifier that does not support it: " + model.classifier.class.simpleName)
|
throw new IllegalStateException("Trying to flatten classifier that does not support it: " + model.classifier.class.simpleName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ class WekaUtils {
|
|||||||
|
|
||||||
// == classifiers ===
|
// == classifiers ===
|
||||||
|
|
||||||
static void saveClassifier(Classifier classifier, String fileName) {
|
static void saveClassifier(Object classifier, String fileName) {
|
||||||
|
|
||||||
ZipOutputStream zos = new ZipOutputStream(new BufferedOutputStream(new FileOutputStream(fileName), BUFFER_SIZE))
|
ZipOutputStream zos = new ZipOutputStream(new BufferedOutputStream(new FileOutputStream(fileName), BUFFER_SIZE))
|
||||||
//zos.setLevel(9)
|
//zos.setLevel(9)
|
||||||
@@ -67,7 +67,7 @@ class WekaUtils {
|
|||||||
oos.close()
|
oos.close()
|
||||||
}
|
}
|
||||||
|
|
||||||
static Classifier loadClassifier(String fileName) {
|
static Object loadClassifier(String fileName) {
|
||||||
InputStream zis = null
|
InputStream zis = null
|
||||||
try {
|
try {
|
||||||
zis = new ZipInputStream(new BufferedInputStream(new FileInputStream(fileName), BUFFER_SIZE))
|
zis = new ZipInputStream(new BufferedInputStream(new FileInputStream(fileName), BUFFER_SIZE))
|
||||||
@@ -81,9 +81,9 @@ class WekaUtils {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static Classifier loadClassifier(InputStream ins) {
|
static Object loadClassifier(InputStream ins) {
|
||||||
try {
|
try {
|
||||||
return (Classifier) SerializationHelper.read(ins)
|
return SerializationHelper.read(ins)
|
||||||
} finally {
|
} finally {
|
||||||
ins.close()
|
ins.close()
|
||||||
}
|
}
|
||||||
@@ -94,7 +94,7 @@ class WekaUtils {
|
|||||||
* @param classifier
|
* @param classifier
|
||||||
*/
|
*/
|
||||||
@CompileDynamic
|
@CompileDynamic
|
||||||
static void disableParallelism(Classifier classifier) {
|
static void disableParallelism(Object classifier) {
|
||||||
String[] threadPropNames = ["numThreads","numExecutionSlots"] // names used for num.threads property by different classifiers
|
String[] threadPropNames = ["numThreads","numExecutionSlots"] // names used for num.threads property by different classifiers
|
||||||
threadPropNames.each { String name ->
|
threadPropNames.each { String name ->
|
||||||
if (classifier.hasProperty(name))
|
if (classifier.hasProperty(name))
|
||||||
@@ -102,13 +102,6 @@ class WekaUtils {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* load from jar
|
|
||||||
*/
|
|
||||||
static Classifier loadClassifierFromPath(String path) {
|
|
||||||
return (Classifier) SerializationHelper.read(path.class.getResourceAsStream(path));
|
|
||||||
}
|
|
||||||
|
|
||||||
static void trainClassifier(Classifier classifier, FeatureVectors data) {
|
static void trainClassifier(Classifier classifier, FeatureVectors data) {
|
||||||
validateDataset(data.instances)
|
validateDataset(data.instances)
|
||||||
classifier.buildClassifier(data.instances)
|
classifier.buildClassifier(data.instances)
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package cz.siret.prank.program.ml
|
package cz.siret.prank.program.ml
|
||||||
|
|
||||||
|
import cz.siret.prank.fforest.api.FasterForestConverter
|
||||||
|
import cz.siret.prank.fforest.api.TrainableFasterForest
|
||||||
import groovy.transform.CompileStatic
|
import groovy.transform.CompileStatic
|
||||||
import hr.irb.fastRandomForest.FastRandomForest
|
import hr.irb.fastRandomForest.FastRandomForest
|
||||||
import org.junit.jupiter.api.Disabled
|
import org.junit.jupiter.api.Disabled
|
||||||
@@ -20,7 +22,11 @@ class ModelConverterTest {
|
|||||||
Model model = Model.load("distro/models/default.model")
|
Model model = Model.load("distro/models/default.model")
|
||||||
assert model.classifier instanceof FastRandomForest
|
assert model.classifier instanceof FastRandomForest
|
||||||
|
|
||||||
new ModelConverter().frfToFlatForest((FastRandomForest)model.classifier)
|
TrainableFasterForest trainableForest = new ModelConverter().frfToTrainableBinaryForest((FastRandomForest)model.classifier)
|
||||||
|
|
||||||
|
FasterForestConverter.convertFasterForest(trainableForest, FasterForestConverter.ForestType.FlatBinaryForest)
|
||||||
|
FasterForestConverter.convertFasterForest(trainableForest, FasterForestConverter.ForestType.LegacyFlatBinaryForest)
|
||||||
|
FasterForestConverter.convertFasterForest(trainableForest, FasterForestConverter.ForestType.InterleavedBfsForest)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user