mirror of
https://github.com/rdk/p2rank.git
synced 2026-06-04 12:44:24 +08:00
Merge changes from upstream
This commit is contained in:
@@ -9,7 +9,7 @@ apply plugin: 'idea'
|
||||
|
||||
|
||||
group = 'cz.siret'
|
||||
version = '2.0-dev.7'
|
||||
version = '2.0-dev.8'
|
||||
|
||||
description = 'Ligand binding site prediction based on machine learning.'
|
||||
|
||||
|
||||
@@ -46,11 +46,6 @@ import cz.siret.prank.program.params.Params
|
||||
|
||||
//== FAETURES
|
||||
|
||||
/**
|
||||
* include volsite pharmacophore properties
|
||||
*/
|
||||
use_volsite_features = true
|
||||
|
||||
extra_features = ["chem","volsite","protrusion","bfactor"]
|
||||
|
||||
atom_table_features = ["ap5sasaValids","ap5sasaInvalids"] // "apRawValids","apRawInvalids","atomicHydrophobicity"
|
||||
@@ -264,11 +259,6 @@ import cz.siret.prank.program.params.Params
|
||||
*/
|
||||
pred_clustering_dist = 5
|
||||
|
||||
/**
|
||||
* distance to extend clusters around hotspots
|
||||
*/
|
||||
pred_surrounding = 3.5
|
||||
|
||||
/**
|
||||
* cuttoff distance of protein surface atoms considered as part of the pocket
|
||||
*/
|
||||
|
||||
@@ -46,11 +46,6 @@ import cz.siret.prank.program.params.Params
|
||||
|
||||
//== FAETURES
|
||||
|
||||
/**
|
||||
* include volsite pharmacophore properties
|
||||
*/
|
||||
use_volsite_features = true
|
||||
|
||||
atom_table_features = ["apRawValids","apRawInvalids","atomicHydrophobicity"]
|
||||
|
||||
extra_features = ["chem","volsite","protrusion","bfactor"]
|
||||
@@ -265,9 +260,9 @@ import cz.siret.prank.program.params.Params
|
||||
pred_clustering_dist = 3
|
||||
|
||||
/**
|
||||
* distance to extend clusters around hotspots
|
||||
* SAS points around ligandable points (an their score) will be included in the pocket
|
||||
*/
|
||||
pred_surrounding = 3.5
|
||||
extended_pocket_cutoff = 3.5
|
||||
|
||||
/**
|
||||
* cuttoff distance of protein surface atoms considered as part of the pocket
|
||||
|
||||
37
new.groovy
37
new.groovy
@@ -35,18 +35,11 @@ import cz.siret.prank.program.params.Params
|
||||
|
||||
predictions = true
|
||||
|
||||
out_prefix_date = true
|
||||
|
||||
crossval_threads = 5
|
||||
|
||||
cache_datasets = true
|
||||
|
||||
clear_sec_caches = false
|
||||
|
||||
clear_prim_caches = false
|
||||
|
||||
|
||||
log_cases = true
|
||||
|
||||
|
||||
/**
|
||||
* calculate feature importance
|
||||
@@ -54,7 +47,8 @@ import cz.siret.prank.program.params.Params
|
||||
*/
|
||||
feature_importances = false
|
||||
|
||||
output_only_stats = true
|
||||
stats_collect_predictions = true
|
||||
|
||||
|
||||
/**
|
||||
* collect negatives just from decoy pockets found by other method
|
||||
@@ -66,16 +60,39 @@ import cz.siret.prank.program.params.Params
|
||||
|
||||
atom_table_features = ["ap5sasaValids","ap5sasaInvalids","apRawValids","apRawInvalids","atomicHydrophobicity"]
|
||||
|
||||
extra_features = ["protrusion","bfactor"]
|
||||
extra_features = ["chem","volsite","protrusion","bfactor"]
|
||||
|
||||
residue_table_features = ["RAx"]
|
||||
|
||||
average_feat_vectors = true
|
||||
|
||||
balance_class_weights = true
|
||||
|
||||
target_class_weight_ratio = 0.055
|
||||
|
||||
|
||||
// technical
|
||||
|
||||
cache_datasets = true
|
||||
|
||||
clear_sec_caches = false
|
||||
|
||||
clear_prim_caches = false
|
||||
|
||||
log_cases = true
|
||||
|
||||
output_only_stats = true
|
||||
|
||||
log_to_console = false
|
||||
|
||||
log_level = "WARN"
|
||||
|
||||
log_to_file = true
|
||||
|
||||
ploop_delete_runs = true
|
||||
|
||||
zip_log_file = true
|
||||
|
||||
out_prefix_date = true
|
||||
|
||||
}
|
||||
|
||||
2
prank.sh
2
prank.sh
@@ -42,6 +42,6 @@ PARAMS="$PRANK_LOCALENV_PARAMS"
|
||||
|
||||
CMD="$JAVACMD $JAVA_OPTS -cp ${CLASSPATH} cz.siret.prank.program.Main ${PARAMS} $@"
|
||||
echo "+" $CMD
|
||||
"$JAVACMD" $JAVA_OPTS -cp "${CLASSPATH}" cz.siret.prank.program.Main ${PARAMS} "$@" 2>>local-debug.log
|
||||
"$JAVACMD" $JAVA_OPTS -cp "${CLASSPATH}" cz.siret.prank.program.Main ${PARAMS} "$@"
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package cz.siret.prank.domain
|
||||
|
||||
import groovy.transform.CompileStatic
|
||||
|
||||
/**
|
||||
* 20 main amino acid codes
|
||||
*/
|
||||
@CompileStatic
|
||||
enum AA {
|
||||
|
||||
ALA,
|
||||
@@ -24,4 +29,17 @@ enum AA {
|
||||
TRP,
|
||||
TYR;
|
||||
|
||||
private static final Map<String, AA> index = new HashMap<String, AA>()
|
||||
|
||||
|
||||
static {
|
||||
for (AA value : EnumSet.allOf(AA.class)) {
|
||||
index.put(value.name(), value)
|
||||
}
|
||||
}
|
||||
|
||||
public static AA forName(String name) {
|
||||
return index.get(name)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -157,8 +157,8 @@ class Dataset implements Parametrized {
|
||||
void clearSecondaryCaches() {
|
||||
items.each {
|
||||
if (it.cachedPair!=null) {
|
||||
it.cachedPair.prediction.protein.clearCachedSurfaces()
|
||||
it.cachedPair.liganatedProtein.clearCachedSurfaces()
|
||||
it.cachedPair.prediction.protein.clearSecondaryData()
|
||||
it.cachedPair.liganatedProtein.clearSecondaryData()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,10 +121,14 @@ class Protein implements Parametrized {
|
||||
return trainSurface
|
||||
}
|
||||
|
||||
void clearCachedSurfaces() {
|
||||
/**
|
||||
* clears generated surfaces and secondary data
|
||||
*/
|
||||
void clearSecondaryData() {
|
||||
connollySurface = null
|
||||
trainSurface = null
|
||||
exposedAtoms = null
|
||||
secondaryData.clear()
|
||||
}
|
||||
|
||||
Atoms getAllLigandAtoms() {
|
||||
|
||||
@@ -113,7 +113,7 @@ class PrankFeatureExtractor extends FeatureExtractor<PrankFeatureVector> impleme
|
||||
surfaceLayerAtoms = protein.exposedAtoms
|
||||
}
|
||||
|
||||
log.debug "surfaceLayerAtoms:$surfaceLayerAtoms.count (surfaceAtoms: $pocket.surfaceAtoms.count) "
|
||||
log.debug "surfaceLayerAtoms:$surfaceLayerAtoms.count (surfaceAtoms: ${pocket?.surfaceAtoms?.count}) "
|
||||
|
||||
preEvaluateProperties(surfaceLayerAtoms)
|
||||
if (DO_SMOOTH_REPRESENTATION) {
|
||||
@@ -195,7 +195,7 @@ class PrankFeatureExtractor extends FeatureExtractor<PrankFeatureVector> impleme
|
||||
|
||||
|
||||
|
||||
log.info "P2R protein:$protein.proteinAtoms.count exposedAtoms:$res.surfaceLayerAtoms.count deepSurrounding:$res.deepSurrounding.count connollyPoints:$res.sampledPoints.count"
|
||||
log.info "P2R protein:$protein.proteinAtoms.count exposedAtoms:$res.surfaceLayerAtoms.count deepSurrounding:$res.deepSurrounding.count sasPoints:$res.sampledPoints.count"
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package cz.siret.prank.features.api
|
||||
|
||||
import cz.siret.prank.features.implementation.Asa2Feature
|
||||
import cz.siret.prank.features.implementation.AsaFeature
|
||||
import cz.siret.prank.features.implementation.AsaResiduesFeature
|
||||
import cz.siret.prank.features.implementation.BfactorFeature
|
||||
import cz.siret.prank.features.implementation.ContactResiduesPositionFeature
|
||||
import cz.siret.prank.features.implementation.ProtrusionFeature
|
||||
import cz.siret.prank.features.implementation.ProtrusionHistogramFeature
|
||||
import cz.siret.prank.features.implementation.SurfaceProtrusionFeature
|
||||
@@ -48,6 +52,10 @@ class FeatureRegistry {
|
||||
registerFeature(new ConservationFeature())
|
||||
registerFeature(new ConservationCloudFeature())
|
||||
registerFeature(new ConservationCloudScaledFeature())
|
||||
registerFeature(new ContactResiduesPositionFeature())
|
||||
registerFeature(new AsaFeature())
|
||||
registerFeature(new Asa2Feature())
|
||||
registerFeature(new AsaResiduesFeature())
|
||||
|
||||
// Register new feature implementations here
|
||||
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package cz.siret.prank.features.implementation
|
||||
|
||||
import cz.siret.prank.domain.Protein
|
||||
import cz.siret.prank.features.api.SasFeatureCalculationContext
|
||||
import cz.siret.prank.features.api.SasFeatureCalculator
|
||||
import cz.siret.prank.geom.Atoms
|
||||
import cz.siret.prank.program.params.Parametrized
|
||||
import cz.siret.prank.utils.Writable
|
||||
import groovy.transform.CompileStatic
|
||||
import groovy.util.logging.Slf4j
|
||||
import org.biojava.nbio.structure.Atom
|
||||
import org.biojava.nbio.structure.StructureTools
|
||||
import org.biojava.nbio.structure.asa.AsaCalculator
|
||||
|
||||
/**
|
||||
* Local protein solvent accessible surface area feature
|
||||
*/
|
||||
@Slf4j
|
||||
@CompileStatic
|
||||
class Asa2Feature extends SasFeatureCalculator implements Parametrized, Writable {
|
||||
|
||||
static final String NAME = "asa2"
|
||||
|
||||
@Override
|
||||
String getName() { NAME }
|
||||
|
||||
@Override
|
||||
List<String> getHeader() {
|
||||
return ["asa2.1", "asa2.2"]
|
||||
}
|
||||
|
||||
ProtAsa calcProtAsa(Protein protein, double probeRadius) {
|
||||
int nSpherePoints = AsaCalculator.DEFAULT_N_SPHERE_POINTS
|
||||
int threads = 1
|
||||
boolean hetAtoms = false
|
||||
|
||||
Atom[] protAtoms = StructureTools.getAllNonHAtomArray(protein.structure, hetAtoms)
|
||||
AsaCalculator asaCalculator = new AsaCalculator(protein.structure, probeRadius, nSpherePoints, threads, hetAtoms)
|
||||
double[] atomAsas = asaCalculator.calculateAsas()
|
||||
protAtoms[0].getPDBserial()
|
||||
|
||||
Map<Integer, Double> asaByAtom = new HashMap<>()
|
||||
for (int i=0; i!= protAtoms.length; ++i) {
|
||||
asaByAtom.put protAtoms[i].PDBserial, atomAsas[i]
|
||||
}
|
||||
|
||||
return new ProtAsa(protein, asaByAtom)
|
||||
}
|
||||
|
||||
@Override
|
||||
void preProcessProtein(Protein protein) {
|
||||
if (!protein.secondaryData.containsKey("prot_atom_asa")) {
|
||||
protein.secondaryData.put "prot_atom_asa", calcProtAsa(protein, params.feat_asa_probe_radius)
|
||||
}
|
||||
if (!protein.secondaryData.containsKey("prot_atom_asa2")) {
|
||||
protein.secondaryData.put "prot_atom_asa2", calcProtAsa(protein, params.feat_asa_probe_radius2)
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
double[] calculateForSasPoint(Atom sasPoint, SasFeatureCalculationContext context) {
|
||||
Atoms localAtoms = context.protein.exposedAtoms.cutoffAroundAtom(sasPoint, params.feat_asa_neigh_radius)
|
||||
|
||||
ProtAsa protAsa = (ProtAsa) context.protein.secondaryData.get("prot_atom_asa")
|
||||
double localAsa = (double) localAtoms.collect { Atom a -> protAsa.asaByAtom.get(a.PDBserial) ?: 0 }.sum(0)
|
||||
|
||||
ProtAsa protAsa2 = (ProtAsa) context.protein.secondaryData.get("prot_atom_asa2")
|
||||
double localAsa2 = (double) localAtoms.collect { Atom a -> protAsa2.asaByAtom.get(a.PDBserial) ?: 0 }.sum(0)
|
||||
|
||||
return [localAsa, localAsa2] as double[]
|
||||
}
|
||||
|
||||
|
||||
static class ProtAsa {
|
||||
Protein protein
|
||||
Map<Integer, Double> asaByAtom
|
||||
|
||||
ProtAsa(Protein protein, Map<Integer, Double> asaByAtom) {
|
||||
this.protein = protein
|
||||
this.asaByAtom = asaByAtom
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package cz.siret.prank.features.implementation
|
||||
|
||||
import cz.siret.prank.domain.Protein
|
||||
import cz.siret.prank.features.api.SasFeatureCalculationContext
|
||||
import cz.siret.prank.features.api.SasFeatureCalculator
|
||||
import cz.siret.prank.geom.Atoms
|
||||
import cz.siret.prank.program.params.Parametrized
|
||||
import cz.siret.prank.utils.Writable
|
||||
import groovy.transform.CompileStatic
|
||||
import groovy.util.logging.Slf4j
|
||||
import org.biojava.nbio.structure.Atom
|
||||
import org.biojava.nbio.structure.StructureTools
|
||||
import org.biojava.nbio.structure.asa.AsaCalculator
|
||||
|
||||
/**
|
||||
* Local protein solvent accessible surface area feature
|
||||
*/
|
||||
@Slf4j
|
||||
@CompileStatic
|
||||
class AsaFeature extends SasFeatureCalculator implements Parametrized, Writable {
|
||||
|
||||
static final String NAME = "asa"
|
||||
|
||||
@Override
|
||||
String getName() { NAME }
|
||||
|
||||
@Override
|
||||
void preProcessProtein(Protein protein) {
|
||||
if (protein.secondaryData.containsKey("prot_atom_asa")) {
|
||||
return
|
||||
}
|
||||
|
||||
double probeRadius = params.feat_asa_probe_radius
|
||||
int nSpherePoints = AsaCalculator.DEFAULT_N_SPHERE_POINTS
|
||||
int threads = 1
|
||||
boolean hetAtoms = false
|
||||
|
||||
Atom[] protAtoms = StructureTools.getAllNonHAtomArray(protein.structure, hetAtoms)
|
||||
AsaCalculator asaCalculator = new AsaCalculator(protein.structure, probeRadius, nSpherePoints, threads, hetAtoms)
|
||||
double[] atomAsas = asaCalculator.calculateAsas()
|
||||
protAtoms[0].getPDBserial()
|
||||
|
||||
Map<Integer, Double> asaByAtom = new HashMap<>()
|
||||
for (int i=0; i!= protAtoms.length; ++i) {
|
||||
asaByAtom.put protAtoms[i].PDBserial, atomAsas[i]
|
||||
}
|
||||
|
||||
protein.secondaryData.put "prot_atom_asa", new ProtAsa(protein, asaByAtom)
|
||||
}
|
||||
|
||||
@Override
|
||||
double[] calculateForSasPoint(Atom sasPoint, SasFeatureCalculationContext context) {
|
||||
Atoms localAtoms = context.protein.exposedAtoms.cutoffAroundAtom(sasPoint, params.feat_asa_neigh_radius)
|
||||
ProtAsa protAsa = (ProtAsa) context.protein.secondaryData.get("prot_atom_asa")
|
||||
double localAsa = (double) localAtoms.collect { Atom a -> protAsa.asaByAtom.get(a.PDBserial) ?: 0 }.sum(0)
|
||||
|
||||
return [localAsa] as double[]
|
||||
}
|
||||
|
||||
|
||||
static class ProtAsa {
|
||||
Protein protein
|
||||
Map<Integer, Double> asaByAtom
|
||||
|
||||
ProtAsa(Protein protein, Map<Integer, Double> asaByAtom) {
|
||||
this.protein = protein
|
||||
this.asaByAtom = asaByAtom
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package cz.siret.prank.features.implementation
|
||||
|
||||
import cz.siret.prank.domain.Protein
|
||||
import cz.siret.prank.features.api.SasFeatureCalculationContext
|
||||
import cz.siret.prank.features.api.SasFeatureCalculator
|
||||
import cz.siret.prank.geom.Atoms
|
||||
import cz.siret.prank.program.params.Parametrized
|
||||
import groovy.transform.CompileStatic
|
||||
import groovy.util.logging.Slf4j
|
||||
import org.biojava.nbio.structure.Atom
|
||||
import org.biojava.nbio.structure.Group
|
||||
import org.biojava.nbio.structure.ResidueNumber
|
||||
import org.biojava.nbio.structure.asa.AsaCalculator
|
||||
import org.biojava.nbio.structure.asa.GroupAsa
|
||||
|
||||
/**
|
||||
* Local protein solvent accessible surface area feature
|
||||
*/
|
||||
@Slf4j
|
||||
@CompileStatic
|
||||
class AsaResiduesFeature extends SasFeatureCalculator implements Parametrized {
|
||||
|
||||
static final String NAME = "asares"
|
||||
|
||||
@Override
|
||||
String getName() { NAME }
|
||||
|
||||
@Override
|
||||
void preProcessProtein(Protein protein) {
|
||||
double probeRadius = params.feat_asa_probe_radius
|
||||
int nSpherePoints = AsaCalculator.DEFAULT_N_SPHERE_POINTS
|
||||
int threads = 1
|
||||
boolean hetAtoms = false
|
||||
|
||||
AsaCalculator asaCalculator = new AsaCalculator(protein.structure, probeRadius, nSpherePoints, threads, hetAtoms)
|
||||
|
||||
List<GroupAsa> asas = asaCalculator.groupAsas.toList()
|
||||
|
||||
protein.secondaryData.put "prot_asa", new ProtAsa(protein, asas)
|
||||
}
|
||||
|
||||
@Override
|
||||
double[] calculateForSasPoint(Atom sasPoint, SasFeatureCalculationContext context) {
|
||||
List<Group> groups = context.protein.exposedAtoms.cutoffAroundAtom(sasPoint, params.feat_asa_neigh_radius).distinctGroups
|
||||
ProtAsa protAsa = (ProtAsa) context.protein.secondaryData.get("prot_asa")
|
||||
double localAsa = (double) groups.collect { Group g -> protAsa.groupAsaMap.get(g.residueNumber) ?: 0 }.sum(0)
|
||||
|
||||
return [localAsa] as double[]
|
||||
}
|
||||
|
||||
|
||||
static class ProtAsa {
|
||||
Protein protein
|
||||
List<GroupAsa> groupAsas
|
||||
|
||||
Map<ResidueNumber, Double> groupAsaMap = new HashMap<>()
|
||||
|
||||
ProtAsa(Protein protein, List<GroupAsa> groupAsas) {
|
||||
this.protein = protein
|
||||
this.groupAsas = groupAsas
|
||||
|
||||
for (GroupAsa gasa : groupAsas) {
|
||||
|
||||
double asa = gasa.asaC
|
||||
ResidueNumber resNum = gasa?.group?.residueNumber
|
||||
|
||||
if (resNum!=null) {
|
||||
groupAsaMap.put resNum, asa
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package cz.siret.prank.features.implementation
|
||||
|
||||
import com.google.common.collect.ArrayListMultimap
|
||||
import com.google.common.collect.Multimap
|
||||
import cz.siret.prank.domain.AA
|
||||
import cz.siret.prank.features.api.SasFeatureCalculationContext
|
||||
import cz.siret.prank.features.api.SasFeatureCalculator
|
||||
import cz.siret.prank.geom.Atoms
|
||||
import cz.siret.prank.geom.Struct
|
||||
import cz.siret.prank.program.params.Parametrized
|
||||
import cz.siret.prank.utils.PDBUtils
|
||||
import groovy.transform.CompileStatic
|
||||
import groovy.util.logging.Slf4j
|
||||
import org.biojava.nbio.structure.AminoAcid
|
||||
import org.biojava.nbio.structure.AminoAcidImpl
|
||||
import org.biojava.nbio.structure.Atom
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
@Slf4j
|
||||
@CompileStatic
|
||||
class ContactResiduesPositionFeature extends SasFeatureCalculator implements Parametrized {
|
||||
|
||||
static String NAME = 'crpos'
|
||||
|
||||
static List<AA> AATYPES = AA.values().sort { it.name() }.toList()
|
||||
|
||||
final List<String> HEADER = new ArrayList<>()
|
||||
|
||||
double MAX_DIST = 20;
|
||||
|
||||
//===========================================================================================================//
|
||||
|
||||
double contactDist
|
||||
|
||||
ContactResiduesPositionFeature() {
|
||||
contactDist = params.feat_crang_contact_dist
|
||||
|
||||
for (AA aa : AATYPES) {
|
||||
String prefix = NAME + '.' + aa.name().toLowerCase() + '.'
|
||||
HEADER.add prefix + 'count'
|
||||
HEADER.add prefix + 'distca'
|
||||
HEADER.add prefix + 'distclosest'
|
||||
HEADER.add prefix + 'distcenter'
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
String getName() {
|
||||
NAME
|
||||
}
|
||||
|
||||
@Override
|
||||
List<String> getHeader() {
|
||||
HEADER
|
||||
}
|
||||
|
||||
@Override
|
||||
double[] calculateForSasPoint(Atom sasPoint, SasFeatureCalculationContext context) {
|
||||
|
||||
Atoms contactAtoms = context.neighbourhoodAtoms.cutoffAroundAtom(sasPoint, contactDist)
|
||||
List<AminoAcid> contactResidues = (List<AminoAcid>)(List)contactAtoms.getDistinctGroups().findAll{ it instanceof AminoAcid }.toList()
|
||||
|
||||
log.debug 'contact residues: ' + contactResidues.size()
|
||||
|
||||
// TODO: this can be optmized
|
||||
|
||||
Multimap<AA, AminoAcid> contactResIndex = ArrayListMultimap.create(20, 3);
|
||||
for (AminoAcid res : contactResidues) {
|
||||
AA aa = AA.forName(PDBUtils.getResidueCode(res))
|
||||
if (aa!=null) {
|
||||
contactResIndex.put(aa, res)
|
||||
}
|
||||
}
|
||||
Map<AA, Collection<AminoAcid>> cresmap = contactResIndex.asMap()
|
||||
|
||||
double[] vect = new double[HEADER.size()]
|
||||
|
||||
int i = 0
|
||||
for (AA aa : AATYPES) {
|
||||
double count = 0
|
||||
double distclosest = MAX_DIST
|
||||
double distca = MAX_DIST
|
||||
double distcenter = MAX_DIST
|
||||
|
||||
Collection<AminoAcid> residues = cresmap.get(aa)
|
||||
if (residues!=null && !residues.empty) {
|
||||
|
||||
AminoAcid closestResOfType = residues.min { Atoms.allFromGroup(it).dist(sasPoint) }
|
||||
Atoms ratoms = Atoms.allFromGroup(closestResOfType)
|
||||
|
||||
count = residues.size()
|
||||
distclosest = ratoms.dist(sasPoint)
|
||||
distcenter = Struct.dist ratoms.centerOfMass, sasPoint
|
||||
distca = (closestResOfType.CA==null) ? distcenter : Struct.dist(closestResOfType.CA, sasPoint)
|
||||
}
|
||||
|
||||
vect[i] = count
|
||||
vect[i+1] = distca
|
||||
vect[i+2] = distclosest
|
||||
vect[i+3] = distcenter
|
||||
|
||||
i += 4
|
||||
}
|
||||
|
||||
return vect
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,7 @@ import org.biojava.nbio.structure.Atom
|
||||
@CompileStatic
|
||||
class ProtrusionHistogramFeature extends SasFeatureCalculator implements Parametrized {
|
||||
|
||||
static final double MIN_DIST = 2d
|
||||
static final double MIN_DIST = 4
|
||||
|
||||
@Override
|
||||
String getName() {
|
||||
|
||||
@@ -187,7 +187,9 @@ public final class Atoms implements Iterable<Atom> {
|
||||
public List<Group> getDistinctGroups() {
|
||||
Set<Group> res = new HashSet<>();
|
||||
for (Atom a : list) {
|
||||
res.add(a.getGroup());
|
||||
if (a.getGroup()!=null) {
|
||||
res.add(a.getGroup());
|
||||
}
|
||||
}
|
||||
|
||||
List<Group> sres = Struct.sortGroups(res);
|
||||
|
||||
@@ -22,6 +22,7 @@ class LogManager implements Writable {
|
||||
static final String LOGGER_NAME = "cz.siret.prank"
|
||||
static final String CONSOLE_APPENDER_NAME = "Console"
|
||||
static final String FILE_APPENDER_NAME = "File"
|
||||
static final String PATTERN = "[%level] %logger{0} - %msg%n"
|
||||
|
||||
boolean loggingToFile = false
|
||||
String logFile
|
||||
@@ -40,9 +41,10 @@ class LogManager implements Writable {
|
||||
config = ctx.getConfiguration();
|
||||
loggerConfig = config.getLoggerConfig(loggerName)
|
||||
|
||||
loggerConfig.getAppenders().each { System.out.println "APPENDER: " + it.value.name }
|
||||
// loggerConfig.getAppenders().each { System.out.println "APPENDER: " + it.value.name }
|
||||
write "logToConsole: $logToConsole"
|
||||
write "logToFile: $logToFile"
|
||||
write "logLevel: ${level.name()}"
|
||||
|
||||
loggerConfig.setLevel(level)
|
||||
|
||||
@@ -63,12 +65,12 @@ class LogManager implements Writable {
|
||||
|
||||
ctx.updateLoggers();
|
||||
|
||||
loggerConfig.getAppenders().each { System.out.println "APPENDER: " + it.value.name }
|
||||
//loggerConfig.getAppenders().each { System.out.println "APPENDER: " + it.value.name }
|
||||
}
|
||||
|
||||
private static Appender addFileAppender(Configuration config, String loggerName, String logFile, Level level) {
|
||||
|
||||
String pattern = "[%level] %logger{0} - %msg%n"
|
||||
String pattern = PATTERN
|
||||
int bufferSize = 5000
|
||||
|
||||
Futils.delete(logFile)
|
||||
|
||||
@@ -10,6 +10,7 @@ import groovy.util.logging.Slf4j
|
||||
import org.apache.commons.lang3.StringUtils
|
||||
|
||||
import static cz.siret.prank.utils.ATimer.startTimer
|
||||
import static cz.siret.prank.utils.Futils.writeFile
|
||||
|
||||
@Slf4j
|
||||
class Main implements Parametrized, Writable {
|
||||
@@ -174,6 +175,10 @@ class Main implements Parametrized, Writable {
|
||||
|
||||
}
|
||||
|
||||
void writeCmdLineArgs(String outdir) {
|
||||
writeFile("$outdir/cmdline_args.txt", args)
|
||||
}
|
||||
|
||||
//===========================================================================================================//
|
||||
|
||||
void doRunPredict(String label, boolean evalPredict) {
|
||||
@@ -366,15 +371,20 @@ class Main implements Parametrized, Writable {
|
||||
|
||||
boolean error = false
|
||||
|
||||
Main main
|
||||
try {
|
||||
|
||||
error = new Main(parsedArgs).run()
|
||||
main = new Main(parsedArgs)
|
||||
error = main.run()
|
||||
|
||||
} catch (PrankException e) {
|
||||
|
||||
error = true
|
||||
writeError e.message
|
||||
log.error(e.message, e)
|
||||
if (main.logManager.loggingToFile) {
|
||||
write "For details see log file: '$main.logManager.logFile'"
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
|
||||
|
||||
@@ -68,11 +68,6 @@ class Params {
|
||||
|
||||
//== FAETURES
|
||||
|
||||
/**
|
||||
* include volsite pharmacophore properties
|
||||
*/
|
||||
boolean use_volsite_features = true
|
||||
|
||||
List<String> extra_features = ["protrusion","bfactor"]
|
||||
|
||||
List<String> atom_table_features = ["apRawValids","apRawInvalids","atomicHydrophobicity"] // "ap5sasaValids","ap5sasaInvalids"
|
||||
@@ -314,6 +309,8 @@ class Params {
|
||||
*/
|
||||
double pred_point_threshold = 0.4
|
||||
|
||||
boolean include_surrounding_score = false
|
||||
|
||||
/**
|
||||
* minimum cluster size (of ligandable points) for initial clustering
|
||||
*/
|
||||
@@ -325,9 +322,9 @@ class Params {
|
||||
double pred_clustering_dist = 5
|
||||
|
||||
/**
|
||||
* distance to extend clusters around hotspots
|
||||
* SAS points around ligandable points (an their score) will be included in the pocket
|
||||
*/
|
||||
double pred_surrounding = 3.5
|
||||
double extended_pocket_cutoff = 3.5
|
||||
|
||||
/**
|
||||
* cuttoff distance of protein surface atoms considered as part of the pocket
|
||||
@@ -491,6 +488,26 @@ class Params {
|
||||
/** produce ROC and PR curve graphs (not fully implemented yet) */
|
||||
boolean stats_curves = false
|
||||
|
||||
/**
|
||||
* Contact residues distance cutoff
|
||||
*/
|
||||
double feat_crang_contact_dist = 3
|
||||
|
||||
/**
|
||||
* probe radius for calculating accessible surface area for asa feature
|
||||
*/
|
||||
double feat_asa_probe_radius = 1.4
|
||||
|
||||
/**
|
||||
* probe radius for calculating accessible surface area for asa feature
|
||||
*/
|
||||
double feat_asa_probe_radius2 = 1.4
|
||||
|
||||
/**
|
||||
* radius of the neighbourhood considered in asa feature
|
||||
*/
|
||||
double feat_asa_neigh_radius = 6
|
||||
|
||||
//===========================================================================================================//
|
||||
|
||||
String getVersion() {
|
||||
|
||||
@@ -53,6 +53,7 @@ class Experiments extends Routine {
|
||||
datadirRoot = params.dataset_base_dir
|
||||
label = "run_" + trainDataSet.label + "_" + (doCrossValidation ? "crossval" : evalDataSet.label)
|
||||
outdir = main.findOutdir(label)
|
||||
main.writeCmdLineArgs(outdir)
|
||||
|
||||
main.configureLoggers(outdir)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import groovyx.gpars.GParsPool
|
||||
|
||||
import static cz.siret.prank.utils.ATimer.startTimer
|
||||
import static cz.siret.prank.utils.Futils.mkdirs
|
||||
import static cz.siret.prank.utils.Futils.writeFile
|
||||
|
||||
/**
|
||||
* Routine for grid optimization. Loops through values of one or more RangeParam and produces resulting statistics and plots.
|
||||
|
||||
@@ -3,6 +3,7 @@ package cz.siret.prank.program.routines
|
||||
import cz.siret.prank.program.Main
|
||||
import cz.siret.prank.program.PrankException
|
||||
import cz.siret.prank.program.params.Parametrized
|
||||
import cz.siret.prank.program.params.Params
|
||||
import cz.siret.prank.utils.Futils
|
||||
import cz.siret.prank.utils.Writable
|
||||
import groovy.transform.CompileStatic
|
||||
@@ -37,5 +38,5 @@ class Routine implements Parametrized, Writable {
|
||||
String v = "version: " + Main.version + "\n"
|
||||
writeFile("$outdir/params.txt", v + params.toString())
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -18,6 +18,8 @@ import weka.classifiers.Classifier
|
||||
import weka.core.Instance
|
||||
import weka.core.Instances
|
||||
|
||||
import static cz.siret.prank.score.prediction.PointScoreCalculator.predictedPositive
|
||||
import static cz.siret.prank.score.prediction.PointScoreCalculator.predictedScore
|
||||
import static cz.siret.prank.utils.ATimer.startTimer
|
||||
|
||||
@Slf4j
|
||||
@@ -99,8 +101,8 @@ class TrainEvalRoutine extends EvalRoutine implements Parametrized {
|
||||
ClassifierStats trainStats = new ClassifierStats()
|
||||
for (Instance inst : trainVectors) {
|
||||
double[] hist = classifier.distributionForInstance(inst)
|
||||
double score = PointScoreCalculator.predictedScore(hist)
|
||||
boolean predicted = hist[1] > hist[0]
|
||||
double score = predictedScore(hist)
|
||||
boolean predicted = predictedPositive(score)
|
||||
boolean observed = inst.classValue() > 0
|
||||
|
||||
trainStats.addPrediction(observed, predicted, score, hist)
|
||||
|
||||
@@ -133,8 +133,8 @@ class EvalResults implements Parametrized, Writable {
|
||||
|
||||
//===========================================================================================================//
|
||||
|
||||
m.TIME_TRAIN = trainTime
|
||||
m.TIME_EVAL = evalTime
|
||||
m.TIME_TRAIN_M = (double)(trainTime ?: 0) / 60000
|
||||
m.TIME_EVAL_M = (double)(evalTime ?: 0) / 60000
|
||||
|
||||
m.TRAIN_VECTORS = avgTrainVectors
|
||||
m.TRAIN_POSITIVES = avgTrainPositives
|
||||
|
||||
@@ -6,11 +6,13 @@ import cz.siret.prank.domain.PredictionPair
|
||||
import cz.siret.prank.domain.Protein
|
||||
import cz.siret.prank.features.implementation.conservation.ConservationScore
|
||||
import cz.siret.prank.geom.Atoms
|
||||
import cz.siret.prank.program.rendering.LabeledPoint
|
||||
import cz.siret.prank.score.criteria.*
|
||||
import groovy.util.logging.Slf4j
|
||||
import org.apache.commons.lang3.StringUtils
|
||||
|
||||
import static cz.siret.prank.utils.Formatter.*
|
||||
import static java.util.Collections.emptyList
|
||||
|
||||
/**
|
||||
* Represents evaluation of pocket prediction on a dataset of proteins
|
||||
@@ -22,6 +24,9 @@ import static cz.siret.prank.utils.Formatter.*
|
||||
@Slf4j
|
||||
class Evaluation {
|
||||
|
||||
/** cutoff distance in A around ligand atoms that determins which SAS points cover the ligand */
|
||||
static final double LIG_SAS_CUTOFF = 2
|
||||
|
||||
IdentificationCriterium standardCriterium = new DCA(4.0)
|
||||
List<IdentificationCriterium> criteria
|
||||
List<ProteinRow> proteinRows = Collections.synchronizedList(new ArrayList<>())
|
||||
@@ -36,6 +41,9 @@ class Evaluation {
|
||||
int smallLigandCount
|
||||
int distantLigandCount
|
||||
|
||||
int ligSASPointsCount
|
||||
int ligSASPointsCoveredCount
|
||||
|
||||
Evaluation(List<IdentificationCriterium> criteria) {
|
||||
this.criteria = criteria
|
||||
}
|
||||
@@ -92,6 +100,8 @@ class Evaluation {
|
||||
List<PocketRow> tmpPockets = new ArrayList<>()
|
||||
|
||||
Protein lp = pair.liganatedProtein
|
||||
Atoms sasPoints = pair.prediction.protein.connollySurface.points
|
||||
Atoms labeledPoints = new Atoms(pair.prediction.labeledPoints ?: emptyList())
|
||||
|
||||
ProteinRow protRow = new ProteinRow()
|
||||
protRow.name = pair.name
|
||||
@@ -109,7 +119,13 @@ class Evaluation {
|
||||
protRow.smallLigNames = lp.smallLigands.collect { "$it.name($it.size)" }.join(" ")
|
||||
protRow.distantLigands = lp.distantLigands.size()
|
||||
protRow.distantLigNames = lp.distantLigands.collect { "$it.name($it.size|${format(it.contactDistance,1)}|${format(it.centerToProteinDist,1)})" }.join(" ")
|
||||
protRow.connollyPoints = pair.prediction.protein.connollySurface.points.count
|
||||
protRow.sasPoints = sasPoints.count
|
||||
|
||||
// ligand coverage
|
||||
Atoms ligSasPoints = labeledPoints.cutoffAtoms(lp.allLigandAtoms, LIG_SAS_CUTOFF)
|
||||
int n_ligSasPoints = ligSasPoints.count
|
||||
int n_ligSasPointsCovered = ligSasPoints.toList().findAll { ((LabeledPoint)it).predicted }.toList().size()
|
||||
log.debug "XXXX n_ligSasPoints: {} covered: {}", n_ligSasPoints, n_ligSasPointsCovered
|
||||
|
||||
// Conservation stats
|
||||
ConservationScore score = lp.secondaryData.get(ConservationScore.conservationScoreKey)
|
||||
@@ -187,6 +203,8 @@ class Evaluation {
|
||||
proteinRows.add(protRow)
|
||||
ligandRows.addAll(tmpLigRows)
|
||||
pocketRows.addAll(tmpPockets)
|
||||
ligSASPointsCount += n_ligSasPoints
|
||||
ligSASPointsCoveredCount += n_ligSasPointsCovered
|
||||
}
|
||||
}
|
||||
|
||||
@@ -200,6 +218,8 @@ class Evaluation {
|
||||
ignoredLigandCount += eval.ignoredLigandCount
|
||||
smallLigandCount += eval.smallLigandCount
|
||||
distantLigandCount += eval.distantLigandCount
|
||||
ligSASPointsCount += eval.ligSASPointsCount
|
||||
ligSASPointsCoveredCount += eval.ligSASPointsCoveredCount
|
||||
}
|
||||
|
||||
double calcSuccRate(int assesorNum, int tolerance) {
|
||||
@@ -266,26 +286,38 @@ class Evaluation {
|
||||
return a
|
||||
}
|
||||
|
||||
//===========================================================================================================//
|
||||
|
||||
public <T> double avg(List<T> list, Closure<T> closure) {
|
||||
if (list.size()==0) return Double.NaN
|
||||
list.collect { closure(it) }.findAll { it!=Double.NaN }.sum(0) / list.size()
|
||||
}
|
||||
|
||||
double div(double a, double b) {
|
||||
if (b==0d)
|
||||
return Double.NaN
|
||||
return a / b
|
||||
}
|
||||
|
||||
//===========================================================================================================//
|
||||
|
||||
double getAvgPockets() {
|
||||
pocketCount / proteinCount
|
||||
div pocketCount, proteinCount
|
||||
}
|
||||
|
||||
double getAvgLigandAtoms() {
|
||||
ligandRows.collect {it.atoms}.sum(0) / ligandCount
|
||||
div ligandRows.collect {it.atoms}.sum(0), ligandCount
|
||||
}
|
||||
|
||||
double getAvgPocketVolume() {
|
||||
pocketRows.collect { it.pocketVolume }.sum(0) / pocketCount
|
||||
div pocketRows.collect { it.pocketVolume }.sum(0), pocketCount
|
||||
}
|
||||
double getAvgPocketVolumeTruePockets() {
|
||||
avg pocketRows.findAll { it.truePocket }, {PocketRow it -> it.pocketVolume }
|
||||
}
|
||||
|
||||
double getAvgPocketSurfAtoms() {
|
||||
pocketRows.collect { it.surfaceAtomCount }.sum(0) / pocketCount
|
||||
div pocketRows.collect { it.surfaceAtomCount }.sum(0), pocketCount
|
||||
}
|
||||
|
||||
double getAvgPocketSurfAtomsTruePockets() {
|
||||
@@ -293,32 +325,30 @@ class Evaluation {
|
||||
}
|
||||
|
||||
double getAvgPocketInnerPoints() {
|
||||
pocketRows.collect { it.auxInfo.samplePoints }.sum(0) / pocketCount
|
||||
div pocketRows.collect { it.auxInfo.samplePoints }.sum(0), pocketCount
|
||||
}
|
||||
double getAvgPocketInnerPointsTruePockets() {
|
||||
avg pocketRows.findAll { it.truePocket }, {PocketRow it -> it.auxInfo.samplePoints }
|
||||
}
|
||||
|
||||
double getAvgProteinAtoms() {
|
||||
proteinRows.collect { it.protAtoms }.sum(0) / proteinCount
|
||||
div proteinRows.collect { it.protAtoms }.sum(0), proteinCount
|
||||
}
|
||||
|
||||
double getAvgExposedAtoms() {
|
||||
proteinRows.collect { it.exposedAtoms }.sum(0) / proteinCount
|
||||
div proteinRows.collect { it.exposedAtoms }.sum(0), proteinCount
|
||||
}
|
||||
|
||||
double getAvgProteinConollyPoints() {
|
||||
avg proteinRows, {ProteinRow it -> it.connollyPoints }
|
||||
avg proteinRows, {ProteinRow it -> it.sasPoints }
|
||||
}
|
||||
|
||||
double getAvgLigCenterToProtDist() {
|
||||
avg ligandRows, {LigRow it -> it.centerToProtDist}
|
||||
}
|
||||
|
||||
public <T> double avg(List<T> list, Closure<T> closure) {
|
||||
if (list.size()==0) return Double.NaN
|
||||
list.collect { closure(it) }.findAll { it!=Double.NaN }.sum(0) / list.size()
|
||||
|
||||
double getLigandCoverage() {
|
||||
div ligSASPointsCoveredCount, ligSASPointsCount
|
||||
}
|
||||
|
||||
double getAvgClosestPocketDist() {
|
||||
@@ -347,6 +377,7 @@ class Evaluation {
|
||||
|
||||
m.AVG_LIG_CENTER_TO_PROT_DIST = avgLigCenterToProtDist
|
||||
m.AVG_LIG_CLOSTES_POCKET_DIST = avgClosestPocketDist
|
||||
m.LIGAND_COVERAGE = ligandCoverage
|
||||
|
||||
m.AVG_POCKETS = avgPockets
|
||||
m.AVG_POCKET_SURF_ATOMS = avgPocketSurfAtoms
|
||||
@@ -491,6 +522,8 @@ class Evaluation {
|
||||
double avgConservation
|
||||
double avgBindingConservation
|
||||
double avgNonBindingConservation
|
||||
|
||||
int sasPoints
|
||||
}
|
||||
|
||||
static class LigRow {
|
||||
|
||||
@@ -19,6 +19,7 @@ import weka.classifiers.Classifier
|
||||
import weka.core.DenseInstance
|
||||
import weka.core.Instances
|
||||
|
||||
import static cz.siret.prank.score.prediction.PointScoreCalculator.predictedPositive
|
||||
import static cz.siret.prank.score.prediction.PointScoreCalculator.predictedScore
|
||||
|
||||
/**
|
||||
@@ -90,14 +91,12 @@ class WekaSumRescorer extends PocketRescorer implements Parametrized {
|
||||
// classification
|
||||
|
||||
FeatureVector props = extractor.calcFeatureVector(point.point)
|
||||
point.@hist = getDistributionForPoint(classifier, props)
|
||||
double[] hist = getDistributionForPoint(classifier, props)
|
||||
|
||||
// labels and statistics
|
||||
|
||||
double[] hist = point.hist
|
||||
double predictedScore = predictedScore(hist) // not all classifiers give histogram that sums up to 1
|
||||
|
||||
boolean predicted = hist[1] > hist[0]
|
||||
boolean predicted = predictedPositive(predictedScore)
|
||||
boolean observed = false
|
||||
|
||||
if (ligandAtoms!=null) {
|
||||
@@ -105,6 +104,10 @@ class WekaSumRescorer extends PocketRescorer implements Parametrized {
|
||||
observed = (closestLigandDistance <= POSITIVE_POINT_LIGAND_DISTANCE)
|
||||
}
|
||||
|
||||
point.@hist = hist
|
||||
point.predicted = predicted
|
||||
point.observed = observed
|
||||
|
||||
if (collectingStatistics) {
|
||||
stats.addPrediction(observed, predicted, predictedScore, hist)
|
||||
}
|
||||
@@ -143,7 +146,7 @@ class WekaSumRescorer extends PocketRescorer implements Parametrized {
|
||||
|
||||
double[] hist = getDistributionForPoint(classifier, props)
|
||||
double predictedScore = predictedScore(hist) // not all classifiers give histogram that sums up to 1
|
||||
boolean predicted = hist[1] > hist[0]
|
||||
boolean predicted = predictedPositive(predictedScore)
|
||||
boolean observed = false
|
||||
|
||||
if (collectingStatistics) {
|
||||
|
||||
@@ -6,6 +6,8 @@ import groovy.transform.CompileStatic
|
||||
import java.text.DecimalFormat
|
||||
|
||||
import static cz.siret.prank.utils.Formatter.formatPercent
|
||||
import static java.lang.Double.NaN
|
||||
import static java.lang.Math.log
|
||||
|
||||
/**
|
||||
* Binary classifier statistics collector and calculator
|
||||
@@ -13,8 +15,11 @@ import static cz.siret.prank.utils.Formatter.formatPercent
|
||||
@CompileStatic
|
||||
class ClassifierStats implements Parametrized {
|
||||
|
||||
static final double EPS = 1e-15d
|
||||
static final int HISTOGRAM_BINS = 100
|
||||
|
||||
|
||||
|
||||
String name
|
||||
|
||||
int[][] op // [observed][predicted]
|
||||
@@ -27,6 +32,7 @@ class ClassifierStats implements Parametrized {
|
||||
double sumSE = 0
|
||||
double sumSEpos = 0
|
||||
double sumSEneg = 0
|
||||
double sumLogLoss = 0
|
||||
|
||||
Histograms histograms = new Histograms()
|
||||
|
||||
@@ -77,7 +83,7 @@ class ClassifierStats implements Parametrized {
|
||||
void addPrediction(boolean observed, boolean predicted, double score, double[] hist) {
|
||||
|
||||
double obsv = observed ? 1 : 0
|
||||
double e = Math.abs(obsv-score)
|
||||
double e = Math.abs(obsv - score)
|
||||
double se = e*e
|
||||
|
||||
sumE += e
|
||||
@@ -91,6 +97,11 @@ class ClassifierStats implements Parametrized {
|
||||
sumSEneg += se
|
||||
}
|
||||
|
||||
double pCorrect = observed ? score : 1-score
|
||||
if (pCorrect<EPS)
|
||||
pCorrect = EPS
|
||||
sumLogLoss -= log(pCorrect)
|
||||
|
||||
histograms.score.put(score)
|
||||
if (observed) {
|
||||
histograms.scorePos.put(score)
|
||||
@@ -114,17 +125,17 @@ class ClassifierStats implements Parametrized {
|
||||
double calcMCC(double TP, double FP, double TN, double FN) {
|
||||
double n = TP*TN - FP*FN
|
||||
double d = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)
|
||||
d = Math.sqrt(d);
|
||||
d = Math.sqrt(d)
|
||||
if (d == 0d) {
|
||||
d = 1d;
|
||||
d = 1d
|
||||
}
|
||||
|
||||
return n / d;
|
||||
return n / d
|
||||
}
|
||||
|
||||
double div(double a, double b) {
|
||||
if (b==0d)
|
||||
return Double.NaN
|
||||
return NaN
|
||||
return a / b
|
||||
}
|
||||
|
||||
@@ -160,19 +171,47 @@ class ClassifierStats implements Parametrized {
|
||||
|
||||
private Advanced advanced = null
|
||||
|
||||
double getTp() { op[1][1] }
|
||||
double getFp() { op[0][1] }
|
||||
double getTn() { op[0][0] }
|
||||
double getFn() { op[1][0] }
|
||||
double getTP() { op[1][1] }
|
||||
double getFP() { op[0][1] }
|
||||
double getTN() { op[0][0] }
|
||||
double getFN() { op[1][0] }
|
||||
|
||||
/** Observed Positive */
|
||||
double getOP() {
|
||||
TP + FN
|
||||
}
|
||||
|
||||
/** Observed Negative */
|
||||
double getON() {
|
||||
FP + TN
|
||||
}
|
||||
|
||||
/** Predicted Positive */
|
||||
double getPP() {
|
||||
TP + FP
|
||||
}
|
||||
|
||||
/** Predicted Negative */
|
||||
double getPN() {
|
||||
TN + FN
|
||||
}
|
||||
|
||||
double getOPON_ratio() {
|
||||
div OP, ON
|
||||
}
|
||||
|
||||
double getPPPN_ratio() {
|
||||
div PP, PN
|
||||
}
|
||||
|
||||
/** Precision = Positive Predictive Value */
|
||||
double getP() {
|
||||
div tp , (tp + fp)
|
||||
div TP , (TP + FP)
|
||||
}
|
||||
|
||||
/** Recall = Sensitivity = True Positive Rate */
|
||||
double getR() {
|
||||
div tp , (tp + fn)
|
||||
div TP , (TP + FN)
|
||||
}
|
||||
|
||||
/** F-measure */
|
||||
@@ -188,36 +227,51 @@ class ClassifierStats implements Parametrized {
|
||||
}
|
||||
|
||||
double getMCC() {
|
||||
calcMCC(tp, fp, tn, fn)
|
||||
calcMCC(TP, FP, TN, FN)
|
||||
}
|
||||
|
||||
/** negative predictive value */
|
||||
double getNPV() {
|
||||
div tn , (tn + fn)
|
||||
div TN , (TN + FN)
|
||||
}
|
||||
|
||||
/** specificity = true negative rate */
|
||||
double getSPC() {
|
||||
div tn , (tn + fp)
|
||||
div TN , (TN + FP)
|
||||
}
|
||||
|
||||
/** accuraccy */
|
||||
double getACC() {
|
||||
div( (tp + tn) , count )
|
||||
div( (TP + TN) , count )
|
||||
}
|
||||
|
||||
/** balanced accuracy */
|
||||
double getBACC() {
|
||||
(r + SPC) / 2
|
||||
}
|
||||
|
||||
/** TP versus the bad */
|
||||
double getTPX() {
|
||||
div tp, tp + fn + fp
|
||||
div TP, TP + FN + FP
|
||||
}
|
||||
|
||||
/** log TP */
|
||||
double getLTP() {
|
||||
try {
|
||||
-log( TP / (PP * OP) )
|
||||
} catch (Exception e) {
|
||||
NaN
|
||||
}
|
||||
}
|
||||
|
||||
/** false positive rate */
|
||||
double getFPR() {
|
||||
div fp , (fp + tn)
|
||||
div FP , (FP + TN)
|
||||
}
|
||||
|
||||
/** false negative rate */
|
||||
double getFNR() {
|
||||
div fn , (tp + fn)
|
||||
div FN , (TP + FN)
|
||||
}
|
||||
|
||||
/** positive likelihood ratio */
|
||||
@@ -237,28 +291,34 @@ class ClassifierStats implements Parametrized {
|
||||
|
||||
/** false discovery rate */
|
||||
double getFDR() {
|
||||
div fp , (tp + fp)
|
||||
div FP , (TP + FP)
|
||||
}
|
||||
|
||||
/** false ommision rate */
|
||||
double getFOR() {
|
||||
div fn , (fn + tn)
|
||||
div FN , (FN + TN)
|
||||
}
|
||||
|
||||
/** Youden's J statistic = Youden's index */
|
||||
/** Youden's J statistic = Youden's index = Informedness */
|
||||
double getYJS() {
|
||||
r + SPC -1
|
||||
r + SPC - 1
|
||||
}
|
||||
|
||||
/** Markedness */
|
||||
double getMRK() {
|
||||
p + NPV - 1
|
||||
}
|
||||
|
||||
|
||||
/** Discriminant Power ... <1 = poor, >3 = good, fair otherwise */
|
||||
double getDPOW() {
|
||||
if (r==1 || SPC==1)
|
||||
return Double.NaN
|
||||
return NaN
|
||||
double x = r / (1-r)
|
||||
double y = SPC / (1-SPC)
|
||||
double c = Math.sqrt(3) / Math.PI
|
||||
|
||||
c * ( Math.log(x) + Math.log(y) )
|
||||
c * ( log(x) + log(y) )
|
||||
}
|
||||
|
||||
double getME() { div sumE, count }
|
||||
@@ -271,6 +331,26 @@ class ClassifierStats implements Parametrized {
|
||||
double getMSEneg() { div sumSEneg, count }
|
||||
double getMSEbalanced() { (MSEneg + MSEpos) / 2 }
|
||||
|
||||
double getLogLoss() {
|
||||
div sumLogLoss, count
|
||||
}
|
||||
|
||||
/** Uncertainty coefficient, aka Proficiency */
|
||||
double getUC() {
|
||||
try {
|
||||
double L = (OP + ON) * log(OP + ON)
|
||||
double LTP = TP * log( TP / (PP * OP) )
|
||||
double LFP = FP * log( FP / (PP * ON) )
|
||||
double LFN = FN * log( FN / (PN * OP) )
|
||||
double LTN = TN * log( TN / (PN * ON) )
|
||||
double LP = OP * log( OP / count )
|
||||
double LN = ON * log( ON / count )
|
||||
double UC = (L + LTP + LFP + LFN + LTN) / (L + LP + LN)
|
||||
return UC
|
||||
} catch (Exception e) {
|
||||
return NaN
|
||||
}
|
||||
}
|
||||
|
||||
double getAUC() {
|
||||
if (advanced==null) advanced = calculateAdvanced()
|
||||
@@ -310,8 +390,8 @@ class ClassifierStats implements Parametrized {
|
||||
}
|
||||
|
||||
class Advanced {
|
||||
double wekaAUC = Double.NaN
|
||||
double wekaAUPRC = Double.NaN
|
||||
double wekaAUC = 0
|
||||
double wekaAUPRC = 0
|
||||
}
|
||||
|
||||
}
|
||||
@@ -354,13 +434,13 @@ class ClassifierStats implements Parametrized {
|
||||
sb << ",(npv),(p)\n"
|
||||
sb << "\n"
|
||||
sb << "pred: , [0], [1]\n"
|
||||
sb << "obs[0] , ${tn}, ${fp}, ${formatPercent(SPC)}\n"
|
||||
sb << "obs[1] , ${fn}, ${tp}, ${formatPercent(R)}\n"
|
||||
sb << "obs[0] , ${TN}, ${FP}, ${formatPercent(SPC)}\n"
|
||||
sb << "obs[1] , ${FN}, ${TP}, ${formatPercent(R)}\n"
|
||||
sb << " , ${formatPercent(NPV)}, ${formatPercent(P)}\n"
|
||||
sb << "\n"
|
||||
sb << "%:\n"
|
||||
sb << ", ${rel(tn)}, ${rel(fp)}\n"
|
||||
sb << ", ${rel(fn)}, ${rel(tp)}\n"
|
||||
sb << ", ${rel(TN)}, ${rel(FP)}\n"
|
||||
sb << ", ${rel(FN)}, ${rel(TP)}\n"
|
||||
sb << "\n"
|
||||
sb << "ACC:, ${format(ACC)}, accuracy\n"
|
||||
sb << "\n"
|
||||
@@ -372,17 +452,6 @@ class ClassifierStats implements Parametrized {
|
||||
sb << "\n"
|
||||
sb << "FM:, ${format(f1)}, F-measure\n"
|
||||
sb << "MCC:, ${format(MCC)}, Matthews correlation coefficient\n"
|
||||
|
||||
sb << "\n"
|
||||
sb << "ME:, ${format(ME)}, Mean error\n"
|
||||
sb << "MEpos:, ${format(MEpos)}, ME on positive observations\n"
|
||||
sb << "MEneg:, ${format(MEneg)}, Mean error on negative observations\n"
|
||||
sb << "MEbal:, ${format(MEbalanced)}, Mean error balanced\n"
|
||||
sb << "\n"
|
||||
sb << "MSE:, ${format(MSE)}, Mean squared error\n"
|
||||
sb << "MSEpos:, ${format(MSEpos)}, MSE on positive observations\n"
|
||||
sb << "MSEneg:, ${format(MSEneg)}, MSE on negative observations\n"
|
||||
sb << "MSEbal:, ${format(MSEbalanced)}, Mean error balanced\n"
|
||||
}
|
||||
|
||||
return sb.toString()
|
||||
|
||||
@@ -21,7 +21,7 @@ class PocketPredictor implements Parametrized {
|
||||
|
||||
private double POCKET_PROT_SURFACE_CUTOFF = params.pred_protein_surface_cutoff
|
||||
private int MIN_CLUSTER_SIZE = params.pred_min_cluster_size
|
||||
private double SURROUNDING_DIST = params.pred_surrounding
|
||||
private double EXTENDED_POCKET_CUTOFF = params.extended_pocket_cutoff
|
||||
private double CLUSTERING_DIST = params.pred_clustering_dist
|
||||
private double POINT_THRESHOLD = params.pred_point_threshold
|
||||
private boolean BALANCE_POINT_DENSITY = params.balance_density
|
||||
@@ -40,8 +40,9 @@ class PocketPredictor implements Parametrized {
|
||||
}
|
||||
|
||||
private boolean admitPoint(LabeledPoint point) {
|
||||
double p = PointScoreCalculator.predictedScore(point.hist)
|
||||
return p > POINT_THRESHOLD
|
||||
// double p = PointScoreCalculator.predictedScore(point.hist)
|
||||
// return p > POINT_THRESHOLD
|
||||
point.predicted
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -67,9 +68,14 @@ class PocketPredictor implements Parametrized {
|
||||
|
||||
List<PrankPocket> pockets = filteredClusters.collect { Atoms clusterPoints ->
|
||||
|
||||
Atoms extendedPocketPoints = connollyPoints.cutoffAtoms(clusterPoints, SURROUNDING_DIST)
|
||||
double score = extendedPocketPoints.collect { score((LabeledPoint)it, connollyPoints) }.sum()
|
||||
Atoms pocketSurfaceAtoms = protein.exposedAtoms.cutoffAtoms(extendedPocketPoints, POCKET_PROT_SURFACE_CUTOFF)
|
||||
Atoms pocketPoints = clusterPoints
|
||||
if (EXTENDED_POCKET_CUTOFF > 0d) {
|
||||
Atoms extendedPocketPoints = connollyPoints.cutoffAtoms(clusterPoints, EXTENDED_POCKET_CUTOFF)
|
||||
pocketPoints = extendedPocketPoints
|
||||
}
|
||||
|
||||
double score = (double) pocketPoints.collect { score((LabeledPoint)it, connollyPoints) }.sum(0)
|
||||
Atoms pocketSurfaceAtoms = protein.exposedAtoms.cutoffAtoms(pocketPoints, POCKET_PROT_SURFACE_CUTOFF)
|
||||
|
||||
try {
|
||||
if (params.score_pockets_by == "conservation" || params.score_pockets_by == "combi") {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package cz.siret.prank.score.prediction
|
||||
|
||||
import cz.siret.prank.program.params.Parametrized
|
||||
import cz.siret.prank.program.params.Params
|
||||
import groovy.transform.CompileStatic
|
||||
|
||||
/**
|
||||
@@ -21,6 +22,11 @@ class PointScoreCalculator implements Parametrized {
|
||||
hist[1] / (hist[0] + hist[1])
|
||||
}
|
||||
|
||||
static boolean predictedPositive(double predictedScore) {
|
||||
predictedScore >= Params.inst.pred_point_threshold
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* calculates ligandability score of the point form binary classification historgram
|
||||
*
|
||||
|
||||
@@ -13,6 +13,10 @@ class CmdLineArgs {
|
||||
static class NamedArg {
|
||||
String name
|
||||
String value
|
||||
|
||||
public String toString() {
|
||||
return name + "=" + value
|
||||
}
|
||||
}
|
||||
|
||||
String[] argList
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
<configuration status="INFO">
|
||||
<!--<ThresholdFilter level="all"/>-->
|
||||
<appenders>
|
||||
<Console name="Console" target="SYSTEM_ERR">
|
||||
<Console name="Console" > <!-- target="SYSTEM_ERR" -->
|
||||
<PatternLayout pattern="[%level] %logger{0} - %msg%n"/>
|
||||
<!--<PatternLayout pattern="%d{HH:mm:ss} [%t] %-5level %logger{36} - %msg%n"/>-->
|
||||
<!--<pattern>[%level] %logger{0} - %msg%n</pattern>-->
|
||||
|
||||
@@ -62,6 +62,8 @@ import cz.siret.prank.program.params.Params
|
||||
|
||||
log_to_console = false
|
||||
|
||||
log_level = "WARN"
|
||||
|
||||
log_to_file = true
|
||||
|
||||
ploop_delete_runs = true
|
||||
|
||||
Reference in New Issue
Block a user