mirror of
https://github.com/rdkit/rdkit.git
synced 2026-06-05 22:04:27 +08:00
228 lines
6.3 KiB
Python
Executable File
228 lines
6.3 KiB
Python
Executable File
# $Id: TreeVis.py 5033 2006-03-02 19:24:02Z glandrum $
|
|
#
|
|
# Copyright (C) 2002,2003 Greg Landrum and Rational Discovery LLC
|
|
# All Rights Reserved
|
|
#
|
|
""" functionality for drawing trees on sping canvases
|
|
|
|
"""
|
|
from sping import pid as piddle
|
|
import math
|
|
|
|
class VisOpts(object):
|
|
circRad = 10
|
|
minCircRad = 4
|
|
maxCircRad = 16
|
|
circColor = piddle.Color(0.6,0.6,0.9)
|
|
terminalEmptyColor = piddle.Color(.8,.8,.2)
|
|
terminalOnColor = piddle.Color(0.8,0.8,0.8)
|
|
terminalOffColor = piddle.Color(0.2,0.2,0.2)
|
|
outlineColor = piddle.transparent
|
|
lineColor = piddle.Color(0,0,0)
|
|
lineWidth = 2
|
|
horizOffset = 10
|
|
vertOffset = 50
|
|
labelFont = piddle.Font(face='helvetica',size=10)
|
|
highlightColor = piddle.Color(1.,1.,.4)
|
|
highlightWidth = 2
|
|
|
|
visOpts = VisOpts()
|
|
|
|
def CalcTreeNodeSizes(node):
|
|
"""Recursively calculate the total number of nodes under us.
|
|
|
|
results are set in node.totNChildren for this node and
|
|
everything underneath it.
|
|
"""
|
|
children = node.GetChildren()
|
|
if len(children) > 0:
|
|
nHere = 0
|
|
nBelow=0
|
|
for child in children:
|
|
CalcTreeNodeSizes(child)
|
|
nHere = nHere + child.totNChildren
|
|
if child.nLevelsBelow > nBelow:
|
|
nBelow = child.nLevelsBelow
|
|
else:
|
|
nBelow = 0
|
|
nHere = 1
|
|
|
|
node.nExamples = len(node.GetExamples())
|
|
node.totNChildren = nHere
|
|
node.nLevelsBelow = nBelow+1
|
|
|
|
def _ExampleCounter(node,min,max):
|
|
if node.GetTerminal():
|
|
cnt = node.nExamples
|
|
if cnt < min: min = cnt
|
|
if cnt > max: max = cnt
|
|
else:
|
|
for child in node.GetChildren():
|
|
provMin,provMax = _ExampleCounter(child,min,max)
|
|
if provMin < min: min = provMin
|
|
if provMax > max: max = provMax
|
|
return min,max
|
|
|
|
def _ApplyNodeScales(node,min,max):
|
|
if node.GetTerminal():
|
|
if max!=min:
|
|
loc = float(node.nExamples - min)/(max-min)
|
|
else:
|
|
loc = .5
|
|
node._scaleLoc = loc
|
|
else:
|
|
for child in node.GetChildren():
|
|
_ApplyNodeScales(child,min,max)
|
|
|
|
def SetNodeScales(node):
|
|
min,max = 1e8,-1e8
|
|
min,max = _ExampleCounter(node,min,max)
|
|
node._scales=min,max
|
|
_ApplyNodeScales(node,min,max)
|
|
|
|
|
|
def DrawTreeNode(node,loc,canvas,nRes=2,scaleLeaves=False,showPurity=False):
|
|
"""Recursively displays the given tree node and all its children on the canvas
|
|
"""
|
|
try:
|
|
nChildren = node.totNChildren
|
|
except AttributeError:
|
|
nChildren = None
|
|
if nChildren is None:
|
|
CalcTreeNodeSizes(node)
|
|
|
|
if not scaleLeaves or not node.GetTerminal():
|
|
rad = visOpts.circRad
|
|
else:
|
|
try:
|
|
scaleLoc = node._scaleLoc
|
|
except:
|
|
scaleLoc = 0.5
|
|
|
|
rad = visOpts.minCircRad + node._scaleLoc*(visOpts.maxCircRad-visOpts.minCircRad)
|
|
|
|
x1 = loc[0] - rad
|
|
y1 = loc[1] - rad
|
|
x2 = loc[0] + rad
|
|
y2 = loc[1] + rad
|
|
|
|
|
|
if showPurity and node.GetTerminal():
|
|
examples = node.GetExamples()
|
|
nEx = len(examples)
|
|
if nEx:
|
|
tgtVal = int(node.GetLabel())
|
|
purity = 0.0
|
|
for ex in examples:
|
|
if int(ex[-1])==tgtVal:
|
|
purity += 1./len(examples)
|
|
else:
|
|
purity = 1.0
|
|
|
|
deg = purity*math.pi
|
|
xFact = rad*math.sin(deg)
|
|
yFact = rad*math.cos(deg)
|
|
pureX = loc[0]+xFact
|
|
pureY = loc[1]+yFact
|
|
|
|
|
|
children = node.GetChildren()
|
|
# just move down one level
|
|
childY = loc[1] + visOpts.vertOffset
|
|
# this is the left-hand side of the leftmost span
|
|
childX = loc[0] - ((visOpts.horizOffset+visOpts.circRad)*node.totNChildren)/2
|
|
for i in xrange(len(children)):
|
|
# center on this child's space
|
|
child = children[i]
|
|
halfWidth = ((visOpts.horizOffset+visOpts.circRad)*child.totNChildren)/2
|
|
|
|
childX = childX + halfWidth
|
|
childLoc = [childX,childY]
|
|
canvas.drawLine(loc[0],loc[1],childLoc[0],childLoc[1],
|
|
visOpts.lineColor,visOpts.lineWidth)
|
|
DrawTreeNode(child,childLoc,canvas,nRes=nRes,scaleLeaves=scaleLeaves,
|
|
showPurity=showPurity)
|
|
|
|
# and move over to the leftmost point of the next child
|
|
childX = childX + halfWidth
|
|
|
|
if node.GetTerminal():
|
|
lab = node.GetLabel()
|
|
cFac = float(lab)/float(nRes-1)
|
|
if hasattr(node,'GetExamples') and node.GetExamples():
|
|
theColor = (1.-cFac)*visOpts.terminalOffColor + cFac*visOpts.terminalOnColor
|
|
outlColor = visOpts.outlineColor
|
|
else:
|
|
theColor = (1.-cFac)*visOpts.terminalOffColor + cFac*visOpts.terminalOnColor
|
|
outlColor = visOpts.terminalEmptyColor
|
|
canvas.drawEllipse(x1,y1,x2,y2,
|
|
outlColor,visOpts.lineWidth,
|
|
theColor)
|
|
if showPurity:
|
|
canvas.drawLine(loc[0],loc[1],pureX,pureY,piddle.Color(1,1,1),2)
|
|
else:
|
|
theColor = visOpts.circColor
|
|
canvas.drawEllipse(x1,y1,x2,y2,
|
|
visOpts.outlineColor,visOpts.lineWidth,
|
|
theColor)
|
|
|
|
# this does not need to be done every time
|
|
canvas.defaultFont=visOpts.labelFont
|
|
|
|
labelStr = str(node.GetLabel())
|
|
strLoc = (loc[0] - canvas.stringWidth(labelStr)/2,
|
|
loc[1]+canvas.fontHeight()/4)
|
|
|
|
canvas.drawString(labelStr,strLoc[0],strLoc[1])
|
|
node._bBox = (x1,y1,x2,y2)
|
|
|
|
def CalcTreeWidth(tree):
|
|
try:
|
|
tree.totNChildren
|
|
except AttributeError:
|
|
CalcTreeNodeSizes(tree)
|
|
totWidth = tree.totNChildren * (visOpts.circRad+visOpts.horizOffset)
|
|
return totWidth
|
|
|
|
def DrawTree(tree,canvas,nRes=2,scaleLeaves=False,allowShrink=True,showPurity=False):
|
|
dims = canvas.size
|
|
loc = (dims[0]/2,visOpts.vertOffset)
|
|
if scaleLeaves:
|
|
#try:
|
|
# l = tree._scales
|
|
#except AttributeError:
|
|
# l = None
|
|
#if l is None:
|
|
SetNodeScales(tree)
|
|
if allowShrink:
|
|
treeWid = CalcTreeWidth(tree)
|
|
while treeWid > dims[0]:
|
|
visOpts.circRad /= 2
|
|
visOpts.horizOffset /= 2
|
|
treeWid = CalcTreeWidth(tree)
|
|
DrawTreeNode(tree,loc,canvas,nRes,scaleLeaves=scaleLeaves,
|
|
showPurity=showPurity)
|
|
|
|
def ResetTree(tree):
|
|
tree._scales = None
|
|
tree.totNChildren = None
|
|
for child in tree.GetChildren():
|
|
ResetTree(child)
|
|
|
|
def _simpleTest(canv):
|
|
from Tree import TreeNode as Node
|
|
root = Node(None,'r',label='r')
|
|
c1 = root.AddChild('l1_1',label='l1_1')
|
|
c2 = root.AddChild('l1_2',isTerminal=1,label=1)
|
|
c3 = c1.AddChild('l2_1',isTerminal=1,label=0)
|
|
c4 = c1.AddChild('l2_2',isTerminal=1,label=1)
|
|
|
|
DrawTreeNode(root,(150,visOpts.vertOffset),canv)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
from sping.PIL.pidPIL import PILCanvas
|
|
canv = PILCanvas(size=(300,300),name='test.png')
|
|
_simpleTest(canv)
|
|
canv.save()
|