归纳决策树ID3(Java实现)
2013-01-05 18:27
453 查看
先上问题吧,我们统计了14天的气象数据(指标包括outlook,temperature,humidity,windy),并已知这些天气是否打球(play)。如果给出新一天的气象指标数据:sunny,cool,high,TRUE,判断一下会不会去打球。
table 1
这个问题当然可以用朴素贝叶斯法求解,分别计算在给定天气条件下打球和不打球的概率,选概率大者作为推测结果。
现在我们使用ID3归纳决策树的方法来求解该问题。
通常以2为底数,所以信息熵的单位是bit。
补充两个对数去处公式:
在没有给定任何天气信息时,根据历史数据,我们只知道新的一天打球的概率是9/14,不打的概率是5/14。此时的熵为:
属性有4个:outlook,temperature,humidity,windy。我们首先要决定哪个属性作树的根节点。
对每项指标分别统计:在不同的取值下打球和不打球的次数。
table 2
下面我们计算当已知变量outlook的值时,信息熵为多少。
outlook=sunny时,2/5的概率打球,3/5的概率不打球。entropy=0.971
outlook=overcast时,entropy=0
outlook=rainy时,entropy=0.971
而根据历史统计数据,outlook取值为sunny、overcast、rainy的概率分别是5/14、4/14、5/14,所以当已知变量outlook的值时,信息熵为:5/14 × 0.971 + 4/14 × 0 + 5/14 × 0.971 = 0.693
这样的话系统熵就从0.940下降到了0.693,信息增溢gain(outlook)为0.940-0.693=0.247
同样可以计算出gain(temperature)=0.029,gain(humidity)=0.152,gain(windy)=0.048。
gain(outlook)最大(即outlook在第一步使系统的信息熵下降得最快),所以决策树的根节点就取outlook。
接下来要确定N1取temperature、humidity还是windy?在已知outlook=sunny的情况,根据历史数据,我们作出类似table 2的一张表,分别计算gain(temperature)、gain(humidity)和gain(windy),选最大者为N1。
依此类推,构造决策树。当系统的信息熵降为0时,就没有必要再往下构造决策树了,此时叶子节点都是纯的--这是理想情况。最坏的情况下,决策树的高度为属性(决策变量)的个数,叶子节点不纯(这意味着我们要以一定的概率来作出决策)。
实验用的数据文件:
程序代码:
最终生成的文件如下:
用图形象地表示就是:
table 1
outlook | temperature | humidity | windy | play |
sunny | hot | high | FALSE | no |
sunny | hot | high | TRUE | no |
overcast | hot | high | FALSE | yes |
rainy | mild | high | FALSE | yes |
rainy | cool | normal | FALSE | yes |
rainy | cool | normal | TRUE | no |
overcast | cool | normal | TRUE | yes |
sunny | mild | high | FALSE | no |
sunny | cool | normal | FALSE | yes |
rainy | mild | normal | FALSE | yes |
sunny | mild | normal | TRUE | yes |
overcast | mild | high | TRUE | yes |
overcast | hot | normal | FALSE | yes |
rainy | mild | high | TRUE | no |
现在我们使用ID3归纳决策树的方法来求解该问题。
预备知识:信息熵
熵是无序性(或不确定性)的度量指标。假如事件A的全概率划分是(A1,A2,...,An),每部分发生的概率是(p1,p2,...,pn),那信息熵定义为:通常以2为底数,所以信息熵的单位是bit。
补充两个对数去处公式:
ID3算法
构造树的基本想法是随着树深度的增加,节点的熵迅速地降低。熵降低的速度越快越好,这样我们有望得到一棵高度最矮的决策树。在没有给定任何天气信息时,根据历史数据,我们只知道新的一天打球的概率是9/14,不打的概率是5/14。此时的熵为:
属性有4个:outlook,temperature,humidity,windy。我们首先要决定哪个属性作树的根节点。
对每项指标分别统计:在不同的取值下打球和不打球的次数。
table 2
outlook | temperature | humidity | windy | play | |||||||||
yes | no | yes | no | yes | no | yes | no | yes | no | ||||
sunny | 2 | 3 | hot | 2 | 2 | high | 3 | 4 | FALSE | 6 | 2 | 9 | 5 |
overcast | 4 | 0 | mild | 4 | 2 | normal | 6 | 1 | TRUR | 3 | 3 | ||
rainy | 3 | 2 | cool | 3 | 1 |
outlook=sunny时,2/5的概率打球,3/5的概率不打球。entropy=0.971
outlook=overcast时,entropy=0
outlook=rainy时,entropy=0.971
而根据历史统计数据,outlook取值为sunny、overcast、rainy的概率分别是5/14、4/14、5/14,所以当已知变量outlook的值时,信息熵为:5/14 × 0.971 + 4/14 × 0 + 5/14 × 0.971 = 0.693
这样的话系统熵就从0.940下降到了0.693,信息增溢gain(outlook)为0.940-0.693=0.247
同样可以计算出gain(temperature)=0.029,gain(humidity)=0.152,gain(windy)=0.048。
gain(outlook)最大(即outlook在第一步使系统的信息熵下降得最快),所以决策树的根节点就取outlook。
接下来要确定N1取temperature、humidity还是windy?在已知outlook=sunny的情况,根据历史数据,我们作出类似table 2的一张表,分别计算gain(temperature)、gain(humidity)和gain(windy),选最大者为N1。
依此类推,构造决策树。当系统的信息熵降为0时,就没有必要再往下构造决策树了,此时叶子节点都是纯的--这是理想情况。最坏的情况下,决策树的高度为属性(决策变量)的个数,叶子节点不纯(这意味着我们要以一定的概率来作出决策)。
Java实现
最终的决策树保存在了XML中,使用了Dom4J,注意如果要让Dom4J支持按XPath选择节点,还得引入包jaxen.jar。程序代码要求输入文件满足ARFF格式,并且属性都是标称变量。实验用的数据文件:
@relation weather.symbolic @attribute outlook {sunny, overcast, rainy} @attribute temperature {hot, mild, cool} @attribute humidity {high, normal} @attribute windy {TRUE, FALSE} @attribute play {yes, no} @data sunny,hot,high,FALSE,no sunny,hot,high,TRUE,no overcast,hot,high,FALSE,yes rainy,mild,high,FALSE,yes rainy,cool,normal,FALSE,yes rainy,cool,normal,TRUE,no overcast,cool,normal,TRUE,yes sunny,mild,high,FALSE,no sunny,cool,normal,FALSE,yes rainy,mild,normal,FALSE,yes sunny,mild,normal,TRUE,yes overcast,mild,high,TRUE,yes overcast,hot,normal,FALSE,yes rainy,mild,high,TRUE,no
程序代码:
package schoolarship; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.FileWriter; import java.io.IOException; import java.io.InputStreamReader; import java.io.UnsupportedEncodingException; import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedList; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; import org.dom4j.Document; import org.dom4j.DocumentHelper; import org.dom4j.Element; import org.dom4j.io.OutputFormat; import org.dom4j.io.XMLWriter; public class ID3 { //存储属性的名称,这里判别变量和决策变量一律称为“属性” private ArrayList<String> attribute = new ArrayList<String>(); //存储每个属性(都是离散变量)的取值集合 private ArrayList<ArrayList<String>> attributevalue = new ArrayList<ArrayList<String>>(); //存储所有的训练数据,这是一个二维数组 private ArrayList<String[]> data = new ArrayList<String[]>(); //决策变量的属性列表中的索引号 int decatt; //用于匹配ARFF文件中的@attribute行 public static final String patternString = "@attribute(.*)[{](.*?)[}]"; //使用Dom4j读写XML文件 Document xmldoc; Element root; //构造函数中初始化Dom元素 public ID3() { xmldoc = DocumentHelper.createDocument(); root = xmldoc.addElement("root"); root.addElement("DecisionTree").addAttribute("value", "null"); } public static void main(String[] args) { ID3 inst = new ID3(); //读入训练文件 inst.readARFF(new File("d:\\weather.arff")); //设置决策变量的名称 inst.setDec("play"); //将所有属性(决策变量除外)的索引号存入ll LinkedList<Integer> ll = new LinkedList<Integer>(); for (int i = 0; i < inst.attribute.size(); i++) { if (i != inst.decatt) ll.add(i); } //将全部训练数据的序号存入al ArrayList<Integer> al = new ArrayList<Integer>(); for (int i = 0; i < inst.data.size(); i++) { al.add(i); } //递归构建决策树 inst.buildDT(inst.root, al, ll); //将决策树写入XML文件 inst.writeXML("d:\\dt.xml"); } //读取输入文件,为全局变量attribute、attributevalue和data赋值 public void readARFF(File file) { try { FileInputStream fis = new FileInputStream(file); InputStreamReader isr = new InputStreamReader(fis, initBookEncode(fis)); BufferedReader br = new BufferedReader(isr); String line; Pattern pattern = Pattern.compile(patternString); while ((line = br.readLine()) != null) { Matcher matcher = pattern.matcher(line); if (matcher.find()) { attribute.add(matcher.group(1).trim()); String[] values = matcher.group(2).split(","); ArrayList<String> al = new ArrayList<String>(values.length); for (String value : values) { al.add(value.trim()); } attributevalue.add(al); } else if (line.startsWith("@data")) { while ((line = br.readLine()) != null) { if (line.equals("")) continue; String[] row = line.split(","); data.add(row); } } else { continue; } } br.close(); } catch (IOException e1) { e1.printStackTrace(); } } //将参数n赋给全局变量decatt public void setDec(int n) { if (n < 0 || n >= attribute.size()) { System.err.println("给定的决策变量名称有误"); System.exit(2); } decatt = n; } //根据属性的名称设置全局变量decatt public void setDec(String name) { int n = attribute.indexOf(name); setDec(n); } //计算信息熵。arr中存储各种情况的频数 public double getEntropy(int[] arr) { int sum = 0; for (int i = 0; i < arr.length; i++) { sum += arr[i]; } return getEntropy(arr, sum); } //计算信息熵。arr中存储各种情况的频数,sum给出频数的总和 public double getEntropy(int[] arr, int sum) { if (sum == 0) return 0; double entropy = 0.0; for (int i = 0; i < arr.length; i++) { //加上Double.MIN_VALUE是为了防止出现log(0)的情况 entropy -= arr[i] * Math.log(arr[i] + Double.MIN_VALUE) / Math.log(2); } entropy += sum * Math.log(sum + Double.MIN_VALUE) / Math.log(2); entropy /= sum; //由于上面加了Double.MIN_VALUE,所以算出来的熵可能会略大于1 if (entropy > 1 && entropy - 1 < 0.00001) entropy = 1; return entropy; } //subset给写训练数据的一个子集(subset中存储的是每条数据的索引号),判断这些子集的决策变量值是否都相同 public boolean infoPure(ArrayList<Integer> subset) { String value = data.get(subset.get(0))[decatt]; for (int i = 1; i < subset.size(); i++) { String next = data.get(subset.get(i))[decatt]; if (!value.equals(next)) return false; } return true; } /** * 计算节点的信息熵 * @param subset 节点上所包含的数据子集 * @param index 节点以第index个属性作为判断的依据 * @return 节点的信息熵 */ public double calNodeEntropy(ArrayList<Integer> subset, int index) { int sum = subset.size(); double entropy = 0.0; int[][] info = new int[attributevalue.get(index).size()][]; for (int i = 0; i < info.length; i++) info[i] = new int[attributevalue.get(decatt).size()]; int[] count = new int[attributevalue.get(index).size()]; for (int i = 0; i < sum; i++) { int n = subset.get(i); String nodevalue = data.get(n)[index]; int nodeind = attributevalue.get(index).indexOf(nodevalue); count[nodeind]++; String decvalue = data.get(n)[decatt]; int decind = attributevalue.get(decatt).indexOf(decvalue); info[nodeind][decind]++; } for (int i = 0; i < info.length; i++) { entropy += getEntropy(info[i]) * count[i] / sum; } return entropy; } // 递归构建决策树 public void buildDT(Element ele, ArrayList<Integer> subset, LinkedList<Integer> selatt) { //指定name和value的节点不包含数据子集时,递归可以终止。同时要删除该节点 if (subset.size() == 0){ ele.getParent().remove(ele); return; } //selatt.size() == 0说明树已经达到最大的深度,即所有判别属性都已经用完了。 //这个时候递归还没有终止说明训练数据中存在判别属性值完全相同,决策属性值却不相同的情况,取决策属性值最多的情况为最终结果 if(selatt.size() == 0){ Map<String,Integer> map=new HashMap<String,Integer>(); for(int i:subset){ String key=data.get(i)[decatt]; Integer v=map.get(key); if(v!=null) map.put(key, v+1); else map.put(key, 1); } String decision="should not appear"; int maxCount=-1; Set<Entry<String,Integer>> set=map.entrySet(); for(Entry<String,Integer> entry:set){ if(entry.getValue()>maxCount){ maxCount=entry.getValue(); decision=entry.getKey(); } } ele.setText(decision); return; } //如果节点是纯的,那么就到达叶子节点了,给出决策,不需要继续递归了 if (infoPure(subset)) { ele.setText(data.get(subset.get(0))[decatt]); return; } //选择下一个用于判别的属性。应该选熵最小的,因为这样信息增益最大 int minIndex = -1; double minEntropy = Double.MAX_VALUE; for (int i = 0; i < selatt.size(); i++) { if (i == decatt) continue; double entropy = calNodeEntropy(subset, selatt.get(i)); if (entropy < minEntropy) { minIndex = selatt.get(i); minEntropy = entropy; } } String nodeName = attribute.get(minIndex); //每次递归时selatt都会少一个元素,即去除刚刚选择的判别属性 selatt.remove(new Integer(minIndex)); //刚刚选择的属性有多少种取值,该节点就有多少个分枝。遍历这些分枝,递归完善子树。 ArrayList<String> attvalues = attributevalue.get(minIndex); for (String val : attvalues) { Element child=ele.addElement(nodeName).addAttribute("value", val); ArrayList<Integer> al = new ArrayList<Integer>(); for (int i = 0; i < subset.size(); i++) { if (data.get(subset.get(i))[minIndex].equals(val)) { al.add(subset.get(i)); } } //注意bBuildDT()里面selatt会被改变,所以每次传递这个参数的时候要进行深复制 buildDT(child, al, new LinkedList<Integer>(selatt)); } } //将Dom写入XML文件 public void writeXML(String filename) { try { File file = new File(filename); if (!file.exists()) file.createNewFile(); FileWriter fw = new FileWriter(file); OutputFormat format = OutputFormat.createPrettyPrint(); XMLWriter output = new XMLWriter(fw, format); output.write(xmldoc); output.close(); } catch (IOException e) { System.out.println(e.getMessage()); } } /*正面这两个函数用于正确读取中文文件*/ String changeToGBK(String ss, String code) { String temp = null; try { temp = new String(ss.getBytes(), code); } catch (UnsupportedEncodingException e) { e.printStackTrace(); } return temp; } public String initBookEncode(FileInputStream fileInputStream) { String encode = "gb2312"; try { byte[] head = new byte[3]; fileInputStream.read(head); if (head[0] == -17 && head[1] == -69 && head[2] == -65) encode = "UTF-8"; else if (head[0] == -1 && head[1] == -2) encode = "UTF-16"; else if (head[0] == -2 && head[1] == -1) encode = "Unicode"; } catch (IOException e) { System.out.println(e.getMessage()); } return encode; } }
最终生成的文件如下:
<?xml version="1.0" encoding="UTF-8"?> <root> <DecisionTree value="null"> <outlook value="sunny"> <humidity value="high">no</humidity> <humidity value="normal">yes</humidity> </outlook> <outlook value="overcast">yes</outlook> <outlook value="rainy"> <windy value="TRUE">no</windy> <windy value="FALSE">yes</windy> </outlook> </DecisionTree> </root>
用图形象地表示就是:
相关文章推荐
- 归纳决策树ID3(Java实现)
- 归纳决策树ID3(Java实现)
- 归纳决策树ID3(Java实现)
- 归纳决策树ID3(Java实现)
- 决策树归纳(ID3属性选择度量)Java实现
- 归纳决策树ID3(Java实现)
- 归纳决策树ID3(Java实现)
- 决策树ID3(Java实现)
- 决策树ID3算法的java实现(基本适用所有的ID3)
- ID3决策树的Java实现
- ID3决策树预测的java实现
- 基本排序算法Java实现归纳(一)
- 机器学习(周志华)习题解答4.3: Python小白详解ID3决策树的实现
- 归纳决策树ID3
- 归纳决策树ID3(信息熵的计算和计算原理写的很清楚)
- 机器学习入门学习笔记:(3.2)ID3决策树程序实现
- 决策树ID3 算法python实现
- 机器学习实战决策树的java实现
- 分类算法-----决策树(ID3)算法原理和Python实现
- 数据挖掘-决策树ID3分类算法的C++实现