diff --git a/src/main/groovy/cz/siret/prank/program/routines/results/Evaluation.groovy b/src/main/groovy/cz/siret/prank/program/routines/results/Evaluation.groovy index ecda5b13..c4f7e859 100644 --- a/src/main/groovy/cz/siret/prank/program/routines/results/Evaluation.groovy +++ b/src/main/groovy/cz/siret/prank/program/routines/results/Evaluation.groovy @@ -16,6 +16,7 @@ import groovy.util.logging.Slf4j import org.apache.commons.lang3.StringUtils import javax.annotation.concurrent.ThreadSafe +import java.util.function.Function import static cz.siret.prank.geom.Atoms.intersection import static cz.siret.prank.geom.Atoms.union @@ -559,12 +560,8 @@ class Evaluation implements Parametrized { return res } - - - //===========================================================================================================// - /** * * @param criterionIndex @@ -651,43 +648,47 @@ class Evaluation implements Parametrized { //===========================================================================================================// - @CompileDynamic - double avg(List list) { - if (list.size()==0) return Double.NaN - list.findAll { !it.isNaN() }.sum(0) / list.size() - } - - @CompileDynamic - double avg(List list, Closure closure) { - if (list.size()==0) return Double.NaN - list.collect { closure(it) }.findAll { it!=Double.NaN }.sum(0) / list.size() - } - - @CompileDynamic - double avgNanTo0(List list, Closure closure) { - if (list.size()==0) return Double.NaN - list.collect { closure(it) }.collect { nanNullTo0(it) }.sum(0) / list.size() - } - - double nanNullTo0(Double d) { - if (d == null || d.isNaN()) { - return 0d - } else { - return d + /** + * Average of non-NaN values. Divides sum by total list size (not just non-NaN count). + */ + static double avg(List list) { + if (list.isEmpty()) return Double.NaN + double sum = 0d + for (Double v : list) { + if (v != null && !v.isNaN()) { + sum += v + } } + return sum / list.size() } /** - * average only on proteins that have relevant ligands + * Apply closure to each element, then average non-null, non-NaN results. Divides by total list size. */ - double avgLigProt(List list, Closure closure) { - List ligProts = list.findAll { it.ligands > 0 }.toList() - return avg(ligProts, closure) + static double avg(List list, Function function) { + if (list.isEmpty()) return Double.NaN + double sum = 0d + for (T item : list) { + Double v = function(item) + if (v!=null && !v.isNaN()) { + sum += v + } + } + return sum / list.size() } - double div(double a, double b) { - if (b==0d) - return Double.NaN + /** + * Average only on proteins that have relevant ligands. + * + * Note: this divides by the total number of proteins with ligands + */ + static double avgLigProt(List list, Function function) { + + return avg(list.findAll { it.ligands > 0 }, function) + } + + static double div(double a, double b) { + if (b == 0d) return Double.NaN return a / b } @@ -709,7 +710,7 @@ class Evaluation implements Parametrized { @CompileDynamic double getAvgPocketVolumeTruePockets() { - avg pocketRows.findAll { it.truePocket }, {PocketRow it -> it.pocketVolume } + avg pocketRows.findAll { it.truePocket }, { PocketRow it -> it.pocketVolume } } @CompileDynamic @@ -719,7 +720,7 @@ class Evaluation implements Parametrized { @CompileDynamic double getAvgPocketSurfAtomsTruePockets() { - avg pocketRows.findAll { it.truePocket }, {PocketRow it -> it.surfaceAtomCount } + avg pocketRows.findAll { it.truePocket }, { PocketRow it -> it.surfaceAtomCount } } @CompileDynamic @@ -729,7 +730,7 @@ class Evaluation implements Parametrized { @CompileDynamic double getAvgPocketInnerPointsTruePockets() { - avg pocketRows.findAll { it.truePocket }, {PocketRow it -> it.auxInfo.samplePoints } + avg pocketRows.findAll { it.truePocket }, { PocketRow it -> it.auxInfo.samplePoints } } @CompileDynamic @@ -768,7 +769,6 @@ class Evaluation implements Parametrized { /** * Todo optimize closures */ - @CompileDynamic Map getStats() { def m = new LinkedHashMap() // keep insertion order @@ -799,10 +799,10 @@ class Evaluation implements Parametrized { m.AVG_LIGCOV_SUCC = avgLigProt proteinRows, { it.ligandCoverageSucc } // avg by proteins (unlike DCA and others) m.AVG_LIG_POINT_SCORE = avgLigandPointScore // average of all ligand adjacent points - m.AVG_LIG_AVG_POINT_SCORE = avgNanTo0 ligandRows, { it.avgPointScore } // average of ligand averages - m.AVG_LIG_MAX_POINT_SCORE = avgNanTo0 ligandRows, { it.maxPointScore } - m.AVG_LIG_AVG_MAX3_POINT_SCORE = avgNanTo0 ligandRows, { it.avgMax3PointScore } - m.AVG_LIG_AVG_MAXHALF_POINT_SCORE = avgNanTo0 ligandRows, { it.avgMaxHalfPointScore } + m.AVG_LIG_AVG_POINT_SCORE = avg ligandRows, { it.avgPointScore } // average of ligand averages + m.AVG_LIG_MAX_POINT_SCORE = avg ligandRows, { it.maxPointScore } + m.AVG_LIG_AVG_MAX3_POINT_SCORE = avg ligandRows, { it.avgMax3PointScore } + m.AVG_LIG_AVG_MAXHALF_POINT_SCORE = avg ligandRows, { it.avgMaxHalfPointScore } m.AVG_POCKETS = avgPockets m.AVG_POCKET_SURF_ATOMS = avgPocketSurfAtoms @@ -812,16 +812,19 @@ class Evaluation implements Parametrized { m.AVG_POCKET_VOLUME = avgPocketVolume m.AVG_POCKET_VOLUME_TRUE_POCKETS = avgPocketVolumeTruePockets - m.AVG_POCKET_CONSERVATION = avg pocketRows, { it.avgConservation } - m.AVG_TRUE_POCKET_CONSERVATION = avg pocketRows.findAll { it.truePocket }, { it.avgConservation } - m.AVG_FALSE_POCKET_CONSERVATION = avg pocketRows.findAll { !it.truePocket }, { it.avgConservation } + def truePockets = pocketRows.findAll { it.truePocket } + def falsePockets = pocketRows.findAll { !it.truePocket } - m.AVG_TRUE_POCKET_PRANK_RANK = avg pocketRows.findAll { it.truePocket }, { it.newRank } - m.AVG_FALSE_POCKET_PRANK_RANK = avg pocketRows.findAll { !it.truePocket }, { it.newRank } - m.AVG_TRUE_POCKET_CONSERVATION_RANK = avg pocketRows.findAll { it.truePocket }, { it.conservationRank } - m.AVG_FALSE_POCKET_CONSERVATION_RANK = avg pocketRows.findAll { !it.truePocket }, { it.conservationRank } - m.AVG_TRUE_POCKET_COMBINED_RANK = avg pocketRows.findAll { it.truePocket }, { it.combinedRank } - m.AVG_FALSE_POCKET_COMBINED_RANK = avg pocketRows.findAll { !it.truePocket }, { it.combinedRank } + m.AVG_POCKET_CONSERVATION = avg pocketRows, { it.avgConservation } + m.AVG_TRUE_POCKET_CONSERVATION = avg truePockets, { it.avgConservation } + m.AVG_FALSE_POCKET_CONSERVATION = avg falsePockets, { it.avgConservation } + + m.AVG_TRUE_POCKET_PRANK_RANK = avg truePockets, { it.newRank } + m.AVG_FALSE_POCKET_PRANK_RANK = avg falsePockets, { it.newRank } + m.AVG_TRUE_POCKET_CONSERVATION_RANK = avg truePockets, { it.conservationRank as double } + m.AVG_FALSE_POCKET_CONSERVATION_RANK = avg falsePockets, { it.conservationRank as double } + m.AVG_TRUE_POCKET_COMBINED_RANK = avg truePockets, { it.combinedRank as double } + m.AVG_FALSE_POCKET_COMBINED_RANK = avg falsePockets, { it.combinedRank as double } m.DCA_4_0 = calcSuccessRate("DCA_4", 0) m.DCA_4_1 = calcSuccessRate("DCA_4", 1) @@ -830,11 +833,11 @@ class Evaluation implements Parametrized { m.DCA_4_10 = calcSuccessRate("DCA_4", 10) m.DCA_4_99 = calcSuccessRate("DCA_4", 99) - m.DCA_4_0_NOMINAL = m.DCA_4_0 * m.LIGANDS - m.DCA_4_1_NOMINAL = m.DCA_4_1 * m.LIGANDS - m.DCA_4_2_NOMINAL = m.DCA_4_2 * m.LIGANDS - m.DCA_4_4_NOMINAL = m.DCA_4_4 * m.LIGANDS - m.DCA_4_10_NOMINAL = m.DCA_4_10 * m.LIGANDS + m.DCA_4_0_NOMINAL = (double)m.DCA_4_0 * (long)m.LIGANDS + m.DCA_4_1_NOMINAL = (double)m.DCA_4_1 * (long)m.LIGANDS + m.DCA_4_2_NOMINAL = (double)m.DCA_4_2 * (long)m.LIGANDS + m.DCA_4_4_NOMINAL = (double)m.DCA_4_4 * (long)m.LIGANDS + m.DCA_4_10_NOMINAL = (double)m.DCA_4_10 * (long)m.LIGANDS m.DCA_4_0_PC = calcSuccessRateProteinCentric("DCA_4", 0) m.DCA_4_2_PC = calcSuccessRateProteinCentric("DCA_4", 2) @@ -882,26 +885,33 @@ class Evaluation implements Parametrized { m.DSO_02_T5 = calcSuccessRateTopN("DSO_0.2",5) m.DSO_02_T7 = calcSuccessRateTopN("DSO_0.2",7) - m.OPT1 = 100*m.DCA_4_0 + 100*m.DCA_4_2 + 50*m.DCA_4_4 + 10*m.AVG_LIGCOV_SUCC + 5*m.AVG_DSO_SUCC - m.OPT2 = 100*m.DCA_4_0_PC + 50*m.DCA_4_2_PC + 5*m.AVG_LIGCOV_SUCC + 3*m.AVG_DSO_SUCC + m.OPT1 = 100*(double)m.DCA_4_0a + 100*(double)m.DCA_4_2 + 50*(double)m.DCA_4_4 + 10*(double)m.AVG_LIGCOV_SUCC + 5*(double)m.AVG_DSO_SUCC + m.OPT2 = 100*(double)m.DCA_4_0_PC + 50*(double)m.DCA_4_2_PC + 5*(double)m.AVG_LIGCOV_SUCC + 3*(double)m.AVG_DSO_SUCC - - // write predicted scores to file if requested - // TODO: move this somewhere else (getStats() shouldn't write to disk) - if (StringUtils.isNotBlank(params.log_scores_to_file)) { - PrintWriter w = new PrintWriter(new BufferedWriter( - new FileWriter(params.log_scores_to_file, true))) - w.println("First line of the file") - nonBindingScores.forEach({ it -> w.print(it); w.print(' ') }) - w.println() - bindingScores.forEach({ it -> w.print(it); w.print(' ') }) - w.println() - w.close() - } + writeScoresToFileIfRequested() return m } + /** + * Append binding/non-binding scores to the file specified by log_scores_to_file param. + */ + private void writeScoresToFileIfRequested() { + if (StringUtils.isNotBlank(params.log_scores_to_file)) { + PrintWriter w = new PrintWriter(new BufferedWriter( + new FileWriter(params.log_scores_to_file, false))) + try { + w.println("First line of the file") + nonBindingScores.forEach({ it -> w.print(it); w.print(' ') }) + w.println() + bindingScores.forEach({ it -> w.print(it); w.print(' ') }) + w.println() + } finally { + w.close() + } + } + } + /** * get list of evaluation criteria used during eval routines */ diff --git a/src/main/groovy/cz/siret/prank/utils/MathUtils.java b/src/main/groovy/cz/siret/prank/utils/MathUtils.java index 15b94664..148edaf4 100644 --- a/src/main/groovy/cz/siret/prank/utils/MathUtils.java +++ b/src/main/groovy/cz/siret/prank/utils/MathUtils.java @@ -28,7 +28,7 @@ public class MathUtils { //===============================================================================================// public static double nanToZero(double x) { - if (Double.isNaN(x)) return 0; + if (Double.isNaN(x)) return 0.0d; return x; } @@ -43,7 +43,7 @@ public class MathUtils { } public static double safeDiv(double x, double y){ - if (y == 0.0) return 0; + if (y == 0.0d) return 0.0d; return x / y; }