CNN_CT

来源:互联网 发布:js button 不可用 编辑:程序博客网 时间:2024/06/06 10:52
#coding:utf-8import tensorflow as tfimport numpy as np#导入数据train_dataset=np.load('train_dataset.npy')train_labels=np.load('train_labels.npy')test_dataset=np.load('test_dataset.npy')test_labels=np.load('test_labels.npy')image_height=2000image_width=1600num_channels=1num_labels=2kernel_size=5def reformat(dataset,labels):    dataset=dataset.reshape((-1,image_height,image_width,num_channels)).astype(np.float32)    labels=(np.arange(num_labels)==labels[:,None].astype(np.float32))    return  dataset,labelstrain_dataset,train_labels=reformat(train_dataset,train_labels)test_dataset,train_labels=reformat(test_dataset,train_labels)print (train_dataset.shape,train_labels.shape)print (test_dataset,test_labels)def weight_variable(shape):    inital=tf.truncated_normal(shape,stddev=0.1)  #从截断的正态分布中输出随机值    return tf.Variable(inital)             #如果x的取值在区间(μ-2σ,μ+2σ)之外则重新进行选择def bias_variable(shape):    initial=tf.constant(0.1,shape=shape)  #初始化常量    return tf.Variable(initial)def conv2d(x,W):    return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')def max_pool_2x2(x):    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')"""def variable_summaries(var):    with tf.name_scope('summaries'):"""graph=tf.Graphwith graph.as_default() as graph:    tf.reset_default_graph()    with tf.name_scope('input'):        x=tf.placeholder(tf.float32,shape=[None,image_height,image_width,num_channels],name='x-input')        y_=tf.placeholder(tf.float32,shape=[None,2],name='y-input')    def cnn_layer(input_tensor,input_dim,output_dim,layer_name,act=tf.nn.relu):        with tf.name_scope(layer_name):            with tf.name_scope('weights'):                weights=weight_variable([kernel_size,kernel_size,input_dim,output_dim])            with tf.name_scope('biases'):                biases=bias_variable([output_dim])            with tf.name_scope('convolution'):                preactivate=conv2d(input_tensor,weights)+biases                activations=act(preactivate,name='activation')            with tf.name_scope('pooling'):                pooling=max_pool_2x2(activations)            print('create cnn layer')            return  pooling    conv_pool_1=cnn_layer(x,1,32,'conv_pool_1')    conv_pool_2=cnn_layer(conv_pool_1,32,64,'conv_pool2')    conv_pool_3=cnn_layer(conv_pool_2,64,128,'conv_pool_3')    conv_pool_4=cnn_layer(conv_pool_3,128,256,'conv_pool_4')    with tf.name_scope('reshape_conv_output'):        h_pool4_flat=tf.reshape(conv_pool_4,[-1,125*100*256])    h_fc1=cnn_layer(h_pool4_flat,125*100*256,1024,'densely_layer')    with tf.name_scope('dropout'):        keep_prob=tf.placeholder(tf.float32,name='keep_prob')        h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)    y_conv=cnn_layer(h_fc1_drop,1024,2,'output',act=tf.identity)    with tf.name_scope('cross_entropy'):    #交叉熵,计算损失函数        cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_conv,y_))    with tf.name_scope('train'):        train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)  #优化器    with tf.name_scope('accuracy'):         with tf.name_scope('correct_prediction'):             correct_prediction=tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))    with tf.name_scope('accuracy'):        accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))    tf.summary.scalar('accuracy',accuracy)    merged=tf.summary.merge_all()sess=tf.InteractiveSession(graph=graph)train_writer=tf.summary.FileWriter('summaries/train',sess.graph)test_writer=tf.summary.FileWriter('summaries/test',sess.graph)sess.run(tf.initialize_all_variables())print('initialized')train_batch_size=20test_batch_size=10for i in  range(801):    print('step %s'%i)    def feed_dict_train():        offset=(i*train_batch_size)%(train_labels.shape[0]-train_batch_size)        xs=train_dataset[offset:(offset+train_batch_size),:]  #???????        ys=train_labels[offset:(offset+train_batch_size),:]        return  {keep_prob:0.5,x:xs,y_:ys}    def feet_dict_test():        offset=(i*test_batch_size)%(test_labels.shape[0]-test_batch_size)        xs=test_dataset[offset:(offset+test_batch_size),:]  #??????        ys=test_labels[offset:(offset+test_batch_size),:]        return {keep_prob:1.0,x:xs,y:ys}    if i%10==0:        summary,acc=sess.run([merged,accuracy],feed_dict=feet_dict_test())        test_writer.add_summary(summary,i)        print('Accuracy at step %s:%s'%(i,acc))    elif i%20==19:        summary,_=sess.run([merged,train_step],feed_dict=feed_dict_train())        train_writer.add_summary(summary,i)        print('Adding run metadata for',i)    else:        summary,_=sess.run([merged,train_step],feed_dict=feed_dict_train())        train_writer.add_summary(summary,i)train_writer.close()test_writer.close()

代码运行结果图

补上昨天的,昨天弄了一天了,
感觉的好好看API了。

原创粉丝点击