【阿里云教程】使用PAI深度学习tensorflow读取OSS教程
来源:互联网 发布:智慧教室 知乎 编辑:程序博客网 时间:2024/06/05 04:09
在PAI上, 使用TensorFlow读取OSS文件
作者: 万千钧
转载的出处
本文适合有一定TensorFlow基础, 且准备使用PAI的同学阅读
目录
1. 如何PAI上读取数据
2. 如何减少读取的费用开支
3. 使用OSS需要注意的问题
1. 在PAI上读取数据
Python不支持读取oss的数据, 故所有调用 python Open(), os.path.exist() 等文件, 文件夹操作的函数的代码都无法执行.
如Scipy.misc.imread(), numpy.load() 等
那如何在PAI读取数据呢, 通常我们采用两种办法.
如果只是简单的读取一张图片, 或者一个文本等, 可以使用tf.gfile下的函数, 具体成员函数如下
tf.gfile.Copy(oldpath, newpath, overwrite=False) # 拷贝文件tf.gfile.DeleteRecursively(dirname) # 递归删除目录下所有文件tf.gfile.Exists(filename) # 文件是否存在tf.gfile.FastGFile(name, mode='r') # 无阻塞读写文件tf.gfile.GFile(name, mode='r') # 读写文件tf.gfile.Glob(filename) # 列出文件夹下所有文件, 支持patterntf.gfile.IsDirectory(dirname) # 返回dirname是否为一个目录tf.gfile.ListDirectory(dirname) # 列出dirname下所有文件tf.gfile.MakeDirs(dirname) # 在dirname下创建一个文件夹, 如果父目录不存在, 会自动创建父目录. 如果文件夹已经存在, 且文件夹可写, 会返回成功tf.gfile.MkDir(dirname) # 在dirname处创建一个文件夹tf.gfile.Remove(filename) # 删除filenametf.gfile.Rename(oldname, newname, overwrite=False) # 重命名tf.gfile.Stat(dirname) # 返回目录的统计数据tf.gfile.Walk(top, inOrder=True) # 返回目录的文件树
具体的文档可以参照这里(可能需要翻墙)
如果是一批一批的读取文件, 一般会采用tf.WholeFileReader() 和 tf.train.batch() / tf.train.shuffer_batch()
接下来会重点介绍常用的 tf.gfile.Glob, tf.gfile.FastGFile, tf.WholeFileReader() 和 tf.train.shuffer_batch()
读取文件一般有两步
1. 获取文件列表 2. 读取文件
如果是批量读取, 还有第三步
3. 创建batch
从代码上手: 在使用PAI的时候, 通常需要在右侧设置读取目录, 代码文件等参数, 这些参数都会通过--XXX的形式传入
tf.flags可以提供了这个功能
import tensorflow as tfimport osFLAGS = tf.flags.FLAGS# 前面的buckets, checkpointDir都是固定的, 不建议更改tf.flags.DEFINE_string('buckets', 'oss://XXX', '训练图片所在文件夹')tf.flags.DEFINE_string('batch_size', '15', 'batch大小')# 获取文件列表files = tf.gfile.Glob(os.path.join(FLAGS.buckets,'*.jpg')) # 如我想列出buckets下所有jpg文件路径
接下来就分两种情况了
1. (小规模读取时建议) tf.gfile.FastGfile()for path in files:
file_content = tf.gfile.FastGFile(path, 'rb').read() # 一定记得使用rb读取, 不然很多情况下都会报错 image = tf.image.decode_jpeg(file_content, channels=3) # 本教程以JPG图片为例
2. (大批量读取时建议) tf.WholeFileReader()
reader = tf.WholeFileReader() # 实例化一个readerfileQueue = tf.train.string_input_producer(files) # 创建一个供reader读取的队列file_name, file_content = reader.read(fileQueue) # 使reader从队列中读取一个文件image = tf.image.decode_jpeg(file_content, channels=3) # 讲读取结果解码为图片label = XXX # 这里省略处理label的过程batch = tf.train.shuffle_batch([label, image], batch_size=FLAGS.batch_size, num_threads=4, capacity=1000 + 3 * FLAGS.batch_size, min_after_dequeue=1000)sess = tf.Session() # 创建Sessiontf.train.start_queue_runners(sess=sess) # 重要!!! 这个函数是启动队列, 不加这句线程会一直阻塞labels, images = sess.run(batch) # 获取结果
解释下其中重要的部分tf.train.string_input_producer, 这个是把files转换成一个队列, 并且需要 tf.train.start_queue_runners 来启动队列
tf.train.shuffle_batch 参数解释
batch_size 批大小, 每次运行这个batch, 返回多少个数据
num_threads 运行线程数, 在PAI上4个就好
capacity 随机取文件范围, 比如你的数据集有10000个数据, 你想从5000个数据中随机取, capacity就设置成5000.
min_after_dequeue 维持队列的最小长度, 这里只要注意不要大于capacity即可
2.费用开支
这里只讨论读取文件所需要的费用开支
原则上来说, PAI不跨区域读取OSS是不收费的, 但是OSS的API是收费的. PAI在使用 tf.gile.Glob 的时候 会产生GET请求, 在写入 tensorboard 的时候, 也会产生PUT请求. 这两种请求都是按次收费的, 具体价格如下
标准型单价: 0.01元/万次
低频访问型单价: 0.1元/万次
归档型单价: 0.1元/万次
当数据集有几十万图片, 通过 tf.gile.Glob 一次就需要几毛钱. 所以减少费用开支的方法就是减少GET请求次数
这里给出几种解决思路
1. 最好的解决思路, 把所有会使用到的数据, 一并上传传到OSS, 然后使用tensorflow拷贝到运行时目录, 最后通过tensorflow读取, 这样是最节省开支的.
2. 通过tfrecords, 在本地, 提前把几十上百张图片通过tfrecords存下来, 这样读取的时候可以减少GET请求
3. 把训练使用的图片随着代码的压缩包一起传上去, 不走OSS读取
三种方法都可以显著的减少开支.
3.使用中需要注意的
事实上, 每次读取传过来的地址就是 oss://你的buckets名字/XXX, 本以为不需要在PAI界面上 设置, 直接读取这个目录就好, 事实上并不如此.
PAI没有权限读取不在数据源目录和输出目录下的文件, 所以在使用路径前, 确保他们已经在控制台右侧设置过.
另外如果需要写入文件到OSS, 可以使用 tf.gfile.fastGfile('OSS路径', 'wb').write('内容')
OSS路径推荐使用
FLAGS.checkpointDir
FLAGS.summaryDIr
这样的形式传入, 经过测试好像也只有这两个目录下有写权限, FLAGS.buckets有读权限
- 【阿里云教程】使用PAI深度学习tensorflow读取OSS教程
- 阿里云机器学习平台PAI的视频介绍(其中tensorflow高级教程有tf的代码优化讲解)
- 阿里云PAI深度学习TensorFlow图像识别例子完整流程及出错案例
- 阿里云PAI读取图像
- 深度学习tensorflow教程-DNNClassifer
- TensorFlow教程05:MNIST深度学习初探
- 深度学习之TensorFlow进阶教程一
- TensorFlow教程05:MNIST深度学习初探
- PAI深度学习Tensorflow框架多机多卡多PS Server使用说明
- 深度学习笔记——深度学习框架TensorFlow(七)[TensorFlow广度&深度教程]
- 阿里云OSS使用-Python
- 阿里云PAI
- 阿里云 机器学习pai的使用数据的使用以及模型的存储
- 阿里云服务器使用教程
- 阿里云服务器使用教程
- 阿里云虚拟主机使用教程
- 阿里云code使用教程
- 如何使用深度学习自动识别限速标志?这里有一份Keras和TensorFlow教程
- orcal 数据库基础
- cstring转char*方法, 以及wchar转char方法
- NIO之FileChannel
- spring boot+mybatis 多数据源切换
- Spark RDD 读书笔记
- 【阿里云教程】使用PAI深度学习tensorflow读取OSS教程
- Nginx是什么
- AMD OpenCL 大学课程
- 如何使用siege对接口进行性能测试
- 详解java定时任务
- 循环
- NIO之Selector
- DFS学习借鉴的博客
- sublime连接linux