MLLib之LogisticRegression

来源:互联网 发布:dede cms 编辑:程序博客网 时间:2024/06/16 22:09

MLlib 线性回归

1. 数据输入:

case_data.txt

1,1 1
1,1.1 0.9
1,1 1.2
2,10 11
2,9 10
2,10 12
3,50 52
3,49 50
3,48 49

from pyspark.mllib.linalg import Vectors
from pyspark.mllib.regression import LabeledPoint

def parseLine(line):
parts = line.split(',')
label = float(parts[0])
print(parts[1])
print(parts[1].split(' '))
features = Vectors.dense([float(x) for x in parts[1].split(' ')])
return LabeledPoint(label, features)

df = sc.textFile(dataPath).map(parseLine)

2. 训练模型:

def logisticRegression(df,arguments):
"""
Only supports binary classification
"""
from pyspark.mllib.classification import LogisticRegressionWithSGD
maxIter = 100
if arguments.maxIter != None:
maxIter = float(arguments.maxIter)
lrModel = LogisticRegressionWithSGD.train(df,iterations=maxIter)
return lrModel

modelPath = arguments.modelPath
model.save(sc, modelPath)

3. 预测输入数据

data=Vectors.dense([float(x) for x in dataSet.split(',')])

预测:

from pyspark.mllib.classification import LogisticRegressionModel
model = LogisticRegressionModel.load(sc,modelPath)

prediction = model.predict(data)


原创粉丝点击