DecisionTree_python
#coding:utf-8 from numpy import * from math import * import operator def file2matrix(filename): fr=open(filename) lines=fr.readlines() lenth=len(lines) rematrix=zeros((lenth,7)) label=["seze","gendi","qiaoshen","wenli","qibu","chugan"]#西瓜特征集 index=0 for line in lines: line=line.strip() lin=line.split(" ") rematrix[index:]=lin index=index+1 return rematrix,label def singlesplit(data,axis,value): newlistt=[] for feat in data: if feat[axis]==value: newlist=list([feat[axis]]) newlist.extend([feat[-1]]) newlistt.append(newlist) return newlistt def allsplit(data): alldata=[] baseEntry=calcshannon(data) ordermax=0.0 bestfuture=-1 lenth=len(data[0]) for i in range(lenth-1): b=[example[i] for example in data]#取得特征的所有取值 newEntry=0.0 uniq=set(b)#特征的可能取值 for j in uniq: cooldata=singlesplit(data,i,j) prob=len(cooldata)/float(len(data)) newEntry+=prob*calcshannon(cooldata) info=baseEntry-newEntry if(info>ordermax): ordermax=info bestfuture=i return bestfuture def calcshannon(data): simplenum=len(data) tempdict={} for line in data: tail=line[-1] if tail not in tempdict.keys(): tempdict[tail]=0 tempdict[tail]+=1 shannonEntry=0.0 for k in tempdict.keys(): prob=tempdict[k]/float(simplenum) shannonEntry-=prob*log(prob,2) return shannonEntry def selectbigger(label): calcdict={} for line in label: if line not in calcdict.keys(): calcdict[line]=0 calcdict+=1 Getsorted=sorted(calcdict.iteritems(),key=operator.itemgetter(1),reverse=True) return Getsorted[0][0] def createTree(data,label): labellist=[tt[-1] for tt in data] if labellist.count(labellist[0])==len(labellist):#所有样本均为同类 return labellist[0] if len(data[0])==1:#特征集为空 return selectbigger(labellist) bestfuture=allsplit(data) bestlabel=label[bestfuture] tree={bestlabel:{}}#用字典递归建立树 del(label[bestfuture]) bestval=[tt[bestfuture] for tt in data] uniq=set(bestval) for value in uniq: sublabel=label tree[bestlabel][value]=createTree(singlesplit(data,bestfuture,value),sublabel) return tree def classifier(inputree,featurelabel,clsdata): firststr=inputree.keys()[0] secondict=inputree[firststr] classlabel=\'\' featindex=featurelabel.index(firststr) for key in secondict.keys(): if clsdata[featindex]==key: if type(secondict[key]).__name__==\'dict\':#当节点为字典是,继续递归,否则返回当前的节点值 classlabel=classifier(secondict[key],featurelabel,clsdata) else: classlabel=secondict[key] return classlabel dataset,label=file2matrix("out.txt") mytree=createTree(dataset,label) dataset,label=file2matrix("out.txt")#createTree中label元素已被全部删除,而classifier要用label print classifier(mytree,label,[3,1,1,3,3,1])
版权声明:本文为semen原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。