用 theano 求解 Logistic Regression (SGD 优化算法)
来源:互联网 发布:中国跳水最厉害的知乎 编辑:程序博客网 时间:2024/05/21 09:30
1. model
这里待求解的是一个 binary logistic regression,它是一个分类模型,参数是权值矩阵
当然最终的目标是求解在整个样本集
- 这里的取均值是为了解耦后续的正则化系数,以及 SGD 时的步长的选择;
当然也可对
2. theano 的使用
实现 theano 下的最小化问题的求解,涉及如下的四个流程:
(1)声明符号变量;
import numpy import theano.tensor as Tfrom theano import shared, functionx = T.matrix()y = T.lvector()w = shared(numpy.random.randn(100))b = shared(numpy.zeros(()))print 'step 1, initial mode: 'print w.get_value(), b.get_value()
(2)使用这些变量构建符号表达式图(symbolic expression graph)
# hypothesisp_1 = 1/(1+T.exp(-T.dot(x, w)-b))xent = -y*T.log(p_1)-(1-y)*T.log(1-p_1)cost = xent.mean() + 0.01*(w**2).sum()gw, gb = T.grad(cost, [w, b]);prediction = p_1 > .5
(3)编译 Theano functions;
train = function(inputs=[x, y], outputs=[predication, xent], updates={w:w-0.1*gw, b:b-0.1*gb})predict = function(inputs=[x], outputs=predication)
(4)调用编译好的函数来执行数值计算;
N = 4feats = 100D = (numpy.random.randn(N, feats), numpy.random.randi(low=0, high=2, size=(N,)))training_epochs = 10for _ in range(training_epochs): pred, err = train(D[0], D[1])print 'final model: 'print 'target values for D', D[1]print 'predication on D', predict(D[0])
0 0
- 用 theano 求解 Logistic Regression (SGD 优化算法)
- Theano(5):Logistic Regression实例
- Theano Logistic Regression
- theano logistic regression讲解
- 测试Mahout的Logistic Regression (SGD)
- 机器学习之logistic regression(SGD)
- Logistic Regression求解classification问题
- logistic regression using Theano 注释版
- theano tutorial(四) logistic regression 练习
- 逻辑回归(Logistic Regression)-牛顿法求解参数
- Logistic Regression 分类算法
- Logistic Regression 算法学习
- OWLQN算法介绍,及go语言实现logistic regression优化
- Logistic Regression 逻辑回归算法
- LR(Logistic Regression)算法详解
- 利用Theano理解深度学习——Logistic Regression
- 利用Theano理解深度学习——Logistic Regression
- theano logistic regression讲解之续模型测试
- 关于Page directive must not have multiple occurrences of pageencoding
- less
- 数据结构实验之栈六:下一较大值(二)
- 排序算法时间、空间复杂度
- 从MVC到前后端分离(REST-个人也认为是目前比较流行和比较好的方式)
- 用 theano 求解 Logistic Regression (SGD 优化算法)
- 第十二周实践求最大公约数
- centos7安装nfs服务
- JVM基础(7)——jdk常用内置工具
- Linux+db2+was部署问题总结
- 【iOS开发】在程序被送入后台时,开启一个长期任务(voip)。
- VC++ 出现Debug Assertion Failed!
- 【leetcode】454. 4Sum II【M】
- augmented reality(AR)入门实例