选择性的加载网络模型的前几层训练(27)---《深度学习》

来源:互联网 发布:石家庄直销软件 编辑:程序博客网 时间:2024/04/29 11:17

加载模型的前几层拼接自己构建的层进行训练
注意这里我们使用了nets.inception.inception_v3_base来进行网络模型的部分恢复,因为nets.inception.inception_v3_base中可以指定final_endpoint参数进行网络的末尾层指定,然后通过在saver的restore函数中进行参数的设定来确保那些权值进行恢复,那些不需要进行恢复!
train.py

#-*-coding=utf-8-*-from PIL import Imageimport osimport os.pathimport numpy as npimport tensorflow as tfimport tensorflow.contrib.slim as slimimport tensorflow.contrib.slim.nets as netsimport inception_resnet_v2import img_convertheight = 299width = 299channels = 3num_classes=1001def convert(dir):    filelists=os.listdir(dir)    arr_col=[]    for file in filelists:        file_path=os.path.join(dir,file)        img=Image.open(file_path).resize((299,299)).convert("RGB")        r,g,b=img.split()        r_arr=np.array(r)        g_arr=np.array(g)        b_arr=np.array(b)        img_arr=np.concatenate((r_arr,g_arr,b_arr))        result=img_arr.reshape((299,299,3))        arr_col.append(result)    return arr_coldef convert_3_2_4_dims(arr_):    ret=np.zeros((len(arr_),arr_[0].shape[0],arr_[0].shape[1],arr_[0].shape[2]))    for i in range(len(arr_)):        ret[i,:,:,:]=arr_[i]    return retif __name__=="__main__":    o_dir="E:/test"    num_classes=182    batch_size=3    epoches=2    X = tf.placeholder(tf.float32, shape=[None, height, width, channels])    y = tf.placeholder(tf.float32,shape=[None,182])    with slim.arg_scope(nets.inception.inception_v3_arg_scope()):        logits,end_points_ = nets.inception.inception_v3_base(X,final_endpoint='Mixed_7c')        variables_to_restore=slim.get_variables_to_restore()        shape=logits.get_shape().as_list()        dim=1        for d in shape[1:]:            dim*=d        fc_=tf.reshape(logits,[-1,dim])        fc0_weights=tf.get_variable(name="fc0_weights",shape=(dim,182),initializer=tf.contrib.layers.xavier_initializer())        fc0_biases=tf.get_variable(name="fc0_biases",shape=(182),initializer=tf.contrib.layers.xavier_initializer())        logits_=tf.nn.bias_add(tf.matmul(fc_,fc0_weights),fc0_biases)        predictions=tf.nn.softmax(logits_)        #cross_entropy = -tf.reduce_sum(y*tf.log(predictions))          cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=logits_))        #cross_entropy_mean=tf.reduce_mean(cross_entropy)        train_step=tf.train.GradientDescentOptimizer(1e-6).minimize(cross_entropy)        correct_pred=tf.equal(tf.argmax(y,1),tf.argmax(predictions,1))        #acc=tf.reduce_sum(tf.cast(correct_pred,tf.float32))        accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))    with tf.Session() as sess:        batches=img_convert.data_lrn(img_convert.load_data(o_dir,num_classes,batch_size))        sess.run(tf.global_variables_initializer())        saver=tf.train.Saver(variables_to_restore)        saver.restore(sess,os.path.join("E:\\","inception_v3.ckpt"))        for epoch in range(epoches):            for batch in batches:                sess.run(train_step,feed_dict={X:batch[0],y:batch[1]})        acc=sess.run(accuracy,feed_dict={X:batches[0][0],y:batches[1][1]})        print(acc)        print("Done")

img_convert.py

#coding=utf-8from PIL import Imageimport osimport os.pathimport numpy as npimport tensorflow as tfimport tensorflow.contrib.slim as slimimport tensorflow.contrib.slim.nets as netsimport inception_resnet_v2def convert(dir):    filelists=os.listdir(dir)    arr_col=[]    for file in filelists:        file_path=os.path.join(dir,file)        img=Image.open(file_path).resize((299,299)).convert("RGB")        r,g,b=img.split()        r_arr=np.array(r)        g_arr=np.array(g)        b_arr=np.array(b)        img_arr=np.concatenate((r_arr,g_arr,b_arr))        result=img_arr.reshape((299,299,3))        arr_col.append(result)    return arr_coldef convert_3_2_4_dims(arr_):    ret=np.zeros((len(arr_),arr_[0].shape[0],arr_[0].shape[1],arr_[0].shape[2]))    for i in range(len(arr_)):        ret[i,:,:,:]=arr_[i]    return retdef to_categorial(y,n_classes):    y_std=np.zeros([len(y),n_classes])    for i in range(len(y)):        y_std[i,y[i]]=1.0    return y_stddef batch_list(x,y,batch_size):    batches=[]    for i in range(int(len(x)/batch_size)):        batch_data=[x[batch_size*i:batch_size*i+batch_size],y[batch_size*i:batch_size*i+batch_size]]        batches.append(list(batch_data))    if(i+1)*batch_size<len(x):        batch_data=[x[batch_size*(i+1):],y[batch_size*(i+1):]]        batches.append(list(batch_data))    return batchesdef load_data(dir,num_classes,batch_size):    arr_col=convert_3_2_4_dims(convert(dir))    arr_col=arr_col.astype(np.float32)    #因为这儿我没指定它的标签,所以就随机指定了一些标签    z=np.random.rand(arr_col.shape[0])*num_classes    z=z.astype("int")    labels=np.array(z)    batches=batch_list(arr_col,to_categorial(labels,num_classes),batch_size)    return batchesdef data_lrn(batches):    for batch in batches:        batch[0]/=255    return batches
阅读全文
0 0
原创粉丝点击