mirror of
https://github.com/rdkit/rdkit.git
synced 2026-06-07 22:44:25 +08:00
473 lines
14 KiB
Python
Executable File
473 lines
14 KiB
Python
Executable File
# $Id$
|
|
#
|
|
# Copyright (C) 2002-2006 greg Landrum and Rational Discovery LLC
|
|
#
|
|
# @@ All Rights Reserved @@
|
|
#
|
|
|
|
"""Command line tool to construct an enrichment plot from saved composite models
|
|
|
|
Usage: EnrichPlot [optional args] -d dbname -t tablename <models>
|
|
|
|
Required Arguments:
|
|
-d "dbName": the name of the database for screening
|
|
|
|
-t "tablename": provide the name of the table with the data to be screened
|
|
|
|
<models>: file name(s) of pickled composite model(s).
|
|
If the -p argument is also provided (see below), this argument is ignored.
|
|
|
|
Optional Arguments:
|
|
- -a "list": the list of result codes to be considered active. This will be
|
|
eval'ed, so be sure that it evaluates as a list or sequence of
|
|
integers. For example, -a "[1,2]" will consider activity values 1 and 2
|
|
to be active
|
|
|
|
- --enrich "list": identical to the -a argument above.
|
|
|
|
- --thresh: sets a threshold for the plot. If the confidence falls below
|
|
this value, picking will be terminated
|
|
|
|
- -H: screen only the hold out set (works only if a version of
|
|
BuildComposite more recent than 1.2.2 was used).
|
|
|
|
- -T: screen only the training set (works only if a version of
|
|
BuildComposite more recent than 1.2.2 was used).
|
|
|
|
- -S: shuffle activity values before screening
|
|
|
|
- -R: randomize activity values before screening
|
|
|
|
- -F *filter frac*: filters the data before training to change the
|
|
distribution of activity values in the training set. *filter frac*
|
|
is the fraction of the training set that should have the target value.
|
|
**See note in BuildComposite help about data filtering**
|
|
|
|
- -v *filter value*: filters the data before training to change the
|
|
distribution of activity values in the training set. *filter value*
|
|
is the target value to use in filtering.
|
|
**See note in BuildComposite help about data filtering**
|
|
|
|
- -p "tableName": provides the name of a db table containing the
|
|
models to be screened. If you use this argument, you should also
|
|
use the -N argument (below) to specify a note value.
|
|
|
|
- -N "note": provides a note to be used to pull models from a db table.
|
|
|
|
- --plotFile "filename": writes the data to an output text file (filename.dat)
|
|
and creates a gnuplot input file (filename.gnu) to plot it
|
|
|
|
- --showPlot: causes the gnuplot plot constructed using --plotFile to be
|
|
displayed in gnuplot.
|
|
|
|
"""
|
|
import RDConfig
|
|
from Numeric import *
|
|
import cPickle,copy
|
|
#from Dbase.DbConnection import DbConnect
|
|
from ML.Data import DataUtils,SplitData,Stats
|
|
from Dbase.DbConnection import DbConnect
|
|
import DataStructs
|
|
from ML import CompositeRun
|
|
import sys,os,types
|
|
|
|
__VERSION_STRING="2.3.3"
|
|
def message(msg,noRet=0,dest=sys.stderr):
|
|
""" emits messages to _sys.stderr_
|
|
override this in modules which import this one to redirect output
|
|
|
|
**Arguments**
|
|
|
|
- msg: the string to be displayed
|
|
|
|
"""
|
|
if noRet:
|
|
dest.write('%s '%(msg))
|
|
else:
|
|
dest.write('%s\n'%(msg))
|
|
def error(msg,dest=sys.stderr):
|
|
""" emits messages to _sys.stderr_
|
|
override this in modules which import this one to redirect output
|
|
|
|
**Arguments**
|
|
|
|
- msg: the string to be displayed
|
|
|
|
"""
|
|
sys.stderr.write('ERROR: %s\n'%(msg))
|
|
|
|
def ScreenModel(mdl,descs,data,picking=[1],indices=[],errorEstimate=0):
|
|
""" collects the results of screening an individual composite model that match
|
|
a particular value
|
|
|
|
**Arguments**
|
|
|
|
- mdl: the composite model
|
|
|
|
- descs: a list of descriptor names corresponding to the data set
|
|
|
|
- data: the data set, a list of points to be screened.
|
|
|
|
- picking: (Optional) a list of values that are to be collected.
|
|
For examples, if you want an enrichment plot for picking the values
|
|
1 and 2, you'd having picking=[1,2].
|
|
|
|
**Returns**
|
|
|
|
a list of 4-tuples containing:
|
|
|
|
- the id of the point
|
|
|
|
- the true result (from the data set)
|
|
|
|
- the predicted result
|
|
|
|
- the confidence value for the prediction
|
|
|
|
"""
|
|
mdl.SetInputOrder(descs)
|
|
res = []
|
|
if mdl.GetQuantBounds():
|
|
needsQuant = 1
|
|
else:
|
|
needsQuant = 0
|
|
|
|
if not indices: indices = range(len(data))
|
|
nTrueActives=0
|
|
for i in indices:
|
|
if errorEstimate:
|
|
use=[]
|
|
for j in range(len(mdl)):
|
|
tmp = mdl.GetModel(j)
|
|
if not hasattr(tmp,'_trainIndices') or \
|
|
i not in tmp._trainIndices:
|
|
use.append(j)
|
|
else:
|
|
use=None
|
|
pt = data[i]
|
|
pred,conf = mdl.ClassifyExample(pt,onlyModels=use)
|
|
if needsQuant:
|
|
pt = mdl.QuantizeActivity(pt[:])
|
|
trueRes = pt[-1]
|
|
if trueRes in picking:
|
|
nTrueActives+=1
|
|
if pred in picking:
|
|
res.append((pt[0],trueRes,pred,conf))
|
|
return nTrueActives,res
|
|
|
|
def AccumulateCounts(predictions,thresh=0,sortIt=1):
|
|
""" Accumulates the data for the enrichment plot for a single model
|
|
|
|
**Arguments**
|
|
|
|
- predictions: a list of 3-tuples (as returned by _ScreenModels_)
|
|
|
|
- thresh: a threshold for the confidence level. Anything below
|
|
this threshold will not be considered
|
|
|
|
- sortIt: toggles sorting on confidence levels
|
|
|
|
|
|
**Returns**
|
|
|
|
- a list of 3-tuples:
|
|
|
|
- the id of the active picked here
|
|
|
|
- num actives found so far
|
|
|
|
- number of picks made so far
|
|
|
|
"""
|
|
if sortIt:
|
|
predictions.sort(lambda x,y:cmp(y[3],x[3]))
|
|
res = []
|
|
nCorrect = 0
|
|
nPts = 0
|
|
for i in range(len(predictions)):
|
|
id,real,pred,conf = predictions[i]
|
|
if conf > thresh:
|
|
if pred == real:
|
|
nCorrect += 1
|
|
nPts += 1
|
|
res.append((id,nCorrect,nPts))
|
|
|
|
return res
|
|
|
|
def MakePlot(details,final,counts,pickVects,nModels,nTrueActs=-1):
|
|
if not hasattr(details,'plotFile') or not details.plotFile:
|
|
return
|
|
|
|
dataFileName = '%s.dat'%(details.plotFile)
|
|
outF = open(dataFileName,'w+')
|
|
i = 0
|
|
while i < len(final) and counts[i] != 0:
|
|
if nModels>1:
|
|
mean,sd = Stats.MeanAndDev(pickVects[i])
|
|
confInterval = Stats.GetConfidenceInterval(sd,len(pickVects[i]),level=90)
|
|
outF.write('%d %f %f %d %f\n'%(i+1,final[i][0]/counts[i],
|
|
final[i][1]/counts[i],counts[i],confInterval))
|
|
else:
|
|
outF.write('%d %f %f %d\n'%(i+1,final[i][0]/counts[i],
|
|
final[i][1]/counts[i],counts[i]))
|
|
i+=1
|
|
outF.close()
|
|
plotFileName = '%s.gnu'%(details.plotFile)
|
|
gnuF = open(plotFileName,'w+')
|
|
gnuHdr="""# Generated by EnrichPlot.py version: %s
|
|
set size square 0.7
|
|
set xr [0:]
|
|
set data styl points
|
|
set ylab 'Num Correct Picks'
|
|
set xlab 'Num Picks'
|
|
set grid
|
|
set nokey
|
|
set term postscript enh color solid "Helvetica" 16
|
|
set term win
|
|
"""%(__VERSION_STRING)
|
|
print >>gnuF,gnuHdr
|
|
if nTrueActs >0:
|
|
print >>gnuF,'set yr [0:%d]'%nTrueActs
|
|
print >>gnuF,'plot x with lines'
|
|
if nModels>1:
|
|
everyGap = i/20
|
|
print >>gnuF,'replot "%s" using 1:2 with lines,'%(dataFileName),
|
|
print >>gnuF,'"%s" every %d using 1:2:5 with yerrorbars'%(dataFileName,
|
|
everyGap)
|
|
else:
|
|
print >>gnuF,'replot "%s" with points'%(dataFileName)
|
|
gnuF.close()
|
|
|
|
if hasattr(details,'showPlot') and details.showPlot:
|
|
try:
|
|
import os
|
|
from Gnuplot import Gnuplot
|
|
p = Gnuplot()
|
|
#p('cd "%s"'%(os.getcwd()))
|
|
p('load "%s"'%(plotFileName))
|
|
raw_input('press return to continue...\n')
|
|
except:
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
def Usage():
|
|
""" displays a usage message and exits """
|
|
sys.stderr.write(__doc__)
|
|
sys.exit(-1)
|
|
|
|
if __name__=='__main__':
|
|
import getopt
|
|
try:
|
|
args,extras = getopt.getopt(sys.argv[1:],'d:t:a:N:p:cSTHF:v:',
|
|
('thresh=','plotFile=','showPlot',
|
|
'pickleCol=','OOB','noSort','pickBase=',
|
|
'doROC','rocThresh=','enrich='))
|
|
except:
|
|
import traceback
|
|
traceback.print_exc()
|
|
Usage()
|
|
|
|
|
|
details = CompositeRun.CompositeRun()
|
|
CompositeRun.SetDefaults(details)
|
|
|
|
details.activeTgt=[1]
|
|
details.doTraining = 0
|
|
details.doHoldout = 0
|
|
details.dbTableName = ''
|
|
details.plotFile = ''
|
|
details.showPlot = 0
|
|
details.pickleCol = -1
|
|
details.errorEstimate=0
|
|
details.sortIt=1
|
|
details.pickBase = ''
|
|
details.doROC=0
|
|
details.rocThresh=-1
|
|
for arg,val in args:
|
|
if arg == '-d':
|
|
details.dbName = val
|
|
if arg == '-t':
|
|
details.dbTableName = val
|
|
elif arg == '-a' or arg == '--enrich':
|
|
details.activeTgt = eval(val)
|
|
if(type(details.activeTgt) not in (types.TupleType,types.ListType)):
|
|
details.activeTgt = (details.activeTgt,)
|
|
|
|
elif arg == '--thresh':
|
|
details.threshold = float(val)
|
|
elif arg == '-N':
|
|
details.note = val
|
|
elif arg == '-p':
|
|
details.persistTblName = val
|
|
elif arg == '-S':
|
|
details.shuffleActivities = 1
|
|
elif arg == '-H':
|
|
details.doTraining = 0
|
|
details.doHoldout = 1
|
|
elif arg == '-T':
|
|
details.doTraining = 1
|
|
details.doHoldout = 0
|
|
elif arg == '-F':
|
|
details.filterFrac=float(val)
|
|
elif arg == '-v':
|
|
details.filterVal=float(val)
|
|
elif arg == '--plotFile':
|
|
details.plotFile = val
|
|
elif arg == '--showPlot':
|
|
details.showPlot=1
|
|
elif arg == '--pickleCol':
|
|
details.pickleCol=int(val)-1
|
|
elif arg == '--OOB':
|
|
details.errorEstimate=1
|
|
elif arg == '--noSort':
|
|
details.sortIt=0
|
|
elif arg == '--doROC':
|
|
details.doROC=1
|
|
elif arg == '--rocThresh':
|
|
details.rocThresh=int(val)
|
|
elif arg == '--pickBase':
|
|
details.pickBase=val
|
|
|
|
if not details.dbName or not details.dbTableName:
|
|
Usage()
|
|
print '*******Please provide both the -d and -t arguments'
|
|
|
|
message('Building Data set\n')
|
|
dataSet = DataUtils.DBToData(details.dbName,details.dbTableName,
|
|
user=RDConfig.defaultDBUser,
|
|
password=RDConfig.defaultDBPassword,
|
|
pickleCol=details.pickleCol,
|
|
pickleClass=DataStructs.ExplicitBitVect)
|
|
|
|
descs = dataSet.GetVarNames()
|
|
nPts = dataSet.GetNPts()
|
|
message('npts: %d\n'%(nPts))
|
|
final = zeros((nPts,2),Float)
|
|
counts = zeros(nPts,Int)
|
|
selPts = [None]*nPts
|
|
|
|
models = []
|
|
if details.persistTblName:
|
|
conn = DbConnect(details.dbName,details.persistTblName)
|
|
message('-> Retrieving models from database')
|
|
curs = conn.GetCursor()
|
|
curs.execute("select model from %s where note='%s'"%(details.persistTblName,details.note))
|
|
message('-> Reconstructing models')
|
|
try:
|
|
blob = curs.fetchone()
|
|
except:
|
|
blob = None
|
|
while blob:
|
|
message(' Building model %d'%len(models))
|
|
blob = blob[0]
|
|
try:
|
|
models.append(cPickle.loads(str(blob)))
|
|
except:
|
|
import traceback
|
|
traceback.print_exc()
|
|
print 'Model failed'
|
|
else:
|
|
message(' <-Done')
|
|
try:
|
|
blob = curs.fetchone()
|
|
except:
|
|
blob = None
|
|
curs = None
|
|
else:
|
|
for modelName in extras:
|
|
try:
|
|
model = cPickle.load(open(modelName,'rb'))
|
|
except:
|
|
import traceback
|
|
print 'problems with model %s:'%modelName
|
|
traceback.print_exc()
|
|
else:
|
|
models.append(model)
|
|
nModels = len(models)
|
|
pickVects = {}
|
|
halfwayPts = [1e8]*len(models)
|
|
for whichModel,model in enumerate(models):
|
|
tmpD = dataSet
|
|
try:
|
|
seed = model._randomSeed
|
|
except AttributeError:
|
|
pass
|
|
else:
|
|
DataUtils.InitRandomNumbers(seed)
|
|
if details.shuffleActivities:
|
|
DataUtils.RandomizeActivities(tmpD,
|
|
shuffle=1)
|
|
if hasattr(model,'_splitFrac') and (details.doHoldout or details.doTraining):
|
|
trainIdx,testIdx = SplitData.SplitIndices(tmpD.GetNPts(),model._splitFrac,
|
|
silent=1)
|
|
if details.filterFrac != 0.0:
|
|
trainFilt,temp = DataUtils.FilterData(tmpD,details.filterVal,
|
|
details.filterFrac,-1,
|
|
indicesToUse=trainIdx,
|
|
indicesOnly=1)
|
|
testIdx += temp
|
|
trainIdx = trainFilt
|
|
if details.doTraining:
|
|
testIdx,trainIdx = trainIdx,testIdx
|
|
else:
|
|
testIdx = range(tmpD.GetNPts())
|
|
|
|
message('screening %d examples'%(len(testIdx)))
|
|
nTrueActives,screenRes = ScreenModel(model,descs,tmpD,picking=details.activeTgt,
|
|
indices=testIdx,
|
|
errorEstimate=details.errorEstimate)
|
|
message('accumulating')
|
|
runningCounts = AccumulateCounts(screenRes,
|
|
sortIt=details.sortIt,
|
|
thresh=details.threshold)
|
|
if details.pickBase:
|
|
pickFile = open('%s.%d.picks'%(details.pickBase,whichModel+1),'w+')
|
|
else:
|
|
pickFile = None
|
|
|
|
|
|
for i,entry in enumerate(runningCounts):
|
|
entry = runningCounts[i]
|
|
selPts[i] = entry[0]
|
|
final[i][0] += entry[1]
|
|
final[i][1] += entry[2]
|
|
v = pickVects.get(i,[])
|
|
v.append(entry[1])
|
|
pickVects[i] = v
|
|
counts[i] += 1
|
|
if pickFile:
|
|
pickFile.write('%s\n'%(entry[0]))
|
|
if entry[1] >= nTrueActives/2 and entry[2]<halfwayPts[whichModel]:
|
|
halfwayPts[whichModel]=entry[2]
|
|
message('Halfway point: %d\n'%halfwayPts[whichModel])
|
|
|
|
if details.plotFile:
|
|
MakePlot(details,final,counts,pickVects,nModels,nTrueActs=nTrueActives)
|
|
else:
|
|
if nModels>1:
|
|
print '#Index\tAvg_num_correct\tConf90Pct\tAvg_num_picked\tNum_picks\tlast_selection'
|
|
else:
|
|
print '#Index\tAvg_num_correct\tAvg_num_picked\tNum_picks\tlast_selection'
|
|
|
|
i = 0
|
|
while i < nPts and counts[i] != 0:
|
|
if nModels>1:
|
|
mean,sd = Stats.MeanAndDev(pickVects[i])
|
|
confInterval = Stats.GetConfidenceInterval(sd,len(pickVects[i]),level=90)
|
|
print '%d\t%f\t%f\t%f\t%d\t%s'%(i+1,final[i][0]/counts[i],confInterval,
|
|
final[i][1]/counts[i],
|
|
counts[i],str(selPts[i]))
|
|
else:
|
|
print '%d\t%f\t%f\t%d\t%s'%(i+1,final[i][0]/counts[i],
|
|
final[i][1]/counts[i],
|
|
counts[i],str(selPts[i]))
|
|
i += 1
|
|
|
|
mean,sd = Stats.MeanAndDev(halfwayPts)
|
|
print 'Halfway point: %.2f(%.2f)'%(mean,sd)
|
|
|
|
|