【Tensorflow】怎样为你的网络预加工和打包训练数据?(二):小数据集的处理方案

来源:互联网 发布:淘宝设置不包邮地区 编辑:程序博客网 时间:2024/05/17 07:22

实验环境:python2.7

第二篇我们来讲一讲小数据集的处理方法,小数据集一般多以文本存储为主,csv是一种流行的数据格式,另外也有txt等。当然也会有.mat或者.npy这种经过处理的格式。


一.处理csv格式数据集

实验数据集是鸢尾花卉数据集iris,格式是.csv,需要的同学可以到这里下载

为了工程需要我直接介绍读取该类型数据的最快方法,通过一些库,我们是可以用很少的步骤就读取进来训练的,这里用到的是一个各种数据操作方法的集合库,pandas。

下载pandas:

sudo pip install pandas

然后导入:

import pandas

使用read_csv函数快速读取一个csv文件,到底有多方便?一句话就够了

data = pandas.read_csv("iris.csv")

此时返回的data我们可以看看它是长什么样的:

我们再对比一下,csv文件中的数据:

这时候你应该发现问题了,读取csv的时候默认把第一行作为列标题读进来了,导致后续的数据就不对了,显然一句话搞定的东西会出现很多问题。注意数据集的特殊性,iris数据集是不带有标题列的,所以我们就要说明一下,添加这一个参数:

data = pandas.read_csv("iris.csv", header=None)


现在输出就对了,可以看到系统自动为列生成了一组索引,当然我们可以自定义索引的名字:

data = pandas.read_csv("iris.csv", header=None, prefix='col')


在数字前面加字符串

也可以分别指定具体的名字:

data = pandas.read_csv("iris.csv", header=None,                        names=['atr1','atr2','atr3','atr4','label'])


让我们打印数据的格式看看:

print type(data)print type(data["atr1"])print type(data["atr1"][0])


可以看到具体元素的值是numpy的,但是其余的都还是pandas的自带格式,怎么转换呢,如下:

train_data = data.as_matrix(columns=['atr1','atr2','atr3','atr4'])label = data.as_matrix(columns=['label'])print train_data,label

这样我们就把指定的几列转换为numpy数组了,但是,还是会出现一个问题,读取csv默认的元素type是np.float64,也就是说label也是np.float64类型的,处理方案可以对读取完毕的numpy数组处理,也可以读取的时候处理,如下:

data = pandas.read_csv("iris.csv", header=None,                        names=['atr1','atr2','atr3','atr4','label'],                        dtype={'label':np.int8})

完整程序如下,这里我用了np.squeeze来去掉长度为1的维度,这个应该好理解:

import pandasimport numpy as npdata = pandas.read_csv("iris.csv", header=None,                        names=['atr1','atr2','atr3','atr4','label'],                        dtype={'label':np.int8})train_data = data.as_matrix(columns=['atr1','atr2','atr3','atr4'])label = data.as_matrix(columns=['label'])label = np.squeeze(label)

就这么几行,数据集就导入了!


二.txt的处理方法

和上面类似,txt文件也是可以用read_csv来处理的,因为两者的根本区别只是分隔符不同而已,举一个例子:在我的用tensorflow实现usps和mnist数据集的迁移学习使用到的数据集usps,我们将它下载下来,手工删除第一行10 256的分类说明和尾行的-1


因为这两行会影响我们结果的生成,然后调用:

data = pandas.read_csv("usps_train.jf", sep='\s+', header=None)

数据就生成好了,这里我们指定了sep分割符的类型是空格或者多于一个空格,总共7291个样本,第一列为标签,后面256列分别表示像素值。

当然你也可以像我在用tensorflow实现usps和mnist数据集的迁移学习中的做法一样,用python原生的方法读取,秀一秀你的代码技术偷笑偷笑,但是做工程的话,还是以方便为主,一句话就搞定的事,何乐而不为呢?


三.延伸

补充一下,遇到csv较大内存不够的情况,可以尝试使用read_csv中的分成chunk分块读取的方案,这里我就不描述了(搞deep learning的我相信大家的内存都很大,不会被小小几个G难住吧,哈哈)

附上分块读取的解决方案,和read_csv函数参数的详解

参数详解:http://www.cnblogs.com/datablog/p/6127000.html

分块读取csv:http://blog.csdn.net/zm714981790/article/details/51375475



阅读全文
0 0