mirror of
https://github.com/rdkit/rdkit.git
synced 2026-06-05 22:04:27 +08:00
299 lines
8.7 KiB
Python
Executable File
299 lines
8.7 KiB
Python
Executable File
#
|
|
# Copyright (C) 2003-2004 Rational Discovery LLC
|
|
# All Rights Reserved
|
|
#
|
|
import RDConfig,RDRandom
|
|
from Numeric import *
|
|
import RandomArray
|
|
import types,os.path,sys
|
|
SeqTypes=(types.ListType,types.TupleType)
|
|
|
|
def SplitIndices(nPts,frac,silent=1,legacy=0,replacement=0):
|
|
""" splits a set of indices into a data set into 2 pieces
|
|
|
|
**Arguments**
|
|
|
|
- nPts: the total number of points
|
|
|
|
- frac: the fraction of the data to be put in the first data set
|
|
|
|
- silent: (optional) toggles display of stats
|
|
|
|
- legacy: (optional) use the legacy splitting approach
|
|
|
|
- replacement: (optional) use selection with replacement
|
|
|
|
**Returns**
|
|
|
|
a 2-tuple containing the two sets of indices.
|
|
|
|
**Notes**
|
|
|
|
- the _legacy_ splitting approach uses randomly-generated floats
|
|
and compares them to _frac_. This is provided for
|
|
backwards-compatibility reasons.
|
|
|
|
- the default splitting approach uses a random permutation of
|
|
indices which is split into two parts.
|
|
|
|
- selection with replacement can generate duplicates.
|
|
|
|
|
|
**Usage**:
|
|
|
|
We'll start with a set of indices and pick from them using
|
|
the three different approaches:
|
|
>>> from ML.Data import DataUtils
|
|
|
|
The base approach always returns the same number of compounds in
|
|
each set and has no duplicates:
|
|
>>> DataUtils.InitRandomNumbers((23,42))
|
|
>>> test,train = SplitIndices(10,.5)
|
|
>>> test
|
|
[9, 4, 3, 8, 2]
|
|
>>> train
|
|
[7, 6, 1, 5, 0]
|
|
>>> test,train = SplitIndices(10,.5)
|
|
>>> test
|
|
[4, 6, 8, 2, 7]
|
|
>>> train
|
|
[5, 9, 0, 3, 1]
|
|
|
|
The legacy approach can return varying numbers, but still has no
|
|
duplicates. Note the indices come back ordered:
|
|
>>> DataUtils.InitRandomNumbers((23,42))
|
|
>>> test,train = SplitIndices(10,.5,legacy=1)
|
|
>>> test
|
|
[0, 1, 2, 3, 4, 7, 9]
|
|
>>> train
|
|
[5, 6, 8]
|
|
>>> test,train = SplitIndices(10,.5,legacy=1)
|
|
>>> test
|
|
[4, 5, 7, 8, 9]
|
|
>>> train
|
|
[0, 1, 2, 3, 6]
|
|
|
|
The replacement approach returns a fixed number in the training set,
|
|
a variable number in the test set and can contain duplicates in the
|
|
training set.
|
|
>>> DataUtils.InitRandomNumbers((23,42))
|
|
>>> test,train = SplitIndices(10,.5,replacement=1)
|
|
>>> test
|
|
[1, 1, 3, 0, 1]
|
|
>>> train
|
|
[2, 4, 5, 6, 7, 8, 9]
|
|
>>> test,train = SplitIndices(10,.5,replacement=1)
|
|
>>> test
|
|
[9, 5, 4, 8, 0]
|
|
>>> train
|
|
[1, 2, 3, 6, 7]
|
|
|
|
"""
|
|
if frac<0. or frac > 1.:
|
|
raise ValueError('frac must be between 0.0 and 1.0 (frac=%f)'%(frac))
|
|
|
|
if replacement:
|
|
nTrain = int(nPts*frac)
|
|
resData = [None]*nTrain
|
|
resTest = []
|
|
for i in range(nTrain):
|
|
val = int(RDRandom.random()*nPts)
|
|
if val==nPts: val = nPts-1
|
|
resData[i] = val
|
|
for i in range(nPts):
|
|
if i not in resData:
|
|
resTest.append(i)
|
|
elif legacy:
|
|
resData = []
|
|
resTest = []
|
|
for i in range(nPts):
|
|
val = RDRandom.random()
|
|
if val < frac:
|
|
resData.append(i)
|
|
else:
|
|
resTest.append(i)
|
|
else:
|
|
perm = RandomArray.permutation(nPts)
|
|
nTrain = int(nPts*frac)
|
|
|
|
resData = list(perm[:nTrain])
|
|
resTest = list(perm[nTrain:])
|
|
|
|
if not silent:
|
|
print 'Training with %d (of %d) points.'%(len(resData),nPts)
|
|
print '\t%d points are in the hold-out set.'%(len(resTest))
|
|
return resData,resTest
|
|
|
|
|
|
def SplitDataSet(data,frac,silent=0):
|
|
""" splits a data set into two pieces
|
|
|
|
**Arguments**
|
|
|
|
- data: a list of examples to be split
|
|
|
|
- frac: the fraction of the data to be put in the first data set
|
|
|
|
- silent: controls the amount of visual noise produced.
|
|
|
|
**Returns**
|
|
|
|
a 2-tuple containing the two new data sets.
|
|
|
|
"""
|
|
if frac>0. or frac < 1.:
|
|
raise ValueError('frac must be between 0.0 and 1.0')
|
|
|
|
nOrig = len(data)
|
|
train,test = SplitIndices(nOrig,frac,silent=1)
|
|
resData = [data[x] for x in train]
|
|
resTest = [data[x] for x in test]
|
|
|
|
if not silent:
|
|
print 'Training with %d (of %d) points.'%(len(resData),nOrig)
|
|
print '\t%d points are in the hold-out set.'%(len(resTest))
|
|
return resData,resTest
|
|
|
|
|
|
def SplitDbData(conn,fracs,table='',fields='*',where='',join='',
|
|
labelCol='',
|
|
useActs=0,nActs=2,actCol='',actBounds=[],
|
|
silent=0):
|
|
""" "splits" a data set held in a DB by returning lists of ids
|
|
|
|
**Arguments**:
|
|
|
|
- conn: a DbConnect object
|
|
|
|
- frac: the split fraction. This can optionally be specified as a
|
|
sequence with a different fraction for each activity value.
|
|
|
|
- table,fields,where,join: (optional) SQL query parameters
|
|
|
|
- useActs: (optional) toggles splitting based on activities
|
|
(ensuring that a given fraction of each activity class ends
|
|
up in the hold-out set)
|
|
Defaults to 0
|
|
|
|
- nActs: (optional) number of possible activity values, only
|
|
used if _useActs_ is nonzero
|
|
Defaults to 2
|
|
|
|
- actCol: (optional) name of the activity column
|
|
Defaults to use the last column returned by the query
|
|
|
|
- actBounds: (optional) sequence of activity bounds
|
|
(for cases where the activity isn't quantized in the db)
|
|
Defaults to an empty sequence
|
|
|
|
- silent: controls the amount of visual noise produced.
|
|
|
|
**Usage**:
|
|
|
|
Set up the db connection, the simple tables we're using have actives with even
|
|
ids and inactives with odd ids:
|
|
>>> from ML.Data import DataUtils
|
|
>>> from Dbase.DbConnection import DbConnect
|
|
>>> if not RDConfig.usePgSQL:
|
|
... fName = os.path.join(RDConfig.RDCodeDir,'ML','Data','test_data','data.gdb')
|
|
... else:
|
|
... fName = '::RDTests'
|
|
>>> conn = DbConnect(fName)
|
|
|
|
Pull a set of points from a simple table... take 33% of all points:
|
|
>>> DataUtils.InitRandomNumbers((23,42))
|
|
>>> train,test = SplitDbData(conn,1./3.,'basic_2class')
|
|
>>> train
|
|
['id-10', 'id-5', 'id-4', 'id-9']
|
|
|
|
...take 50% of actives and 50% of inactives:
|
|
>>> DataUtils.InitRandomNumbers((23,42))
|
|
>>> train,test = SplitDbData(conn,.5,'basic_2class',useActs=1)
|
|
>>> train
|
|
['id-9', 'id-7', 'id-5', 'id-8', 'id-6', 'id-4']
|
|
|
|
Notice how the results came out sorted by activity
|
|
|
|
We can be asymmetrical: take 33% of actives and 50% of inactives:
|
|
>>> DataUtils.InitRandomNumbers((23,42))
|
|
>>> train,test = SplitDbData(conn,[.5,1./3.],'basic_2class',useActs=1)
|
|
>>> train
|
|
['id-9', 'id-7', 'id-5', 'id-8', 'id-6']
|
|
|
|
And we can pull from tables with non-quantized activities by providing
|
|
activity quantization bounds:
|
|
>>> DataUtils.InitRandomNumbers((23,42))
|
|
>>> train,test = SplitDbData(conn,.5,'float_2class',useActs=1,actBounds=[1.0])
|
|
>>> train
|
|
['id-9', 'id-7', 'id-5', 'id-8', 'id-6', 'id-4']
|
|
|
|
"""
|
|
if not table:
|
|
table=conn.tableName
|
|
if actBounds and len(actBounds)!=nActs-1:
|
|
raise ValueError('activity bounds list length incorrect')
|
|
if useActs:
|
|
if type(fracs) not in SeqTypes:
|
|
fracs = tuple([fracs]*nActs)
|
|
for frac in fracs:
|
|
if frac <0.0 or frac>1.0:
|
|
raise ValueError('fractions must be between 0.0 and 1.0')
|
|
else:
|
|
if type(fracs) in SeqTypes:
|
|
frac = fracs[0]
|
|
if frac<0.0 or frac>1.0:
|
|
raise ValueError('fractions must be between 0.0 and 1.0')
|
|
else:
|
|
frac = fracs
|
|
# start by getting the name of the ID column:
|
|
colNames = conn.GetColumnNames(table=table,what=fields,join=join)
|
|
idCol = colNames[0]
|
|
|
|
if not useActs:
|
|
# get the IDS:
|
|
d = conn.GetData(table=table,fields=idCol,join=join)
|
|
ids = [x[0] for x in d]
|
|
nRes = len(ids)
|
|
train,test = SplitIndices(nRes,frac,silent=1)
|
|
trainPts = [ids[x] for x in train]
|
|
testPts = [ids[x] for x in test]
|
|
else:
|
|
trainPts = []
|
|
testPts = []
|
|
if not actCol:
|
|
actCol = colNames[-1]
|
|
whereBase=where.strip()
|
|
if whereBase.find('where')!=0:
|
|
whereBase = 'where '+whereBase
|
|
if where:
|
|
whereBase += ' and '
|
|
for act in range(nActs):
|
|
frac = fracs[act]
|
|
if not actBounds:
|
|
whereTxt = whereBase + '%s=%d'%(actCol,act)
|
|
else:
|
|
whereTxt = whereBase
|
|
if act!=0:
|
|
whereTxt += '%s>=%f '%(actCol,actBounds[act-1])
|
|
if act < nActs-1:
|
|
if act!=0:
|
|
whereTxt += 'and '
|
|
whereTxt += '%s<%f'%(actCol,actBounds[act])
|
|
d = conn.GetData(table=table,fields=idCol,join=join,where=whereTxt)
|
|
ids = [x[0] for x in d]
|
|
nRes = len(ids)
|
|
train,test = SplitIndices(nRes,frac,silent=1)
|
|
trainPts.extend([ids[x] for x in train])
|
|
testPts.extend([ids[x] for x in test])
|
|
|
|
return trainPts,testPts
|
|
|
|
def _test():
|
|
import doctest,sys
|
|
return doctest.testmod(sys.modules["__main__"])
|
|
|
|
if __name__ == '__main__':
|
|
import sys
|
|
failed,tried = _test()
|
|
sys.exit(failed)
|