您的位置:首页 > 其它

决策树-泰坦尼克号生还预测

2017-08-18 11:53 323 查看
LR和SVM都在某种程度上要求被学习的数据特征和目标之间遵照线性假设。然后许多现实场景下,这种假设不存在。
比如根据年龄预测流感的死亡率,如果用线性模型假设,那只有两个可能:年龄越大/越小,死亡率越高。根据经验,青壮年更不容易因患流感而死亡。年龄和因流感的死亡不存在线性关系。
在机器学习模型中,决策树是描述非线性关系的不二之选。
信用卡申请的审核,涉及多项特征,是典型的决策树模型。对于是否同意申请,是二分类决策任务,只有yes/no两种分类结果。
使用多种不同特征组合搭建多层决策树的情况,模型在学习的时候需要考虑特征节点的选取顺序。常用的方式包括信息熵(Information Gain)和基尼不纯性(Gini Impurity)。本文不做讨论。sklearn中默认配置的决策树模型使用的是Gini impurity作为排序特征的度量指标。
虽然很难获取信用卡的客户资料,但有类似的借助客户档案进行二分类的任务。
本文进行泰坦尼克号的乘客的生还预测,许多专家尝试通过计算机模拟和分析找出隐藏在数据背后的生还逻辑。

Python源码:
#coding=utf-8
import pandas as pd
#-------------data split
from sklearn.cross_validation import train_test_split
#-------------feature transfer
from sklearn.feature_extraction import DictVectorizer
#-------------
from sklearn.tree import DecisionTreeClassifier
#-------------
from sklearn.metrics import classification_report

#-------------download data
titanic=pd.read_csv('http://biostat.mc.vanderbilt.edu/wiki/pub/Main/DataSets/titanic.txt')
print titanic.head()
#transfer to dataFrame format by pandas,use info() to show statistics of data
print titanic.info()
#-------------feature selection
X=titanic[['pclass','age','sex']]
y=titanic['survived']

print 'bf processing\n',X.info()
#-------------feature processing
X['age'].fillna(X['age'].mean(),inplace=True)
print 'af processing\n',X.info
#-------------data split
#75% training set,25% testing set
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.25,random_state=33)
#-------------feature transfer  from String to int
vec=DictVectorizer(sparse=False)
X_train=vec.fit_transform(X_train.to_dict(orient='record'))
#print vec.feature_names  60
#AttributeError: 'DictVectorizer' object has no attribute 'feature_names'
print vec.get_feature_names()
X_test=vec.transform(X_test.to_dict(orient='record'))
#-------------training
#initialize
dtc=DecisionTreeClassifier()
dtc.fit(X_train,y_train)
y_predict=dtc.predict(X_test)
#-------------performance
print 'The Accuracy is',dtc.score(X_test,y_test)
print classification_report(y_test,y_predict,target_names=['died','survived'])


Result:
   row.names pclass  survived  \

0          1    1st         1

1          2    1st         0

2          3    1st         0

3          4    1st         0

4          5    1st         1

                                              name      age     embarked  \

0                     Allen, Miss Elisabeth Walton  29.0000  Southampton

1                      Allison, Miss Helen Loraine   2.0000  Southampton

2              Allison, Mr Hudson Joshua Creighton  30.0000  Southampton

3  Allison, Mrs Hudson J.C. (Bessie Waldo Daniels)  25.0000  Southampton

4                    Allison, Master Hudson Trevor   0.9167  Southampton

                         home.dest room      ticket   boat     sex

0                     St Louis, MO  B-5  24160 L221      2  female

1  Montreal, PQ / Chesterville, ON  C26         NaN    NaN  female

2  Montreal, PQ / Chesterville, ON  C26         NaN  (135)    male

3  Montreal, PQ / Chesterville, ON  C26         NaN    NaN  female

4  Montreal, PQ / Chesterville, ON  C22         NaN     11    male

<class 'pandas.core.frame.DataFrame'>

RangeIndex: 1313 entries, 0 to 1312

Data columns (total 11 columns):

row.names    1313 non-null int64

pclass       1313 non-null object

survived     1313 non-null int64

name         1313 non-null object

age          633 non-null float64

embarked     821 non-null object

home.dest    754 non-null object

room         77 non-null object

ticket       69 non-null object

boat         347 non-null object

sex          1313 non-null object

dtypes: float64(1), int64(2), object(8)

memory usage: 112.9+ KB

None

bf processing

<class 'pandas.core.frame.DataFrame'>

RangeIndex: 1313 entries, 0 to 1312

Data columns (total 3 columns):

pclass    1313 non-null object

age       633 non-null float64

sex       1313 non-null object

dtypes: float64(1), object(2)

memory usage: 30.8+ KB

None

/Users/mac/workspace/conda/anaconda/lib/python2.7/site-packages/pandas/core/generic.py:3660: SettingWithCopyWarning:

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  self._update_inplace(new_data)

af processing

<bound method DataFrame.info of      pclass        age     sex

0       1st  29.000000  female

1       1st   2.000000  female

2       1st  30.000000    male

3       1st  25.000000  female

4       1st   0.916700    male

5       1st  47.000000    male

6       1st  63.000000  female

7       1st  39.000000    male

8       1st  58.000000  female

9       1st  71.000000    male

10      1st  47.000000    male

11      1st  19.000000  female

12      1st  31.194181  female

13      1st  31.194181    male

14      1st  31.194181    male

15      1st  50.000000  female

16      1st  24.000000    male

17      1st  36.000000    male

18      1st  37.000000    male

19      1st  47.000000  female

20      1st  26.000000    male

21      1st  25.000000    male

22      1st  25.000000    male

23      1st  19.000000  female

24      1st  28.000000    male

25      1st  45.000000    male

26      1st  39.000000    male

27      1st  30.000000  female

28      1st  58.000000  female

29      1st  31.194181    male

...     ...        ...     ...

1283    3rd  31.194181  female

1284    3rd  31.194181    male

1285    3rd  31.194181    male

1286    3rd  31.194181    male

1287    3rd  31.194181    male

1288    3rd  31.194181    male

1289    3rd  31.194181    male

1290    3rd  31.194181    male

1291    3rd  31.194181    male

1292    3rd  31.194181    male

1293    3rd  31.194181  female

1294    3rd  31.194181    male

1295    3rd  31.194181    male

1296    3rd  31.194181    male

1297    3rd  31.194181    male

1298    3rd  31.194181    male

1299    3rd  31.194181    male

1300    3rd  31.194181    male

1301    3rd  31.194181    male

1302    3rd  31.194181    male

1303    3rd  31.194181    male

1304    3rd  31.194181  female

1305    3rd  31.194181    male

1306    3rd  31.194181  female

1307    3rd  31.194181  female

1308    3rd  31.194181    male

1309    3rd  31.194181    male

1310    3rd  31.194181    male

1311    3rd  31.194181  female

1312    3rd  31.194181    male

[1313 rows x 3 columns]>

['age', 'pclass=1st', 'pclass=2nd', 'pclass=3rd', 'sex=female', 'sex=male']

The Accuracy is 0.781155015198

             precision    recall  f1-score   support

       died       0.78      0.91      0.84       202

   survived       0.80      0.58      0.67       127

avg / total       0.78      0.78      0.77       329

该数据共有1313条乘客信息,有些特征数据是缺失的,有些是数值类型,有些是字符串。
预处理环节中特征的选择十分重要,需要一些背景知识,根据对事故的了解,sex,age,pclass都可能是关键因素。
需要完成的数据处理任务:
1.初始的数据中,age列只有633个需要补充完整,一般,使用平均数或者中位数都是对模型偏离造成最小影响的策略。
2.sex和pclass列的值是列别型,需转化为数值特征,用0/1代替
算法特点:
相对于其它的模型,决策树在模型描述上有巨大的优势,推断逻辑非常直观,具有清晰的可解释性,也方便了模型的可视化。这些特征同时也保证使用该模型时,无需考虑对数据的量化甚至标准化的。与KNN不同,DT仍然属于有参数模型,需花费更多的时间在训练数据上。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息