keras使用稀疏矩阵输入

来源:互联网 发布:sql update 多行数据 编辑:程序博客网 时间:2024/05/21 15:01

Keras的数据存储是Numpy系的方式,本身不支持稀疏矩阵的输入输出。这个问题就很尴尬,高维数据就难在小内存的机器上运行。

解决方案也是有的,Keras内部是Theano或Tensorflow的,这两个都支持稀疏矩阵的输入输出,所以就是可以解决的。解决方案主要是参考这里:http://www.jianshu.com/p/bf1b637acf5a


1. 数据维数不算太高:

用稀疏矩阵读入之后,用X.todense()把X从scipy.csr式的存储转成numpy.veriable就可以用了。

2. 数据维数很高:

数据高维时用X.todense()会报MemeryError错。需要修改Keras源码:keras/engine/train.py

首先以稀疏scipy.csc_matrix作为fit或train的输入。

line: 766

len(ins[0]) -> ins[0].shape[0]

len(val_ins[0]) -> val_ins[0].shape[0]

line: 872

nb_sample = len(ins[0]) -> nb_sample = ins[0].shape[0]

line: 925

同line 872

line: 815

在slice_X()函数后面加入Code如下

import scipy

from scipy.sparse import csc_matrix, csr_matrix

if scipy.sparse.issparse(ins_batch[0]):

ins_batch[0] = ins_batch[0].toarray()

if scipy.sparse.issparse(ins_batch[1]):

ins_batch[1] = ins_batch[1].toarray()

line: 886

同815 在slice_X()函数后面加入上述Code

line: 929

同815

0 0