resnet__残差神经网络搭建
来源:互联网 发布:数据库系统概论王珊 编辑:程序博客网 时间:2024/06/05 04:36
# -*- coding: utf-8 -*-import tensorflow as tffrom collections import namedtuplefrom math import sqrtdef print_activations(t): print(t.op.name,t.get_shape().as_list())def conv2d(x,n_fliters,k_h=5,k_w=5, stride_h=2,stride_w=2, stddev=0.02,activation=lambda x:x, bias=True,padding='SAME',name="Conv2D"): with tf.variable_scope(name): w=tf.get_variable('weight',[k_h,k_w,x.get_shape()[-1],n_fliters], initializer=tf.truncated_normal_initializer(stddev=stddev)) tf.summary.histogram(name+'weight',w) conv=tf.nn.conv2d(x,w,strides=[1,stride_h,stride_w,1],padding=padding) if bias: b=tf.get_variable('bias',[n_fliters], initializer=tf.truncated_normal_initializer(stddev=stddev)) tf.summary.histogram(name+'bias',b) conv=conv+b print_activations(conv) return activation(conv)def linear(x,n_units,scope=None,stddev=0.02, activation=tf.identity): shape=x.get_shape().as_list() with tf.variable_scope(scope or "linear"): weight =tf.get_variable("weight",[shape[1],n_units],tf.float32, tf.random_normal_initializer(stddev=stddev)) tf.summary.histogram('weight',weight) bias=tf.get_variable('bias',[n_units],tf.float32,tf.random_normal_initializer(stddev=stddev)) tf.summary.histogram(tf.matmul(x,weight)+bias)def ResNet(x,n_outputs,activation=tf.nn.relu): LayerBlock=namedtuple('LayerBlock',['num_repeats','num_fiters','bottleneck_size']) #创建Block的类只包含数据结构,不包含具体方法。 blocks=[LayerBlock(3,128,32), LayerBlock(3,256,64), LayerBlock(3,512,128), LayerBlock(3,1024,256), LayerBlock(3,2048,512), LayerBlock(3,4096,1024)] input_shape=x.get_shape().as_list() if len(input_shape)==2: ndim=int(sqrt(input_shape[1])) if ndim*ndim !=input_shape[1]: raise ValueError('input_shape should be square') x=tf.reshape(x,[-1,ndim,ndim,1]) tf.summary.image('input',x,10) net=conv2d(x,64,k_h=7,k_w=7,name='conv1',activation=activation) #第一卷积扩展到64个信道和下采样 net=tf.nn.max_pool(net,[1,2,2,1],strides=[1,2,2,1],padding='SAME') print_activations(net) net=conv2d(net,blocks[0].num_fiters,k_h=1,k_w=1, stride_h=1,stride_w=1,padding='VAlID',name='conv2') #建设残差神经网络 for blocks_i,block in enumerate(blocks): #循环 res blocks for repeat_i in range(block.num_repeats): name='block_%d/repeat_%d'%(blocks_i,repeat_i) conv=conv2d(net,block.bottleneck_size,k_h=1,k_w=1, padding='VALID',stride_h=1,stride_w=1, activation=activation,name=name+'/conv_in') conv=conv2d(conv,block.bottleneck_size,k_h=3,k_w=3, padding='VALID',stride_h=1,stride_w=1, activation=activation, name=name+'/conv_bottleneck') conv=conv2d(conv,block.num_fiters,k_h=1,k_w=1, padding='VALID',stride_h=1,stride_w=1, activation=activation, name=name+'/conv_out') net=conv+net try: next_block=blocks[blocks_i+1] net=conv2d(net,next_block.num_fiters,k_h=3,k_w=3, padding='SAME',stride_h=1,stride_w=1, name='blcok_%d/conv_upscale' % blocks_i) except IndexError: pass net=tf.nn.avg_pool(net,ksize=[1,net.get_shape().as_list()[1],net.get_shape().as_list()[2],1], strides=[1,1,1,1],padding='VALID') print_activations(net) net=tf.reshape(net,[-1,net.get_shape().as_list()[1]*net.get_shape().as_list()[2],1], strides=[1,1,1,1],padding='VALID') print_activations(net) net=linear(net,n_outputs) return net
阅读全文
0 0
- resnet__残差神经网络搭建
- 残差神经网络
- 卷积神经网络残差计算
- CNN 卷积神经网络-- 残差计算
- 【Learning Notes】基于 boosting 原理训练深层残差神经网络
- 变种神经网络的典型代表:深度残差网络
- 转载:变种神经网络的典型代表:深度残差网络
- 深度三维残差神经网络:视频理解新突破
- ICCV | 深度三维残差神经网络:视频理解新突破
- ICCV | 深度三维残差神经网络:视频理解新突破
- Coursera Deep Learning 第四课 卷积神经网络 第二周 编程作业 残差神经网络 Residual Networks
- 使用Keras搭建深度残差网络
- 一文理解深度学习,卷积神经网络,循环神经网络的脉络和原理3-残差神经网络
- Deep Learning-TensorFlow (14) CNN卷积神经网络_深度残差网络 ResNet
- 深度学习之神经网络结构——残差网络ResNet
- Deep Spatio-Temporal Residual Networks(深度时空残差神经网络)
- 从零开始搭建 ResNet 之 残差网络(持续更新)
- tensorflow搭建自己的残差网络(ResNet)
- 网络编程-Socket
- POJ 3669 Meteor Shower
- [SCOI2005] BZOJ 1087 互不侵犯King
- Java基础(三):Java集合总结
- onselect事件在表单元素中的使用
- resnet__残差神经网络搭建
- NOIP 提高组 2006
- JavaScript对象创建的几种方式
- hibernate框架笔记之检索策略
- 使用Spring Boot框架导致存入汉字到MySQL数据库为乱码解决方案
- NYOJ【62】笨小熊【字符串】&&【统计元素】
- C++基础知识积累
- Struts2配置流程及处理请求过程
- The Accomodation of Students(HDU-2444)(二分图判定与最大匹配)