基于FTRL的在线CTR预测算法

来源:互联网 发布:类似于知乎的app 编辑:程序博客网 时间:2024/05/16 19:36

在程序化广告投放中,一个优秀的CTR预测算法会给广告主、Adx以及用户都将带来好处。Google公司2013在《ResearchGate》上发表了一篇“Ad click prediction: a view from the trenches”论文,这篇论文是基于FTRL的在线CTR预测算法,下面将讲解该算法的主要思想以及Java实现。

什么是Online Learning

传统的批量算法的每次迭代是对全体训练数据集进行计算(例如计算全局梯度),优点是精度和收敛还可以,缺点是无法有效处理大数据集(此时全局梯度计算代价太大),且没法应用于数据流做在线学习。而在线学习算法的特点是:每来一个训练样本,就用该样本产生的loss和梯度对模型迭代一次,一个一个数据地进行训练,因此可以处理大数据量训练和在线训练。准确地说,Online Learning并不是一种模型,而是一种模型的训练方法,Online Learning能够根据线上反馈数据,实时快速地进行模型调整,使得模型及时反映线上的变化,提高线上预测的准确率。Online Learning的流程包括:将模型的预测结果展现给用户,然后收集用户的反馈数据,再用来训练模型,形成闭环的系统。如下图所示:

这里写图片描述

这篇论文提出的基于FTRL的在线CTR预测算法,就是一种Online Learning算法。即,针对每一个训练样本,首先通过一种方式进行预测,然后再利用一种损失函数进行误差评估,最后再通过所评估的误差值对参数进行更新迭代。直到所有样本全部遍历完,则结束。那么,如何选择模型预测方法、评估指标以及模型更新公式就是该算法的重点所在。下面将介绍论文中这三部分内容:

  1. 预测方法:在每一轮t中,针对特征样本xtRd,以及迭代后(第一此则是给定初值)的模型参数wt,我们可以预测该样本的标记值:pt=σ(wt,xt),其中σ(a)=1/(1+exp(a))是一个sigmoid函数。

  2. 损失函数:对一个特征样本xt,其对应的标记为yt0,1,则通过LogLoss(logistic loss)来作为损失函数,即: lt(wt)=ytlogpt(1yt)log(1pt)

  3. 迭代公式:我们的目的是使得损失函数尽可能的小,即可以采用极大似然估计来求解参数。首先求梯度 gt=dltdw=(σ(wxt)yt)xt=(ptyt)xt,使用FTRL进行迭代:
    这里写图片描述
    其中,σs为学习率且σ1:t=1ntg1:t=ts=1gtλ1为正则化参数。该最优化公式可以化简为:
    这里写图片描述
    则,如果我们令zt1=g1:t1t1s=1σsws,则在第t轮迭代前,令zt=zt1+gt(1nt1nt1)wt(此处和论文中的公式不一致,我觉得应该是减去最后一项,而不是加,作者在后面伪代码中也改成了减,故此处可能是作者笔误)

下面令梯度为0,则可以得到该优化问题的解析解:
这里写图片描述

到此就叙述完该算法的理论部分了,我想大部分人对这部分也不太感兴趣吧,下面直接上伪代码和Java实现吧(过程和理论部分其实是一致的,嘿嘿,想深入的还是研究下理论部分吧):

这里写图片描述

基于FTRL的在线CTR预测算法的Java实现

模型参数类

package DataClass;public class FTRLParameters {    public double alpha;//学习速率参数    public double beta;//调整参数,值为1时效果较好,无需调整    public double L1_lambda;//L1范式参数    public double L2_lambda;//L2范式参数    public int dataDimensions;//数据特征维度数    public int testDataSize;//测试集分次处理每次处理的个数    public int interval;//每间隔interval进行一次打印    public String modelPath;//模型训练参数的存放路径    public FTRLParameters(double alpha, double beta,                          double L1, double L2, int dataDimensions,int testDataSize, int interval,String modelPath) {        this.alpha = alpha;        this.beta = beta;        this.L1_lambda = L1;        this.L2_lambda = L2;        this.dataDimensions = dataDimensions;        this.testDataSize = testDataSize;        this.interval = interval;        this.modelPath = modelPath;    }}

模型训练类

package model;import DataPreprocessing.FileOperation;import evaluate.LogLossEvalutor;import java.io.*;import java.util.Map;import java.util.TreeMap;public class FTRLLocalTrain {    private FTRLProximal learner;    private FTRLModelLoad mload;    private LogLossEvalutor evalutor;    private int printInterval;    public FTRLLocalTrain(FTRLModelLoad mload, FTRLProximal learner, LogLossEvalutor evalutor, int interval) {        this.mload = mload;        this.learner = learner;        this.evalutor = evalutor;        this.printInterval = interval;    }    /**     * 训练方法     * */    public void train(String modelPath,double[][] X,double[] Y) throws IOException {        int trainedNum = 0;        double totalLoss = 0.0;//损失值        long startTime = System.currentTimeMillis();       BufferedReader mp = new BufferedReader(new InputStreamReader(new FileInputStream(new File(modelPath)), "UTF-8"));        while((line = mp.readLine())!=null){            learner.loadModel(modelPath);        }        for(int j=0;j<X.length;j++){            Map<Integer, Double> x = new TreeMap<Integer, Double>();            for (int i = 0; i < X[0].length; i++) {                x.put(i, X[j][i]);            }            double y = ((int)Y[j] == 1) ? 1. : 0.;            double p = learner.predict(x);            learner.updateModel(x, p, y);            double loss = LogLossEvalutor.calLogLoss(p, y);            evalutor.addLogLoss(loss);            totalLoss += loss;            trainedNum += 1;            if (trainedNum % printInterval == 0) {                long currentTime = System.currentTimeMillis();                double minutes = (double) (currentTime - startTime) / 60000;                System.out.printf("%.3f, %.5f\n", minutes, evalutor.getAverageLogLoss());            }        }        learner.saveModel(modelPath);        System.out.printf("global average loss: %.5f\n", totalLoss / trainedNum);    }}

模型更新类

package model;import DataClass.FTRLParameters;import java.io.*;import java.util.HashMap;import java.util.Map;import java.util.Map.Entry;public class FTRLProximal {    // parameters->alpha, beta, l1, l2, dimensions    private FTRLParameters parameters;    // n->squared sum of past gradients    public double[] n;    // z->weights    public double[] z;    // w->lazy weights    public Map<Integer, Double> w;    public double[] n_;    public double[] z_;    public Map<Integer, Double> w_;    public FTRLProximal(FTRLParameters parameters) {        this.parameters = parameters;        this.n = new double[parameters.dataDimensions];        this.z = new double[parameters.dataDimensions];        this.w = null;    }    /** x->p(y=1|x; w) , get w, nothing is changed*/    public double  predict(Map<Integer, Double> x) {        w = new HashMap<Integer, Double>();        double decisionValue = 0.0;        for (Entry<Integer, Double> e : x.entrySet()) {            double sgn = sign(z[e.getKey()]);            double weight = 0.0;            if (sgn * z[e.getKey()] <= parameters.L1_lambda) {                w.put(e.getKey(), weight);            } else {                weight = (sgn * parameters.L1_lambda - z[e.getKey()])                        / ((parameters.beta + Math.sqrt(n[e.getKey()]))                                / parameters.alpha + parameters.L2_lambda);                w.put(e.getKey(), weight);            }            decisionValue += e.getValue() * weight;        }        decisionValue = Math.max(Math.min(decisionValue, 35.), -35.);        return 1. / (1. + Math.exp(-decisionValue));    }    /** input: sample x, probability p, label y(-1(or 0) or 1)      *  used: w      *  update: n, z*/    public void updateModel(Map<Integer, Double> x, double p, double y) {        for(Entry<Integer, Double> e : x.entrySet()) {            double grad = p * e.getValue();            if(y == 1.0) {                grad = (p - y) * e.getValue();            }            double sigma = (Math.sqrt(n[e.getKey()] + grad * grad) -                     Math.sqrt(n[e.getKey()])) / parameters.alpha;            z[e.getKey()] += (grad - sigma * w.get(e.getKey()));            n[e.getKey()] += grad * grad;        }    }    /**     * N、Z、W     * 模型参数保存函数     * */    public void saveModel(String filePath) throws IOException {        String n_=String.valueOf(n[0]);        String z_=String.valueOf(z[0]);        String w_=String.valueOf(w.get(0));        for(int i=1;i<n.length;i++){            n_ = n_+" "+String.valueOf(n[i]);            z_ = z_+" "+String.valueOf(z[i]);            w_ = w_+" "+String.valueOf(w.get(i));        }        try{            File file = new File(filePath);            if(!file.exists()){                file.createNewFile();            }            FileWriter fileWriter = new FileWriter(filePath);            BufferedWriter bufferWriter = new BufferedWriter(fileWriter);            bufferWriter.write(n_+"\r\n");            bufferWriter.write(z_+"\r\n");            bufferWriter.write(w_);            bufferWriter.close();            System.out.print("Done");        }catch (IOException e){            e.printStackTrace();        }    }    public void loadModel(String filePath) throws IOException {        BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(filePath)), "UTF-8"));        String line = null;        String[][] Str = new String[3][];        int i = 0;        while((line = br.readLine()) != null) {            Str[i] = line.split(" ");            i++;        }        n = new double[n.length];        z = new double[z.length];        w = new HashMap<Integer, Double>();        for(int j=0;j<n.length;j++){            n[j] = Double.valueOf(Str[0][j]);            z[j] = Double.valueOf(Str[1][j]);            w.put(j,Double.valueOf(Str[2][j]));        }    }    public double predict_(Map<Integer, Double> x) {        double decisionValue = 0.0;        for (Entry<Integer, Double> e : x.entrySet()) {            decisionValue += e.getValue() * w_.get(e.getKey());        }        decisionValue = Math.max(Math.min(decisionValue, 35.), -35.);        return 1. / (1. + Math.exp(-decisionValue));    }    private double sign(double x) {        if (x > 0) {            return 1.0;        } else if (x < 0) {            return -1.0;        } else {            return 0.0;        }    }}

模型预测类

package model;import java.io.*;import java.util.HashMap;import java.util.Map;import java.util.TreeMap;/** * 模型下载与预测方法 * n、z、w为需下载的模型参数 */public class FTRLModelLoad {    public double[] n;    public double[] z;    public Map<Integer, Double> w;    /**     * 模型下载方法     * 输入:模型文件所在路径     * 功能:算法全局参数更新     * */    public Map<Integer, Double> loadModel(String filePath) throws IOException {        BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(filePath)), "UTF-8"));        String line = null;        String[][] Str = new String[3][];        int i = 0;        while((line = br.readLine()) != null) {            Str[i] = line.split(" ");            i++;        }        n = new double[Str[0].length];        z = new double[Str[0].length];        w = new HashMap<Integer, Double>();        for(int j=0;j<Str[0].length;j++){            n[j] = Double.valueOf(Str[0][j]);            z[j] = Double.valueOf(Str[1][j]);            w.put(j,Double.valueOf(Str[2][j]));        }        return w;    }    /**     * 预测函数     * */    public double predict_(double[] x_,Map<Integer,Double> w) {        Map<Integer,Double> x = new TreeMap<Integer, Double>();        for(int i=0;i<x_.length;i++){            x.put(i,x_[i]);        }        double decisionValue = 0.0;        for (Map.Entry<Integer, Double> e : x.entrySet()) {            decisionValue += e.getValue() * w.get(e.getKey());        }        decisionValue = Math.max(Math.min(decisionValue, 35.), -35.);        return 1. / (1. + Math.exp(-decisionValue));    }}

损失函数类

package evaluate;public class LogLossEvalutor {    private int testDataSize;    private double[] logloss;    private int position;    private double totalLoss;    private boolean enoughData;    public LogLossEvalutor(int testDataSize) {        this.testDataSize = testDataSize;        logloss = new double[testDataSize];        position = 0;        totalLoss = 0.0;    }    public void addLogLoss(double loss) {        totalLoss = totalLoss + loss - logloss[position];        logloss[position] = loss;        position += 1;        if(position >= testDataSize) {            position = 0;            enoughData = true;        }    }    public double getAverageLogLoss() {        if(enoughData) {            return totalLoss / testDataSize;        } else {            return totalLoss / position;        }    }    /** prob: p(y=1|x;w), y: 1 or 0(-1) */    public static double calLogLoss(double prob, double y) {        //预测值范围控制方法        double p = Math.max(Math.min(prob,  1-1e-15), 1e-15);        return y == 1.? -Math.log(p) : -Math.log(1. - p);    }    public static void main(String[] args) {        LogLossEvalutor evalutor = new LogLossEvalutor(4);        double[] losses = {3, 2, 1, 0.7, 0.5, 0.2};        for(int i=0; i<losses.length; i++) {            evalutor.addLogLoss(losses[i]);            System.out.println(evalutor.getAverageLogLoss());        }    }}

下面为程序的一个测试结果图:

这里写图片描述

参考文献
(1) Ad Click Prediction: a View from the Trenches.H. Brendan McMahan, Gary Holt, D. Sculley et al
(2)美团技术团队《Online Learning算法理论与实践》

原创粉丝点击