您的位置:首页 > 编程语言 > Java开发

决策树归纳(ID3属性选择度量)Java实现

2014-12-31 10:31 573 查看
一般的决策树归纳框架见之前的博文:http://blog.csdn.net/zhyoulun/article/details/41978381

ID3属性选择度量原理

ID3使用信息增益作为属性选择度量。该度量基于香农在研究消息的值或”信息内容“的信息论方面的先驱工作。该结点N代表或存放分区D的元组。选择具有最高信息增益的属性作为结点N的分裂属性。该属性使结果分区中对元祖分类所需要的信息量最小,并反映这些分区中的最小随机性或”不纯性“。这种方法使得对一个对象分类所需要的期望测试数目最小,并确保找到一颗简单的(但不必是最简单的)树。

对D中的元组分类所需要的期望信息由下式给出,



其中pi是D忠任意元组属于类Ci的非零概率。使用以2为底的对数函数是因为信息用二进制编码。Info(D)是识别D中元组的类标号所需要的平均信息量。注意,此时我们所有的信息只是每个类的元组所占的百分比。

现在假设我们要按照某属性A划分D中的元组,其中属性A根据训练数据的观测具有v个不同的值{a1,a2,...av}。可以用属性A将D划分为v个分区或子集{D1,D2,...,Dv},其中Dj包含D中的元组,它们的A值为aj。这些分区对应于从节点N生长出来的分支。理想情况下,我们希望该划分产生元组的准确分类。即希望每个分区都是纯的(实际情况多半是不纯的,如分区可能包含来自不同类的元组)。在此划分之后,为了得到准确的分类,我们还需要多少信息?这个量由下式度量:



其中|Dj|/|D|充当第j个分区的权重。Info_A(D)是基于按A划分对D的元组分类所需要的期望值信息需要的期望信息越小,分区的纯度越高

信息增益定义为原来的信息需求(仅基于类比例)与新的信息需求(对A划分后)之前的差。即



换言之,Gain(A)告诉我们通过A上的划分我们得到了多少。它是知道A的值而导致的信息需求的期望减少。选择具有最高信息增益Gain(A)的属性A作为结点N的分裂属性。

以下为例子。

数据

data.txt

youth,high,no,fair,no
youth,high,no,excellent,no
middle_aged,high,no,fair,yes
senior,medium,no,fair,yes
senior,low,yes,fair,yes
senior,low,yes,excellent,no
middle_aged,low,yes,excellent,yes
youth,medium,no,fair,no
youth,low,yes,fair,yes
senior,medium,yes,fair,yes
youth,medium,yes,excellent,yes
middle_aged,medium,no,excellent,yes
middle_aged,high,yes,fair,yes
senior,medium,no,excellent,no


attr.txt

age,income,student,credit_rating,buys_computer


运算结果

age(1:youth; 2:middle_aged; 3:senior; )
credit_rating(1:fair; 2:excellent; )
leaf:no()
leaf:yes()
leaf:yes()
student(1:no; 2:yes; )
leaf:no()
leaf:yes()




最后附上java代码

DecisionTree.java

package com.zhyoulun.decision;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Map;

/**
* 负责数据的读入和写出,以及生成决策树
*
* @author zhyoulun
*
*/
public class DecisionTree
{
private ArrayList<ArrayList<String>> allDatas;
private ArrayList<String> allAttributes;

/**
* 从文件中读取所有相关数据
* @param dataFilePath
* @param attrFilePath
*/
public DecisionTree(String dataFilePath,String attrFilePath)
{
super();

try
{
this.allDatas = new ArrayList<>();
this.allAttributes = new ArrayList<>();

InputStreamReader inputStreamReader = new InputStreamReader(new FileInputStream(new File(dataFilePath)));
BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
String line = null;
while((line=bufferedReader.readLine())!=null)
{
String[] strings = line.split(",");
ArrayList<String> data = new ArrayList<>();
for(int i=0;i<strings.length;i++)
data.add(strings[i]);
this.allDatas.add(data);
}

inputStreamReader = new InputStreamReader(new FileInputStream(new File(attrFilePath)));
bufferedReader = new BufferedReader(inputStreamReader);
while((line=bufferedReader.readLine())!=null)
{
String[] strings = line.split(",");
for(int i=0;i<strings.length;i++)
this.allAttributes.add(strings[i]);
}

inputStreamReader.close();
bufferedReader.close();

} catch (FileNotFoundException e)
{
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e)
{
// TODO Auto-generated catch block
e.printStackTrace();
}

//		for(int i=0;i<this.allAttributes.size();i++)
//		{
//			System.out.print(this.allAttributes.get(i)+" ");
//		}
//		System.out.println();
//
//		for(int i=0;i<this.allDatas.size();i++)
//		{
//			for(int j=0;j<this.allDatas.get(i).size();j++)
//			{
//				System.out.print(this.allDatas.get(i).get(j)+" ");
//			}
//			System.out.println();
//		}

}

/**
* @param allDatas
* @param allAttributes
*/
public DecisionTree(ArrayList<ArrayList<String>> allDatas,
ArrayList<String> allAttributes)
{
super();
this.allDatas = allDatas;
this.allAttributes = allAttributes;
}

public ArrayList<ArrayList<String>> getAllDatas()
{
return allDatas;
}

public void setAllDatas(ArrayList<ArrayList<String>> allDatas)
{
this.allDatas = allDatas;
}

public ArrayList<String> getAllAttributes()
{
return allAttributes;
}

public void setAllAttributes(ArrayList<String> allAttributes)
{
this.allAttributes = allAttributes;
}

/**
* 递归生成决策数
* @return
*/
public static TreeNode generateDecisionTree(ArrayList<ArrayList<String>> datas, ArrayList<String> attrs)
{
TreeNode treeNode = new TreeNode();

//如果D中的元素都在同一类C中,then
if(isInTheSameClass(datas))
{
treeNode.setName(datas.get(0).get(datas.get(0).size()-1));
//			rootNode.setName();
return treeNode;
}
//如果attrs为空,then(这种情况一般不会出现,我们应该是要对所有的候选属性集合构建决策树)
if(attrs.size()==0)
return treeNode;

CriterionID3 criterionID3 = new CriterionID3(datas, attrs);
int splitingCriterionIndex = criterionID3.attributeSelectionMethod();

treeNode.setName(attrs.get(splitingCriterionIndex));
treeNode.setRules(getValueSet(datas, splitingCriterionIndex));

attrs.remove(splitingCriterionIndex);

Map<String, ArrayList<ArrayList<String>>> subDatasMap = criterionID3.getSubDatasMap(splitingCriterionIndex);
//		for(String key:subDatasMap.keySet())
//		{
//			System.out.println("===========");
//			System.out.println(key);
//			for(int i=0;i<subDatasMap.get(key).size();i++)
//			{
//				for(int j=0;j<subDatasMap.get(key).get(i).size();j++)
//				{
//					System.out.print(subDatasMap.get(key).get(i).get(j)+" ");
//				}
//				System.out.println();
//			}
//		}

for(String key:subDatasMap.keySet())
{
ArrayList<TreeNode> treeNodes = treeNode.getChildren();
treeNodes.add(generateDecisionTree(subDatasMap.get(key), attrs));
treeNode.setChildren(treeNodes);
}

return treeNode;
}

/**
* 获取datas中index列的值域
* @param data
* @param index
* @return
*/
public static ArrayList<String> getValueSet(ArrayList<ArrayList<String>> datas,int index)
{
ArrayList<String> values = new ArrayList<String>();
String r = "";
for (int i = 0; i < datas.size(); i++) {
r = datas.get(i).get(index);
if (!values.contains(r)) {
values.add(r);
}
}
return values;
}

/**
* 最后一列是类标号,判断最后一列是否相同
* @param datas
* @return
*/
private static boolean isInTheSameClass(ArrayList<ArrayList<String>> datas)
{
String flag = datas.get(0).get(datas.get(0).size()-1);//第0行,最后一列赋初值
for(int i=0;i<datas.size();i++)
{
if(!datas.get(i).get(datas.get(i).size()-1).equals(flag))
return false;
}
return true;
}

public static void main(String[] args)
{
String dataPath = "files/data.txt";
String attrPath = "files/attr.txt";

//初始化原始数据
DecisionTree decisionTree = new DecisionTree(dataPath,attrPath);

//生成决策树
TreeNode treeNode = generateDecisionTree(decisionTree.getAllDatas(), decisionTree.getAllAttributes());

print(treeNode,0);
}

private static void print(TreeNode treeNode,int level)
{
for(int i=0;i<level;i++)
System.out.print("\t");
System.out.print(treeNode.getName());
System.out.print("(");
for(int i=0;i<treeNode.getRules().size();i++)
System.out.print((i+1)+":"+treeNode.getRules().get(i)+"; ");
System.out.println(")");

ArrayList<TreeNode> treeNodes = treeNode.getChildren();
for(int i=0;i<treeNodes.size();i++)
{
print(treeNodes.get(i),level+1);
}
}

}


CriterionID3.java

package com.zhyoulun.decision;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
* ID3,选择分裂准则
*
* @author zhyoulun
*
*/
public class CriterionID3
{
private ArrayList<ArrayList<String>> datas;
private ArrayList<String> attributes;

private Map<String, ArrayList<ArrayList<String>>> subDatasMap;

/**
* 计算所有的信息增益,获取最大的一项作为分裂属性
* @return
*/
public int attributeSelectionMethod()
{
double gain = -1.0;
int maxIndex = 0;
for(int i=0;i<this.attributes.size()-1;i++)
{
double tempGain = this.calcGain(i);
if(tempGain>gain)
{
gain = tempGain;
maxIndex = i;
}
}

return maxIndex;
}

/**
* 计算 Gain(age)=Info(D)-Info_age(D) 等
* @param index
* @return
*/
/**
* @param index
* @param isCalcSubDatasMap
* @return
*/
private double calcGain(int index)
{
double result = 0;

//计算Info(D)
int lastIndex = datas.get(0).size()-1;
ArrayList<String> valueSet = DecisionTree.getValueSet(this.datas,lastIndex);
for(String value:valueSet)
{
int count = 0;
for(int i=0;i<datas.size();i++)
{
if(datas.get(i).get(lastIndex).equals(value))
count++;
}

result += -(1.0*count/datas.size())*Math.log(1.0*count/datas.size())/Math.log(2);
//			System.out.println(result);
}
//		System.out.println("==========");

//计算Info_a(D)
valueSet = DecisionTree.getValueSet(this.datas,index);

//		for(String temp:valueSet)
//			System.out.println(temp);
//		System.out.println("==========");

for(String value:valueSet)
{
ArrayList<ArrayList<String>> subDatas = new ArrayList<>();
for(int i=0;i<datas.size();i++)
{
if(datas.get(i).get(index).equals(value))
subDatas.add(datas.get(i));
}

//			for(ArrayList<String> temp:subDatas)
//			{
//				for(String temp2:temp)
//					System.out.print(temp2+" ");
//				System.out.println();
//			}

ArrayList<String> subValueSet = DecisionTree.getValueSet(subDatas, lastIndex);

//			System.out.print("subValueSet:");
//			for(String temp:subValueSet)
//				System.out.print(temp+" ");
//			System.out.println();

for(String subValue:subValueSet)
{
//				System.out.println("+++++++++++++++");
//				System.out.println(subValue);
int count = 0;
for(int i=0;i<subDatas.size();i++)
{
if(subDatas.get(i).get(lastIndex).equals(subValue))
count++;
}
//				System.out.println(count);
result += -1.0*subDatas.size()/datas.size()*(-(1.0*count/subDatas.size())*Math.log(1.0*count/subDatas.size())/Math.log(2));
//				System.out.println(result);
}

}

return result;

}

public CriterionID3(ArrayList<ArrayList<String>> datas,
ArrayList<String> attributes)
{
super();
this.datas = datas;
this.attributes = attributes;
}

public ArrayList<ArrayList<String>> getDatas()
{
return datas;
}

public void setDatas(ArrayList<ArrayList<String>> datas)
{
this.datas = datas;
}

public ArrayList<String> getAttributes()
{
return attributes;
}

public void setAttributes(ArrayList<String> attributes)
{
this.attributes = attributes;
}

public Map<String, ArrayList<ArrayList<String>>> getSubDatasMap(int index)
{
ArrayList<String> valueSet = DecisionTree.getValueSet(this.datas, index);
this.subDatasMap = new HashMap<String, ArrayList<ArrayList<String>>>();

for(String value:valueSet)
{
ArrayList<ArrayList<String>> subDatas = new ArrayList<>();
for(int i=0;i<this.datas.size();i++)
{
if(this.datas.get(i).get(index).equals(value))
subDatas.add(this.datas.get(i));
}
for(int i=0;i<subDatas.size();i++)
{
subDatas.get(i).remove(index);
}
this.subDatasMap.put(value, subDatas);
}

return subDatasMap;
}

public void setSubDatasMap(Map<String, ArrayList<ArrayList<String>>> subDatasMap)
{
this.subDatasMap = subDatasMap;
}

}


TreeNode.java

package com.zhyoulun.decision;

import java.util.ArrayList;

public class TreeNode
{
private String name; 								// 该结点的名称(分裂属性)
private ArrayList<String> rules; 				// 结点的分裂规则(假设均为离散值)
//	private ArrayList<ArrayList<String>> datas; 	// 划分到该结点的训练元组(datas.get(i)表示一个训练元组)
//	private ArrayList<String> candidateAttributes; // 划分到该结点的候选属性(与训练元组的个数一致)
private ArrayList<TreeNode> children; 			// 子结点

public TreeNode()
{
this.name = "";
this.rules = new ArrayList<String>();
this.children = new ArrayList<TreeNode>();
//		this.datas = null;
//		this.candidateAttributes = null;
}

public String getName()
{
return name;
}

public void setName(String name)
{
this.name = name;
}

public ArrayList<String> getRules()
{
return rules;
}

public void setRules(ArrayList<String> rules)
{
this.rules = rules;
}

public ArrayList<TreeNode> getChildren()
{
return children;
}

public void setChildren(ArrayList<TreeNode> children)
{
this.children = children;
}

//	public ArrayList<ArrayList<String>> getDatas()
//	{
//		return datas;
//	}
//
//	public void setDatas(ArrayList<ArrayList<String>> datas)
//	{
//		this.datas = datas;
//	}
//
//	public ArrayList<String> getCandidateAttributes()
//	{
//		return candidateAttributes;
//	}
//
//	public void setCandidateAttributes(ArrayList<String> candidateAttributes)
//	{
//		this.candidateAttributes = candidateAttributes;
//	}

}


参考:《数据挖掘概念与技术(第3版)》

转载请注明出处:
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息