可视化tensorflow中间层

来源:互联网 发布:蜂窝网络有电信的吗 编辑:程序博客网 时间:2024/06/08 02:42

如何可视化使用tensorflow框架的网络中间层呢,网上找的答案都是使用tensor board,但是我想的是将网络的中坚层用一张一张图片显示并保存下来,下面附上代码:

附上一句,这里导入的库都是从faster rcnn tensorflow版本里面导入的,如果想用这个代码,需要clone faster_rcnn的代码下来才能用!

#-*-coding:utf-8-*-
#此函数用来可视化tensorflow中间层
import _init_paths
from fast_rcnn.config import cfg
import argparse
from utils.timer import Timer
import numpy as np
import cv2
import sys
from utils.cython_nms import nms
from utils.transform import lidar_3d_to_corners, corners_to_bv, lidar_cnr_to_img_single, lidar_cnr_to_img
from utils.draw import show_lidar_corners, show_image_boxes, scale_to_255
import cPickle
from utils.blob import im_list_to_blob
import os

import math
from networks.factory import get_network
import tensorflow as tf
import matplotlib.pyplot as plt
import time
import mayavi.mlab as mlab
from utils.draw import draw_lidar, draw_gt_boxes3d
#from fast_rcnn.test_mv import box_detect
from datasets.factory import get_imdb
from fast_rcnn.train_mv import get_training_roidb
import argparse

def parse_args():#用来解析从命令行输入的参数
    parser = argparse.ArgumentParser(description='Visulize the MV3D network')
    parser.add_argument('--ceng',dest='ceng',help='which ceng to visulize',default='None',type=str)
    parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
                        default=0, type=int)
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',
                        default='MV3D_test')
    parser.add_argument('--model', dest='model', help='Model path',
                        default='/home/lingck/MV3D_TF/output/faster_rcnn_end2end/train/VGGnet_fast_rcnn_iter_50000.ckpt.meta')
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)

    args = parser.parse_args()
    return args

def check_file(path):
    if not os.path.exists(path):
        os.mkdir(path)


def visulize(sess, net, im,image_name, bv, calib,  ceng):
    """Detect object classes in an lidar bv  given object proposals.
    Arguments:
        net (caffe.Net): Fast R-CNN network to use
        bv (ndarray): lidar bv to test
        boxes (ndarray): R x 4 array of object proposals
    Returns:
        scores (ndarray): R x K array of object class scores (K includes
            background as object category 0)
        boxes (ndarray): R x (4*K) array of predicted bounding boxes
    """


    im_blob = im - cfg.PIXEL_MEANS #图像减去均值
    lidar_bv_blob = bv

    im_blob = im_blob.reshape((1, im_blob.shape[0], im_blob.shape[1], im_blob.shape[2]))#(1,375,1242,3)
    lidar_bv_blob = lidar_bv_blob.reshape((1, lidar_bv_blob.shape[0], lidar_bv_blob.shape[1], lidar_bv_blob.shape[2]))#(1,601,601,9)

    blobs = {'image_data': im_blob,
             'lidar_bv_data': lidar_bv_blob}

    im_scales = [1]

    blobs['calib'] = calib
    bv_blob = blobs['lidar_bv_data']
    blobs['im_info'] = np.array(
        [[bv_blob.shape[1], bv_blob.shape[2], im_scales[0]]],
        dtype=np.float32)#此处因为输入的是鸟瞰图,故此处的信息为鸟瞰图的信息
    # forward pass
    #conv1_1 = net.get_output('conv1_1')
    feed_dict={net.lidar_bv_data: blobs['lidar_bv_data'],
               net.image_data: blobs['image_data'],
               net.im_info: blobs['im_info'],
               net.calib: blobs['calib'],
               net.keep_prob: 1.0}

    CENG = sess.run(net.get_output(ceng),feed_dict = feed_dict)
    
    print "ceng->shape:",CENG.shape
    # # print deconv1.shape
    activation = CENG
    img_path = '/home/lingck/MV3D_TF/tools/vis_feat/' + str(image_name) + '/'
    check_file(img_path)
    feat_path=img_path + str(ceng) + '/'
    check_file(feat_path)
    #feat_path = '/home/lingck/MV3D_TF/tools/vis_feat'
    # # featuremaps = activation.shape[3]
    featuremaps = 48
    plt.figure(1, figsize=(15,15))
    for featuremap in range(featuremaps):
         plt.figure() # sets the number of feature maps to show on each row and column
         plt.title('FeatureMap ' + str(featuremap)) # displays the feature map number
         plt.axis('off')
         plt.imshow(activation[0,:,:, featuremap], interpolation="nearest", cmap="jet")
         plt.savefig(feat_path + ceng+'_'+str(featuremap) + ".jpg", dpi = 400, bbox_inches = "tight")
    plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.01, hspace=0.01)
    plt.tight_layout(pad=0.1, h_pad=0.001, w_pad=0.001)
    #plt.show()


def make_calib(calib_dir):


    with open(calib_dir) as fi:
        lines = fi.readlines()

    obj = lines[2].strip().split(' ')[1:]
    P2 = np.array(obj, dtype=np.float32)
    obj = lines[3].strip().split(' ')[1:]
    P3 = np.array(obj, dtype=np.float32)
    obj = lines[4].strip().split(' ')[1:]
    R0 = np.array(obj, dtype=np.float32)
    obj = lines[5].strip().split(' ')[1:]
    Tr_velo_to_cam = np.array(obj, dtype=np.float32)

    calib = np.empty((4, 12))
    calib[0,:] = P2.reshape(12)
    calib[1,:] = P3.reshape(12)
    calib[2,:9] = R0.reshape(9)
    calib[3,:] = Tr_velo_to_cam.reshape(12)

    return calib


def make_bird_view(velo_file):

    #print("Processing: ", velo_file)
    scan = np.fromfile(velo_file, dtype=np.float32)
    scan = scan.reshape((-1, 4))
    bird_view = []
    return scan, bird_view

def demo(sess, net, root_dir,im_name, image_name, ceng):
    """Test a Fast R-CNN network on an image database."""

    # Load the demo image

    im_file = os.path.join(root_dir, 'image_2' , image_name+'.png')
    velo_file = os.path.join(root_dir, 'velodyne', image_name+'.bin')
    calib_file = os.path.join(root_dir, 'calib', image_name+'.txt')
    bv_file = os.path.join(root_dir, 'lidar_bv', image_name+'.npy')
    

    im = cv2.imread(im_file)
    velo = make_bird_view(velo_file)[0]
    bv = np.load(bv_file)
    calib = make_calib(calib_file)
    image_name = image_name
    #timer = Timer()
    #timer.tic()
    visulize(sess, net, im,im_name, bv, calib, ceng)
    #timer.toc()

    


if __name__=='__main__':

    cfg.TEST.HAS_RPN = True
    args = parse_args()
    
    if args.model == ' ':
        raise IOError(('Error: Model not found.\n'))

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) #创建一个sess会话
    # 如果你指定的设备不存在,允许TF自动分配设备,设置gpu
    # load network
    net = get_network(args.demo_net)#构建网络
    #print net
    # load model
    saver = tf.train.Saver(max_to_keep=5)#构建一个saver对象来管理checkpoint文件
    net.load(args.model, sess, saver, True)
    #print '\n\nLoaded network {:s}'.format(args.model)

    # im_names = ['000456', '000542', '001150',
    #             '001763', '004545']

    root_dir = '/home/lingck/MV3D_TF/data/KITTI/object/training'

    
    im_name = 21   
    print '~~~~~~~~~~~~visulize feature map~~~~~~~~~~~~~~'
    print 'Demo for data/demo/{}'.format(im_name)
    demo(sess, net, root_dir,im_name, str(im_name).zfill(6),args.ceng)
    print"~~~DONE~~~"

    #plt.show()

此函数的重点部分在于visulize函数,中间层的名字要和网络中层的名字对应上。

下面看几张效果图



















原创粉丝点击