C4.5决策树算法
2012-11-18 21:49
176 查看
属性选择上采用信息增益率,另外要注意决策树是穷举的,也就是所有的condition都要能有决策,尽管没有某个属性值的存在(但声明了这个可能)。还有默认类的选择上C4.5采用的策略是未被覆盖的多数类。
package com.tur4.algorithm; import java.io.BufferedReader; import java.io.File; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.FileReader; import java.io.IOException; import java.io.OutputStream; import java.io.UnsupportedEncodingException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; import org.apache.log4j.Logger; import org.dom4j.Document; import org.dom4j.DocumentHelper; import org.dom4j.Element; import org.dom4j.io.OutputFormat; import org.dom4j.io.XMLWriter; public class C4_5 { private class Res{ public boolean isPure = true; public String clazz; } private static Logger LOG = Logger.getLogger(C4_5.class); private static final String ROOT = "decision_tree"; private static final String VALUE = "value"; private static final String ALL = "all"; private List<ArrayList<String>> data = new ArrayList<ArrayList<String>>(); private int decidx; private List<String> attrNames = new ArrayList<String>(); private List<ArrayList<String>> attrValues = new ArrayList<ArrayList<String>>(); private File file; private static String patternString = "@attribute\\s+([^\\s]+)\\s*\\{([^\\}]+)\\}"; private Element root; private Document doc; private String outFilePath = "decisionTree.xml"; private String mostClass;//未被覆盖的多数类 private int[] flags; public void setOutFilePath(String path){ this.outFilePath = path; } public C4_5(String filePath, String decision, String split){ this.file = new File(filePath); BufferedReader reader = null; try { reader = new BufferedReader(new FileReader(this.file)); String str = null; boolean isData = false; while((str = reader.readLine()) != null){ if(str.trim().length() == 0) continue; if(str.trim().startsWith("@data")){ isData = true; continue; } if(str.trim().startsWith("@attribute")){ Pattern pattern = Pattern.compile(patternString); Matcher m = pattern.matcher(str); if(m.find()){ attrNames.add(m.group(1)); ArrayList<String> values = new ArrayList<String>(); String[] vals = m.group(2).split(split); for(String val: vals) values.add(val.trim()); attrValues.add(values); } }else if(isData){ String[] vals = str.split(split); ArrayList<String> record = new ArrayList<String>(attrNames.size()); for(String val: vals) record.add(val.trim()); data.add(record); } } } catch (FileNotFoundException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); }finally{ if(reader != null) try { reader.close(); } catch (IOException e) { e.printStackTrace(); } } //LOG.debug("attrNames=" + attrNames + "\r\nattrValues=" + attrValues + "\r\ndata=" + data); if(decision==null || decision.trim().length()==0) this.decidx = attrNames.size()-1; else setDecision(decision); flags = new int[data.size()]; for(int i=0;i<data.size();++i) flags[i] = 0; doc = DocumentHelper.createDocument(); root = doc.addElement(ROOT).addAttribute(VALUE, ALL); } public void setDecision(String decision){ for(int i=0;i<attrNames.size();++i){ if(attrNames.get(i).equals(decision)){ decidx = i; break; } } } private double calcEntropy(int[] info){ int sum = 0; for(int integer:info) sum += integer; if(sum == 0) return 0.0; double entropy = 0.0; for(int i=0;i<info.length;++i){ entropy -= info[i] * Math.log(Double.MIN_VALUE + info[i]) / Math.log(2); } entropy += sum * Math.log(Double.MIN_VALUE + sum) / Math.log(2); //LOG.info("info=" + Arrays.toString(info)+"\tentropy="+entropy/sum); return entropy/sum; } private Integer findAttrValueIndex(int attrIdx, String val){ List<String> attrs = attrValues.get(attrIdx); for(int i=0;i<attrs.size();++i) if(attrs.get(i).equals(val)) return i; return null; } private double calcExpectEntropyByAttr(int attrIdx, List<Integer> idxSet){ int diffValues = attrValues.get(attrIdx).size(); int classNum = attrValues.get(decidx).size(); int info[][] = new int[diffValues][]; for(int i=0;i<diffValues;++i){ info[i] = new int[classNum]; } int count[] = new int[diffValues]; for(Integer i: idxSet){ List<String> record = data.get(i); String val = record.get(attrIdx); int idx = findAttrValueIndex(attrIdx, val); count[idx]++; String clazzVal = record.get(decidx); int classIdx = findAttrValueIndex(decidx, clazzVal); info[idx][classIdx]++; } double entropy = 0.0; double splitEntropy = 0.0; double sum = idxSet.size(); for(int i=0;i<diffValues;++i){ //LOG.debug("count[i]/sum * calcEntropy(info[i])="+(count[i]/sum)+"*"+calcEntropy(info[i])); entropy += count[i]/sum * calcEntropy(info[i]); splitEntropy -= count[i] * Math.log(Double.MIN_VALUE + count[i]) / Math.log(2); } splitEntropy += sum * Math.log(Double.MIN_VALUE + sum) / Math.log(2); splitEntropy /= sum; //LOG.debug("entropy="+entropy+"\tsplitEntropy="+splitEntropy+"\tres="+entropy/(splitEntropy + Double.MIN_VALUE)); return entropy/(splitEntropy + Double.MIN_VALUE); } private Res pureInfo(List<Integer> set){ Res res = new Res(); String clazz = data.get(set.get(0)).get(decidx); for(Integer idx: set){ if(!data.get(idx).get(decidx).equals(clazz)){ res.isPure = false; } } res.clazz = clazz; return res; } public void buildTree(){ List<Integer> records = new ArrayList<Integer>(); for(int i=0;i<data.size();++i) records.add(i); ArrayList<String> attrs = new ArrayList<String>(); for(String attr: attrNames){ if(!attr.equals(attrNames.get(decidx))) attrs.add(attr); } buildTree("", ROOT, ALL, records, attrs); XMLWriter writer = null; try { OutputStream out = new FileOutputStream(new File(outFilePath)); writer = new XMLWriter(out, OutputFormat.createPrettyPrint()); writer.write(doc); } catch (FileNotFoundException e) { e.printStackTrace(); } catch (UnsupportedEncodingException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); }finally{ if(writer != null) try { writer.close(); } catch (IOException e) { e.printStackTrace(); } } } private Integer findAttrNameIndex(String attrName){ for(int i=0;i<attrNames.size();++i) if(attrNames.get(i).equals(attrName)) return i; return null; } private String getNotCoverMostClass(){ int classNum[] = new int[attrValues.get(decidx).size()]; for(int i=0;i<attrValues.get(decidx).size();++i) classNum[i] = 0; for(int i=0;i<flags.length;++i){ if(flags[i] == 0) classNum[attrNames.indexOf(data.get(i).get(decidx))]++; } int max = 0; int idx = -1; for(int i=0;i<classNum.length;++i) if(max < classNum[i]){ max = classNum[i]; idx = i; } return attrValues.get(decidx).get(idx); } /** * find the most class in the set * @param set * @return */ private String getMostProbabilityClass(List<Integer> set){ Map<String, Integer> map = new HashMap<String, Integer>(); String res = null; Integer most = 0; for(Integer idx: set){ String key = data.get(idx).get(decidx); if(!map.containsKey(key)) map.put(key, 1); else map.put(key, map.get(key) + 1); if(map.get(key) > most){ most = map.get(key); res = key; } } mostClass = res; return res; } private void buildTree(String xpath, String preEleName, String preVal, List<Integer> set, ArrayList<String> attrs){ String newXpath = xpath + "/" + preEleName + "[@" + VALUE + "='" + preVal + "']"; @SuppressWarnings("rawtypes") List nodes = root.selectNodes(newXpath); LOG.debug(newXpath); Element ele = null; Iterator it = nodes.iterator(); while(it.hasNext()){ ele = (Element) it.next(); if(ele.attributeValue(VALUE).equals(preVal)) break; } if(set.size() == 0){ if(mostClass == null){ List<Integer> records = new ArrayList<Integer>(data.size()); for(int i=0;i<data.size();++i) records.add(i); mostClass = getMostProbabilityClass(records); } ele.setText(mostClass); return; } //LOG.debug("element=" + ele); //LOG.debug("eleName=" + preEleName +"\tval=" + preVal); Res res = pureInfo(set); if(res.isPure){ ele.addText(res.clazz); for(Integer idx: set) flags[idx] = 1; return; } if(attrs==null || attrs.size()==0){ String className = getMostProbabilityClass(set); ele.addText(className); return; } double minEntropy = Double.MAX_VALUE; int attrIdx = -1; String attr = null; for(int i=0;i<attrs.size();++i){ String tmpAttr = attrs.get(i); int tmpAttrIdx = findAttrNameIndex(tmpAttr); double entropy = calcExpectEntropyByAttr(tmpAttrIdx, set); if(entropy < minEntropy){ minEntropy = entropy; attr = tmpAttr; attrIdx = findAttrNameIndex(attr);//attrNames的索引不一定与可选属性集的索引相同 } } ArrayList<String> remainAttrs = (ArrayList<String>) attrs.clone(); for(int i=0;i<remainAttrs.size();++i){ if(remainAttrs.get(i).equals(attr)){ remainAttrs.remove(i); break; } } if(attrIdx == -1){ ele.setText(getNotCoverMostClass()); return; } LOG.debug("attrs="+attrs + "\tattr="+attr + "\tattrIdx="+attrIdx); List<ArrayList<Integer>> subsets = new ArrayList<ArrayList<Integer>>(attrValues.get(attrIdx).size()); for(int i=0;i<attrValues.get(attrIdx).size();++i){ subsets.add(new ArrayList<Integer>()); } for(Integer idx: set){ List<String> record = data.get(idx); int attrValIndex = findAttrValueIndex(attrIdx, record.get(attrIdx)); subsets.get(attrValIndex).add(idx); } /* for(int i=0;i<subsets.size();++i){ for(int j=0;j<subsets.get(i).size();++j) System.out.println(data.get(subsets.get(i).get(j))); System.out.println("+++++++++++++++++++++++++"); }*/ List<String> values = attrValues.get(attrIdx); LOG.debug(attr + "=" + values + "attrIdx=" + attrIdx); for(int i=0;i<values.size();++i){ if(subsets.get(i).size() != 0){ ele.addElement(attr).addAttribute(VALUE, values.get(i)); buildTree(newXpath, attr, values.get(i), subsets.get(i), (ArrayList<String>) remainAttrs.clone()); } } } }
相关文章推荐
- Accord C4.5决策树算法(C# C4.5决策树算法)
- 决策树算法:ID3和C4.5
- 经典决策树算法:ID3、C4.5和CART
- 【机器学习2】决策树算法之CART、C4.5
- Thinking in SQL系列之四:数据挖掘C4.5决策树算法
- C4.5决策树算法
- C4.5决策树算法
- Thinking in SQL系列之四:数据挖掘C4.5决策树算法
- C4.5决策树算法介绍
- Ng机器学习系列补充:1、决策树算法ID3和C4.5
- R语言-决策树算法(C4.5和CART)的实现
- ID3和C4.5决策树算法总结
- Python 决策树算法(ID3 & C4.5)
- 整理--决策树算法:ID3和C4.5
- 决策树算法ID3,C4.5, CART
- C4.5决策树算法思想
- 机器学习中决策树算法原理主要有ID3、C4.5、CART算法
- 整理--决策树算法:ID3和C4.5
- C4.5决策树算法
- 数据挖掘——分类——决策树算法之ID3与C4.5原理解析