您的位置:首页 > 编程语言 > Python开发

Python 决策树算法(ID3 & C4.5)

2016-11-14 23:09 369 查看
决策树(DecisionTree)算法:按照样本的属性逐步进行分类,为了能够使分类更快、更有效。每一个新分类属性的选择依据可以是信息增益IG和信息增益率IGR,前者为最基本的ID3算法,后者为改进后的C4.5算法。

以ID3为例,其训练过程的编程思路如下:

(1)输入x、y(x为样本,y为label),行为样本,列为样本特征。

(2)计算信息增益IG,获取使IG最大的特征。

(3)获得删除最佳分类特征后的样本阵列。

(4)按照最佳分类特征的属性值将更新后的样本进行归类。

属性值1(x1,y1)   属性值2(x2,y2)   属性值(x3,y3)

(5)分别对以上类别重复以上操作直至到达叶节点(递归调用)。

叶节点的特征:

(1)所有的标签值y都一样。

(2)没有特征可以继续划分。

测试过程的编程思路如下:

(1)读取训练好的决策树。

(2)从根节点开始递归遍历整个决策树直到到达叶节点为止。

以下为具体代码,训练后的决策树结构为递归套用的字典,其是由特征值组成的索引加上label组成的。
# -*- coding: utf-8 -*-
"""

Created on Mon Nov 07 09:06:37 2016

@author: yehx

"""
# -*- coding: utf-8 -*-
"""

Created on Sun Feb 21 12:17:10 2016

Decision Tree Source Code

@author: liudiwei

"""
import os
import numpy as np

class DecitionTree():

    """Thisis a decision tree classifier. """

    

    def __init__(self, criteria='ID3'):

       self._tree = None

       if criteria == 'ID3' or criteria == 'C4.5':

          self._criteria = criteria

       else:

          raise Exception("criterionshould
be ID3 or C4.5")

    

    def _calEntropy(slef, y):

       '''

       功能:_calEntropy用于计算香农熵 e=-sum(pi*logpi)

       参数:其中y为数组array

       输出:信息熵entropy

       '''

       n = y.shape[0]  

       labelCounts = {}

       for label in y:

          if label not in labelCounts.keys():

              labelCounts[label] = 1

          else:

              labelCounts[label] += 1

       entropy = 0.0

       for key in labelCounts:

          prob = float(labelCounts[key])/n

           entropy-= prob* np.log2(prob)

       return entropy

    

    def _splitData(self, X, y, axis, cutoff):

       """

      参数:X为特征,y为label,axis为某个特征的下标,cutoff是下标为axis特征取值值

       输出:返回数据集中特征下标为axis,特征值等于cutoff的子数据集

       先将特征列从样本矩阵里除去,然后将属性值为cutoff的数据归为一类

       """

       ret = []

       featVec = X[:,axis]

       n = X.shape[1]   
 #特征个数

       #除去第axis列特征后的样本矩阵

       X = X[:,[i for i in range(n) if i!=axis]]

       for i in range(len(featVec)):

          if featVec[i] == cutoff:

              ret.append(i)

   
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: