Files
rdkit/Python/ML/Data/SplitData.py
Greg Landrum 75a79b6327 initial import
2006-05-06 22:20:08 +00:00

299 lines
8.7 KiB
Python
Executable File

#
# Copyright (C) 2003-2004 Rational Discovery LLC
# All Rights Reserved
#
import RDConfig,RDRandom
from Numeric import *
import RandomArray
import types,os.path,sys
SeqTypes=(types.ListType,types.TupleType)
def SplitIndices(nPts,frac,silent=1,legacy=0,replacement=0):
""" splits a set of indices into a data set into 2 pieces
**Arguments**
- nPts: the total number of points
- frac: the fraction of the data to be put in the first data set
- silent: (optional) toggles display of stats
- legacy: (optional) use the legacy splitting approach
- replacement: (optional) use selection with replacement
**Returns**
a 2-tuple containing the two sets of indices.
**Notes**
- the _legacy_ splitting approach uses randomly-generated floats
and compares them to _frac_. This is provided for
backwards-compatibility reasons.
- the default splitting approach uses a random permutation of
indices which is split into two parts.
- selection with replacement can generate duplicates.
**Usage**:
We'll start with a set of indices and pick from them using
the three different approaches:
>>> from ML.Data import DataUtils
The base approach always returns the same number of compounds in
each set and has no duplicates:
>>> DataUtils.InitRandomNumbers((23,42))
>>> test,train = SplitIndices(10,.5)
>>> test
[9, 4, 3, 8, 2]
>>> train
[7, 6, 1, 5, 0]
>>> test,train = SplitIndices(10,.5)
>>> test
[4, 6, 8, 2, 7]
>>> train
[5, 9, 0, 3, 1]
The legacy approach can return varying numbers, but still has no
duplicates. Note the indices come back ordered:
>>> DataUtils.InitRandomNumbers((23,42))
>>> test,train = SplitIndices(10,.5,legacy=1)
>>> test
[0, 1, 2, 3, 4, 7, 9]
>>> train
[5, 6, 8]
>>> test,train = SplitIndices(10,.5,legacy=1)
>>> test
[4, 5, 7, 8, 9]
>>> train
[0, 1, 2, 3, 6]
The replacement approach returns a fixed number in the training set,
a variable number in the test set and can contain duplicates in the
training set.
>>> DataUtils.InitRandomNumbers((23,42))
>>> test,train = SplitIndices(10,.5,replacement=1)
>>> test
[1, 1, 3, 0, 1]
>>> train
[2, 4, 5, 6, 7, 8, 9]
>>> test,train = SplitIndices(10,.5,replacement=1)
>>> test
[9, 5, 4, 8, 0]
>>> train
[1, 2, 3, 6, 7]
"""
if frac<0. or frac > 1.:
raise ValueError('frac must be between 0.0 and 1.0 (frac=%f)'%(frac))
if replacement:
nTrain = int(nPts*frac)
resData = [None]*nTrain
resTest = []
for i in range(nTrain):
val = int(RDRandom.random()*nPts)
if val==nPts: val = nPts-1
resData[i] = val
for i in range(nPts):
if i not in resData:
resTest.append(i)
elif legacy:
resData = []
resTest = []
for i in range(nPts):
val = RDRandom.random()
if val < frac:
resData.append(i)
else:
resTest.append(i)
else:
perm = RandomArray.permutation(nPts)
nTrain = int(nPts*frac)
resData = list(perm[:nTrain])
resTest = list(perm[nTrain:])
if not silent:
print 'Training with %d (of %d) points.'%(len(resData),nPts)
print '\t%d points are in the hold-out set.'%(len(resTest))
return resData,resTest
def SplitDataSet(data,frac,silent=0):
""" splits a data set into two pieces
**Arguments**
- data: a list of examples to be split
- frac: the fraction of the data to be put in the first data set
- silent: controls the amount of visual noise produced.
**Returns**
a 2-tuple containing the two new data sets.
"""
if frac>0. or frac < 1.:
raise ValueError('frac must be between 0.0 and 1.0')
nOrig = len(data)
train,test = SplitIndices(nOrig,frac,silent=1)
resData = [data[x] for x in train]
resTest = [data[x] for x in test]
if not silent:
print 'Training with %d (of %d) points.'%(len(resData),nOrig)
print '\t%d points are in the hold-out set.'%(len(resTest))
return resData,resTest
def SplitDbData(conn,fracs,table='',fields='*',where='',join='',
labelCol='',
useActs=0,nActs=2,actCol='',actBounds=[],
silent=0):
""" "splits" a data set held in a DB by returning lists of ids
**Arguments**:
- conn: a DbConnect object
- frac: the split fraction. This can optionally be specified as a
sequence with a different fraction for each activity value.
- table,fields,where,join: (optional) SQL query parameters
- useActs: (optional) toggles splitting based on activities
(ensuring that a given fraction of each activity class ends
up in the hold-out set)
Defaults to 0
- nActs: (optional) number of possible activity values, only
used if _useActs_ is nonzero
Defaults to 2
- actCol: (optional) name of the activity column
Defaults to use the last column returned by the query
- actBounds: (optional) sequence of activity bounds
(for cases where the activity isn't quantized in the db)
Defaults to an empty sequence
- silent: controls the amount of visual noise produced.
**Usage**:
Set up the db connection, the simple tables we're using have actives with even
ids and inactives with odd ids:
>>> from ML.Data import DataUtils
>>> from Dbase.DbConnection import DbConnect
>>> if not RDConfig.usePgSQL:
... fName = os.path.join(RDConfig.RDCodeDir,'ML','Data','test_data','data.gdb')
... else:
... fName = '::RDTests'
>>> conn = DbConnect(fName)
Pull a set of points from a simple table... take 33% of all points:
>>> DataUtils.InitRandomNumbers((23,42))
>>> train,test = SplitDbData(conn,1./3.,'basic_2class')
>>> train
['id-10', 'id-5', 'id-4', 'id-9']
...take 50% of actives and 50% of inactives:
>>> DataUtils.InitRandomNumbers((23,42))
>>> train,test = SplitDbData(conn,.5,'basic_2class',useActs=1)
>>> train
['id-9', 'id-7', 'id-5', 'id-8', 'id-6', 'id-4']
Notice how the results came out sorted by activity
We can be asymmetrical: take 33% of actives and 50% of inactives:
>>> DataUtils.InitRandomNumbers((23,42))
>>> train,test = SplitDbData(conn,[.5,1./3.],'basic_2class',useActs=1)
>>> train
['id-9', 'id-7', 'id-5', 'id-8', 'id-6']
And we can pull from tables with non-quantized activities by providing
activity quantization bounds:
>>> DataUtils.InitRandomNumbers((23,42))
>>> train,test = SplitDbData(conn,.5,'float_2class',useActs=1,actBounds=[1.0])
>>> train
['id-9', 'id-7', 'id-5', 'id-8', 'id-6', 'id-4']
"""
if not table:
table=conn.tableName
if actBounds and len(actBounds)!=nActs-1:
raise ValueError('activity bounds list length incorrect')
if useActs:
if type(fracs) not in SeqTypes:
fracs = tuple([fracs]*nActs)
for frac in fracs:
if frac <0.0 or frac>1.0:
raise ValueError('fractions must be between 0.0 and 1.0')
else:
if type(fracs) in SeqTypes:
frac = fracs[0]
if frac<0.0 or frac>1.0:
raise ValueError('fractions must be between 0.0 and 1.0')
else:
frac = fracs
# start by getting the name of the ID column:
colNames = conn.GetColumnNames(table=table,what=fields,join=join)
idCol = colNames[0]
if not useActs:
# get the IDS:
d = conn.GetData(table=table,fields=idCol,join=join)
ids = [x[0] for x in d]
nRes = len(ids)
train,test = SplitIndices(nRes,frac,silent=1)
trainPts = [ids[x] for x in train]
testPts = [ids[x] for x in test]
else:
trainPts = []
testPts = []
if not actCol:
actCol = colNames[-1]
whereBase=where.strip()
if whereBase.find('where')!=0:
whereBase = 'where '+whereBase
if where:
whereBase += ' and '
for act in range(nActs):
frac = fracs[act]
if not actBounds:
whereTxt = whereBase + '%s=%d'%(actCol,act)
else:
whereTxt = whereBase
if act!=0:
whereTxt += '%s>=%f '%(actCol,actBounds[act-1])
if act < nActs-1:
if act!=0:
whereTxt += 'and '
whereTxt += '%s<%f'%(actCol,actBounds[act])
d = conn.GetData(table=table,fields=idCol,join=join,where=whereTxt)
ids = [x[0] for x in d]
nRes = len(ids)
train,test = SplitIndices(nRes,frac,silent=1)
trainPts.extend([ids[x] for x in train])
testPts.extend([ids[x] for x in test])
return trainPts,testPts
def _test():
import doctest,sys
return doctest.testmod(sys.modules["__main__"])
if __name__ == '__main__':
import sys
failed,tried = _test()
sys.exit(failed)