选择性的加载网络模型的前几层训练(27)---《深度学习》
来源:互联网 发布:石家庄直销软件 编辑:程序博客网 时间:2024/04/29 11:17
加载模型的前几层拼接自己构建的层进行训练
注意这里我们使用了nets.inception.inception_v3_base来进行网络模型的部分恢复,因为nets.inception.inception_v3_base中可以指定final_endpoint参数进行网络的末尾层指定,然后通过在saver的restore函数中进行参数的设定来确保那些权值进行恢复,那些不需要进行恢复!
train.py
#-*-coding=utf-8-*-from PIL import Imageimport osimport os.pathimport numpy as npimport tensorflow as tfimport tensorflow.contrib.slim as slimimport tensorflow.contrib.slim.nets as netsimport inception_resnet_v2import img_convertheight = 299width = 299channels = 3num_classes=1001def convert(dir): filelists=os.listdir(dir) arr_col=[] for file in filelists: file_path=os.path.join(dir,file) img=Image.open(file_path).resize((299,299)).convert("RGB") r,g,b=img.split() r_arr=np.array(r) g_arr=np.array(g) b_arr=np.array(b) img_arr=np.concatenate((r_arr,g_arr,b_arr)) result=img_arr.reshape((299,299,3)) arr_col.append(result) return arr_coldef convert_3_2_4_dims(arr_): ret=np.zeros((len(arr_),arr_[0].shape[0],arr_[0].shape[1],arr_[0].shape[2])) for i in range(len(arr_)): ret[i,:,:,:]=arr_[i] return retif __name__=="__main__": o_dir="E:/test" num_classes=182 batch_size=3 epoches=2 X = tf.placeholder(tf.float32, shape=[None, height, width, channels]) y = tf.placeholder(tf.float32,shape=[None,182]) with slim.arg_scope(nets.inception.inception_v3_arg_scope()): logits,end_points_ = nets.inception.inception_v3_base(X,final_endpoint='Mixed_7c') variables_to_restore=slim.get_variables_to_restore() shape=logits.get_shape().as_list() dim=1 for d in shape[1:]: dim*=d fc_=tf.reshape(logits,[-1,dim]) fc0_weights=tf.get_variable(name="fc0_weights",shape=(dim,182),initializer=tf.contrib.layers.xavier_initializer()) fc0_biases=tf.get_variable(name="fc0_biases",shape=(182),initializer=tf.contrib.layers.xavier_initializer()) logits_=tf.nn.bias_add(tf.matmul(fc_,fc0_weights),fc0_biases) predictions=tf.nn.softmax(logits_) #cross_entropy = -tf.reduce_sum(y*tf.log(predictions)) cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=logits_)) #cross_entropy_mean=tf.reduce_mean(cross_entropy) train_step=tf.train.GradientDescentOptimizer(1e-6).minimize(cross_entropy) correct_pred=tf.equal(tf.argmax(y,1),tf.argmax(predictions,1)) #acc=tf.reduce_sum(tf.cast(correct_pred,tf.float32)) accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32)) with tf.Session() as sess: batches=img_convert.data_lrn(img_convert.load_data(o_dir,num_classes,batch_size)) sess.run(tf.global_variables_initializer()) saver=tf.train.Saver(variables_to_restore) saver.restore(sess,os.path.join("E:\\","inception_v3.ckpt")) for epoch in range(epoches): for batch in batches: sess.run(train_step,feed_dict={X:batch[0],y:batch[1]}) acc=sess.run(accuracy,feed_dict={X:batches[0][0],y:batches[1][1]}) print(acc) print("Done")
img_convert.py
#coding=utf-8from PIL import Imageimport osimport os.pathimport numpy as npimport tensorflow as tfimport tensorflow.contrib.slim as slimimport tensorflow.contrib.slim.nets as netsimport inception_resnet_v2def convert(dir): filelists=os.listdir(dir) arr_col=[] for file in filelists: file_path=os.path.join(dir,file) img=Image.open(file_path).resize((299,299)).convert("RGB") r,g,b=img.split() r_arr=np.array(r) g_arr=np.array(g) b_arr=np.array(b) img_arr=np.concatenate((r_arr,g_arr,b_arr)) result=img_arr.reshape((299,299,3)) arr_col.append(result) return arr_coldef convert_3_2_4_dims(arr_): ret=np.zeros((len(arr_),arr_[0].shape[0],arr_[0].shape[1],arr_[0].shape[2])) for i in range(len(arr_)): ret[i,:,:,:]=arr_[i] return retdef to_categorial(y,n_classes): y_std=np.zeros([len(y),n_classes]) for i in range(len(y)): y_std[i,y[i]]=1.0 return y_stddef batch_list(x,y,batch_size): batches=[] for i in range(int(len(x)/batch_size)): batch_data=[x[batch_size*i:batch_size*i+batch_size],y[batch_size*i:batch_size*i+batch_size]] batches.append(list(batch_data)) if(i+1)*batch_size<len(x): batch_data=[x[batch_size*(i+1):],y[batch_size*(i+1):]] batches.append(list(batch_data)) return batchesdef load_data(dir,num_classes,batch_size): arr_col=convert_3_2_4_dims(convert(dir)) arr_col=arr_col.astype(np.float32) #因为这儿我没指定它的标签,所以就随机指定了一些标签 z=np.random.rand(arr_col.shape[0])*num_classes z=z.astype("int") labels=np.array(z) batches=batch_list(arr_col,to_categorial(labels,num_classes),batch_size) return batchesdef data_lrn(batches): for batch in batches: batch[0]/=255 return batches
阅读全文
0 0
- 选择性的加载网络模型的前几层训练(27)---《深度学习》
- tensorflow对自己的数据进行训练(选择性的恢复权值)(26)---《深度学习》
- VggNet10模型的cifar10深度学习训练
- 【深度学习】训练网络的方法总结
- 深度学习-CAFFE利用CIFAR10网络模型训练自己的图像数据获得模型-4应用生成模型进行预测
- 深度学习实战——caffe windows 下训练自己的网络模型
- 深度学习实战——caffe windows 下训练自己的网络模型
- 深度学习网络模型训练中loss为nans的总结
- 【神经网络与深度学习】深度学习实战——caffe windows 下训练自己的网络模型
- 深度学习-CAFFE利用CIFAR10网络模型训练自己的图像数据获得模型-1.制作自己的数据集
- 深度学习ssd检测模型训练自己的数据集
- 对深度学习训练模型过程的理解
- 【深度学习】tensorflow加载VGG16的网络结构和模型参数
- 深度学习-CAFFE利用CIFAR10网络模型训练自己的图像数据获得模型-3结合caffe中的CIFAR10修改相关配置文件并训练
- 深度网络的预训练
- 深度学习笔记(二)-模型训练
- 深度学习(十四):详解Matconvnet使用imagenet模型训练自己的数据集
- 深度学习(一)学会用CAFFE训练自己的模型
- 语句
- https的安全证书与网页重写
- 整数划分问题输出所有划分结果及总数
- CodeForces
- httpd性能调整及服务器安全
- 选择性的加载网络模型的前几层训练(27)---《深度学习》
- Java递归实现99乘法表
- JavaMail-发送一封简单邮件(附带附件)
- 阿里云分布式缓存OCS与DB之间的数据一致性
- tomcat 环境迁移至weblogic 下载文件失败
- 自动装配bean
- <C++> 基于C++11/14/17的线程池实现
- JavaScript 关于进制之间的转换实现
- JS无缝滚动