机器学习实战——改进约会网站匹配效果

来源:互联网 发布:直播cms 编辑:程序博客网 时间:2024/06/06 03:46

接上文,改进约会网站的匹配效果,数据集有四列,分别为:飞行时间,玩游戏时间,冰淇淋消费,是否为感兴趣的约会对象。其中是否为感兴趣的约会对象分为三类:不感兴趣,有点感兴趣和非常感兴趣。

def file2matrix(filename):  #读入文本记录
    fr = open(filename)
    numberOfLines = len(fr.readlines())         #get the number of lines in the file
    returnMat = zeros((numberOfLines,3))        #prepare matrix to return
    classLabelVector = []                       #prepare labels return  
    fr = open(filename)
    index = 0
    for line in fr.readlines():
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3]
        classLabelVector.append(int(listFromLine[-1]))
        index += 1
    return returnMat,classLabelVector
   

 len(fr.readlines()) :获取整个文件有多少行

 zeros((numberOfLines,3))   :生成一个空的矩阵,内容都是0,这样生成二维矩阵,可以明确有几行几列

returnMat[index,:]  :表示对returnMat中第index行所有元素按从头到尾顺序赋值,:前后都省略,表示从编号0项开始直到最后一位

 listFromLine[0:3]   :实际上是左闭右开区间,包括0但不包括3

.append :是list中不断在末尾增加值的方法

这里主要说明了python中读文件和将文件内容转化为矩阵


def autoNorm(dataSet):    #数据归一化
    minVals = dataSet.min(0)   #取最小值
    maxVals = dataSet.max(0) #取最大值
    ranges = maxVals - minVals
    normDataSet = zeros(shape(dataSet)) #建一个和dataSet形状相同的矩阵
    m = dataSet.shape[0]
    normDataSet = dataSet - tile(minVals, (m,1))
    normDataSet = normDataSet/tile(ranges, (m,1))   #newValue=(oldValue-min)/(max-min)
    return normDataSet, ranges, minVals

normDataSet = zeros(shape(dataSet)) :建一个和dataSet形状相同的矩阵,用0填充

这里的归一化,也是全部用矩阵处理,比起写循环简练很多


def datingClassTest():       #计算准确率
    hoRatio = 0.50      #hold out 10%
    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file
    normMat, ranges, minVals = autoNorm(datingDataMat)  #归一化
    m = normMat.shape[0]
    numTestVecs = int(m*hoRatio)
    errorCount = 0.0
    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
        print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
        if (classifierResult != datingLabels[i]): errorCount += 1.0   #计算预测错误的个数
    print "the total error rate is: %f" % (errorCount/float(numTestVecs))
    print errorCount




0 0