利用Theano理解深度学习——Logistic Regression
来源:互联网 发布:nginx 函数 编辑:程序博客网 时间:2024/05/16 14:33
一、Logistic Regression
1、LR模型
Logistic回归是广义线性模型的一种,属于线性的分类模型,在其模型中主要有两个参数,即:权重矩阵
对于输入向量
模型对于输入向量
通常使用的是其二分类的模型,即属于类别
2、损失函数
在LR模型中,需要求解的参数为权重矩阵
为了方便,通常使用负的Log似然函数,即the negative log-likelihood(NLL)作为其损失函数,此时,需要计算的是NLL的极小值。损失函数
3、随机梯度下降法
为了求解LR模型中的参数,在上面定义了LR模型的损失函数,即NLL。此时,只需计算NLL的极小值条件下的参数
随机梯度下降法(Stochastic gradient descent,SGD)与传统的批梯度下降法的原则一致,都是选择最快的下降方向,但是,与批梯度不同的是,在选择下降方向时,批梯度是对所有的训练样本计算其梯度,而SGD仅仅是对一部分样本计算其梯度,通常情况下,在SGD中,通常选择根据一个样本计算其梯度,SGD的伪代码如下:
在深度学习算法的模型训练中,可以使用SGD的一个变种形式,称为“minibatches”。在Minibatch SGD中,其工作原理与SGD一致,其区别仅仅是在Minibatch SGD中,通过多个样本计算其梯度,而不是根据一个样本,但又不同于批梯度下降法中的根据整个训练集计算其梯度。根据所需样本量的大小,Minibatch SGD是出于SGD与批梯度之间的一种变形形式。其伪代码如下所示:
对于minibatch的大小
在LR模型的计算中,此时只需计算NLL的对于参数
二、基于Theano的Logistic Regression实现解析
1、导入数据集
导入数据集的函数为load_data(dataset)
,具体的函数形式如下:
def load_data(dataset):'''导入数据:type dataset: string:param dataset: MNIST数据集''' #1、处理文件目录 data_dir, data_file = os.path.split(dataset)#把路径分割成dirname和basename,返回一个元组 if data_dir == "" and not os.path.isfile(dataset): new_path = os.path.join( os.path.split(__file__)[0],#__file__表示的是当前的路径 ".", "data", dataset ) if os.path.isfile(new_path) or data_file == 'mnist.pkl.gz': dataset = new_path#文件所在的目录 if (not os.path.isfile(dataset)) and data_file == 'mnist.pkl.gz': import urllib origin = ( 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' ) print 'Downloading data from %s' % origin urllib.urlretrieve(origin, dataset) print '... loading data' #2、打开文件 f = gzip.open(dataset, 'rb')# 打开一个gzip已经压缩好的gzip格式的文件,并返回一个文件对象:file object. train_set, valid_set, test_set = cPickle.load(f)#载入本地的文件 f.close() '''训练集train_set,验证集valid_set和测试集test_set的格式:元组(input, target) 其中,input是一个矩阵(numpy.ndarray),每一行代表一个样本;target是一个向量(numpy.ndarray),大小与input的行数对应 ''' def shared_dataset(data_xy, borrow=True): data_x, data_y = data_xy shared_x = theano.shared(numpy.asarray(data_x, dtype=theano.config.floatX), borrow=borrow) shared_y = theano.shared(numpy.asarray(data_y, dtype=theano.config.floatX), borrow=borrow) return shared_x, T.cast(shared_y, 'int32')#将shared_y转换成整型 #3、将数据处理成需要的形式 test_set_x, test_set_y = shared_dataset(test_set) valid_set_x, valid_set_y = shared_dataset(valid_set) train_set_x, train_set_y = shared_dataset(train_set) #4、返回数据集 rval = [(train_set_x, train_set_y), (valid_set_x, valid_set_y), (test_set_x, test_set_y)] return rval
需要导入的模块主要有os
、gzip
和cPickle
。其中os
模块主要用于在本地查找dataset文件,具有目录的处理以及文件的判断等函数;gzip
模块提供了一些简单的对文件进行压缩和解压缩的函数功能;cPickle
模块可以对任意一种类型的python对象进行序列化操作。
1、程序中的os
模块
在load_data(dataset)
函数中,使用到的主要是os.path
模块,使用到的函数是:
os.path.split(path)
:把路径分割成dirname和basename,返回一个元组os.path.isfile(path)
:判断路径是否为文件os.path.join(path1[, path2[, ...]])
:把目录和文件名合成一个路径
注:__file__
表示的是当前的路径
2、程序中的gzip
模块
gzip
模块主要提供了一些简单的对文件进行压缩和解压缩的函数功能。使用到的函数是:
gzip.open(dataset, 'rb')
: 打开一个gzip已经压缩好的gzip格式的文件,并返回一个文件对象:file object.
3、程序中的cPickle
模块
cPickle
模块可以对任意一种类型的python对象进行序列化操作,使用到的函数是:
cPickle.load(file)
:主要是载入本地的文件。
在导入数据的过程中,将数据做成了带有存储性质的形式,这样的形式可以使得变量在不同的函数之间共享,具体的构造函数为theano.shared()
。
4、theano.shared()
函数
函数theano.shared()
的格式如下:
如果设置borrow=False
,这表示在使用变量的过程中将是深拷贝,对数据的任何改变不会影响到原始的变量,通过控制该参数可以实现不同函数之间对变量的共享。
2、构建LogisticRegression
类
LogisticRegression
类的代码如下所示:
class LogisticRegression(object): def __init__(self, input, n_in, n_out): """ 初始化参数 :type input: theano.tensor.TensorType :param input: 一个minibatch :type n_in: int :param n_in: 输入的特征的个数 :type n_out: int :param n_out: 输出单元的个数,即输出的类别个数,在本例中共有10个类别 """ #初始化参数W和b self.W = theano.shared(value=numpy.zeros((n_in, n_out), dtype=theano.config.floatX), name='W', borrow=True) self.b = theano.shared(value=numpy.zeros((n_out,), dtype=theano.config.floatX), name='b', borrow=True) #计算属于不同的类别的概率 self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b) # 计算所属的类别 self.y_pred = T.argmax(self.p_y_given_x, axis=1) #参数声明 self.params = [self.W, self.b] self.input = input def negative_log_likelihood(self, y): """负的log似然函数 :type y: theano.tensor.TensorType :param y: 对应的类别标签 """ return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]), y]) '''T.arange(y.shape[0])返回的是一个向量[0,1,...,len(y)],y也是一个向量,如[3,5,6...,9],代表的是所属的类别 T.log(self.p_y_given_x)[T.arange(y.shape[0]), y]表示的是T.log(self.p_y_given_x)[0,3],mean函数内部是一个向量 ''' def errors(self, y): """计算在minibatch中的错误率 :type y: theano.tensor.TensorType :param y: 对应的类别标签 """ # 检查y与y_pred是否具有相同的维度 if y.ndim != self.y_pred.ndim: raise TypeError( 'y should have the same shape as self.y_pred', ('y', y.type, 'y_pred', self.y_pred.type) ) # 检查y的数据格式 if y.dtype.startswith('int'): #返回错误率 return T.mean(T.neq(self.y_pred, y)) else: raise NotImplementedError()
在LogisticRegression
类中主要有三个函数,构造函数__init__()
,负的log似然函数negative_log_likelihood()
和计算错误率函数errors()
。
1、构造函数__init__()
在构造函数中主要有这样一些函数,theano.shared()
、theano.tensor.nnet.softmax()
,theano.tensor.nnet.dot()
和theano.tensor.argmax()
。其中,theano.shared()
在上面已经简单解释了;theano.tensor.nnet.softmax()
主要用于计算属于每一个类别的概率;theano.tensor.nnet.dot()
用于计算矩阵计算;theano.tensor.argmax()
用于返回最终所属的类别。
2、负的log似然函数negative_log_likelihood()
在负的log似然函数negative_log_likelihood()
中使用到的函数是theano.tensor.mean()
,该函数用于计算均值。
3、计算错误率函数errors()
计算错误率函数用于在validation阶段和testing阶段对模型的评估,主要的思想是利用模型对验证集以及测试集进行预测,用预测的结果y_pred
与样本标签y
进行对比,记录错误的个数,并返回错误的概率。用到的函数为theano.tensor.neq()
和theano.tensor.mean()
,函数theano.tensor.neq(self.y_pred, y)
用于统计self.y_pred
和y
中不相等的样本的个数。
3、sgd_optimization_mnist
函数
这个函数是整个Logistic回归算法的核心部分,用于构建整个算法的流程,该函数主要分为以下几个部分:
- 导入数据集
- 建立模型
- 训练模型
1、导入数据集
导入函数部分的代码已在上面解释过了。处理数据集部分的代码如下:
#1、导入数据集 datasets = load_data(dataset) train_set_x, train_set_y = datasets[0]#训练集 valid_set_x, valid_set_y = datasets[1]#验证集 #计算minibatches,得到训练集,验证集和测试集的minibatch的大小 n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size n_valid_batches = valid_set_x.get_value(borrow=True).shape[0] / batch_size
2、建立模型
在建立模型阶段,首先是一些全局符号变量的声明,然后初始化分类器,接着是构建好验证模型和训练模型,具有的代码如下:
# 2、建模 print '... building the model' index = T.lscalar() # 声明一个符号变量,用于minibatch索引 # 声明符号变量x和y x = T.matrix('x') y = T.ivector('y') # 2.1 初始化分类器 classifier = LogisticRegression(input=x, n_in=28 * 28, n_out=10) cost = classifier.negative_log_likelihood(y) #验证模型 validate_model = theano.function( inputs=[index], outputs=classifier.errors(y), givens={ x: valid_set_x[index * batch_size: (index + 1) * batch_size], y: valid_set_y[index * batch_size: (index + 1) * batch_size] } ) # 计算梯度 g_W = T.grad(cost=cost, wrt=classifier.W) g_b = T.grad(cost=cost, wrt=classifier.b) # 对参数的更新 updates = [(classifier.W, classifier.W - learning_rate * g_W), (classifier.b, classifier.b - learning_rate * g_b)] # 模型的训练规则 train_model = theano.function( inputs=[index], outputs=cost, updates=updates, givens={ x: train_set_x[index * batch_size: (index + 1) * batch_size], y: train_set_y[index * batch_size: (index + 1) * batch_size] } )
3、训练模型
在模型的训练过程中,通过随机梯度下降不断调整模型中的参数
#3、训练模型 print '... training the model' # early-stopping 参数 patience = 5000 patience_increase = 2 improvement_threshold = 0.995 validation_frequency = min(n_train_batches, patience / 2) best_validation_loss = numpy.inf#初始化一个最大的误差 start_time = timeit.default_timer()#计算开始的时间 done_looping = False#用于判断是否结束循环的标志 epoch = 0#当前的迭代次数 while (epoch < n_epochs) and (not done_looping): epoch = epoch + 1 #对每一个minibatch的数据进行训练 for minibatch_index in xrange(n_train_batches): minibatch_avg_cost = train_model(minibatch_index)#得到负的log似然值 #计算迭代的次数 iter = (epoch - 1) * n_train_batches + minibatch_index #每次将minibatch数据集计算一遍便开始计算validation if (iter + 1) % validation_frequency == 0: # 在验证集上验证模型的优劣 validation_losses = [validate_model(i) for i in xrange(n_valid_batches)] this_validation_loss = numpy.mean(validation_losses) print( 'epoch %i, minibatch %i/%i, validation error %f %%' % ( epoch, minibatch_index + 1, n_train_batches, this_validation_loss * 100. ) ) # 记录下在验证集上表现最好的模型 if this_validation_loss < best_validation_loss: #当性能足够好时不断提高patience以使得模型提早结束 if this_validation_loss < best_validation_loss * improvement_threshold: patience = max(patience, iter * patience_increase) best_validation_loss = this_validation_loss # 保存最优的模型 with open('best_model.pkl', 'w') as f: cPickle.dump(classifier, f) #提早退出循环 if patience <= iter: done_looping = True break end_time = timeit.default_timer()# 运行结束的时间 # 打印最优的验证结果 print( ( 'Optimization complete with best validation score of %f %%' ) % (best_validation_loss * 100.) ) # 打印运行的时间 print 'The code run for %d epochs, with %f epochs/sec' % ( epoch, 1. * epoch / (end_time - start_time)) print >> sys.stderr, ('The code for file ' + os.path.split(__file__)[1] + ' ran for %.1fs' % ((end_time - start_time)))
4、predict
函数
在predict
函数中,使用到的是模型和测试数据集,具体的函数如下:
def predict(): """用训练好的模型进行预测 """ # 导入训练好的模型 classifier = cPickle.load(open('best_model.pkl')) # 建立预测模型 predict_model = theano.function( inputs=[classifier.input], outputs=classifier.y_pred) # 导入测试数据集 dataset='mnist.pkl.gz' datasets = load_data(dataset) test_set_x, test_set_y = datasets[2] test_set_x = test_set_x.get_value() predicted_values = predict_model(test_set_x[:10])#进行预测 print ("Predicted values for the first 10 examples in test set:") print predicted_values
使用的函数主要是导入函数和模型的函数,在上述都已经简单介绍过。
三、实验结果
1、训练模型
... loading data... building the model... training the modelepoch 1, minibatch 83/83, validation error 12.458333 %epoch 2, minibatch 83/83, validation error 11.010417 %epoch 3, minibatch 83/83, validation error 10.312500 %epoch 4, minibatch 83/83, validation error 9.875000 %epoch 5, minibatch 83/83, validation error 9.562500 %epoch 6, minibatch 83/83, validation error 9.322917 %epoch 7, minibatch 83/83, validation error 9.187500 %epoch 8, minibatch 83/83, validation error 8.989583 %epoch 9, minibatch 83/83, validation error 8.937500 %epoch 10, minibatch 83/83, validation error 8.750000 %epoch 11, minibatch 83/83, validation error 8.666667 %epoch 12, minibatch 83/83, validation error 8.583333 %epoch 13, minibatch 83/83, validation error 8.489583 %epoch 14, minibatch 83/83, validation error 8.427083 %epoch 15, minibatch 83/83, validation error 8.354167 %epoch 16, minibatch 83/83, validation error 8.302083 %epoch 17, minibatch 83/83, validation error 8.250000 %epoch 18, minibatch 83/83, validation error 8.229167 %epoch 19, minibatch 83/83, validation error 8.260417 %epoch 20, minibatch 83/83, validation error 8.260417 %epoch 21, minibatch 83/83, validation error 8.208333 %epoch 22, minibatch 83/83, validation error 8.187500 %epoch 23, minibatch 83/83, validation error 8.156250 %epoch 24, minibatch 83/83, validation error 8.114583 %epoch 25, minibatch 83/83, validation error 8.093750 %epoch 26, minibatch 83/83, validation error 8.104167 %epoch 27, minibatch 83/83, validation error 8.104167 %epoch 28, minibatch 83/83, validation error 8.052083 %epoch 29, minibatch 83/83, validation error 8.052083 %epoch 30, minibatch 83/83, validation error 8.031250 %epoch 31, minibatch 83/83, validation error 8.010417 %epoch 32, minibatch 83/83, validation error 7.979167 %epoch 33, minibatch 83/83, validation error 7.947917 %epoch 34, minibatch 83/83, validation error 7.875000 %epoch 35, minibatch 83/83, validation error 7.885417 %epoch 36, minibatch 83/83, validation error 7.843750 %epoch 37, minibatch 83/83, validation error 7.802083 %epoch 38, minibatch 83/83, validation error 7.812500 %epoch 39, minibatch 83/83, validation error 7.812500 %epoch 40, minibatch 83/83, validation error 7.822917 %epoch 41, minibatch 83/83, validation error 7.791667 %epoch 42, minibatch 83/83, validation error 7.770833 %epoch 43, minibatch 83/83, validation error 7.750000 %epoch 44, minibatch 83/83, validation error 7.739583 %epoch 45, minibatch 83/83, validation error 7.739583 %epoch 46, minibatch 83/83, validation error 7.739583 %epoch 47, minibatch 83/83, validation error 7.739583 %epoch 48, minibatch 83/83, validation error 7.708333 %epoch 49, minibatch 83/83, validation error 7.677083 %epoch 50, minibatch 83/83, validation error 7.677083 %epoch 51, minibatch 83/83, validation error 7.677083 %epoch 52, minibatch 83/83, validation error 7.656250 %epoch 53, minibatch 83/83, validation error 7.656250 %epoch 54, minibatch 83/83, validation error 7.635417 %epoch 55, minibatch 83/83, validation error 7.635417 %epoch 56, minibatch 83/83, validation error 7.635417 %epoch 57, minibatch 83/83, validation error 7.604167 %epoch 58, minibatch 83/83, validation error 7.583333 %epoch 59, minibatch 83/83, validation error 7.572917 %epoch 60, minibatch 83/83, validation error 7.572917 %epoch 61, minibatch 83/83, validation error 7.583333 %epoch 62, minibatch 83/83, validation error 7.572917 %epoch 63, minibatch 83/83, validation error 7.562500 %epoch 64, minibatch 83/83, validation error 7.572917 %epoch 65, minibatch 83/83, validation error 7.562500 %epoch 66, minibatch 83/83, validation error 7.552083 %epoch 67, minibatch 83/83, validation error 7.552083 %epoch 68, minibatch 83/83, validation error 7.531250 %epoch 69, minibatch 83/83, validation error 7.531250 %epoch 70, minibatch 83/83, validation error 7.510417 %epoch 71, minibatch 83/83, validation error 7.520833 %epoch 72, minibatch 83/83, validation error 7.510417 %epoch 73, minibatch 83/83, validation error 7.500000 %Optimization complete with best validation score of 7.500000 %The code run for 74 epochs, with 2.780229 epochs/secThe code for file logistic_sgd.py ran for 26.6s
2、测试结果
... loading dataPredicted values for the first 10 examples in test set:[7 2 1 0 4 1 4 9 6 9]
参考文献
Deep Learning Tutorials (http://www.deeplearning.net/tutorial/)
- 利用Theano理解深度学习——Logistic Regression
- 利用Theano理解深度学习——Logistic Regression
- 利用Theano理解深度学习——Multilayer Perceptron
- 利用Theano理解深度学习——Convolutional Neural Networks
- 利用Theano理解深度学习——Auto Encoder
- 利用Theano理解深度学习——Multilayer Perceptron
- 利用Theano理解深度学习——Auto Encoder
- 吴恩达学习—Logistic Regression
- Theano Logistic Regression
- theano logistic regression讲解
- Python学习——Logistic Regression
- MXNet学习7——Logistic Regression
- 机器学习笔记——Logistic Regression
- theano学习笔记(一):Classifying MNIST digits using Logistic Regression
- 理解线性回归(二)——Logistic Regression 回归
- Logistic Regression(逻辑回归)(二)—深入理解
- 【theano-windows】学习笔记七——logistic回归
- Logistic Regression的理解
- 【C#高效编程50例】条目2:用运行时常量(readonly)而不是编译期常量(const)
- STL之list容器详解
- 【c++ templates读书笔记】【6】模板的多态
- Java____Timer实现定时功能及其源码研究
- jquery倒计时功能
- 利用Theano理解深度学习——Logistic Regression
- 算法时间复杂度T(n)大小顺序
- Java语言简介
- SQL Server基础--SQL语句
- hdu1272 小希的迷宫 并查集
- Android 开发第五弹:简易时钟(闹钟)
- superoj906 flood
- yum.Errors.MiscError: xz compression not available
- 欢迎使用CSDN-markdown编辑器