CNN_CT
来源:互联网 发布:js button 不可用 编辑:程序博客网 时间:2024/06/06 10:52
#coding:utf-8import tensorflow as tfimport numpy as np#导入数据train_dataset=np.load('train_dataset.npy')train_labels=np.load('train_labels.npy')test_dataset=np.load('test_dataset.npy')test_labels=np.load('test_labels.npy')image_height=2000image_width=1600num_channels=1num_labels=2kernel_size=5def reformat(dataset,labels): dataset=dataset.reshape((-1,image_height,image_width,num_channels)).astype(np.float32) labels=(np.arange(num_labels)==labels[:,None].astype(np.float32)) return dataset,labelstrain_dataset,train_labels=reformat(train_dataset,train_labels)test_dataset,train_labels=reformat(test_dataset,train_labels)print (train_dataset.shape,train_labels.shape)print (test_dataset,test_labels)def weight_variable(shape): inital=tf.truncated_normal(shape,stddev=0.1) #从截断的正态分布中输出随机值 return tf.Variable(inital) #如果x的取值在区间(μ-2σ,μ+2σ)之外则重新进行选择def bias_variable(shape): initial=tf.constant(0.1,shape=shape) #初始化常量 return tf.Variable(initial)def conv2d(x,W): return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')def max_pool_2x2(x): return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')"""def variable_summaries(var): with tf.name_scope('summaries'):"""graph=tf.Graphwith graph.as_default() as graph: tf.reset_default_graph() with tf.name_scope('input'): x=tf.placeholder(tf.float32,shape=[None,image_height,image_width,num_channels],name='x-input') y_=tf.placeholder(tf.float32,shape=[None,2],name='y-input') def cnn_layer(input_tensor,input_dim,output_dim,layer_name,act=tf.nn.relu): with tf.name_scope(layer_name): with tf.name_scope('weights'): weights=weight_variable([kernel_size,kernel_size,input_dim,output_dim]) with tf.name_scope('biases'): biases=bias_variable([output_dim]) with tf.name_scope('convolution'): preactivate=conv2d(input_tensor,weights)+biases activations=act(preactivate,name='activation') with tf.name_scope('pooling'): pooling=max_pool_2x2(activations) print('create cnn layer') return pooling conv_pool_1=cnn_layer(x,1,32,'conv_pool_1') conv_pool_2=cnn_layer(conv_pool_1,32,64,'conv_pool2') conv_pool_3=cnn_layer(conv_pool_2,64,128,'conv_pool_3') conv_pool_4=cnn_layer(conv_pool_3,128,256,'conv_pool_4') with tf.name_scope('reshape_conv_output'): h_pool4_flat=tf.reshape(conv_pool_4,[-1,125*100*256]) h_fc1=cnn_layer(h_pool4_flat,125*100*256,1024,'densely_layer') with tf.name_scope('dropout'): keep_prob=tf.placeholder(tf.float32,name='keep_prob') h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob) y_conv=cnn_layer(h_fc1_drop,1024,2,'output',act=tf.identity) with tf.name_scope('cross_entropy'): #交叉熵,计算损失函数 cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_conv,y_)) with tf.name_scope('train'): train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) #优化器 with tf.name_scope('accuracy'): with tf.name_scope('correct_prediction'): correct_prediction=tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1)) with tf.name_scope('accuracy'): accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) tf.summary.scalar('accuracy',accuracy) merged=tf.summary.merge_all()sess=tf.InteractiveSession(graph=graph)train_writer=tf.summary.FileWriter('summaries/train',sess.graph)test_writer=tf.summary.FileWriter('summaries/test',sess.graph)sess.run(tf.initialize_all_variables())print('initialized')train_batch_size=20test_batch_size=10for i in range(801): print('step %s'%i) def feed_dict_train(): offset=(i*train_batch_size)%(train_labels.shape[0]-train_batch_size) xs=train_dataset[offset:(offset+train_batch_size),:] #??????? ys=train_labels[offset:(offset+train_batch_size),:] return {keep_prob:0.5,x:xs,y_:ys} def feet_dict_test(): offset=(i*test_batch_size)%(test_labels.shape[0]-test_batch_size) xs=test_dataset[offset:(offset+test_batch_size),:] #?????? ys=test_labels[offset:(offset+test_batch_size),:] return {keep_prob:1.0,x:xs,y:ys} if i%10==0: summary,acc=sess.run([merged,accuracy],feed_dict=feet_dict_test()) test_writer.add_summary(summary,i) print('Accuracy at step %s:%s'%(i,acc)) elif i%20==19: summary,_=sess.run([merged,train_step],feed_dict=feed_dict_train()) train_writer.add_summary(summary,i) print('Adding run metadata for',i) else: summary,_=sess.run([merged,train_step],feed_dict=feed_dict_train()) train_writer.add_summary(summary,i)train_writer.close()test_writer.close()
补上昨天的,昨天弄了一天了,
感觉的好好看API了。