Files
rdkit/Python/ML/Cluster/Clusters.py

296 lines
7.2 KiB
Python
Executable File

# $Id$
#
# Copyright (C) 2001-2006 greg Landrum
#
# @@ All Rights Reserved @@
#
""" contains the Cluster class for representing hierarchical cluster trees
"""
from Numeric import *
CMPTOL=1e-6
class Cluster(object):
"""a class for storing clusters/data
**General Remarks**
- It is assumed that the bottom of any cluster hierarchy tree is composed of
the individual data points which were clustered.
- Clusters objects store the following pieces of data, most are
accessible via standard Setters/Getters:
- Children: *Not Settable*, the list of children. You can add children
with the _AddChild()_ and _AddChildren()_ methods.
**Note** this can be of arbitrary length,
but the current algorithms I have only produce trees with two children
per cluster
- Metric: the metric for this cluster (i.e. how far apart its children are)
- Index: the order in which this cluster was generated
- Points: *Not Settable*, the list of original points in this cluster
(calculated recursively from the children)
- PointsPositions: *Not Settable*, the list of positions of the original
points in this cluster (calculated recursively from the children)
- Position: the location of the cluster **Note** for a cluster this
probably means the location of the average of all the Points which are
its children.
- Data: a data field. This is used with the original points to store their
data value (i.e. the value we're using to classify)
- Name: the name of this cluster
"""
def __init__(self,metric=0.0,children=None,position=None,index=-1,name=None,data=None):
"""Constructor
**Arguments**
see the class documentation for the meanings of these arguments
*my wrists are tired*
"""
if children is None:
children = []
if position is None:
position = []
self.metric = metric
self.children = children
self._UpdateLength()
self.pos = position
self.index = index
self.name = name
self._points = None
self._pointsPositions = None
self.data = data
def SetMetric(self,metric):
self.metric = metric
def GetMetric(self):
return self.metric
def SetIndex(self,index):
self.index = index
def GetIndex(self):
return self.index
def SetPosition(self,pos):
self.pos = pos
def GetPosition(self):
return self.pos
def GetPointsPositions(self):
if self._pointsPositions is not None:
return self._pointsPositions
else:
self._GenPoints()
return self._pointsPositions
def GetPoints(self):
if self._points is not None:
return self._points
else:
self._GenPoints()
return self._points
def FindSubtree(self,index):
""" finds and returns the subtree with a particular index
"""
res = None
if index == self.index:
res = self
else:
for child in self.children:
res = child.FindSubtree(index)
if res:
break
return res
def _GenPoints(self):
""" Generates the _Points_ and _PointsPositions_ lists
*intended for internal use*
"""
if len(self) == 1:
self._points = [self]
self._pointsPositions = [self.GetPosition()]
return self._points
else:
res = []
children = self.GetChildren()
children.sort(lambda x,y:cmp(len(y),len(x)))
for child in children:
res += child.GetPoints()
self._points=res
self._pointsPositions = map(lambda x:x.GetPosition(),res)
def AddChild(self,child):
"""Adds a child to our list
**Arguments**
- child: a Cluster
"""
self.children.append(child)
self._GenPoints()
self._UpdateLength()
def AddChildren(self,children):
"""Adds a bunch of children to our list
**Arguments**
- children: a list of Clusters
"""
self.children += children
self._GenPoints()
self._UpdateLength()
def RemoveChild(self,child):
"""Removes a child from our list
**Arguments**
- child: a Cluster
"""
self.children.remove(child)
self._UpdateLength()
def GetChildren(self):
self.children.sort(lambda x,y:cmp(x.GetMetric(),y.GetMetric()))
return self.children
def SetData(self,data):
self.data = data
def GetData(self):
return self.data
def SetName(self,name):
self.name = name
def GetName(self):
if self.name is None:
return 'Cluster(%d)'%(self.GetIndex())
else:
return self.name
def Print(self,level=0,showData=0,offset='\t'):
if not showData or self.GetData() is None:
print '%s%s%s Metric: %f'%(' '*level,self.GetName(),offset,self.GetMetric())
else:
print '%s%s%s Data: %f\t Metric: %f'%(' '*level,self.GetName(),offset,self.GetData(),self.GetMetric())
for child in self.GetChildren():
child.Print(level=level+1,showData=showData,offset=offset)
def Compare(self,other,ignoreExtras=1):
""" not as choosy as self==other
"""
if cmp(type(self),type(other)):
return cmp(type(self),type(other))
if cmp(len(self),len(other)):
return cmp(len(self),len(other))
if not ignoreExtras:
m1,m2=self.GetMetric(),other.GetMetric()
if abs(m1-m2)>CMPTOL:
return cmp(m1,m2)
if cmp(self.GetName(),other.GetName()):
return cmp(self.GetName(),other.GetName())
sP = self.GetPosition()
oP = other.GetPosition()
try:
r = cmp(len(sP),len(oP))
except:
pass
else:
if r:
return r
try:
r = cmp(sP,oP)
except:
r = sum(sP-oP)
if r:
return r
c1,c2=self.GetChildren(),other.GetChildren()
if cmp(len(c1),len(c2)):
return cmp(len(c1),len(c2))
for i in range(len(c1)):
t = c1[i].Compare(c2[i],ignoreExtras=ignoreExtras)
if t:
return t
return 0
def _UpdateLength(self):
""" updates our length
*intended for internal use*
"""
self._len = reduce(lambda x,y: len(y)+x,self.children,1)
def IsTerminal(self):
return self._len<=1
def __len__(self):
""" allows _len(cluster)_ to work
"""
return self._len
def __cmp__(self,other):
""" allows _cluster1 == cluster2_ to work
"""
if cmp(type(self),type(other)):
return cmp(type(self),type(other))
m1,m2=self.GetMetric(),other.GetMetric()
if abs(m1-m2)>CMPTOL:
return cmp(m1,m2)
if cmp(self.GetName(),other.GetName()):
return cmp(self.GetName(),other.GetName())
c1,c2=self.GetChildren(),other.GetChildren()
return cmp(c1,c2)
if __name__ == '__main__':
from ML.Cluster import ClusterUtils
root = Cluster(index=1,metric=1000)
c1 = Cluster(index=10,metric=100)
c1.AddChild(Cluster(index=30,metric=10))
c1.AddChild(Cluster(index=31,metric=10))
c1.AddChild(Cluster(index=32,metric=10))
c2 = Cluster(index=11,metric=100)
c2.AddChild(Cluster(index=40,metric=10))
c2.AddChild(Cluster(index=41,metric=10))
root.AddChild(c1)
root.AddChild(c2)
nodes = ClusterUtils.GetNodeList(root)
indices = [x.GetIndex() for x in nodes]
print 'XXX:',indices