tensorflow pretrained model

来源:互联网 发布:.net framework mac版 编辑:程序博客网 时间:2024/05/16 02:52

tensorflow 简单的加载训练好的模型,如果某一层形状不一样就自动的省略,也可以用reshape。

def optimistic_restore(session, save_file):    reader = tf.train.NewCheckpointReader(save_file)    saved_shapes = reader.get_variable_to_shape_map()    var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables()            if var.name.split(':')[0] in saved_shapes])    restore_vars = []    with tf.variable_scope('', reuse=True):        for var_name, saved_var_name in var_names:            curr_var = tf.get_variable(saved_var_name)            var_shape = curr_var.get_shape().as_list()            if var_shape == saved_shapes[saved_var_name]:                restore_vars.append(curr_var)    saver = tf.train.Saver(restore_vars)    saver.restore(session, save_file)

reference
https://github.com/tensorflow/tensorflow/issues/312

原创粉丝点击