Tensorflow1.0空间变换网络(SpatialTransformer Networks)实现
来源:互联网 发布:c语言超市管理系统 编辑:程序博客网 时间:2024/06/13 23:41
空间变换网络简单介绍:
- 通过locatnet,提取输入图像的theta(将用于仿射变换);
- 根据输入图像的width和height以及仿射变换(或者TPS)的参数theta,可以生成目标位置在输入图像(U)中对应的位置(与输入图像位置一直的目标索引);
(由torch.bmm, Batch matrix matrix product of matrices生成)。 - 根据目标在输入图像中的对应位置(索引矩阵)利用双线性插值得到目标输出。
先给大家展示一下效果:
代码部分:
import tensorflow as tfdef transformer(U, theta, out_size, name='SpatialTransformer', **kwargs): print('beigin-transformer') def _repeat(x, n_repeats): with tf.variable_scope('_repeat'): rep = tf.transpose( tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])), 1), [1, 0]) rep = tf.cast(rep, 'int32') x = tf.matmul(tf.reshape(x, (-1, 1)), rep) return tf.reshape(x, [-1]) def _interpolate(im, x, y, out_size): with tf.variable_scope('_interpolate'): # constants num_batch = tf.shape(im)[0] height = tf.shape(im)[1] width = tf.shape(im)[2] channels = tf.shape(im)[3] x = tf.cast(x, 'float32') y = tf.cast(y, 'float32') height_f = tf.cast(height, 'float32') width_f = tf.cast(width, 'float32') out_height = out_size[0] out_width = out_size[1] zero = tf.zeros([], dtype='int32') max_y = tf.cast(tf.shape(im)[1] - 1, 'int32') max_x = tf.cast(tf.shape(im)[2] - 1, 'int32') # scale indices from [-1, 1] to [0, width/height] x = (x + 1.0)*(width_f) / 2.0 y = (y + 1.0)*(height_f) / 2.0 # do sampling x0 = tf.cast(tf.floor(x), 'int32') x1 = x0 + 1 y0 = tf.cast(tf.floor(y), 'int32') y1 = y0 + 1 x0 = tf.clip_by_value(x0, zero, max_x) x1 = tf.clip_by_value(x1, zero, max_x) y0 = tf.clip_by_value(y0, zero, max_y) y1 = tf.clip_by_value(y1, zero, max_y) dim2 = width dim1 = width*height base = _repeat(tf.range(num_batch)*dim1, out_height*out_width) base_y0 = base + y0*dim2 base_y1 = base + y1*dim2 idx_a = base_y0 + x0 idx_b = base_y1 + x0 idx_c = base_y0 + x1 idx_d = base_y1 + x1 # use indices to lookup pixels in the flat image and restore # channels dim im_flat = tf.reshape(im, tf.stack([-1, channels])) im_flat = tf.cast(im_flat, 'float32') Ia = tf.gather(im_flat, idx_a) Ib = tf.gather(im_flat, idx_b) Ic = tf.gather(im_flat, idx_c) Id = tf.gather(im_flat, idx_d) # and finally calculate interpolated values x0_f = tf.cast(x0, 'float32') x1_f = tf.cast(x1, 'float32') y0_f = tf.cast(y0, 'float32') y1_f = tf.cast(y1, 'float32') wa = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1) wb = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1) wc = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1) wd = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1) output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id]) return output def _meshgrid(height, width): print('begin--meshgrid') with tf.variable_scope('_meshgrid'): # This should be equivalent to: # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width), # np.linspace(-1, 1, height)) # ones = np.ones(np.prod(x_t.shape)) # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones]) x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])), tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0])) print('meshgrid_x_t_ok') y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1), tf.ones(shape=tf.stack([1, width]))) print('meshgrid_y_t_ok') x_t_flat = tf.reshape(x_t, (1, -1)) y_t_flat = tf.reshape(y_t, (1, -1)) print('meshgrid_flat_t_ok') ones = tf.ones_like(x_t_flat) print('meshgrid_ones_ok') print(x_t_flat) print(y_t_flat) print(ones) grid = tf.concat( [x_t_flat, y_t_flat, ones],0) print ('over_meshgrid') return grid def _transform(theta, input_dim, out_size): print('_transform') with tf.variable_scope('_transform'): num_batch = tf.shape(input_dim)[0] height = tf.shape(input_dim)[1] width = tf.shape(input_dim)[2] num_channels = tf.shape(input_dim)[3] theta = tf.reshape(theta, (-1, 2, 3)) theta = tf.cast(theta, 'float32') # grid of (x_t, y_t, 1), eq (1) in ref [1] height_f = tf.cast(height, 'float32') width_f = tf.cast(width, 'float32') out_height = out_size[0] out_width = out_size[1] grid = _meshgrid(out_height, out_width) grid = tf.expand_dims(grid, 0) grid = tf.reshape(grid, [-1]) grid = tf.tile(grid, tf.stack([num_batch])) grid = tf.reshape(grid, tf.stack([num_batch, 3, -1])) #tf.batch_matrix_diag # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s) print('begin--batch--matmul') T_g = tf.matmul(theta, grid) x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1]) y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1]) x_s_flat = tf.reshape(x_s, [-1]) y_s_flat = tf.reshape(y_s, [-1]) input_transformed = _interpolate( input_dim, x_s_flat, y_s_flat, out_size) output = tf.reshape( input_transformed, tf.stack([num_batch, out_height, out_width, num_channels])) print('over_transformer') return output with tf.variable_scope(name): output = _transform(theta, U, out_size) return outputdef batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'): with tf.variable_scope(name): num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2]) indices = [[i]*num_transforms for i in xrange(num_batch)] input_repeated = tf.gather(U, tf.reshape(indices, [-1])) return transformer(input_repeated, thetas, out_size)Demo代码如下:
from scipy import ndimageimport tensorflow as tffrom STN_tf_01 import transformerimport numpy as npimport matplotlib.pyplot as pltimport cv2im=ndimage.imread('C:\\Users\\hasee\\Desktop\\cat.jpg')im=im/255.#im=tf.reshape(im, [1,1200,1600,3])im=im.reshape(1,1200,1600,3)im=im.astype('float32')print('img-over')out_size=(600,800)batch=np.append(im,im,axis=0)batch=np.append(batch,im,axis=0)num_batch=3x=tf.placeholder(tf.float32,[None,1200,1600,3])x=tf.cast(batch,'float32')print('begin---')with tf.variable_scope('spatial_transformer_0'): n_fc=6 w_fc1=tf.Variable(tf.Variable(tf.zeros([1200*1600*3,n_fc]),name='W_fc1')) initial=np.array([[0.5,0,0],[0,0.5,0]]) initial=initial.astype('float32') initial=initial.flatten() b_fc1=tf.Variable(initial_value=initial,name='b_fc1') h_fc1=tf.matmul(tf.zeros([num_batch,1200*1600*3]),w_fc1)+b_fc1 print(x,h_fc1,out_size) h_trans=transformer(x,h_fc1,out_size) sess=tf.Session()sess.run(tf.global_variables_initializer())y=sess.run(h_trans,feed_dict={x:batch})plt.imshow(y[0])plt.show()
1 0
- Tensorflow1.0空间变换网络(SpatialTransformer Networks)实现
- 论文笔记:Spatial Transformer Networks(空间变换网络)
- 论文笔记:Spatial Transformer Networks(空间变换网络)
- tensorflow1.0 LSTM实现
- Spatial Transformer Networks(空间变换神经网络)
- Spatial Transformer Networks(空间变换神经网络)
- 空间映射网络--Spatial Transformer Networks
- 空间映射网络--Spatial Transformer Networks
- 深度学习(六十三)空间变换网络
- 空间变换网络--spatial transform network
- (Tensorflow1.0)强化学习实现游戏AI(Demo_1)
- Anaconda+tensorflow1.0安装
- tensorflow1.0安装
- ubuntu安装tensorflow1.0
- TensorFlow1.0安装
- Mac安装tensorflow1.0
- tensorflow1.0变化
- 安装tensorflow1.0
- AI学习之路(13): 创建随机张量3
- 海思hi3516、hi3519 中 online 与 offline 有什么区别
- 2017-03-01
- HALCON常用算子(HALCON13.0)
- iOS 多个网络请求并行/并发处理
- Tensorflow1.0空间变换网络(SpatialTransformer Networks)实现
- 蚁群算法解决TSP问题
- 数学——Tikonov stablizer
- 阿里云LNMPA的SSL安装与配置
- BarTender出现条码打印位置不准的情况怎么办
- android事件分发
- CentOS7安装问题
- preventDefault()、stopPropagation()、return false 之间的区别
- AbstractPlatformTransactionManager(Spring事务底层核心类)API讲解翻译