机器学习笔记:构建简单决策树模型

来源:互联网 发布:薛之谦封面抄袭知乎 编辑:程序博客网 时间:2024/05/16 08:02

简介

本文是优达学城机器学习纳米学位作业,目标是引导学生构建一个决策树模型来预测泰坦尼克号上的乘客生还率。

依赖库

import numpy as npimport pandas as pd# RMS Titanic data visualization codefrom titanic_visualizations import survival_statsfrom IPython.display import display%matplotlib inline

其中自定义库代码见文末附录。

数据

# Load the datasetin_file = 'titanic_data.csv'full_data = pd.read_csv(in_file)# Print the first few entries of the RMS Titanic datadisplay(full_data.head())

  • Survived:是否存活
  • Pclass:社会阶级
  • SibSp:兄弟姐妹与配偶数量
  • Parch:父母与小孩数量
  • Ticket:船票编号
  • Fare:支付费用
  • Embarked:登船港口(C 代表Cherbourg,Q 代表Queenstown,S 代表Southampton)

因为需要预测的是是否存活即Survived项,所以将这一字段数据移动到新变量outcomes中:

outcomes = full_data['Survived']data = full_data.drop('Survived', axis = 1)

决策模型

全部死亡

对于灾难性事件,最简单的决策就是预测乘客全部死亡:

def predictions_0(data):    """ Model with no features. Always predicts a passenger did not survive. """    predictions = []    for _, passenger in data.iterrows():        predictions.append(0)    return pd.Series(predictions)predictions = predictions_0(data)print(accuracy_score(outcomes, predictions))


可以看到乘客的死亡率到达了61.62%,因为即使简单地做出死亡判定也有61.62%的正确率。

加入性别特征

显然这种简单的只作出一种判断的方法是不实用的,我们需要加入一些样本中的特征值来提高准确度。
首先考虑“性别”特征:

survival_stats(data, outcomes, 'Sex')


这这张图不难想到,想要提高准确度最简单的方法就是预测所有女性全部存活,而所有男性全部死亡:

def predictions_1(data):    """ Model with one feature:             - Predict a passenger survived if they are female. """    predictions = []    for _, passenger in data.iterrows():        if(passenger['Sex']=="female"):            predictions.append(1)        else:            predictions.append(0)    return pd.Series(predictions)predictions = predictions_1(data)print(accuracy_score(outcomes, predictions))


准确率一下子提高到了78.68%。

加入年龄特征

为进一步提高准确度,需要在男性乘客中找出存活率较高乘客的特征,先考虑年龄:

survival_stats(data, outcomes, 'Age', ["Sex == 'male'"])

可以看到10岁以下的小男孩基本上都存活了下来,所以可以在上一布的基础上再预测10岁以下的男性全部存活:

def predictions_2(data):    """ Model with two features:             - Predict a passenger survived if they are female.            - Predict a passenger survived if they are male and younger than 10. """    predictions = []    for _, passenger in data.iterrows():        if((passenger['Sex']=="female")or(passenger['Age']<10)):            predictions.append(1)        else:            predictions.append(0)    return pd.Series(predictions)predictions = predictions_2(data)print(accuracy_score(outcomes, predictions))

多特征

如果需要再提高准确度,需要在活下来的人群中筛选出更多特征,按不同限定条件分别绘制存活人数与特征的关系已找出存活率更高的特征人群。

此处的难点在于如何根据特征来划分这些人群,从而在由不同特征集所划分的人群中找到存活率高的特征集。

性别与舱位

优先考虑女性:

survival_stats(data, outcomes, 'Pclass', ["Sex == 'female'"])


在决策模型中加入条件:
(passenger['Sex']=="female")and(passenger['Pclass']<3)

年龄与舱位

上一步已包含了中层阶级与上层阶级的女性,此处只考虑男性:

survival_stats(data, outcomes, 'Pclass', ["Sex == 'male'","Age < 10"])


在模型中加入条件:
(passenger['Sex']=='male')and(passenger['Age']<10)and(passenger['Pclass']<3)

于是模型变为:

def predictions_3(data):    """ Model with multiple features."""    predictions = []    for _, passenger in data.iterrows():        # Remove the 'pass' statement below and write your prediction conditions here        if(((passenger['Sex']=="female")and(passenger['Pclass']<3))           or((passenger['Sex']=='male')and(passenger['Age']<10)and(passenger['Pclass']<3))):            predictions.append(1)        else:            predictions.append(0)    return pd.Series(predictions)predictions = predictions_3(data)print(accuracy_score(outcomes, predictions))


作业要求是准确度达到80%,但是在筛选特征的时候,在男性乘客中找出存活率高的人群实在是太难了,这个准确度将就一下算了。

总结

决策树每次按照一个特征把数据分割成越来越小的群组(被称为 nodes)。每次数据的一个子集被分出来,如果分割结果的子集中的数据比之前更同质(包含近似的标签),我们的预测也就更加准确。

原创粉丝点击