Files
rdkit/Python/ML/DecTree/TreeVis.py

228 lines
6.3 KiB
Python
Executable File

# $Id$
#
# Copyright (C) 2002,2003 Greg Landrum and Rational Discovery LLC
# All Rights Reserved
#
""" functionality for drawing trees on sping canvases
"""
from sping import pid as piddle
import math
class VisOpts(object):
circRad = 10
minCircRad = 4
maxCircRad = 16
circColor = piddle.Color(0.6,0.6,0.9)
terminalEmptyColor = piddle.Color(.8,.8,.2)
terminalOnColor = piddle.Color(0.8,0.8,0.8)
terminalOffColor = piddle.Color(0.2,0.2,0.2)
outlineColor = piddle.transparent
lineColor = piddle.Color(0,0,0)
lineWidth = 2
horizOffset = 10
vertOffset = 50
labelFont = piddle.Font(face='helvetica',size=10)
highlightColor = piddle.Color(1.,1.,.4)
highlightWidth = 2
visOpts = VisOpts()
def CalcTreeNodeSizes(node):
"""Recursively calculate the total number of nodes under us.
results are set in node.totNChildren for this node and
everything underneath it.
"""
children = node.GetChildren()
if len(children) > 0:
nHere = 0
nBelow=0
for child in children:
CalcTreeNodeSizes(child)
nHere = nHere + child.totNChildren
if child.nLevelsBelow > nBelow:
nBelow = child.nLevelsBelow
else:
nBelow = 0
nHere = 1
node.nExamples = len(node.GetExamples())
node.totNChildren = nHere
node.nLevelsBelow = nBelow+1
def _ExampleCounter(node,min,max):
if node.GetTerminal():
cnt = node.nExamples
if cnt < min: min = cnt
if cnt > max: max = cnt
else:
for child in node.GetChildren():
provMin,provMax = _ExampleCounter(child,min,max)
if provMin < min: min = provMin
if provMax > max: max = provMax
return min,max
def _ApplyNodeScales(node,min,max):
if node.GetTerminal():
if max!=min:
loc = float(node.nExamples - min)/(max-min)
else:
loc = .5
node._scaleLoc = loc
else:
for child in node.GetChildren():
_ApplyNodeScales(child,min,max)
def SetNodeScales(node):
min,max = 1e8,-1e8
min,max = _ExampleCounter(node,min,max)
node._scales=min,max
_ApplyNodeScales(node,min,max)
def DrawTreeNode(node,loc,canvas,nRes=2,scaleLeaves=False,showPurity=False):
"""Recursively displays the given tree node and all its children on the canvas
"""
try:
nChildren = node.totNChildren
except AttributeError:
nChildren = None
if nChildren is None:
CalcTreeNodeSizes(node)
if not scaleLeaves or not node.GetTerminal():
rad = visOpts.circRad
else:
try:
scaleLoc = node._scaleLoc
except:
scaleLoc = 0.5
rad = visOpts.minCircRad + node._scaleLoc*(visOpts.maxCircRad-visOpts.minCircRad)
x1 = loc[0] - rad
y1 = loc[1] - rad
x2 = loc[0] + rad
y2 = loc[1] + rad
if showPurity and node.GetTerminal():
examples = node.GetExamples()
nEx = len(examples)
if nEx:
tgtVal = int(node.GetLabel())
purity = 0.0
for ex in examples:
if int(ex[-1])==tgtVal:
purity += 1./len(examples)
else:
purity = 1.0
deg = purity*math.pi
xFact = rad*math.sin(deg)
yFact = rad*math.cos(deg)
pureX = loc[0]+xFact
pureY = loc[1]+yFact
children = node.GetChildren()
# just move down one level
childY = loc[1] + visOpts.vertOffset
# this is the left-hand side of the leftmost span
childX = loc[0] - ((visOpts.horizOffset+visOpts.circRad)*node.totNChildren)/2
for i in xrange(len(children)):
# center on this child's space
child = children[i]
halfWidth = ((visOpts.horizOffset+visOpts.circRad)*child.totNChildren)/2
childX = childX + halfWidth
childLoc = [childX,childY]
canvas.drawLine(loc[0],loc[1],childLoc[0],childLoc[1],
visOpts.lineColor,visOpts.lineWidth)
DrawTreeNode(child,childLoc,canvas,nRes=nRes,scaleLeaves=scaleLeaves,
showPurity=showPurity)
# and move over to the leftmost point of the next child
childX = childX + halfWidth
if node.GetTerminal():
lab = node.GetLabel()
cFac = float(lab)/float(nRes-1)
if hasattr(node,'GetExamples') and node.GetExamples():
theColor = (1.-cFac)*visOpts.terminalOffColor + cFac*visOpts.terminalOnColor
outlColor = visOpts.outlineColor
else:
theColor = (1.-cFac)*visOpts.terminalOffColor + cFac*visOpts.terminalOnColor
outlColor = visOpts.terminalEmptyColor
canvas.drawEllipse(x1,y1,x2,y2,
outlColor,visOpts.lineWidth,
theColor)
if showPurity:
canvas.drawLine(loc[0],loc[1],pureX,pureY,piddle.Color(1,1,1),2)
else:
theColor = visOpts.circColor
canvas.drawEllipse(x1,y1,x2,y2,
visOpts.outlineColor,visOpts.lineWidth,
theColor)
# this does not need to be done every time
canvas.defaultFont=visOpts.labelFont
labelStr = str(node.GetLabel())
strLoc = (loc[0] - canvas.stringWidth(labelStr)/2,
loc[1]+canvas.fontHeight()/4)
canvas.drawString(labelStr,strLoc[0],strLoc[1])
node._bBox = (x1,y1,x2,y2)
def CalcTreeWidth(tree):
try:
tree.totNChildren
except AttributeError:
CalcTreeNodeSizes(tree)
totWidth = tree.totNChildren * (visOpts.circRad+visOpts.horizOffset)
return totWidth
def DrawTree(tree,canvas,nRes=2,scaleLeaves=False,allowShrink=True,showPurity=False):
dims = canvas.size
loc = (dims[0]/2,visOpts.vertOffset)
if scaleLeaves:
#try:
# l = tree._scales
#except AttributeError:
# l = None
#if l is None:
SetNodeScales(tree)
if allowShrink:
treeWid = CalcTreeWidth(tree)
while treeWid > dims[0]:
visOpts.circRad /= 2
visOpts.horizOffset /= 2
treeWid = CalcTreeWidth(tree)
DrawTreeNode(tree,loc,canvas,nRes,scaleLeaves=scaleLeaves,
showPurity=showPurity)
def ResetTree(tree):
tree._scales = None
tree.totNChildren = None
for child in tree.GetChildren():
ResetTree(child)
def _simpleTest(canv):
from Tree import TreeNode as Node
root = Node(None,'r',label='r')
c1 = root.AddChild('l1_1',label='l1_1')
c2 = root.AddChild('l1_2',isTerminal=1,label=1)
c3 = c1.AddChild('l2_1',isTerminal=1,label=0)
c4 = c1.AddChild('l2_2',isTerminal=1,label=1)
DrawTreeNode(root,(150,visOpts.vertOffset),canv)
if __name__ == '__main__':
from sping.PIL.pidPIL import PILCanvas
canv = PILCanvas(size=(300,300),name='test.png')
_simpleTest(canv)
canv.save()