Tensorflow 05: 导入预训练好的图模型

来源:互联网 发布:淘宝上的东西是正品吗 编辑:程序博客网 时间:2024/05/21 02:48

Tensorfow:导入.pb文件

示例代码

def create_model_graph(model_info):  """"  Creates a graph from saved GraphDef file and returns a Graph object.  Args:    model_info: Dictionary containing information about the model architecture.  Returns:    Graph holding the trained Inception network, and various tensors we'll be    manipulating.  """  with tf.Graph().as_default() as graph:    model_path = os.path.join(FLAGS.model_dir, model_info['model_file_name'])    with gfile.FastGFile(model_path, 'rb') as f:      graph_def = tf.GraphDef()      graph_def.ParseFromString(f.read())      bottleneck_tensor, resized_input_tensor = (tf.import_graph_def(          graph_def,          name='',          return_elements=[              model_info['bottleneck_tensor_name'],              model_info['resized_input_tensor_name'],          ]))  return graph, bottleneck_tensor, resized_input_tensor

相关API含义

  1. gfile.FastGFile:
     google的文件操作,和python 里面的open函数功能类似

  2. import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None):
     功能:导入参数graph_def中定义的tensorflow graph模型;
       Imports the TensorFlow graph in graph_def into the Python Graph
     返回值:返回参数return_elements中定义的一系列的Operation和Tensor对象。
       A list of
    Operationand/orTensorobjects from the imported graph,
    corresponding to the names in
    return_elements`.

阅读全文
0 0
原创粉丝点击