阿里云 机器学习pai的使用数据的使用以及模型的存储

来源:互联网 发布:网络设计收获与体会 编辑:程序博客网 时间:2024/05/16 03:30

1.数据的使用  读取pickle

import osimport sysimport argparseimport tensorflow as tfimport picklefrom tensorflow.python.lib.io import file_ioFLAGS = Nonedef main(_):      dir = os.path.join(FLAGS.buckets, 'Parsing.pickle')    object = file_io.read_file_to_string(dir,True)    result = pickle.loads(object)    training_records = result['training']    validation_records = result['validation']    print(len(training_records))    print("good")if __name__ == '__main__':    parser = argparse.ArgumentParser()    parser.add_argument('--buckets', type=str, default='',                        help='input data path')    parser.add_argument('--checkpointDir', type=str, default='',                        help='output model path')    FLAGS, _ = parser.parse_known_args()    tf.app.run(main=main)
注意点1:buckets的定义,而且是缺省值不用定义具体的oss地址

注意点2:使用tensorflow进行读取,Python的open方法在pai上不能使用

注意点3:pickle存储dump时协议要用2,以为pai上的Python是2.7

2.模型的存储

import tensorflow as tfFLAGS = tf.flags.FLAGStf.flags.DEFINE_string("checkpointDir", "model/test.ckpt", "path to logs directory")w1 = tf.placeholder("float", name="w1")w2 = tf.placeholder("float", name="w2")b1= tf.Variable(2.0,name="bias")feed_dict ={w1:4,w2:8}w3 = tf.add(w1,w2)w4 = tf.multiply(w3,b1,name="op_to_restore")sess = tf.Session()sess.run(tf.global_variables_initializer())saver = tf.train.Saver()print (sess.run(w4,feed_dict))saver.save(sess,FLAGS.checkpointDir)
注意点1:要定义checkpointDir



阅读全文
0 0
原创粉丝点击