spark2.0 AFTSurvivalRegression算法

来源:互联网 发布:(,) 矩阵 编辑:程序博客网 时间:2024/06/08 14:58

spark2.0的机器学习算法比之前的改变最大的是2.0基本采用了dataframe来实现的,但事前的都是用的RDD,看官网说貌似在3.0的时候RDD就不用了,不知道真的假的。

还有一个就是hiveContext和sqlcontext进行了合并,统一是sessioncontext

val spark = SparkSession  .builder  .appName("AFTSurvivalRegressionExample").master("local")  .getOrCreate()

AFTSurvivalRegression
实现了加速失效时间(AFT)模型,这是一个用于检查数据的参数生存回归模型。 它描述了生存时间对数的模型,因此它通常被称为生存分析的对数线性模型
val training = spark.createDataFrame(Seq(  (1.218, 1.0, Vectors.dense(1.560, -0.605)),  (2.949, 0.0, Vectors.dense(0.346, 2.158)),  (3.627, 0.0, Vectors.dense(1.380, 0.231)),  (0.273, 1.0, Vectors.dense(0.520, 1.151)),  (4.199, 0.0, Vectors.dense(0.795, -0.226)))).toDF("label", "censor", "features")第一个label表示的是存活的时间,你可以把这个模型看做是预测你能活多长时间的,当然是需要很多方面的参数的不然就是在扯淡了,虽然这预测听起来很扯淡。。。。。。
第二个censor是结局,1表示死亡,0表示删失数据,病历失访或者尚存活表现在病人身上就是,你这个人得了一个癌症,根据你的各项指标,用这个模型预测你能活的时间听起来就很残酷,1表示这个人已经去世,0可能是还活着或者其他因素而没获取到数据后面的几个参数就是各种病症或者身体情况的症状了,最终都要转化为数据的形式,俗称归一化分位数概率数组参数。分位数概率数组的值应在范围内(0,1)数组应该是非空的。
val quantileProbabilities = Array(0.3, 0.6)val aft = new AFTSurvivalRegression()  .setQuantileProbabilities(quantileProbabilities)如果设置该列,则会输出相应的分位数概率的分位数 .setQuantilesCol("quantiles")
val model = aft.fit(training)输出模型的系数
println(s"Coefficients: ${model.coefficients}")模型的截距 println(s"Intercept: ${model.intercept}")源码里面是这个 val scale = math.exp(parameters(0)) println(s"Scale: ${model.scale}")Coefficients: [-0.4963111466650707,0.19844437699933098]Intercept: 2.63809461510401Scale: 1.5472345574364692model.transform(training).show(false)
+-----+------+--------------+------------------+--------------------------------------+
|label|censor|features      |prediction        |quantiles                             |
+-----+------+--------------+------------------+--------------------------------------+
|1.218|1.0   |[1.56,-0.605] |5.718979487635007 |[1.1603238947151664,4.99545601027477] |
|2.949|0.0   |[0.346,2.158] |18.07652118149533 |[3.667545845471739,15.789611866277625]|
|3.627|0.0   |[1.38,0.231]  |7.381861804239096 |[1.4977061305190829,6.44796261233896] |
|0.273|1.0   |[0.52,1.151]  |13.577612501425284|[2.7547621481506854,11.8598722240697] |
|4.199|0.0   |[0.795,-0.226]|9.013097744073898 |[1.8286676321297826,7.87282650587843] |
+-----+------+--------------+------------------+--------------------------------------+
还可以通过类似sql的方式来选择展示结果
model.transform(training).
selectExpr(
  "label", "censor",
  "round(prediction,2) as prediction").orderBy("label")









原创粉丝点击