决策树算法原理及JAVA实现(ID3)
2014-06-21 21:48
483 查看
0 引言
决策树的目的在于构造一颗树像下面这样的树。图1
图2
1. 如何构造呢?
1.1 参考资料。
本例以图2为例,并参考了以下资料。(1) http://www.cnblogs.com/zhangchaoyang/articles/2196631.html
写的东西非常经典。
(2) http://blog.sina.com.cn/s/blog_67bc5aa60100qays.html(3)机器学习(Tom.Mitchell著) 第三章 决策树,里面详细介绍了信息增益的计算,和熵的计算。建议大家参考
1.2 数据集(训练数据集)
outlook | temperature | humidity | windy | play |
sunny | hot | high | FALSE | no |
sunny | hot | high | TRUE | no |
overcast | hot | high | FALSE | yes |
rainy | mild | high | FALSE | yes |
rainy | cool | normal | FALSE | yes |
rainy | cool | normal | TRUE | no |
overcast | cool | normal | TRUE | yes |
sunny | mild | high | FALSE | no |
sunny | cool | normal | FALSE | yes |
rainy | mild | normal | FALSE | yes |
sunny | mild | normal | TRUE | yes |
overcast | mild | high | TRUE | yes |
overcast | hot | normal | FALSE | yes |
rainy | mild | high | TRUE | no |
1.3 构造原则—选信息增益最大的
从图中知,一共有四个属性,outlook temperature humidity windy,首先选哪一个作为树的第一个节点呢。答案是选信息增益越大的作为开始的节点。信息增益的计算公式如下:Entropy(s)是熵,S样本集,Sv是子集。熵的计算公式如下:
举例:
根据以上的数据,我们只知道新的一天打球的概率是9/14,不打的概率是5/14。此时的熵为
对每项指标分别统计:在不同的取值下打球和不打球的次数。
table 2
outlook | temperature | humidity | windy | play | |||||||||
yes | no | yes | no | yes | no | yes | no | yes | no | ||||
sunny | 2 | 3 | hot | 2 | 2 | high | 3 | 4 | FALSE | 6 | 2 | 9 | 5 |
overcast | 4 | 0 | mild | 4 | 2 | normal | 6 | 1 | TRUR | 3 | 3 | ||
rainy | 3 | 2 | cool | 3 | 1 |
outlook=sunny时,2/5的概率打球,3/5的概率不打球。entropy=0.971
outlook=overcast时,entropy=0
outlook=rainy时,entropy=0.971
而根据历史统计数据,outlook取值为sunny、overcast、rainy的概率分别是5/14、4/14、5/14,所以当已知变量outlook的值时,信息熵为:5/14 × 0.971 + 4/14 × 0 + 5/14 × 0.971 = 0.693
这样的话系统熵就从0.940下降到了0.693,信息增溢gain(outlook)为0.940-0.693=0.247
同样可以计算出gain(temperature)=0.029,gain(humidity)=0.152,gain(windy)=0.048。
gain(outlook)最大(即outlook在第一步使系统的信息熵下降得最快),所以决策树的根节点就取outlook。
1.4 为什么选信息增益最大的?
根据参考资料(2)的结论是:信息增益量越大,这个属性作为一棵树的根节点就能使这棵树更简洁(2)1.5 递归:
接下来要确定N1取temperature、humidity还是windy?在已知outlook=sunny的情况,根据历史数据,我们作出类似table 2的一张表,分别计算gain(temperature)、gain(humidity)和gain(windy),选最大者为N1。依此类推,构造决策树。当系统的信息熵降为0时,就没有必要再往下构造决策树了,此时叶子节点都是纯的--这是理想情况。最坏的情况下,决策树的高度为属性(决策变量)的个数,叶子节点不纯(这意味着我们要以一定的概率来作出决策)。
1.6 递归结束的条件:
如果Examples都为正,那么返回label =+ 的单结点树Root ,熵为0 如果Examples都为反,那么返回label =- 的单结点树Root ,熵为0
如果Attributes为空,那么返回单结点树Root,label=Examples中最普遍的
2. 伪代码
3. java 实现
此仅贴主要的代码,源码请到我的github下载:https://github.com/Bellonor/myHadoopProject/tree/master/com.homework/src/sequence/machinelearning/decisiontree/sequence/machinelearning/decisiontree/myid3
package sequence.machinelearning.decisiontree.myid3; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.LinkedList; public class MyID3 { private static LinkedList<String> attribute = new LinkedList<String>(); // 存储属性的名称 private static LinkedList<ArrayList<String>> attributevalue = new LinkedList<ArrayList<String>>(); // 存储每个属性的取值 private static LinkedList<String[]> data = new LinkedList<String[]>();; // 原始数据 public static final String patternString = "@attribute(.*)[{](.*?)[}]"; public static String[] yesNo; public static TreeNode root; /** * * @param lines 传入要分析的数据集 * @param index 哪个属性?attribute的index */ public Double getGain(LinkedList<String[]> lines,int index){ Double gain=-1.0; List<Double> li=new ArrayList<Double>(); //统计Yes No的次数 for(int i=0;i<yesNo.length;i++){ Double sum=0.0; for(int j=0;j<lines.size();j++){ String[] line=lines.get(j); //data为结构化数据,如果数据最后一列==yes,sum+1 if(line[line.length-1].equals(yesNo[i])){ sum=sum+1; } } li.add(sum); } //计算Entropy(S)计算Entropy(S) 见参考书《机器学习 》Tom.Mitchell著 第3.4.1.2节 Double entropyS=TheMath.getEntropy(lines.size(), li); //下面计算gain List<String> la=attributevalue.get(index); List<Point> lasv=new ArrayList<Point>(); for(int n=0;n<la.size();n++){ String attvalue=la.get(n); //统计Yes No的次数 List<Double> lisub=new ArrayList<Double>();//如:sunny 是yes时发生的次数,是no发生的次数 Double Sv=0.0;//公式3.4中的Sv 见参考书《机器学习(Tom.Mitchell著)》 for(int i=0;i<yesNo.length;i++){ Double sum=0.0; for(int j=0;j<lines.size();j++){ String[] line=lines.get(j); //data为结构化数据,如果数据最后一列==yes,sum+1 if(line[index].equals(attvalue)&&line[line.length-1].equals(yesNo[i])){ sum=sum+1; } } Sv=Sv+sum;//计算总数 lisub.add(sum); } //计算Entropy(S) 见参考书《机器学习(Tom.Mitchell著)》 Double entropySv=TheMath.getEntropy(Sv.intValue(), lisub); // Point p=new Point(); p.setSv(Sv); p.setEntropySv(entropySv); lasv.add(p); } gain=TheMath.getGain(entropyS,lines.size(),lasv); return gain; } //寻找最大的信息增益,将最大的属性定为当前节点,并返回该属性所在list的位置和gain值 public Maxgain getMaxGain(LinkedList<String[]> lines){ if(lines==null||lines.size()<=0){ return null; } Maxgain maxgain = new Maxgain(); Double maxvalue=0.0; int maxindex=-1; for(int i=0;i<attribute.size();i++){ Double tmp=getGain(lines,i); if(maxvalue< tmp){ maxvalue=tmp; maxindex=i; } } maxgain.setMaxgain(maxvalue); maxgain.setMaxindex(maxindex); return maxgain; } //剪取数组 public LinkedList<String[]> filterLines(LinkedList<String[]> lines, String attvalue, int index){ LinkedList<String[]> newlines=new LinkedList<String[]>(); for(int i=0;i<lines.size();i++){ String[] line=lines.get(i); if(line[index].equals(attvalue)){ newlines.add(line); } } return newlines; } public void createDTree(){ root=new TreeNode(); Maxgain maxgain=getMaxGain(data); if(maxgain==null){ System.out.println("没有数据集,请检查!"); } int maxKey=maxgain.getMaxindex(); String nodename=attribute.get(maxKey); root.setName(nodename); root.setLiatts(attributevalue.get(maxKey)); insertNode(data,root,maxKey); } /** * * @param lines 传入的数据集,作为新的递归数据集 * @param node 深入此节点 * @param index 属性位置 */ public void insertNode(LinkedList<String[]> lines,TreeNode node,int index){ List<String> liatts=node.getLiatts(); for(int i=0;i<liatts.size();i++){ String attname=liatts.get(i); LinkedList<String[]> newlines=filterLines(lines,attname,index); if(newlines.size()<=0){ System.out.println("出现异常,循环结束"); return; } Maxgain maxgain=getMaxGain(newlines); double gain=maxgain.getMaxgain(); Integer maxKey=maxgain.getMaxindex(); //不等于0继续递归,等于0说明是叶子节点,结束递归。 if(gain!=0){ TreeNode subnode=new TreeNode(); subnode.setParent(node); subnode.setFatherAttribute(attname); String nodename=attribute.get(maxKey); subnode.setName(nodename); subnode.setLiatts(attributevalue.get(maxKey)); node.addChild(subnode); //不等于0,继续递归 insertNode(newlines,subnode,maxKey); }else{ TreeNode subnode=new TreeNode(); subnode.setParent(node); subnode.setFatherAttribute(attname); //叶子节点是yes还是no?取新行中最后一个必是其名称,因为只有完全是yes,或完全是no的情况下才会是叶子节点 String[] line=newlines.get(0); String nodename=line[line.length-1]; subnode.setName(nodename); node.addChild(subnode); } } } //输出决策树 public void printDTree(TreeNode node) { if(node.getChildren()==null){ System.out.println("--"+node.getName()); return; } System.out.println(node.getName()); List<TreeNode> childs = node.getChildren(); for (int i = 0; i < childs.size(); i++) { System.out.println(childs.get(i).getFatherAttribute()); printDTree(childs.get(i)); } } public static void main(String[] args) { // TODO Auto-generated method stub MyID3 myid3 = new MyID3(); myid3.readARFF(new File("datafile/decisiontree/test/in/weather.nominal.arff")); myid3.createDTree(); myid3.printDTree(root); } //读取arff文件,给attribute、attributevalue、data赋值 public void readARFF(File file) { try { FileReader fr = new FileReader(file); BufferedReader br = new BufferedReader(fr); String line; Pattern pattern = Pattern.compile(patternString); while ((line = br.readLine()) != null) { if (line.startsWith("@decision")) { line = br.readLine(); if(line=="") continue; yesNo = line.split(","); } Matcher matcher = pattern.matcher(line); if (matcher.find()) { attribute.add(matcher.group(1).trim()); String[] values = matcher.group(2).split(","); ArrayList<String> al = new ArrayList<String>(values.length); for (String value : values) { al.add(value.trim()); } attributevalue.add(al); } else if (line.startsWith("@data")) { while ((line = br.readLine()) != null) { if(line=="") continue; String[] row = line.split(","); data.add(row); } } else { continue; } } br.close(); } catch (IOException e1) { e1.printStackTrace(); } } }
相关文章推荐
- 决策树算法原理及JAVA实现(ID3)
- ID3决策树算法原理及C++实现
- 利用Java Reflection(反射)原理,在hibernate里面实现对单表、视图的动态组合查询
- java中使用线程实现Timer(定时器)原理和源码
- java中使用线程实现Timer(定时器)原理和源码
- Java对象池技术原理及其实现
- Java对象池技术原理及其实现
- Taglib原理和实现(作者WalkingWithJava)
- java中使用线程实现Timer(定时器)原理和源码
- java中使用线程实现Timer(定时器)原理和源码
- java中使用线程实现Timer(定时器)原理和源码
- java中使用线程实现Timer(定时器)原理和源码
- Java对象池技术的原理及其实现
- java中使用线程实现Timer(定时器)原理和源码
- java中使用线程实现Timer(定时器)原理和源码
- java中使用线程实现Timer(定时器)原理和源码
- 快速排序原理及java实现
- Java 虚拟机类装载:原理、实现与应用
- java中使用线程实现Timer(定时器)原理和源码
- Java对象池技术的原理及其实现的小结。