您的位置:首页 > 其它

kaggle——泰坦尼克号生死预测

2018-04-02 09:50 393 查看
把很久以前做的泰坦尼克号的代码贴出来。

# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 14:23:12 2017

@author: Yichengfan
"""

import pandas as pd

train = pd.read_csv(r"F:\TS\03_other_parts\Titanic\02_data\train.csv")
test = pd.read_csv(r"F:\TS\03_other_parts\Titanic\02_data\test.csv")

#先分别输出训练集和测试数据的基本信息,这是一个好习惯,可以对数据的规模,
#各个特征的数据类型以及是否缺失等,有一个整体的了解
print(train.info())
print(test.info())
'''
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
PassengerId    891 non-null int64
Survived       891 non-null int64
Pclass         891 non-null int64
Name           891 non-null object
Sex            891 non-null object
Age            714 non-null float64
SibSp          891 non-null int64
Parch          891 non-null int64
Ticket         891 non-null object
Fare           891 non-null float64
Cabin          204 non-null object
Embarked       889 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.6+ KB
None
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 418 entries, 0 to 417
Data columns (total 11 columns):
PassengerId    418 non-null int64
Pclass         418 non-null int64
Name           418 non-null object
Sex            418 non-null object
Age            332 non-null float64
SibSp          418 non-null int64
Parch          418 non-null int64
Ticket         418 non-null object
Fare           417 non-null float64
Cabin          91 non-null object
Embarked       418 non-null object
dtypes: float64(2), int64(4), object(5)
memory usage: 36.0+ KB
None
'''

selectd_features = ['Pclass','Sex', 'Age', 'Embarked','SibSp','Parch', 'Fare']

X_train = train[selectd_features]
X_test = test[selectd_features]

y_train = train['Survived']

#通过之前对数据的总体观察,得知Embarked特征存在缺失值,需要补充
print (X_train['Embarked'].value_counts())
print (X_test['Embarked'].value_counts())
'''
S    644
C    168
Q     77
Name: Embarked, dtype: int64
S    270
C    102
Q     46
Name: Embarked, dtype: int64
'''

#对于Embarked这种类别的型的特征,我们使用出现频率最高的特征值来填充,
#这也是相对可以减少引入误差的一种填充方法

X_train['Embarked'].fillna('S', inplace = True)
X_test['Embarked'].fillna('S', inplace = True)

#对于Age这种数值类型的特征,我们习惯使用求平均值或者中位数来填充缺失值,
#也是相对可以减少引入误差的一种填充方法

X_train['Age'].fillna(X_train['Age'].mean(), inplace = True)
X_test['Age'].fillna(X_test['Age'].mean(), inplace = True)
X_test['Fare'].fillna(X_test['Fare'].mean(), inplace = True)

#重新处理后的训练和测试数据进行验证
print(X_train.info())
print(X_test.info())
'''
\<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 7 columns):
Pclass      891 non-null int64
Sex         891 non-null object
Age         891 non-null float64
Embarked    891 non-null object
SibSp       891 non-null int64
Parch       891 non-null int64
Fare        891 non-null float64
dtypes: float64(2), int64(3), object(2)
memory usage: 48.8+ KB
None
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 418 entries, 0 to 417
Data columns (total 7 columns):
Pclass      418 non-null int64
Sex         418 non-null object
Age         418 non-null float64
Embarked    418 non-null object
SibSp       418 non-null int64
Parch       418 non-null int64
Fare        418 non-null float64
dtypes: float64(2), int64(3), object(2)
memory usage: 22.9+ KB
None
'''

#接下来采用DictVectorizer对特征进行向量化

from sklearn.feature_extraction import DictVectorizer

dict_vec = DictVectorizer(sparse = False)
X_train = dict_vec.fit_transform(X_train.to_dict(orient = 'record'))
dict_vec.feature_names_
'''
['Age',
'Embarked=C',
'Embarked=Q',
'Embarked=S',
'Fare',
'Parch',
'Pclass',
'Sex=female',
'Sex=male',
'SibSp']
'''

X_test = dict_vec.fit_transform(X_test.to_dict(orient = 'record'))

#从sklearn中引入 RandomForestClassifier
from sklearn.ensemble import RandomForestClassifier
#使用默认配置初始化RandomForestClassifier
rfc = RandomForestClassifier()

#从流行的工具包XGBoost导入XGBClassifier
from xgboost import XGBClassifier
xgbc = XGBClassifier()

from sklearn.cross_validation import cross_val_score

#使用5折交叉验证的方法在训练集上分别对默认配置的RandomForestClassifier和
#XGBClassifier进行性能评估,并获得平均分类器准确的得分

cross_val_score(rfc, X_train, y_train, cv= 5).mean()
'''
0.80476830342149963
'''

cross_val_score(xgbc, X_train, y_train, cv= 5).mean()
'''
0.81824559798311003
'''

from sklearn.linear_model import LogisticRegression
lr = LogisticRegression()
cross_val_score(lr, X_train, y_train, cv= 5).mean()
'''
0.79128522828142689
'''

#使用默认配置的RandomForestClassifier进行预操作
rfc.fit(X_train, y_train)
rfc_y_predict = rfc.predict(X_test)
rfc_submission = pd.DataFrame({'PassengerId':test['PassengerId'],
'Survived':rfc_y_predict})
#将RandomForestClassifier测试数据存储在文件中
rfc_submission.to_csv(r'F:\TS\03_other_parts\Titanic\04_output\rfc_submission.csv'
,index = False)

#使用默认配置的RandomForestClassifier进行预操作
xgbc.fit(X_train, y_train)
'''
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
colsample_bytree=1, gamma=0, learning_rate=0.1, max_delta_step=0,
max_depth=3, min_child_weight=1, missing=None, n_estimators=100,
n_jobs=1, nthread=None, objective='binary:logistic', random_state=0,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
silent=True, subsample=1)
'''
xgbc_y_predict = xgbc.predict(X_test)
xgbc_submission = pd.DataFrame({'PassengerId':test['PassengerId'],
'Survived':xgbc_y_predict})
#将RandomForestClassifier测试数据存储在文件中
xgbc_submission.to_csv(r'F:\TS\03_other_parts\Titanic\04_output\xgbc_submission.csv'
,index = False)

#使用并行网格搜索的方式寻找更好的超参数组合,以期待进一步提供XGBClassifier的预测性能
from sklearn.grid_search import GridSearchCV
params = {'max_depth':list(range(2,7)),'n_estimators':list(range(100,1100,200)),
'learning_rate':[0.05,0.1,0.25,0.5,1.0]}

xgbc_best = XGBClassifier()
#n_jobs= -1使用计算机全部的CPU核数
gs = GridSearchCV(xgbc_best, params, n_jobs= -1, cv = 5,verbose = 1)
gs.fit(X_train, y_train)

#使用经过优化超参数配置的XGBClassifier的超参数配置以及交叉验证的准确性
print (gs.best_score_)
print (gs.best_params_)

#使用经过优化的超参数配置的XGBClassifier对测试数据的预测结果存储在文件xgbc_best_submission中
xgbc_best_y_predict = gs.predict(X_test)
xgbc_best_submission = pd.DataFrame({'PassengerId':test['PassengerId'],
'Survived':rfc_y_predict})
xgbc_best_submission.to_csv(r'F:\TS\03_other_parts\Titanic\04_output\xgbc_submission.csv' ,index = False)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: