PyTorch的学习笔记-torch package
来源:互联网 发布:windows虚拟内存开启 编辑:程序博客网 时间:2024/06/05 12:42
主要内容源于对PyTorch的doc的阅读: Doc
仅记录了我认为比较常用和有用的API。
torch.is_tensor(obj),若obj为Tensor类型,那么返回True。
torch.numel(obj),返回Tensor对象中的元素总数。
torch.eye(n),返回一个单位方阵,和MATLAB的eye()非常像。还有其他参数。
torch.from_numpy(obj),利用一个numpy的array创建Tensor。注意,若obj原来是1列或者1行,无论obj是否为2维,所生成的Tensor都是一阶的,若需要2阶的Tensor,需要利用view()函数进行转换。
torch.linspace(start, end, steps),返回一个1维的Tensor。
torch.ones(),与MATLAB的ones很接近。
torch.ones_like(input),返回一个全1的Tensor,其维度与input相一致。
torch.arange(start, end, step),直接返回一个Tensor而不是一个迭代器。
torch.zeros(),与MATLAB的zeros很像。
torch.zeros_like(),与torch.ones_like()类似。
torch.cat(seq, dim),将tuple seq中描述的Tensor进行连接,通过实例说明用法。
torch.chunk(input, chunks, dim),与torch.cat()的作用相反。注意,返回值的数量会随chunks的值而发生变化.
torch.index_select(input, dim, index),注意,index是一个1D的Tensor。
torch.masked_select(input, mask),有点像MATLAB中利用bool类型矩阵进行索引的功能,要求mask是ByteTensor类型的Tensor。参考示例代码。注意,执行结果是一个1D的Tensor。
torch.squeeze(input),将input中维度数值为1的维度去除。可以指定某一维度。结果是共享input的内存的。
torch.t(input),将input进行转置,不是in place。输出的结果是共享内存的。要求input为2D。
torch.unsqeeze(input, dim),在input目前的dim维度上增加一维。
好多random sampling的函数借口,还有inplace的。
torch.save()和torch.load()
不常见的运算函数
torch.clamp(input, min, max),将input的值约束在min和max之间
torch.trunc(input),将input的小数部分舍去。
torch.norm()
还有一些统计功能的函数。
torch.eq(input, other),返回一个Tensor。
torch.equal(input, other),返回True,False。
还有一些用于比较的函数,包括ne(), kthmin(), topk()
torch.grad(),与MATLAB的diag可能不同,这个函数将返回一个与原Tensor维度相同的Tensor。
torch.trace(),
torch.tril()和torch.triu(),返回下三角和上三角Tensor。
有一些用于batch上乘法,加法的函数。
torch.btriface()和torch.btrisolve(),LU分解和线性求解。
torch.dot(), torch.eig(), torch.inverse(), torch.matmul(), torch.mv()等函数。有各种decomposition的函数。
简单示例代码
import numpy as npimport torchimport copyif __name__ == '__main__':t1 = torch.randn(4, 4)n2 = np.random.rand(3)n2 = n2.reshape(3, 1)tn2 = torch.from_numpy(n2)n3 = copy.deepcopy(n2)n3 = n3.reshape(1,3)tn3 = torch.from_numpy(n3)# cat().t4 = torch.cat( (tn3, tn3), dim = 0 )t5 = torch.cat( (tn3, tn3), dim = 1 )t4_2 = torch.cat( (t4, tn3), dim = 0 )# thunk().t6_0, t6_1 = torch.chunk( t1, 2, dim = 0 )t7_0, t7_1 = torch.chunk( t1, 2, dim = 1 )# masked_select().t8_mask = t1.ge(0.1)t8 = torch.masked_select( t1, t8_mask )
- PyTorch的学习笔记-torch package
- 莫烦PyTorch学习笔记(一)——Torch或Numpy
- PyTorch的concat也就是torch.cat实例
- PyTorch学习总结(二)——基于torch.utils.ffi的自定义C扩展
- Pytorch学习笔记(一):pytorch的安装-Ubuntu14.04
- 深度学习笔记1torch的安装
- 深度学习笔记4torch的rnn
- torch入门笔记15:nn package详解
- torch学习笔记
- torch学习笔记
- torch学习笔记<一>
- Torch学习笔记
- Torch学习笔记
- Torch学习笔记
- Torch学习笔记
- Torch学习笔记
- torch学习(六) rnn package
- pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解
- WCF netTcpBinding:如何将net.tcp协议寄宿到IIS
- Visual Studio 2013 Tools for Unity
- This ZooKeeper instance is not currently serving requests
- html中td或th标签的宽度设置
- Spring--《Spring实战》The temporary upload location [/tmp/uploads] is not valid
- PyTorch的学习笔记-torch package
- 为了自己的仅有一次一生 拼搏一把
- 【第六届蓝桥杯】立方尾不变
- 交换机应用点滴
- 海思3518c普通串口更换RS485通讯
- Spring Cloud构建微服务架构—服务消费Ribbon
- BZOJ3566: [SHOI2014]概率充电器(概率DP+容斥)
- 安卓自定义popupMenu样式
- MyEclipse Maven Spring Boot mybatis freemarker 配置实例DEMO