python机器学习案例系列教程——线性分类器

来源:互联网 发布:ubuntu 安装软件中心 编辑:程序博客网 时间:2024/05/17 06:12

全栈工程师开发手册 (作者:栾鹏)

python数据挖掘系列教程

场景:男女不同的属性信息,例如年龄、是否吸烟、是否要孩子、兴趣列表、家庭住址。产生的输出结果,配对成功还是不成功。

我们先使用一个简单的年龄配对的数据集agesonly.csv,如下
每一行分别是女性年龄、男性年龄、是否配对成功(1表示配对成功,0表示配对不成功)

24,30,1

数据构造函数

# 定义数据类class matchrow:    def __init__(self,row,allnum=False):        if allnum:            self.data=[float(row[i]) for i in range(len(row)-1)]   #如果每个属性都是数字就转化为浮点型        else:            self.data=row[0:len(row)-1]  #如果并不是数字,就保留源数据类型        self.match=int(row[len(row)-1])   #最后一位表示匹配结果,0表示匹配失败,1表示匹配成功# 从文件中加载数据.allnum表示是否所有属性都是数字def loadmatch(filename,allnum=False):    rows=[]    for line in open(filename):        rows.append(matchrow(line.split(','),allnum))    return rows

为了查看方便我们使用图形的方式展示。
可视化函数

import matplotlib.pyplot as plt# 绘制只根据年龄进行配对的结果分布散点图def plotagematches(rows):    xdm,ydm=[r.data[0] for r in rows if r.matchresult==1],[r.data[1] for r in rows if r.matchresult==1]    xdn,ydn=[r.data[0] for r in rows if r.matchresult==0],[r.data[1] for r in rows if r.matchresult==0]    plt.plot(xdm,ydm,'bo')    plt.plot(xdn,ydn,'b+')    plt.show()

调用代码

if __name__=='__main__':  #只有在执行当前模块时才会运行此函数    agesonly = loadmatch('agesonly.csv')    #读入只关注年龄的配对情况    plotagematches(agesonly)    #绘制年龄配对散点图

这里写图片描述

上图展示了年龄配对数据的可视化效果。+号表示配对失败的,实心圆表示配对成功的。横坐标为男性年龄,纵坐标为女性年龄。

使用决策树分类

我们知道决策树更适合对离散型数据进行分类,如果是连续型数据(年龄)就需要对数据进行区域划分。

在查看线性分类器以前,这里我们可以先看一下决策树尝试的结果。

这里写图片描述

最终产生的决策树非常复杂,效果不佳。

这里写图片描述

基本的线性分类

线性分类的原理很简单:计算样本数据每个分类中所有节点的平均值。对新输入对象计算到哪个中心点最近就属于哪个分类。

对于只有场景中只有两个分类(匹配成功或不成功)的情况,即两个分类的均值点连线的垂直平分线就是分割线。

首先要实现一个计算样本数据集每个分类均值点的函数

# 使用基本的线性分类。rows为样本数据集。(计算样本数据每个分类中所有节点的平均值。对新输入对象计算到哪个中心点最近就属于哪个分类)def lineartrain(rows):    averages={}    counts={}    for row in rows:        # 得到该坐标点所属的分类        cat=row.matchresult        averages.setdefault(cat,[0.0]*(len(row.data)))        counts.setdefault(cat,0)        # 将该坐标点加入averages中。每个维度都要求均值        for i in range(len(row.data)):            averages[cat][i]+=float(row.data[i])        # 记录每个分类中有多少个坐标点        counts[cat]+=1    # 将总和除以计数值以求得平均值    for cat,avg in averages.items():        for i in range(len(avg)):            avg[i]/=counts[cat]    return averages

我们通过画图查看显示效果

可视化函数

# 绘制线性分类器均值点和分割线def plotlinear(rows):    xdm,ydm=[r.data[0] for r in rows if r.matchresult==1],[r.data[1] for r in rows if r.matchresult==1]    xdn,ydn=[r.data[0] for r in rows if r.matchresult==0],[r.data[1] for r in rows if r.matchresult==0]    plt.plot(xdm,ydm,'bo')    plt.plot(xdn,ydn,'b+')    # 获取均值点    averages = lineartrain(rows)    #绘制均值点    averx=[]    avery=[]    for value in averages.values():        averx.append(value[0])        avery.append(value[1])        plt.plot(averx,avery,'r*')    #绘制垂直平分线作为分割线    # y=-(x1-x0)/(y1-y0)* (x-(x0+x1)/2)+(y0+y1)/2    xnew = range(15,60,1)    print(xnew)    print(averx,avery)    ynew = [-(averx[1]-averx[0])/(avery[1]-avery[0])*(x-(averx[0]+averx[1])/2)+(avery[0]+avery[1])/2 for x in xnew]    plt.plot(xnew, ynew, 'r--')    plt.axis([15, 52, 15, 50])  #设置显示范围    plt.show()

调用函数

if __name__=='__main__':  #只有在执行当前模块时才会运行此函数    agesonly = loadmatch('agesonly.csv')    #读入只关注年龄的配对情况    plotlinear(agesonly)   #绘制线性分类器均值点和分割线

这里写图片描述

图中红色五角星为两种分类的均值点,红色线为分割线。由于不匹配的用户在并不分布在连续区域,而且两个分类的所在区域也不相同,线性分类基本无法使用。

更复杂的线性分类器

上面的线性分类器使用数值的样本数据,距离哪个最近使用的欧几里德距离,而实际中我们会遇到的情况可能更复杂。还是那婚介数据集为例,我们进行复杂数据集的线性分类。

数据集存储在matchmaker.csv,每行数据格式如下
每一行分表表示女性年龄,是否吸烟、是否要孩子、兴趣列表、家庭住址、男性年龄、是否吸烟、是否要孩子、兴趣列表、家庭住址、是否配对成功

39,yes,no,skiing:knitting:dancing,220 W 42nd St New York NY,43,no,yes,soccer:reading:scrabble,824 3rd Ave New York NY,0

在数据集中除了数值型数据(年龄),还有分类型数据(是否吸烟,是否要孩子),还有列表型数据(兴趣列表),以及信息型数据(家庭住址)。而这些信息对于最终的匹配结果都是有用的。

我们需要构造一个新的数据集,将属性全部划归为数值型数据,并对不同的取值范围归一化缩放的相同范围,并且采用另一种距离计算方式——点乘。

当然具体如果转化要看实际问题,我们这里只是以婚介数据集为例,实现了针对婚介数据集的转化方式。

将是否问题转化为数值

# 将是否问题转化为数值。yes转化为1,no转化为-1,缺失或模棱两可转化为0def yesno(v):    if v=='yes': return 1    elif v=='no': return -1    else: return 0

将列表转化为数值

# 将列表转化为数值。获取公共项的数目。获取两个人相同的兴趣数量def matchcount(interest1,interest2):    l1=interest1.split(':')    l2=interest2.split(':')    x=0    for v in l1:        if v in l2: x+=1    return x

将家庭住址信息转化为距离数值

# 利用百度地图来计算两个人的位置距离baidukey="tc42noD8p3SO1hZhFTryMeRv"import urllibimport json# 使用geocoding api发起指定格式的请求,解析指定格式的返回数据,获取地址的经纬度# http://api.map.baidu.com/geocoder/v2/?address=北京市海淀区上地十街10号&output=json&ak=您的ak&callback=showLocationak ='HIa8GVmtk9WSjhuevGfqMCGu'loc_cache={}def getlocation(address):   #这个结果每次获取最好存储在数据库中,不然每次运行都要花费大量的时间获取地址    if address in loc_cache: return loc_cache[address]    urlpath = 'http://api.map.baidu.com/geocoder/v2/?address=%s&output=json&ak=%s' % (urllib.parse.quote_plus(address),ak)    data=urllib.request.urlopen(urlpath).read()    response = json.loads(data,encoding='UTF-8')  # dict    if not response['result']:        print('没有找到地址:'+address)        return None    long = response['result']['location']['lng']    lat = response['result']['location']['lat']    loc_cache[address]=(float(lat),float(long))    print('地址:' + address+"===经纬度:"+str(loc_cache[address]))    return loc_cache[address]# 计算两个地点之间的实际距离def milesdistance(a1,a2):    try:        lat1,long1=getlocation(a1)        lat2,long2=getlocation(a2)        latdif=69.1*(lat2-lat1)        longdif=53.0*(long2-long1)        return (latdif**2+longdif**2)**.5    except:        return None

构造新的数据集

# 构造新的数据集。包含各个复杂属性转化为数值数据def loadnumerical():    oldrows=loadmatch('matchmaker.csv')    newrows=[]    for row in oldrows:        d=row.data        distance = milesdistance(d[4],d[9])  # 以为有可能无法获取地址的经纬度,进而无法获取两地之间的距离,这里就成了缺失值。我们暂且直接抛弃缺失值        if distance:            data=[float(d[0]),yesno(d[1]),yesno(d[2]),                  float(d[5]),yesno(d[6]),yesno(d[7]),                  matchcount(d[3],d[8]),distance,row.matchresult]            newrows.append(matchrow(data))    return newrows

对数值数据进行缩放

# 对数据进行缩放处理,全部归一化到0-1上,因为不同参考变量之间的数值尺度不同def scaledata(rows):    low=[999999999.0]*len(rows[0].data)    high=[-999999999.0]*len(rows[0].data)    # 寻找最大值和最小值    for row in rows:        d=row.data        for i in range(len(d)):            if d[i]<low[i]: low[i]=d[i]            if d[i]>high[i]: high[i]=d[i]    # 对数据进行缩放处理的函数    def scaleinput(d):        return [(d[i]-low[i])/(high[i]-low[i])                for i in range(len(low))]    # 对所有数据进行缩放处理    newrows=[matchrow(scaleinput(row.data)+[row.matchresult]) for row in rows]    # 返回新的数据和缩放处理函数    return newrows,scaleinput

测试代码

现在我们可以使用线性分类器对复杂数据集进行分类了。

if __name__=='__main__':  #只有在执行当前模块时才会运行此函数    numercalset=loadnumerical()   #获取转化为数值型的复杂数据集    scaledset,scalef=scaledata(numercalset)  #对复杂数据集进行比例缩放    catavgs = lineartrain(scaledset)   #计算分类均值点    print(catavgs)    onedata = scalef(numercalset[0].data) #取一个数据作为新数据先比例缩放    dpclassify(onedata,catavgs)   #使用点积结果来判断属于哪个分类
阅读全文
0 0