基于FTRL的在线CTR预测算法
来源:互联网 发布:类似于知乎的app 编辑:程序博客网 时间:2024/05/16 19:36
在程序化广告投放中,一个优秀的CTR预测算法会给广告主、Adx以及用户都将带来好处。Google公司2013在《ResearchGate》上发表了一篇“Ad click prediction: a view from the trenches”论文,这篇论文是基于FTRL的在线CTR预测算法,下面将讲解该算法的主要思想以及Java实现。
什么是Online Learning
传统的批量算法的每次迭代是对全体训练数据集进行计算(例如计算全局梯度),优点是精度和收敛还可以,缺点是无法有效处理大数据集(此时全局梯度计算代价太大),且没法应用于数据流做在线学习。而在线学习算法的特点是:每来一个训练样本,就用该样本产生的loss和梯度对模型迭代一次,一个一个数据地进行训练,因此可以处理大数据量训练和在线训练。准确地说,Online Learning并不是一种模型,而是一种模型的训练方法,Online Learning能够根据线上反馈数据,实时快速地进行模型调整,使得模型及时反映线上的变化,提高线上预测的准确率。Online Learning的流程包括:将模型的预测结果展现给用户,然后收集用户的反馈数据,再用来训练模型,形成闭环的系统。如下图所示:
这篇论文提出的基于FTRL的在线CTR预测算法,就是一种Online Learning算法。即,针对每一个训练样本,首先通过一种方式进行预测,然后再利用一种损失函数进行误差评估,最后再通过所评估的误差值对参数进行更新迭代。直到所有样本全部遍历完,则结束。那么,如何选择模型预测方法、评估指标以及模型更新公式就是该算法的重点所在。下面将介绍论文中这三部分内容:
预测方法:在每一轮
t 中,针对特征样本xt∈Rd ,以及迭代后(第一此则是给定初值)的模型参数wt ,我们可以预测该样本的标记值:pt=σ(wt,xt) ,其中σ(a)=1/(1+exp(−a)) 是一个sigmoid函数。损失函数:对一个特征样本
xt ,其对应的标记为yt∈0,1 ,则通过LogLoss(logistic loss)来作为损失函数,即:lt(wt)=−ytlogpt−(1−yt)log(1−pt) 迭代公式:我们的目的是使得损失函数尽可能的小,即可以采用极大似然估计来求解参数。首先求梯度
gt=dltdw=(σ(w∗xt)−yt)xt=(pt−yt)xt ,使用FTRL进行迭代:
其中,σs 为学习率且σ1:t=1nt ,g1:t=∑ts=1gt ,λ1 为正则化参数。该最优化公式可以化简为:
则,如果我们令zt−1=g1:t−1−∑t−1s=1σsws ,则在第t 轮迭代前,令zt=zt−1+gt−(1nt−1nt−1)wt (此处和论文中的公式不一致,我觉得应该是减去最后一项,而不是加,作者在后面伪代码中也改成了减,故此处可能是作者笔误)
下面令梯度为0,则可以得到该优化问题的解析解:
到此就叙述完该算法的理论部分了,我想大部分人对这部分也不太感兴趣吧,下面直接上伪代码和Java实现吧(过程和理论部分其实是一致的,嘿嘿,想深入的还是研究下理论部分吧):
基于FTRL的在线CTR预测算法的Java实现
模型参数类
package DataClass;public class FTRLParameters { public double alpha;//学习速率参数 public double beta;//调整参数,值为1时效果较好,无需调整 public double L1_lambda;//L1范式参数 public double L2_lambda;//L2范式参数 public int dataDimensions;//数据特征维度数 public int testDataSize;//测试集分次处理每次处理的个数 public int interval;//每间隔interval进行一次打印 public String modelPath;//模型训练参数的存放路径 public FTRLParameters(double alpha, double beta, double L1, double L2, int dataDimensions,int testDataSize, int interval,String modelPath) { this.alpha = alpha; this.beta = beta; this.L1_lambda = L1; this.L2_lambda = L2; this.dataDimensions = dataDimensions; this.testDataSize = testDataSize; this.interval = interval; this.modelPath = modelPath; }}
模型训练类
package model;import DataPreprocessing.FileOperation;import evaluate.LogLossEvalutor;import java.io.*;import java.util.Map;import java.util.TreeMap;public class FTRLLocalTrain { private FTRLProximal learner; private FTRLModelLoad mload; private LogLossEvalutor evalutor; private int printInterval; public FTRLLocalTrain(FTRLModelLoad mload, FTRLProximal learner, LogLossEvalutor evalutor, int interval) { this.mload = mload; this.learner = learner; this.evalutor = evalutor; this.printInterval = interval; } /** * 训练方法 * */ public void train(String modelPath,double[][] X,double[] Y) throws IOException { int trainedNum = 0; double totalLoss = 0.0;//损失值 long startTime = System.currentTimeMillis(); BufferedReader mp = new BufferedReader(new InputStreamReader(new FileInputStream(new File(modelPath)), "UTF-8")); while((line = mp.readLine())!=null){ learner.loadModel(modelPath); } for(int j=0;j<X.length;j++){ Map<Integer, Double> x = new TreeMap<Integer, Double>(); for (int i = 0; i < X[0].length; i++) { x.put(i, X[j][i]); } double y = ((int)Y[j] == 1) ? 1. : 0.; double p = learner.predict(x); learner.updateModel(x, p, y); double loss = LogLossEvalutor.calLogLoss(p, y); evalutor.addLogLoss(loss); totalLoss += loss; trainedNum += 1; if (trainedNum % printInterval == 0) { long currentTime = System.currentTimeMillis(); double minutes = (double) (currentTime - startTime) / 60000; System.out.printf("%.3f, %.5f\n", minutes, evalutor.getAverageLogLoss()); } } learner.saveModel(modelPath); System.out.printf("global average loss: %.5f\n", totalLoss / trainedNum); }}
模型更新类
package model;import DataClass.FTRLParameters;import java.io.*;import java.util.HashMap;import java.util.Map;import java.util.Map.Entry;public class FTRLProximal { // parameters->alpha, beta, l1, l2, dimensions private FTRLParameters parameters; // n->squared sum of past gradients public double[] n; // z->weights public double[] z; // w->lazy weights public Map<Integer, Double> w; public double[] n_; public double[] z_; public Map<Integer, Double> w_; public FTRLProximal(FTRLParameters parameters) { this.parameters = parameters; this.n = new double[parameters.dataDimensions]; this.z = new double[parameters.dataDimensions]; this.w = null; } /** x->p(y=1|x; w) , get w, nothing is changed*/ public double predict(Map<Integer, Double> x) { w = new HashMap<Integer, Double>(); double decisionValue = 0.0; for (Entry<Integer, Double> e : x.entrySet()) { double sgn = sign(z[e.getKey()]); double weight = 0.0; if (sgn * z[e.getKey()] <= parameters.L1_lambda) { w.put(e.getKey(), weight); } else { weight = (sgn * parameters.L1_lambda - z[e.getKey()]) / ((parameters.beta + Math.sqrt(n[e.getKey()])) / parameters.alpha + parameters.L2_lambda); w.put(e.getKey(), weight); } decisionValue += e.getValue() * weight; } decisionValue = Math.max(Math.min(decisionValue, 35.), -35.); return 1. / (1. + Math.exp(-decisionValue)); } /** input: sample x, probability p, label y(-1(or 0) or 1) * used: w * update: n, z*/ public void updateModel(Map<Integer, Double> x, double p, double y) { for(Entry<Integer, Double> e : x.entrySet()) { double grad = p * e.getValue(); if(y == 1.0) { grad = (p - y) * e.getValue(); } double sigma = (Math.sqrt(n[e.getKey()] + grad * grad) - Math.sqrt(n[e.getKey()])) / parameters.alpha; z[e.getKey()] += (grad - sigma * w.get(e.getKey())); n[e.getKey()] += grad * grad; } } /** * N、Z、W * 模型参数保存函数 * */ public void saveModel(String filePath) throws IOException { String n_=String.valueOf(n[0]); String z_=String.valueOf(z[0]); String w_=String.valueOf(w.get(0)); for(int i=1;i<n.length;i++){ n_ = n_+" "+String.valueOf(n[i]); z_ = z_+" "+String.valueOf(z[i]); w_ = w_+" "+String.valueOf(w.get(i)); } try{ File file = new File(filePath); if(!file.exists()){ file.createNewFile(); } FileWriter fileWriter = new FileWriter(filePath); BufferedWriter bufferWriter = new BufferedWriter(fileWriter); bufferWriter.write(n_+"\r\n"); bufferWriter.write(z_+"\r\n"); bufferWriter.write(w_); bufferWriter.close(); System.out.print("Done"); }catch (IOException e){ e.printStackTrace(); } } public void loadModel(String filePath) throws IOException { BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(filePath)), "UTF-8")); String line = null; String[][] Str = new String[3][]; int i = 0; while((line = br.readLine()) != null) { Str[i] = line.split(" "); i++; } n = new double[n.length]; z = new double[z.length]; w = new HashMap<Integer, Double>(); for(int j=0;j<n.length;j++){ n[j] = Double.valueOf(Str[0][j]); z[j] = Double.valueOf(Str[1][j]); w.put(j,Double.valueOf(Str[2][j])); } } public double predict_(Map<Integer, Double> x) { double decisionValue = 0.0; for (Entry<Integer, Double> e : x.entrySet()) { decisionValue += e.getValue() * w_.get(e.getKey()); } decisionValue = Math.max(Math.min(decisionValue, 35.), -35.); return 1. / (1. + Math.exp(-decisionValue)); } private double sign(double x) { if (x > 0) { return 1.0; } else if (x < 0) { return -1.0; } else { return 0.0; } }}
模型预测类
package model;import java.io.*;import java.util.HashMap;import java.util.Map;import java.util.TreeMap;/** * 模型下载与预测方法 * n、z、w为需下载的模型参数 */public class FTRLModelLoad { public double[] n; public double[] z; public Map<Integer, Double> w; /** * 模型下载方法 * 输入:模型文件所在路径 * 功能:算法全局参数更新 * */ public Map<Integer, Double> loadModel(String filePath) throws IOException { BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(filePath)), "UTF-8")); String line = null; String[][] Str = new String[3][]; int i = 0; while((line = br.readLine()) != null) { Str[i] = line.split(" "); i++; } n = new double[Str[0].length]; z = new double[Str[0].length]; w = new HashMap<Integer, Double>(); for(int j=0;j<Str[0].length;j++){ n[j] = Double.valueOf(Str[0][j]); z[j] = Double.valueOf(Str[1][j]); w.put(j,Double.valueOf(Str[2][j])); } return w; } /** * 预测函数 * */ public double predict_(double[] x_,Map<Integer,Double> w) { Map<Integer,Double> x = new TreeMap<Integer, Double>(); for(int i=0;i<x_.length;i++){ x.put(i,x_[i]); } double decisionValue = 0.0; for (Map.Entry<Integer, Double> e : x.entrySet()) { decisionValue += e.getValue() * w.get(e.getKey()); } decisionValue = Math.max(Math.min(decisionValue, 35.), -35.); return 1. / (1. + Math.exp(-decisionValue)); }}
损失函数类
package evaluate;public class LogLossEvalutor { private int testDataSize; private double[] logloss; private int position; private double totalLoss; private boolean enoughData; public LogLossEvalutor(int testDataSize) { this.testDataSize = testDataSize; logloss = new double[testDataSize]; position = 0; totalLoss = 0.0; } public void addLogLoss(double loss) { totalLoss = totalLoss + loss - logloss[position]; logloss[position] = loss; position += 1; if(position >= testDataSize) { position = 0; enoughData = true; } } public double getAverageLogLoss() { if(enoughData) { return totalLoss / testDataSize; } else { return totalLoss / position; } } /** prob: p(y=1|x;w), y: 1 or 0(-1) */ public static double calLogLoss(double prob, double y) { //预测值范围控制方法 double p = Math.max(Math.min(prob, 1-1e-15), 1e-15); return y == 1.? -Math.log(p) : -Math.log(1. - p); } public static void main(String[] args) { LogLossEvalutor evalutor = new LogLossEvalutor(4); double[] losses = {3, 2, 1, 0.7, 0.5, 0.2}; for(int i=0; i<losses.length; i++) { evalutor.addLogLoss(losses[i]); System.out.println(evalutor.getAverageLogLoss()); } }}
下面为程序的一个测试结果图:
参考文献
(1) Ad Click Prediction: a View from the Trenches.H. Brendan McMahan, Gary Holt, D. Sculley et al
(2)美团技术团队《Online Learning算法理论与实践》
- 基于FTRL的在线CTR预测算法
- 点击率预测算法:FTRL
- 在线学习算法FTRL
- 在线学习算法FTRL
- 在线学习算法FTRL详解
- 在线学习算法FTRL-Proximal
- 【算法】在线学习算法FTRL详解
- 在线学习算法FTRL-Proximal原理
- Ftrl算法和FFM算法 广告点击率预测
- 【转载】各大公司广泛使用的在线学习算法FTRL详解
- 各大公司广泛使用的在线学习算法FTRL详解
- 各大公司广泛使用的在线学习算法FTRL详解
- 各大公司广泛使用的在线学习算法FTRL详解
- 各大公司广泛使用的在线学习算法FTRL详解
- 各大公司广泛使用的在线学习算法FTRL详解
- 各大公司广泛使用的在线学习算法FTRL详解
- 各大公司广泛使用的在线学习算法FTRL详解
- 各大公司广泛使用的在线学习算法FTRL详解
- 第一讲, HelloWorld基础语法
- 不用ajax的局部页面跳转实现方法(iframe)
- android 程序全局自动捕获异常 专制系统奔溃,异常退出无法找到原因的问题
- php 安装 sqlsrv pdo_sqlsrv拓展
- Light Bulb ZOJ
- 基于FTRL的在线CTR预测算法
- linux配置Tomcat
- 使用Python的SymPy库解决数学运算问题
- cuda矩阵之心得
- 【拉格朗日求自然数幂和】cf622F
- springboot—restful风格
- Linux 之内核态与用户态
- JavaScript原型
- CTF中的编码与加密题