nd4j mmul and repeat

来源:互联网 发布:java开发restful接口 编辑:程序博客网 时间:2024/06/05 14:45

最近使用dl4j下的nd4j做机器学习相关算法的移植任务,总结下使用nd4j矩阵运算遇到的问题和解决方法


开发平台: Win7 X64

开发工具:IntelliJ IDEA + Maven

dl4j: java下的一个机器学习开源项目

nd4j: dl4j使用的底层的算法库,实现的几乎所有的矩阵相关的操作

nd4j version:0.6.0


1. 矩阵的点成操作输出矩阵为 ‘f’ order,导致使用矩阵进行后续的repeat操作时结果不正确,尝试使用setOrder(‘c’),依然不行.

解决方法:矩阵mmul(2维矩阵点成)完成后使用dup函数复制一份即可ok.


2. INDArray的repeat操作运行效率很低,查看源码在做repeat操作是使用的矩阵length的loop完成的,不清楚如果使用了OpenBlas或者GPU后效率如何

解决方法:自己写repeat算法(假设需要repeat的维度的长度为1),步骤为:

1.将需要repeat的维度转置到第零维(permute)

2.创建临时矩阵变量,第零维的长度为repeat的次数(create)

3.使用增加Row的接口将转制后的Row添加到临时矩阵变量中,此处需要repeat 次数的loop操作(getRow,putRow)

4.将临时变量需要repeat的维度和第零维进行转置(permute)

我使用此算法将我的程序执行效率提升了30倍


3. 还有关于dimShuffle,zeros接口使用的问题,稍后更新

0 0
原创粉丝点击