Refactor Evaluation: simplify avg/div methods, use Function instead of Closure, extract writeScoresToFileIfRequested

This commit is contained in:
rdk
2026-03-15 19:27:15 +01:00
parent 20236ef092
commit 5ac9aab18a
2 changed files with 83 additions and 73 deletions

View File

@@ -16,6 +16,7 @@ import groovy.util.logging.Slf4j
import org.apache.commons.lang3.StringUtils
import javax.annotation.concurrent.ThreadSafe
import java.util.function.Function
import static cz.siret.prank.geom.Atoms.intersection
import static cz.siret.prank.geom.Atoms.union
@@ -559,12 +560,8 @@ class Evaluation implements Parametrized {
return res
}
//===========================================================================================================//
/**
*
* @param criterionIndex
@@ -651,43 +648,47 @@ class Evaluation implements Parametrized {
//===========================================================================================================//
@CompileDynamic
double avg(List<Double> list) {
if (list.size()==0) return Double.NaN
list.findAll { !it.isNaN() }.sum(0) / list.size()
}
@CompileDynamic
<T> double avg(List<T> list, Closure<T> closure) {
if (list.size()==0) return Double.NaN
list.collect { closure(it) }.findAll { it!=Double.NaN }.sum(0) / list.size()
}
@CompileDynamic
<T> double avgNanTo0(List<T> list, Closure<T> closure) {
if (list.size()==0) return Double.NaN
list.collect { closure(it) }.collect { nanNullTo0(it) }.sum(0) / list.size()
}
double nanNullTo0(Double d) {
if (d == null || d.isNaN()) {
return 0d
} else {
return d
/**
* Average of non-NaN values. Divides sum by total list size (not just non-NaN count).
*/
static double avg(List<Double> list) {
if (list.isEmpty()) return Double.NaN
double sum = 0d
for (Double v : list) {
if (v != null && !v.isNaN()) {
sum += v
}
}
return sum / list.size()
}
/**
* average only on proteins that have relevant ligands
* Apply closure to each element, then average non-null, non-NaN results. Divides by total list size.
*/
double avgLigProt(List<ProteinRow> list, Closure<ProteinRow> closure) {
List<ProteinRow> ligProts = list.findAll { it.ligands > 0 }.toList()
return avg(ligProts, closure)
static <T> double avg(List<T> list, Function<T, Double> function) {
if (list.isEmpty()) return Double.NaN
double sum = 0d
for (T item : list) {
Double v = function(item)
if (v!=null && !v.isNaN()) {
sum += v
}
}
return sum / list.size()
}
double div(double a, double b) {
if (b==0d)
return Double.NaN
/**
* Average only on proteins that have relevant ligands.
*
* Note: this divides by the total number of proteins with ligands
*/
static double avgLigProt(List<ProteinRow> list, Function<ProteinRow, Double> function) {
return avg(list.findAll { it.ligands > 0 }, function)
}
static double div(double a, double b) {
if (b == 0d) return Double.NaN
return a / b
}
@@ -709,7 +710,7 @@ class Evaluation implements Parametrized {
@CompileDynamic
double getAvgPocketVolumeTruePockets() {
avg pocketRows.findAll { it.truePocket }, {PocketRow it -> it.pocketVolume }
avg pocketRows.findAll { it.truePocket }, { PocketRow it -> it.pocketVolume }
}
@CompileDynamic
@@ -719,7 +720,7 @@ class Evaluation implements Parametrized {
@CompileDynamic
double getAvgPocketSurfAtomsTruePockets() {
avg pocketRows.findAll { it.truePocket }, {PocketRow it -> it.surfaceAtomCount }
avg pocketRows.findAll { it.truePocket }, { PocketRow it -> it.surfaceAtomCount }
}
@CompileDynamic
@@ -729,7 +730,7 @@ class Evaluation implements Parametrized {
@CompileDynamic
double getAvgPocketInnerPointsTruePockets() {
avg pocketRows.findAll { it.truePocket }, {PocketRow it -> it.auxInfo.samplePoints }
avg pocketRows.findAll { it.truePocket }, { PocketRow it -> it.auxInfo.samplePoints }
}
@CompileDynamic
@@ -768,7 +769,6 @@ class Evaluation implements Parametrized {
/**
* Todo optimize closures
*/
@CompileDynamic
Map getStats() {
def m = new LinkedHashMap() // keep insertion order
@@ -799,10 +799,10 @@ class Evaluation implements Parametrized {
m.AVG_LIGCOV_SUCC = avgLigProt proteinRows, { it.ligandCoverageSucc } // avg by proteins (unlike DCA and others)
m.AVG_LIG_POINT_SCORE = avgLigandPointScore // average of all ligand adjacent points
m.AVG_LIG_AVG_POINT_SCORE = avgNanTo0 ligandRows, { it.avgPointScore } // average of ligand averages
m.AVG_LIG_MAX_POINT_SCORE = avgNanTo0 ligandRows, { it.maxPointScore }
m.AVG_LIG_AVG_MAX3_POINT_SCORE = avgNanTo0 ligandRows, { it.avgMax3PointScore }
m.AVG_LIG_AVG_MAXHALF_POINT_SCORE = avgNanTo0 ligandRows, { it.avgMaxHalfPointScore }
m.AVG_LIG_AVG_POINT_SCORE = avg ligandRows, { it.avgPointScore } // average of ligand averages
m.AVG_LIG_MAX_POINT_SCORE = avg ligandRows, { it.maxPointScore }
m.AVG_LIG_AVG_MAX3_POINT_SCORE = avg ligandRows, { it.avgMax3PointScore }
m.AVG_LIG_AVG_MAXHALF_POINT_SCORE = avg ligandRows, { it.avgMaxHalfPointScore }
m.AVG_POCKETS = avgPockets
m.AVG_POCKET_SURF_ATOMS = avgPocketSurfAtoms
@@ -812,16 +812,19 @@ class Evaluation implements Parametrized {
m.AVG_POCKET_VOLUME = avgPocketVolume
m.AVG_POCKET_VOLUME_TRUE_POCKETS = avgPocketVolumeTruePockets
m.AVG_POCKET_CONSERVATION = avg pocketRows, { it.avgConservation }
m.AVG_TRUE_POCKET_CONSERVATION = avg pocketRows.findAll { it.truePocket }, { it.avgConservation }
m.AVG_FALSE_POCKET_CONSERVATION = avg pocketRows.findAll { !it.truePocket }, { it.avgConservation }
def truePockets = pocketRows.findAll { it.truePocket }
def falsePockets = pocketRows.findAll { !it.truePocket }
m.AVG_TRUE_POCKET_PRANK_RANK = avg pocketRows.findAll { it.truePocket }, { it.newRank }
m.AVG_FALSE_POCKET_PRANK_RANK = avg pocketRows.findAll { !it.truePocket }, { it.newRank }
m.AVG_TRUE_POCKET_CONSERVATION_RANK = avg pocketRows.findAll { it.truePocket }, { it.conservationRank }
m.AVG_FALSE_POCKET_CONSERVATION_RANK = avg pocketRows.findAll { !it.truePocket }, { it.conservationRank }
m.AVG_TRUE_POCKET_COMBINED_RANK = avg pocketRows.findAll { it.truePocket }, { it.combinedRank }
m.AVG_FALSE_POCKET_COMBINED_RANK = avg pocketRows.findAll { !it.truePocket }, { it.combinedRank }
m.AVG_POCKET_CONSERVATION = avg pocketRows, { it.avgConservation }
m.AVG_TRUE_POCKET_CONSERVATION = avg truePockets, { it.avgConservation }
m.AVG_FALSE_POCKET_CONSERVATION = avg falsePockets, { it.avgConservation }
m.AVG_TRUE_POCKET_PRANK_RANK = avg truePockets, { it.newRank }
m.AVG_FALSE_POCKET_PRANK_RANK = avg falsePockets, { it.newRank }
m.AVG_TRUE_POCKET_CONSERVATION_RANK = avg truePockets, { it.conservationRank as double }
m.AVG_FALSE_POCKET_CONSERVATION_RANK = avg falsePockets, { it.conservationRank as double }
m.AVG_TRUE_POCKET_COMBINED_RANK = avg truePockets, { it.combinedRank as double }
m.AVG_FALSE_POCKET_COMBINED_RANK = avg falsePockets, { it.combinedRank as double }
m.DCA_4_0 = calcSuccessRate("DCA_4", 0)
m.DCA_4_1 = calcSuccessRate("DCA_4", 1)
@@ -830,11 +833,11 @@ class Evaluation implements Parametrized {
m.DCA_4_10 = calcSuccessRate("DCA_4", 10)
m.DCA_4_99 = calcSuccessRate("DCA_4", 99)
m.DCA_4_0_NOMINAL = m.DCA_4_0 * m.LIGANDS
m.DCA_4_1_NOMINAL = m.DCA_4_1 * m.LIGANDS
m.DCA_4_2_NOMINAL = m.DCA_4_2 * m.LIGANDS
m.DCA_4_4_NOMINAL = m.DCA_4_4 * m.LIGANDS
m.DCA_4_10_NOMINAL = m.DCA_4_10 * m.LIGANDS
m.DCA_4_0_NOMINAL = (double)m.DCA_4_0 * (long)m.LIGANDS
m.DCA_4_1_NOMINAL = (double)m.DCA_4_1 * (long)m.LIGANDS
m.DCA_4_2_NOMINAL = (double)m.DCA_4_2 * (long)m.LIGANDS
m.DCA_4_4_NOMINAL = (double)m.DCA_4_4 * (long)m.LIGANDS
m.DCA_4_10_NOMINAL = (double)m.DCA_4_10 * (long)m.LIGANDS
m.DCA_4_0_PC = calcSuccessRateProteinCentric("DCA_4", 0)
m.DCA_4_2_PC = calcSuccessRateProteinCentric("DCA_4", 2)
@@ -882,26 +885,33 @@ class Evaluation implements Parametrized {
m.DSO_02_T5 = calcSuccessRateTopN("DSO_0.2",5)
m.DSO_02_T7 = calcSuccessRateTopN("DSO_0.2",7)
m.OPT1 = 100*m.DCA_4_0 + 100*m.DCA_4_2 + 50*m.DCA_4_4 + 10*m.AVG_LIGCOV_SUCC + 5*m.AVG_DSO_SUCC
m.OPT2 = 100*m.DCA_4_0_PC + 50*m.DCA_4_2_PC + 5*m.AVG_LIGCOV_SUCC + 3*m.AVG_DSO_SUCC
m.OPT1 = 100*(double)m.DCA_4_0a + 100*(double)m.DCA_4_2 + 50*(double)m.DCA_4_4 + 10*(double)m.AVG_LIGCOV_SUCC + 5*(double)m.AVG_DSO_SUCC
m.OPT2 = 100*(double)m.DCA_4_0_PC + 50*(double)m.DCA_4_2_PC + 5*(double)m.AVG_LIGCOV_SUCC + 3*(double)m.AVG_DSO_SUCC
// write predicted scores to file if requested
// TODO: move this somewhere else (getStats() shouldn't write to disk)
if (StringUtils.isNotBlank(params.log_scores_to_file)) {
PrintWriter w = new PrintWriter(new BufferedWriter(
new FileWriter(params.log_scores_to_file, true)))
w.println("First line of the file")
nonBindingScores.forEach({ it -> w.print(it); w.print(' ') })
w.println()
bindingScores.forEach({ it -> w.print(it); w.print(' ') })
w.println()
w.close()
}
writeScoresToFileIfRequested()
return m
}
/**
* Append binding/non-binding scores to the file specified by log_scores_to_file param.
*/
private void writeScoresToFileIfRequested() {
if (StringUtils.isNotBlank(params.log_scores_to_file)) {
PrintWriter w = new PrintWriter(new BufferedWriter(
new FileWriter(params.log_scores_to_file, false)))
try {
w.println("First line of the file")
nonBindingScores.forEach({ it -> w.print(it); w.print(' ') })
w.println()
bindingScores.forEach({ it -> w.print(it); w.print(' ') })
w.println()
} finally {
w.close()
}
}
}
/**
* get list of evaluation criteria used during eval routines
*/

View File

@@ -28,7 +28,7 @@ public class MathUtils {
//===============================================================================================//
public static double nanToZero(double x) {
if (Double.isNaN(x)) return 0;
if (Double.isNaN(x)) return 0.0d;
return x;
}
@@ -43,7 +43,7 @@ public class MathUtils {
}
public static double safeDiv(double x, double y){
if (y == 0.0) return 0;
if (y == 0.0d) return 0.0d;
return x / y;
}