Files
rdkit/Python/ML/DecTree/Forest.py
Greg Landrum 75a79b6327 initial import
2006-05-06 22:20:08 +00:00

299 lines
7.8 KiB
Python
Executable File

#
# Copyright (C) 2000 greg Landrum
#
""" code for dealing with forests (collections) of decision trees
**NOTE** This code should be obsolete now that ML.Composite.Composite is up and running.
"""
import cPickle
from Numeric import *
from ML.DecTree import CrossValidate,PruneTree
class Forest(object):
"""a forest of unique decision trees.
adding an existing tree just results in its count field being incremented
and the errors being averaged.
typical usage:
1) grow the forest with AddTree until happy with it
2) call AverageErrors to calculate the average error values
3) call SortTrees to put things in order by either error or count
"""
def MakeHistogram(self):
""" creates a histogram of error/count pairs
"""
nExamples = len(self.treeList)
histo = []
i = 1
lastErr = self.errList[0]
countHere = self.countList[0]
eps = 0.001
while i < nExamples:
if self.errList[i]-lastErr > eps:
histo.append((lastErr,countHere))
lastErr = self.errList[i]
countHere = self.countList[i]
else:
countHere = countHere + self.countList[i]
i = i + 1
return histo
def CollectVotes(self,example):
""" collects votes across every member of the forest for the given example
**Returns**
a list of the results
"""
nTrees = len(self.treeList)
votes = [0]*nTrees
for i in xrange(nTrees):
votes[i] = self.treeList[i].ClassifyExample(example)
return votes
def ClassifyExample(self,example):
""" classifies the given example using the entire forest
**returns** a result and a measure of confidence in it.
**FIX:** statistics sucks... I'm not seeing an obvious way to get
the confidence intervals. For that matter, I'm not seeing
an unobvious way.
For now, this is just treated as a voting problem with the confidence
measure being the percent of trees which voted for the winning result.
"""
self.treeVotes = self.CollectVotes(example)
votes = [0]*len(self._nPossible)
for i in xrange(len(self.treeList)):
res = self.treeVotes[i]
votes[res] = votes[res] + self.countList[i]
totVotes = sum(votes)
res = argmax(votes)
#print 'v:',res,votes,totVotes
return res,float(votes[res])/float(totVotes)
def GetVoteDetails(self):
""" Returns the details of the last vote the forest conducted
this will be an empty list if no voting has yet been done
"""
return self.treeVotes
def Grow(self,examples,attrs,nPossibleVals,nTries=10,pruneIt=0,
lessGreedy=0):
""" Grows the forest by adding trees
**Arguments**
- examples: the examples to be used for training
- attrs: a list of the attributes to be used in training
- nPossibleVals: a list with the number of possible values each variable
(as well as the result) can take on
- nTries: the number of new trees to add
- pruneIt: a toggle for whether or not the tree should be pruned
- lessGreedy: toggles the use of a less greedy construction algorithm where
each possible tree root is used. The best tree from each step is actually
added to the forest.
"""
self._nPossible = nPossibleVals
for i in xrange(nTries):
tree,frac = CrossValidate.CrossValidationDriver(examples,attrs,nPossibleVals,
silent=1,calcTotalError=1,
lessGreedy=lessGreedy)
if pruneIt:
tree,frac2 = PruneTree.PruneTree(tree,tree.GetTrainingExamples(),
tree.GetTestExamples(),
minimizeTestErrorOnly=0)
print 'prune: ', frac,frac2
frac = frac2
self.AddTree(tree,frac)
if i % (nTries/10) == 0:
print 'Cycle: % 4d'%(i)
def Pickle(self,fileName='foo.pkl'):
""" Writes this forest off to a file so that it can be easily loaded later
**Arguments**
fileName is the name of the file to be written
"""
pFile = open(fileName,'wb+')
cPickle.dump(self,pFile,1)
pFile.close()
def AddTree(self,tree,error):
""" Adds a tree to the forest
If an identical tree is already present, its count is incremented
**Arguments**
- tree: the new tree
- error: its error value
**NOTE:** the errList is run as an accumulator,
you probably want to call AverageErrors after finishing the forest
"""
if tree in self.treeList:
idx = self.treeList.index(tree)
self.errList[idx] = self.errList[idx]+error
self.countList[idx] = self.countList[idx] + 1
else:
self.treeList.append(tree)
self.errList.append(error)
self.countList.append(1)
def AverageErrors(self):
""" convert summed error to average error
This does the conversion in place
"""
self.errList = map(lambda x,y:x/y,self.errList,self.countList)
def SortTrees(self,sortOnError=1):
""" sorts the list of trees
**Arguments**
sortOnError: toggles sorting on the trees' errors rather than their counts
"""
if sortOnError:
order = argsort(self.errList)
else:
order = argsort(self.countList)
# these elaborate contortions are required because, at the time this
# code was written, Numeric arrays didn't unpickle so well...
self.treeList = list(take(self.treeList,order))
self.countList = list(take(self.countList,order))
self.errList = list(take(self.errList,order))
def GetTree(self,i):
return self.treeList[i]
def SetTree(self,i,val):
self.treeList[i] = val
def GetCount(self,i):
return self.countList[i]
def SetCount(self,i,val):
self.countList[i] = val
def GetError(self,i):
return self.errList[i]
def SetError(self,i,val):
self.errList[i] = val
def GetDataTuple(self,i):
""" returns all relevant data about a particular tree in the forest
**Arguments**
i: an integer indicating which tree should be returned
**Returns**
a 3-tuple consisting of:
1) the tree
2) its count
3) its error
"""
return (self.treeList[i],self.countList[i],self.errList[i])
def SetDataTuple(self,i,tup):
""" sets all relevant data for a particular tree in the forest
**Arguments**
- i: an integer indicating which tree should be returned
- tup: a 3-tuple consisting of:
1) the tree
2) its count
3) its error
"""
self.treeList[i],self.countList[i],self.errList[i] = tup
def GetAllData(self):
""" Returns everything we know
**Returns**
a 3-tuple consisting of:
1) our list of trees
2) our list of tree counts
3) our list of tree errors
"""
return (self.treeList,self.countList,self.errList)
def __len__(self):
""" allows len(forest) to work
"""
return len(self.treeList)
def __getitem__(self,which):
""" allows forest[i] to work. return the data tuple
"""
return self.GetDataTuple(which)
def __str__(self):
""" allows the forest to show itself as a string
"""
outStr= 'Forest\n'
for i in xrange(len(self.treeList)):
outStr = outStr + \
' Tree % 4d: % 5d occurances %%% 5.2f average error\n'%(i,self.countList[i],
100.*self.errList[i])
return outStr
def __init__(self):
self.treeList=[]
self.errList=[]
self.countList=[]
self.treeVotes=[]
if __name__ == '__main__':
from ML.DecTree import DecTree
f = Forest()
n = DecTree.DecTreeNode(None,'foo')
f.AddTree(n,0.5)
f.AddTree(n,0.5)
f.AverageErrors()
f.SortTrees()
print f