Merge changes from upstream

This commit is contained in:
Lukas Jendele
2017-04-28 14:12:28 +02:00
32 changed files with 648 additions and 118 deletions

View File

@@ -9,7 +9,7 @@ apply plugin: 'idea'
group = 'cz.siret'
version = '2.0-dev.7'
version = '2.0-dev.8'
description = 'Ligand binding site prediction based on machine learning.'

View File

@@ -46,11 +46,6 @@ import cz.siret.prank.program.params.Params
//== FAETURES
/**
* include volsite pharmacophore properties
*/
use_volsite_features = true
extra_features = ["chem","volsite","protrusion","bfactor"]
atom_table_features = ["ap5sasaValids","ap5sasaInvalids"] // "apRawValids","apRawInvalids","atomicHydrophobicity"
@@ -264,11 +259,6 @@ import cz.siret.prank.program.params.Params
*/
pred_clustering_dist = 5
/**
* distance to extend clusters around hotspots
*/
pred_surrounding = 3.5
/**
* cuttoff distance of protein surface atoms considered as part of the pocket
*/

View File

@@ -46,11 +46,6 @@ import cz.siret.prank.program.params.Params
//== FAETURES
/**
* include volsite pharmacophore properties
*/
use_volsite_features = true
atom_table_features = ["apRawValids","apRawInvalids","atomicHydrophobicity"]
extra_features = ["chem","volsite","protrusion","bfactor"]
@@ -265,9 +260,9 @@ import cz.siret.prank.program.params.Params
pred_clustering_dist = 3
/**
* distance to extend clusters around hotspots
* SAS points around ligandable points (an their score) will be included in the pocket
*/
pred_surrounding = 3.5
extended_pocket_cutoff = 3.5
/**
* cuttoff distance of protein surface atoms considered as part of the pocket

View File

@@ -35,18 +35,11 @@ import cz.siret.prank.program.params.Params
predictions = true
out_prefix_date = true
crossval_threads = 5
cache_datasets = true
clear_sec_caches = false
clear_prim_caches = false
log_cases = true
/**
* calculate feature importance
@@ -54,7 +47,8 @@ import cz.siret.prank.program.params.Params
*/
feature_importances = false
output_only_stats = true
stats_collect_predictions = true
/**
* collect negatives just from decoy pockets found by other method
@@ -66,16 +60,39 @@ import cz.siret.prank.program.params.Params
atom_table_features = ["ap5sasaValids","ap5sasaInvalids","apRawValids","apRawInvalids","atomicHydrophobicity"]
extra_features = ["protrusion","bfactor"]
extra_features = ["chem","volsite","protrusion","bfactor"]
residue_table_features = ["RAx"]
average_feat_vectors = true
balance_class_weights = true
target_class_weight_ratio = 0.055
// technical
cache_datasets = true
clear_sec_caches = false
clear_prim_caches = false
log_cases = true
output_only_stats = true
log_to_console = false
log_level = "WARN"
log_to_file = true
ploop_delete_runs = true
zip_log_file = true
out_prefix_date = true
}

View File

@@ -42,6 +42,6 @@ PARAMS="$PRANK_LOCALENV_PARAMS"
CMD="$JAVACMD $JAVA_OPTS -cp ${CLASSPATH} cz.siret.prank.program.Main ${PARAMS} $@"
echo "+" $CMD
"$JAVACMD" $JAVA_OPTS -cp "${CLASSPATH}" cz.siret.prank.program.Main ${PARAMS} "$@" 2>>local-debug.log
"$JAVACMD" $JAVA_OPTS -cp "${CLASSPATH}" cz.siret.prank.program.Main ${PARAMS} "$@"

View File

@@ -1,6 +1,11 @@
package cz.siret.prank.domain
import groovy.transform.CompileStatic
/**
* 20 main amino acid codes
*/
@CompileStatic
enum AA {
ALA,
@@ -24,4 +29,17 @@ enum AA {
TRP,
TYR;
private static final Map<String, AA> index = new HashMap<String, AA>()
static {
for (AA value : EnumSet.allOf(AA.class)) {
index.put(value.name(), value)
}
}
public static AA forName(String name) {
return index.get(name)
}
}

View File

@@ -157,8 +157,8 @@ class Dataset implements Parametrized {
void clearSecondaryCaches() {
items.each {
if (it.cachedPair!=null) {
it.cachedPair.prediction.protein.clearCachedSurfaces()
it.cachedPair.liganatedProtein.clearCachedSurfaces()
it.cachedPair.prediction.protein.clearSecondaryData()
it.cachedPair.liganatedProtein.clearSecondaryData()
}
}
}

View File

@@ -121,10 +121,14 @@ class Protein implements Parametrized {
return trainSurface
}
void clearCachedSurfaces() {
/**
* clears generated surfaces and secondary data
*/
void clearSecondaryData() {
connollySurface = null
trainSurface = null
exposedAtoms = null
secondaryData.clear()
}
Atoms getAllLigandAtoms() {

View File

@@ -113,7 +113,7 @@ class PrankFeatureExtractor extends FeatureExtractor<PrankFeatureVector> impleme
surfaceLayerAtoms = protein.exposedAtoms
}
log.debug "surfaceLayerAtoms:$surfaceLayerAtoms.count (surfaceAtoms: $pocket.surfaceAtoms.count) "
log.debug "surfaceLayerAtoms:$surfaceLayerAtoms.count (surfaceAtoms: ${pocket?.surfaceAtoms?.count}) "
preEvaluateProperties(surfaceLayerAtoms)
if (DO_SMOOTH_REPRESENTATION) {
@@ -195,7 +195,7 @@ class PrankFeatureExtractor extends FeatureExtractor<PrankFeatureVector> impleme
log.info "P2R protein:$protein.proteinAtoms.count exposedAtoms:$res.surfaceLayerAtoms.count deepSurrounding:$res.deepSurrounding.count connollyPoints:$res.sampledPoints.count"
log.info "P2R protein:$protein.proteinAtoms.count exposedAtoms:$res.surfaceLayerAtoms.count deepSurrounding:$res.deepSurrounding.count sasPoints:$res.sampledPoints.count"
return res
}

View File

@@ -1,6 +1,10 @@
package cz.siret.prank.features.api
import cz.siret.prank.features.implementation.Asa2Feature
import cz.siret.prank.features.implementation.AsaFeature
import cz.siret.prank.features.implementation.AsaResiduesFeature
import cz.siret.prank.features.implementation.BfactorFeature
import cz.siret.prank.features.implementation.ContactResiduesPositionFeature
import cz.siret.prank.features.implementation.ProtrusionFeature
import cz.siret.prank.features.implementation.ProtrusionHistogramFeature
import cz.siret.prank.features.implementation.SurfaceProtrusionFeature
@@ -48,6 +52,10 @@ class FeatureRegistry {
registerFeature(new ConservationFeature())
registerFeature(new ConservationCloudFeature())
registerFeature(new ConservationCloudScaledFeature())
registerFeature(new ContactResiduesPositionFeature())
registerFeature(new AsaFeature())
registerFeature(new Asa2Feature())
registerFeature(new AsaResiduesFeature())
// Register new feature implementations here

View File

@@ -0,0 +1,84 @@
package cz.siret.prank.features.implementation
import cz.siret.prank.domain.Protein
import cz.siret.prank.features.api.SasFeatureCalculationContext
import cz.siret.prank.features.api.SasFeatureCalculator
import cz.siret.prank.geom.Atoms
import cz.siret.prank.program.params.Parametrized
import cz.siret.prank.utils.Writable
import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j
import org.biojava.nbio.structure.Atom
import org.biojava.nbio.structure.StructureTools
import org.biojava.nbio.structure.asa.AsaCalculator
/**
* Local protein solvent accessible surface area feature
*/
@Slf4j
@CompileStatic
class Asa2Feature extends SasFeatureCalculator implements Parametrized, Writable {
static final String NAME = "asa2"
@Override
String getName() { NAME }
@Override
List<String> getHeader() {
return ["asa2.1", "asa2.2"]
}
ProtAsa calcProtAsa(Protein protein, double probeRadius) {
int nSpherePoints = AsaCalculator.DEFAULT_N_SPHERE_POINTS
int threads = 1
boolean hetAtoms = false
Atom[] protAtoms = StructureTools.getAllNonHAtomArray(protein.structure, hetAtoms)
AsaCalculator asaCalculator = new AsaCalculator(protein.structure, probeRadius, nSpherePoints, threads, hetAtoms)
double[] atomAsas = asaCalculator.calculateAsas()
protAtoms[0].getPDBserial()
Map<Integer, Double> asaByAtom = new HashMap<>()
for (int i=0; i!= protAtoms.length; ++i) {
asaByAtom.put protAtoms[i].PDBserial, atomAsas[i]
}
return new ProtAsa(protein, asaByAtom)
}
@Override
void preProcessProtein(Protein protein) {
if (!protein.secondaryData.containsKey("prot_atom_asa")) {
protein.secondaryData.put "prot_atom_asa", calcProtAsa(protein, params.feat_asa_probe_radius)
}
if (!protein.secondaryData.containsKey("prot_atom_asa2")) {
protein.secondaryData.put "prot_atom_asa2", calcProtAsa(protein, params.feat_asa_probe_radius2)
}
}
@Override
double[] calculateForSasPoint(Atom sasPoint, SasFeatureCalculationContext context) {
Atoms localAtoms = context.protein.exposedAtoms.cutoffAroundAtom(sasPoint, params.feat_asa_neigh_radius)
ProtAsa protAsa = (ProtAsa) context.protein.secondaryData.get("prot_atom_asa")
double localAsa = (double) localAtoms.collect { Atom a -> protAsa.asaByAtom.get(a.PDBserial) ?: 0 }.sum(0)
ProtAsa protAsa2 = (ProtAsa) context.protein.secondaryData.get("prot_atom_asa2")
double localAsa2 = (double) localAtoms.collect { Atom a -> protAsa2.asaByAtom.get(a.PDBserial) ?: 0 }.sum(0)
return [localAsa, localAsa2] as double[]
}
static class ProtAsa {
Protein protein
Map<Integer, Double> asaByAtom
ProtAsa(Protein protein, Map<Integer, Double> asaByAtom) {
this.protein = protein
this.asaByAtom = asaByAtom
}
}
}

View File

@@ -0,0 +1,71 @@
package cz.siret.prank.features.implementation
import cz.siret.prank.domain.Protein
import cz.siret.prank.features.api.SasFeatureCalculationContext
import cz.siret.prank.features.api.SasFeatureCalculator
import cz.siret.prank.geom.Atoms
import cz.siret.prank.program.params.Parametrized
import cz.siret.prank.utils.Writable
import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j
import org.biojava.nbio.structure.Atom
import org.biojava.nbio.structure.StructureTools
import org.biojava.nbio.structure.asa.AsaCalculator
/**
* Local protein solvent accessible surface area feature
*/
@Slf4j
@CompileStatic
class AsaFeature extends SasFeatureCalculator implements Parametrized, Writable {
static final String NAME = "asa"
@Override
String getName() { NAME }
@Override
void preProcessProtein(Protein protein) {
if (protein.secondaryData.containsKey("prot_atom_asa")) {
return
}
double probeRadius = params.feat_asa_probe_radius
int nSpherePoints = AsaCalculator.DEFAULT_N_SPHERE_POINTS
int threads = 1
boolean hetAtoms = false
Atom[] protAtoms = StructureTools.getAllNonHAtomArray(protein.structure, hetAtoms)
AsaCalculator asaCalculator = new AsaCalculator(protein.structure, probeRadius, nSpherePoints, threads, hetAtoms)
double[] atomAsas = asaCalculator.calculateAsas()
protAtoms[0].getPDBserial()
Map<Integer, Double> asaByAtom = new HashMap<>()
for (int i=0; i!= protAtoms.length; ++i) {
asaByAtom.put protAtoms[i].PDBserial, atomAsas[i]
}
protein.secondaryData.put "prot_atom_asa", new ProtAsa(protein, asaByAtom)
}
@Override
double[] calculateForSasPoint(Atom sasPoint, SasFeatureCalculationContext context) {
Atoms localAtoms = context.protein.exposedAtoms.cutoffAroundAtom(sasPoint, params.feat_asa_neigh_radius)
ProtAsa protAsa = (ProtAsa) context.protein.secondaryData.get("prot_atom_asa")
double localAsa = (double) localAtoms.collect { Atom a -> protAsa.asaByAtom.get(a.PDBserial) ?: 0 }.sum(0)
return [localAsa] as double[]
}
static class ProtAsa {
Protein protein
Map<Integer, Double> asaByAtom
ProtAsa(Protein protein, Map<Integer, Double> asaByAtom) {
this.protein = protein
this.asaByAtom = asaByAtom
}
}
}

View File

@@ -0,0 +1,75 @@
package cz.siret.prank.features.implementation
import cz.siret.prank.domain.Protein
import cz.siret.prank.features.api.SasFeatureCalculationContext
import cz.siret.prank.features.api.SasFeatureCalculator
import cz.siret.prank.geom.Atoms
import cz.siret.prank.program.params.Parametrized
import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j
import org.biojava.nbio.structure.Atom
import org.biojava.nbio.structure.Group
import org.biojava.nbio.structure.ResidueNumber
import org.biojava.nbio.structure.asa.AsaCalculator
import org.biojava.nbio.structure.asa.GroupAsa
/**
* Local protein solvent accessible surface area feature
*/
@Slf4j
@CompileStatic
class AsaResiduesFeature extends SasFeatureCalculator implements Parametrized {
static final String NAME = "asares"
@Override
String getName() { NAME }
@Override
void preProcessProtein(Protein protein) {
double probeRadius = params.feat_asa_probe_radius
int nSpherePoints = AsaCalculator.DEFAULT_N_SPHERE_POINTS
int threads = 1
boolean hetAtoms = false
AsaCalculator asaCalculator = new AsaCalculator(protein.structure, probeRadius, nSpherePoints, threads, hetAtoms)
List<GroupAsa> asas = asaCalculator.groupAsas.toList()
protein.secondaryData.put "prot_asa", new ProtAsa(protein, asas)
}
@Override
double[] calculateForSasPoint(Atom sasPoint, SasFeatureCalculationContext context) {
List<Group> groups = context.protein.exposedAtoms.cutoffAroundAtom(sasPoint, params.feat_asa_neigh_radius).distinctGroups
ProtAsa protAsa = (ProtAsa) context.protein.secondaryData.get("prot_asa")
double localAsa = (double) groups.collect { Group g -> protAsa.groupAsaMap.get(g.residueNumber) ?: 0 }.sum(0)
return [localAsa] as double[]
}
static class ProtAsa {
Protein protein
List<GroupAsa> groupAsas
Map<ResidueNumber, Double> groupAsaMap = new HashMap<>()
ProtAsa(Protein protein, List<GroupAsa> groupAsas) {
this.protein = protein
this.groupAsas = groupAsas
for (GroupAsa gasa : groupAsas) {
double asa = gasa.asaC
ResidueNumber resNum = gasa?.group?.residueNumber
if (resNum!=null) {
groupAsaMap.put resNum, asa
}
}
}
}
}

View File

@@ -0,0 +1,109 @@
package cz.siret.prank.features.implementation
import com.google.common.collect.ArrayListMultimap
import com.google.common.collect.Multimap
import cz.siret.prank.domain.AA
import cz.siret.prank.features.api.SasFeatureCalculationContext
import cz.siret.prank.features.api.SasFeatureCalculator
import cz.siret.prank.geom.Atoms
import cz.siret.prank.geom.Struct
import cz.siret.prank.program.params.Parametrized
import cz.siret.prank.utils.PDBUtils
import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j
import org.biojava.nbio.structure.AminoAcid
import org.biojava.nbio.structure.AminoAcidImpl
import org.biojava.nbio.structure.Atom
/**
*
*/
@Slf4j
@CompileStatic
class ContactResiduesPositionFeature extends SasFeatureCalculator implements Parametrized {
static String NAME = 'crpos'
static List<AA> AATYPES = AA.values().sort { it.name() }.toList()
final List<String> HEADER = new ArrayList<>()
double MAX_DIST = 20;
//===========================================================================================================//
double contactDist
ContactResiduesPositionFeature() {
contactDist = params.feat_crang_contact_dist
for (AA aa : AATYPES) {
String prefix = NAME + '.' + aa.name().toLowerCase() + '.'
HEADER.add prefix + 'count'
HEADER.add prefix + 'distca'
HEADER.add prefix + 'distclosest'
HEADER.add prefix + 'distcenter'
}
}
@Override
String getName() {
NAME
}
@Override
List<String> getHeader() {
HEADER
}
@Override
double[] calculateForSasPoint(Atom sasPoint, SasFeatureCalculationContext context) {
Atoms contactAtoms = context.neighbourhoodAtoms.cutoffAroundAtom(sasPoint, contactDist)
List<AminoAcid> contactResidues = (List<AminoAcid>)(List)contactAtoms.getDistinctGroups().findAll{ it instanceof AminoAcid }.toList()
log.debug 'contact residues: ' + contactResidues.size()
// TODO: this can be optmized
Multimap<AA, AminoAcid> contactResIndex = ArrayListMultimap.create(20, 3);
for (AminoAcid res : contactResidues) {
AA aa = AA.forName(PDBUtils.getResidueCode(res))
if (aa!=null) {
contactResIndex.put(aa, res)
}
}
Map<AA, Collection<AminoAcid>> cresmap = contactResIndex.asMap()
double[] vect = new double[HEADER.size()]
int i = 0
for (AA aa : AATYPES) {
double count = 0
double distclosest = MAX_DIST
double distca = MAX_DIST
double distcenter = MAX_DIST
Collection<AminoAcid> residues = cresmap.get(aa)
if (residues!=null && !residues.empty) {
AminoAcid closestResOfType = residues.min { Atoms.allFromGroup(it).dist(sasPoint) }
Atoms ratoms = Atoms.allFromGroup(closestResOfType)
count = residues.size()
distclosest = ratoms.dist(sasPoint)
distcenter = Struct.dist ratoms.centerOfMass, sasPoint
distca = (closestResOfType.CA==null) ? distcenter : Struct.dist(closestResOfType.CA, sasPoint)
}
vect[i] = count
vect[i+1] = distca
vect[i+2] = distclosest
vect[i+3] = distcenter
i += 4
}
return vect
}
}

View File

@@ -13,7 +13,7 @@ import org.biojava.nbio.structure.Atom
@CompileStatic
class ProtrusionHistogramFeature extends SasFeatureCalculator implements Parametrized {
static final double MIN_DIST = 2d
static final double MIN_DIST = 4
@Override
String getName() {

View File

@@ -187,7 +187,9 @@ public final class Atoms implements Iterable<Atom> {
public List<Group> getDistinctGroups() {
Set<Group> res = new HashSet<>();
for (Atom a : list) {
res.add(a.getGroup());
if (a.getGroup()!=null) {
res.add(a.getGroup());
}
}
List<Group> sres = Struct.sortGroups(res);

View File

@@ -22,6 +22,7 @@ class LogManager implements Writable {
static final String LOGGER_NAME = "cz.siret.prank"
static final String CONSOLE_APPENDER_NAME = "Console"
static final String FILE_APPENDER_NAME = "File"
static final String PATTERN = "[%level] %logger{0} - %msg%n"
boolean loggingToFile = false
String logFile
@@ -40,9 +41,10 @@ class LogManager implements Writable {
config = ctx.getConfiguration();
loggerConfig = config.getLoggerConfig(loggerName)
loggerConfig.getAppenders().each { System.out.println "APPENDER: " + it.value.name }
// loggerConfig.getAppenders().each { System.out.println "APPENDER: " + it.value.name }
write "logToConsole: $logToConsole"
write "logToFile: $logToFile"
write "logLevel: ${level.name()}"
loggerConfig.setLevel(level)
@@ -63,12 +65,12 @@ class LogManager implements Writable {
ctx.updateLoggers();
loggerConfig.getAppenders().each { System.out.println "APPENDER: " + it.value.name }
//loggerConfig.getAppenders().each { System.out.println "APPENDER: " + it.value.name }
}
private static Appender addFileAppender(Configuration config, String loggerName, String logFile, Level level) {
String pattern = "[%level] %logger{0} - %msg%n"
String pattern = PATTERN
int bufferSize = 5000
Futils.delete(logFile)

View File

@@ -10,6 +10,7 @@ import groovy.util.logging.Slf4j
import org.apache.commons.lang3.StringUtils
import static cz.siret.prank.utils.ATimer.startTimer
import static cz.siret.prank.utils.Futils.writeFile
@Slf4j
class Main implements Parametrized, Writable {
@@ -174,6 +175,10 @@ class Main implements Parametrized, Writable {
}
void writeCmdLineArgs(String outdir) {
writeFile("$outdir/cmdline_args.txt", args)
}
//===========================================================================================================//
void doRunPredict(String label, boolean evalPredict) {
@@ -366,15 +371,20 @@ class Main implements Parametrized, Writable {
boolean error = false
Main main
try {
error = new Main(parsedArgs).run()
main = new Main(parsedArgs)
error = main.run()
} catch (PrankException e) {
error = true
writeError e.message
log.error(e.message, e)
if (main.logManager.loggingToFile) {
write "For details see log file: '$main.logManager.logFile'"
}
} catch (Exception e) {

View File

@@ -68,11 +68,6 @@ class Params {
//== FAETURES
/**
* include volsite pharmacophore properties
*/
boolean use_volsite_features = true
List<String> extra_features = ["protrusion","bfactor"]
List<String> atom_table_features = ["apRawValids","apRawInvalids","atomicHydrophobicity"] // "ap5sasaValids","ap5sasaInvalids"
@@ -314,6 +309,8 @@ class Params {
*/
double pred_point_threshold = 0.4
boolean include_surrounding_score = false
/**
* minimum cluster size (of ligandable points) for initial clustering
*/
@@ -325,9 +322,9 @@ class Params {
double pred_clustering_dist = 5
/**
* distance to extend clusters around hotspots
* SAS points around ligandable points (an their score) will be included in the pocket
*/
double pred_surrounding = 3.5
double extended_pocket_cutoff = 3.5
/**
* cuttoff distance of protein surface atoms considered as part of the pocket
@@ -491,6 +488,26 @@ class Params {
/** produce ROC and PR curve graphs (not fully implemented yet) */
boolean stats_curves = false
/**
* Contact residues distance cutoff
*/
double feat_crang_contact_dist = 3
/**
* probe radius for calculating accessible surface area for asa feature
*/
double feat_asa_probe_radius = 1.4
/**
* probe radius for calculating accessible surface area for asa feature
*/
double feat_asa_probe_radius2 = 1.4
/**
* radius of the neighbourhood considered in asa feature
*/
double feat_asa_neigh_radius = 6
//===========================================================================================================//
String getVersion() {

View File

@@ -53,6 +53,7 @@ class Experiments extends Routine {
datadirRoot = params.dataset_base_dir
label = "run_" + trainDataSet.label + "_" + (doCrossValidation ? "crossval" : evalDataSet.label)
outdir = main.findOutdir(label)
main.writeCmdLineArgs(outdir)
main.configureLoggers(outdir)
}

View File

@@ -13,6 +13,7 @@ import groovyx.gpars.GParsPool
import static cz.siret.prank.utils.ATimer.startTimer
import static cz.siret.prank.utils.Futils.mkdirs
import static cz.siret.prank.utils.Futils.writeFile
/**
* Routine for grid optimization. Loops through values of one or more RangeParam and produces resulting statistics and plots.

View File

@@ -3,6 +3,7 @@ package cz.siret.prank.program.routines
import cz.siret.prank.program.Main
import cz.siret.prank.program.PrankException
import cz.siret.prank.program.params.Parametrized
import cz.siret.prank.program.params.Params
import cz.siret.prank.utils.Futils
import cz.siret.prank.utils.Writable
import groovy.transform.CompileStatic
@@ -37,5 +38,5 @@ class Routine implements Parametrized, Writable {
String v = "version: " + Main.version + "\n"
writeFile("$outdir/params.txt", v + params.toString())
}
}

View File

@@ -18,6 +18,8 @@ import weka.classifiers.Classifier
import weka.core.Instance
import weka.core.Instances
import static cz.siret.prank.score.prediction.PointScoreCalculator.predictedPositive
import static cz.siret.prank.score.prediction.PointScoreCalculator.predictedScore
import static cz.siret.prank.utils.ATimer.startTimer
@Slf4j
@@ -99,8 +101,8 @@ class TrainEvalRoutine extends EvalRoutine implements Parametrized {
ClassifierStats trainStats = new ClassifierStats()
for (Instance inst : trainVectors) {
double[] hist = classifier.distributionForInstance(inst)
double score = PointScoreCalculator.predictedScore(hist)
boolean predicted = hist[1] > hist[0]
double score = predictedScore(hist)
boolean predicted = predictedPositive(score)
boolean observed = inst.classValue() > 0
trainStats.addPrediction(observed, predicted, score, hist)

View File

@@ -133,8 +133,8 @@ class EvalResults implements Parametrized, Writable {
//===========================================================================================================//
m.TIME_TRAIN = trainTime
m.TIME_EVAL = evalTime
m.TIME_TRAIN_M = (double)(trainTime ?: 0) / 60000
m.TIME_EVAL_M = (double)(evalTime ?: 0) / 60000
m.TRAIN_VECTORS = avgTrainVectors
m.TRAIN_POSITIVES = avgTrainPositives

View File

@@ -6,11 +6,13 @@ import cz.siret.prank.domain.PredictionPair
import cz.siret.prank.domain.Protein
import cz.siret.prank.features.implementation.conservation.ConservationScore
import cz.siret.prank.geom.Atoms
import cz.siret.prank.program.rendering.LabeledPoint
import cz.siret.prank.score.criteria.*
import groovy.util.logging.Slf4j
import org.apache.commons.lang3.StringUtils
import static cz.siret.prank.utils.Formatter.*
import static java.util.Collections.emptyList
/**
* Represents evaluation of pocket prediction on a dataset of proteins
@@ -22,6 +24,9 @@ import static cz.siret.prank.utils.Formatter.*
@Slf4j
class Evaluation {
/** cutoff distance in A around ligand atoms that determins which SAS points cover the ligand */
static final double LIG_SAS_CUTOFF = 2
IdentificationCriterium standardCriterium = new DCA(4.0)
List<IdentificationCriterium> criteria
List<ProteinRow> proteinRows = Collections.synchronizedList(new ArrayList<>())
@@ -36,6 +41,9 @@ class Evaluation {
int smallLigandCount
int distantLigandCount
int ligSASPointsCount
int ligSASPointsCoveredCount
Evaluation(List<IdentificationCriterium> criteria) {
this.criteria = criteria
}
@@ -92,6 +100,8 @@ class Evaluation {
List<PocketRow> tmpPockets = new ArrayList<>()
Protein lp = pair.liganatedProtein
Atoms sasPoints = pair.prediction.protein.connollySurface.points
Atoms labeledPoints = new Atoms(pair.prediction.labeledPoints ?: emptyList())
ProteinRow protRow = new ProteinRow()
protRow.name = pair.name
@@ -109,7 +119,13 @@ class Evaluation {
protRow.smallLigNames = lp.smallLigands.collect { "$it.name($it.size)" }.join(" ")
protRow.distantLigands = lp.distantLigands.size()
protRow.distantLigNames = lp.distantLigands.collect { "$it.name($it.size|${format(it.contactDistance,1)}|${format(it.centerToProteinDist,1)})" }.join(" ")
protRow.connollyPoints = pair.prediction.protein.connollySurface.points.count
protRow.sasPoints = sasPoints.count
// ligand coverage
Atoms ligSasPoints = labeledPoints.cutoffAtoms(lp.allLigandAtoms, LIG_SAS_CUTOFF)
int n_ligSasPoints = ligSasPoints.count
int n_ligSasPointsCovered = ligSasPoints.toList().findAll { ((LabeledPoint)it).predicted }.toList().size()
log.debug "XXXX n_ligSasPoints: {} covered: {}", n_ligSasPoints, n_ligSasPointsCovered
// Conservation stats
ConservationScore score = lp.secondaryData.get(ConservationScore.conservationScoreKey)
@@ -187,6 +203,8 @@ class Evaluation {
proteinRows.add(protRow)
ligandRows.addAll(tmpLigRows)
pocketRows.addAll(tmpPockets)
ligSASPointsCount += n_ligSasPoints
ligSASPointsCoveredCount += n_ligSasPointsCovered
}
}
@@ -200,6 +218,8 @@ class Evaluation {
ignoredLigandCount += eval.ignoredLigandCount
smallLigandCount += eval.smallLigandCount
distantLigandCount += eval.distantLigandCount
ligSASPointsCount += eval.ligSASPointsCount
ligSASPointsCoveredCount += eval.ligSASPointsCoveredCount
}
double calcSuccRate(int assesorNum, int tolerance) {
@@ -266,26 +286,38 @@ class Evaluation {
return a
}
//===========================================================================================================//
public <T> double avg(List<T> list, Closure<T> closure) {
if (list.size()==0) return Double.NaN
list.collect { closure(it) }.findAll { it!=Double.NaN }.sum(0) / list.size()
}
double div(double a, double b) {
if (b==0d)
return Double.NaN
return a / b
}
//===========================================================================================================//
double getAvgPockets() {
pocketCount / proteinCount
div pocketCount, proteinCount
}
double getAvgLigandAtoms() {
ligandRows.collect {it.atoms}.sum(0) / ligandCount
div ligandRows.collect {it.atoms}.sum(0), ligandCount
}
double getAvgPocketVolume() {
pocketRows.collect { it.pocketVolume }.sum(0) / pocketCount
div pocketRows.collect { it.pocketVolume }.sum(0), pocketCount
}
double getAvgPocketVolumeTruePockets() {
avg pocketRows.findAll { it.truePocket }, {PocketRow it -> it.pocketVolume }
}
double getAvgPocketSurfAtoms() {
pocketRows.collect { it.surfaceAtomCount }.sum(0) / pocketCount
div pocketRows.collect { it.surfaceAtomCount }.sum(0), pocketCount
}
double getAvgPocketSurfAtomsTruePockets() {
@@ -293,32 +325,30 @@ class Evaluation {
}
double getAvgPocketInnerPoints() {
pocketRows.collect { it.auxInfo.samplePoints }.sum(0) / pocketCount
div pocketRows.collect { it.auxInfo.samplePoints }.sum(0), pocketCount
}
double getAvgPocketInnerPointsTruePockets() {
avg pocketRows.findAll { it.truePocket }, {PocketRow it -> it.auxInfo.samplePoints }
}
double getAvgProteinAtoms() {
proteinRows.collect { it.protAtoms }.sum(0) / proteinCount
div proteinRows.collect { it.protAtoms }.sum(0), proteinCount
}
double getAvgExposedAtoms() {
proteinRows.collect { it.exposedAtoms }.sum(0) / proteinCount
div proteinRows.collect { it.exposedAtoms }.sum(0), proteinCount
}
double getAvgProteinConollyPoints() {
avg proteinRows, {ProteinRow it -> it.connollyPoints }
avg proteinRows, {ProteinRow it -> it.sasPoints }
}
double getAvgLigCenterToProtDist() {
avg ligandRows, {LigRow it -> it.centerToProtDist}
}
public <T> double avg(List<T> list, Closure<T> closure) {
if (list.size()==0) return Double.NaN
list.collect { closure(it) }.findAll { it!=Double.NaN }.sum(0) / list.size()
double getLigandCoverage() {
div ligSASPointsCoveredCount, ligSASPointsCount
}
double getAvgClosestPocketDist() {
@@ -347,6 +377,7 @@ class Evaluation {
m.AVG_LIG_CENTER_TO_PROT_DIST = avgLigCenterToProtDist
m.AVG_LIG_CLOSTES_POCKET_DIST = avgClosestPocketDist
m.LIGAND_COVERAGE = ligandCoverage
m.AVG_POCKETS = avgPockets
m.AVG_POCKET_SURF_ATOMS = avgPocketSurfAtoms
@@ -491,6 +522,8 @@ class Evaluation {
double avgConservation
double avgBindingConservation
double avgNonBindingConservation
int sasPoints
}
static class LigRow {

View File

@@ -19,6 +19,7 @@ import weka.classifiers.Classifier
import weka.core.DenseInstance
import weka.core.Instances
import static cz.siret.prank.score.prediction.PointScoreCalculator.predictedPositive
import static cz.siret.prank.score.prediction.PointScoreCalculator.predictedScore
/**
@@ -90,14 +91,12 @@ class WekaSumRescorer extends PocketRescorer implements Parametrized {
// classification
FeatureVector props = extractor.calcFeatureVector(point.point)
point.@hist = getDistributionForPoint(classifier, props)
double[] hist = getDistributionForPoint(classifier, props)
// labels and statistics
double[] hist = point.hist
double predictedScore = predictedScore(hist) // not all classifiers give histogram that sums up to 1
boolean predicted = hist[1] > hist[0]
boolean predicted = predictedPositive(predictedScore)
boolean observed = false
if (ligandAtoms!=null) {
@@ -105,6 +104,10 @@ class WekaSumRescorer extends PocketRescorer implements Parametrized {
observed = (closestLigandDistance <= POSITIVE_POINT_LIGAND_DISTANCE)
}
point.@hist = hist
point.predicted = predicted
point.observed = observed
if (collectingStatistics) {
stats.addPrediction(observed, predicted, predictedScore, hist)
}
@@ -143,7 +146,7 @@ class WekaSumRescorer extends PocketRescorer implements Parametrized {
double[] hist = getDistributionForPoint(classifier, props)
double predictedScore = predictedScore(hist) // not all classifiers give histogram that sums up to 1
boolean predicted = hist[1] > hist[0]
boolean predicted = predictedPositive(predictedScore)
boolean observed = false
if (collectingStatistics) {

View File

@@ -6,6 +6,8 @@ import groovy.transform.CompileStatic
import java.text.DecimalFormat
import static cz.siret.prank.utils.Formatter.formatPercent
import static java.lang.Double.NaN
import static java.lang.Math.log
/**
* Binary classifier statistics collector and calculator
@@ -13,8 +15,11 @@ import static cz.siret.prank.utils.Formatter.formatPercent
@CompileStatic
class ClassifierStats implements Parametrized {
static final double EPS = 1e-15d
static final int HISTOGRAM_BINS = 100
String name
int[][] op // [observed][predicted]
@@ -27,6 +32,7 @@ class ClassifierStats implements Parametrized {
double sumSE = 0
double sumSEpos = 0
double sumSEneg = 0
double sumLogLoss = 0
Histograms histograms = new Histograms()
@@ -77,7 +83,7 @@ class ClassifierStats implements Parametrized {
void addPrediction(boolean observed, boolean predicted, double score, double[] hist) {
double obsv = observed ? 1 : 0
double e = Math.abs(obsv-score)
double e = Math.abs(obsv - score)
double se = e*e
sumE += e
@@ -91,6 +97,11 @@ class ClassifierStats implements Parametrized {
sumSEneg += se
}
double pCorrect = observed ? score : 1-score
if (pCorrect<EPS)
pCorrect = EPS
sumLogLoss -= log(pCorrect)
histograms.score.put(score)
if (observed) {
histograms.scorePos.put(score)
@@ -114,17 +125,17 @@ class ClassifierStats implements Parametrized {
double calcMCC(double TP, double FP, double TN, double FN) {
double n = TP*TN - FP*FN
double d = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)
d = Math.sqrt(d);
d = Math.sqrt(d)
if (d == 0d) {
d = 1d;
d = 1d
}
return n / d;
return n / d
}
double div(double a, double b) {
if (b==0d)
return Double.NaN
return NaN
return a / b
}
@@ -160,19 +171,47 @@ class ClassifierStats implements Parametrized {
private Advanced advanced = null
double getTp() { op[1][1] }
double getFp() { op[0][1] }
double getTn() { op[0][0] }
double getFn() { op[1][0] }
double getTP() { op[1][1] }
double getFP() { op[0][1] }
double getTN() { op[0][0] }
double getFN() { op[1][0] }
/** Observed Positive */
double getOP() {
TP + FN
}
/** Observed Negative */
double getON() {
FP + TN
}
/** Predicted Positive */
double getPP() {
TP + FP
}
/** Predicted Negative */
double getPN() {
TN + FN
}
double getOPON_ratio() {
div OP, ON
}
double getPPPN_ratio() {
div PP, PN
}
/** Precision = Positive Predictive Value */
double getP() {
div tp , (tp + fp)
div TP , (TP + FP)
}
/** Recall = Sensitivity = True Positive Rate */
double getR() {
div tp , (tp + fn)
div TP , (TP + FN)
}
/** F-measure */
@@ -188,36 +227,51 @@ class ClassifierStats implements Parametrized {
}
double getMCC() {
calcMCC(tp, fp, tn, fn)
calcMCC(TP, FP, TN, FN)
}
/** negative predictive value */
double getNPV() {
div tn , (tn + fn)
div TN , (TN + FN)
}
/** specificity = true negative rate */
double getSPC() {
div tn , (tn + fp)
div TN , (TN + FP)
}
/** accuraccy */
double getACC() {
div( (tp + tn) , count )
div( (TP + TN) , count )
}
/** balanced accuracy */
double getBACC() {
(r + SPC) / 2
}
/** TP versus the bad */
double getTPX() {
div tp, tp + fn + fp
div TP, TP + FN + FP
}
/** log TP */
double getLTP() {
try {
-log( TP / (PP * OP) )
} catch (Exception e) {
NaN
}
}
/** false positive rate */
double getFPR() {
div fp , (fp + tn)
div FP , (FP + TN)
}
/** false negative rate */
double getFNR() {
div fn , (tp + fn)
div FN , (TP + FN)
}
/** positive likelihood ratio */
@@ -237,28 +291,34 @@ class ClassifierStats implements Parametrized {
/** false discovery rate */
double getFDR() {
div fp , (tp + fp)
div FP , (TP + FP)
}
/** false ommision rate */
double getFOR() {
div fn , (fn + tn)
div FN , (FN + TN)
}
/** Youden's J statistic = Youden's index */
/** Youden's J statistic = Youden's index = Informedness */
double getYJS() {
r + SPC -1
r + SPC - 1
}
/** Markedness */
double getMRK() {
p + NPV - 1
}
/** Discriminant Power ... <1 = poor, >3 = good, fair otherwise */
double getDPOW() {
if (r==1 || SPC==1)
return Double.NaN
return NaN
double x = r / (1-r)
double y = SPC / (1-SPC)
double c = Math.sqrt(3) / Math.PI
c * ( Math.log(x) + Math.log(y) )
c * ( log(x) + log(y) )
}
double getME() { div sumE, count }
@@ -271,6 +331,26 @@ class ClassifierStats implements Parametrized {
double getMSEneg() { div sumSEneg, count }
double getMSEbalanced() { (MSEneg + MSEpos) / 2 }
double getLogLoss() {
div sumLogLoss, count
}
/** Uncertainty coefficient, aka Proficiency */
double getUC() {
try {
double L = (OP + ON) * log(OP + ON)
double LTP = TP * log( TP / (PP * OP) )
double LFP = FP * log( FP / (PP * ON) )
double LFN = FN * log( FN / (PN * OP) )
double LTN = TN * log( TN / (PN * ON) )
double LP = OP * log( OP / count )
double LN = ON * log( ON / count )
double UC = (L + LTP + LFP + LFN + LTN) / (L + LP + LN)
return UC
} catch (Exception e) {
return NaN
}
}
double getAUC() {
if (advanced==null) advanced = calculateAdvanced()
@@ -310,8 +390,8 @@ class ClassifierStats implements Parametrized {
}
class Advanced {
double wekaAUC = Double.NaN
double wekaAUPRC = Double.NaN
double wekaAUC = 0
double wekaAUPRC = 0
}
}
@@ -354,13 +434,13 @@ class ClassifierStats implements Parametrized {
sb << ",(npv),(p)\n"
sb << "\n"
sb << "pred: , [0], [1]\n"
sb << "obs[0] , ${tn}, ${fp}, ${formatPercent(SPC)}\n"
sb << "obs[1] , ${fn}, ${tp}, ${formatPercent(R)}\n"
sb << "obs[0] , ${TN}, ${FP}, ${formatPercent(SPC)}\n"
sb << "obs[1] , ${FN}, ${TP}, ${formatPercent(R)}\n"
sb << " , ${formatPercent(NPV)}, ${formatPercent(P)}\n"
sb << "\n"
sb << "%:\n"
sb << ", ${rel(tn)}, ${rel(fp)}\n"
sb << ", ${rel(fn)}, ${rel(tp)}\n"
sb << ", ${rel(TN)}, ${rel(FP)}\n"
sb << ", ${rel(FN)}, ${rel(TP)}\n"
sb << "\n"
sb << "ACC:, ${format(ACC)}, accuracy\n"
sb << "\n"
@@ -372,17 +452,6 @@ class ClassifierStats implements Parametrized {
sb << "\n"
sb << "FM:, ${format(f1)}, F-measure\n"
sb << "MCC:, ${format(MCC)}, Matthews correlation coefficient\n"
sb << "\n"
sb << "ME:, ${format(ME)}, Mean error\n"
sb << "MEpos:, ${format(MEpos)}, ME on positive observations\n"
sb << "MEneg:, ${format(MEneg)}, Mean error on negative observations\n"
sb << "MEbal:, ${format(MEbalanced)}, Mean error balanced\n"
sb << "\n"
sb << "MSE:, ${format(MSE)}, Mean squared error\n"
sb << "MSEpos:, ${format(MSEpos)}, MSE on positive observations\n"
sb << "MSEneg:, ${format(MSEneg)}, MSE on negative observations\n"
sb << "MSEbal:, ${format(MSEbalanced)}, Mean error balanced\n"
}
return sb.toString()

View File

@@ -21,7 +21,7 @@ class PocketPredictor implements Parametrized {
private double POCKET_PROT_SURFACE_CUTOFF = params.pred_protein_surface_cutoff
private int MIN_CLUSTER_SIZE = params.pred_min_cluster_size
private double SURROUNDING_DIST = params.pred_surrounding
private double EXTENDED_POCKET_CUTOFF = params.extended_pocket_cutoff
private double CLUSTERING_DIST = params.pred_clustering_dist
private double POINT_THRESHOLD = params.pred_point_threshold
private boolean BALANCE_POINT_DENSITY = params.balance_density
@@ -40,8 +40,9 @@ class PocketPredictor implements Parametrized {
}
private boolean admitPoint(LabeledPoint point) {
double p = PointScoreCalculator.predictedScore(point.hist)
return p > POINT_THRESHOLD
// double p = PointScoreCalculator.predictedScore(point.hist)
// return p > POINT_THRESHOLD
point.predicted
}
/**
@@ -67,9 +68,14 @@ class PocketPredictor implements Parametrized {
List<PrankPocket> pockets = filteredClusters.collect { Atoms clusterPoints ->
Atoms extendedPocketPoints = connollyPoints.cutoffAtoms(clusterPoints, SURROUNDING_DIST)
double score = extendedPocketPoints.collect { score((LabeledPoint)it, connollyPoints) }.sum()
Atoms pocketSurfaceAtoms = protein.exposedAtoms.cutoffAtoms(extendedPocketPoints, POCKET_PROT_SURFACE_CUTOFF)
Atoms pocketPoints = clusterPoints
if (EXTENDED_POCKET_CUTOFF > 0d) {
Atoms extendedPocketPoints = connollyPoints.cutoffAtoms(clusterPoints, EXTENDED_POCKET_CUTOFF)
pocketPoints = extendedPocketPoints
}
double score = (double) pocketPoints.collect { score((LabeledPoint)it, connollyPoints) }.sum(0)
Atoms pocketSurfaceAtoms = protein.exposedAtoms.cutoffAtoms(pocketPoints, POCKET_PROT_SURFACE_CUTOFF)
try {
if (params.score_pockets_by == "conservation" || params.score_pockets_by == "combi") {

View File

@@ -1,6 +1,7 @@
package cz.siret.prank.score.prediction
import cz.siret.prank.program.params.Parametrized
import cz.siret.prank.program.params.Params
import groovy.transform.CompileStatic
/**
@@ -21,6 +22,11 @@ class PointScoreCalculator implements Parametrized {
hist[1] / (hist[0] + hist[1])
}
static boolean predictedPositive(double predictedScore) {
predictedScore >= Params.inst.pred_point_threshold
}
/**
* calculates ligandability score of the point form binary classification historgram
*

View File

@@ -13,6 +13,10 @@ class CmdLineArgs {
static class NamedArg {
String name
String value
public String toString() {
return name + "=" + value
}
}
String[] argList

View File

@@ -2,7 +2,7 @@
<configuration status="INFO">
<!--<ThresholdFilter level="all"/>-->
<appenders>
<Console name="Console" target="SYSTEM_ERR">
<Console name="Console" > <!-- target="SYSTEM_ERR" -->
<PatternLayout pattern="[%level] %logger{0} - %msg%n"/>
<!--<PatternLayout pattern="%d{HH:mm:ss} [%t] %-5level %logger{36} - %msg%n"/>-->
<!--<pattern>[%level] %logger{0} - %msg%n</pattern>-->

View File

@@ -62,6 +62,8 @@ import cz.siret.prank.program.params.Params
log_to_console = false
log_level = "WARN"
log_to_file = true
ploop_delete_runs = true