Pytorch入门学习(八)-----自定义层的实现(甚至不可导operation的backward写法)
来源:互联网 发布:光电转换器淘宝网 编辑:程序博客网 时间:2024/06/05 16:39
总说
虽然pytorch可以自动求导,但是有时候一些操作是不可导的,这时候你需要自定义求导方式。也就是所谓的 “Extending torch.autograd”. 官网虽然给了例子,但是很简单。这里将会更好的说明。
扩展 torch.autograd
class LinearFunction(Function): # 必须是staticmethod @staticmethod # 第一个是ctx,第二个是input,其他是可选参数。 # ctx在这里类似self,ctx的属性可以在backward中调用。 def forward(ctx, input, weight, bias=None): ctx.save_for_backward(input, weight, bias) output = input.mm(weight.t()) if bias is not None: output += bias.unsqueeze(0).expand_as(output) return output @staticmethod def backward(ctx, grad_output): input, weight, bias = ctx.saved_variables grad_input = grad_weight = grad_bias = None if ctx.needs_input_grad[0]: grad_input = grad_output.mm(weight) if ctx.needs_input_grad[1]: grad_weight = grad_output.t().mm(input) if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0).squeeze(0) return grad_input, grad_weight, grad_bias
然后扩展module就很简单,需要重载 nn.Module中的__init__
和__forward__
,
class Linear(nn.Module): def __init__(self, input_features, output_features, bias=True): super(Linear, self).__init__() self.input_features = input_features self.output_features = output_features # nn.Parameter is a special kind of Variable, that will get # automatically registered as Module's parameter once it's assigned # 这个很重要! Parameters是默认需要梯度的! # as an attribute. Parameters and buffers need to be registered, or # they won't appear in .parameters() (doesn't apply to buffers), and # won't be converted when e.g. .cuda() is called. You can use # .register_buffer() to register buffers. # nn.Parameters can never be volatile and, different than Variables, # they require gradients by default. self.weight = nn.Parameter(torch.Tensor(output_features, input_features)) if bias: self.bias = nn.Parameter(torch.Tensor(output_features)) else: # You should always register all possible parameters, but the # optional ones can be None if you want. self.register_parameter('bias', None) # Not a very smart way to initialize weights self.weight.data.uniform_(-0.1, 0.1) if bias is not None: self.bias.data.uniform_(-0.1, 0.1) def forward(self, input): # See the autograd section for explanation of what happens here. return LinearFunction.apply(input, self.weight, self.bias)
forward的说明
- 虽然说一个网络的输入是Variable形式,那么每个网络层的输出也是Variable形式。但是,当自定义autograd时,在forward中,所有的Variable参数将会转成tensor!因此这里的input也是tensor.在forward中可以任意操作。因为forward中是不需要Variable变量的,这是因为你自定义了
backward
方式。在传入forward前,autograd engine会自动将Variable
unpack成Tensor。 - ctx是context,
ctx.save_for_backward
会将他们转换为Variable形式。比如
@staticmethod def backward(ctx, grad_output): input, = ctx.saved_variables
此时input已经是需要grad的Variable了。
3. save_for_backward
只能传入Variable或是Tensor的变量,如果是其他类型的,可以用 ctx.xyz = xyz
,使其在backward
中可以用。
backward说明
grad_output是variable
其实这个没啥好说的。就是默认情况下,你拿到grad_output时候,会发现它是一个Variable,至于requires_grad是否为True,取决于你在外面调用.backward
或是.grad
时候的那个Variable是不是需要grad的。如果那个Variable是需要grad的,那么我们这里反向的grad_ouput也是requires_grad为True,那么我们甚至可以计算二阶梯度。用WGAN-GP之类的。
backward中我能一开始就.data拿出数据进行操作吗?
虽然自定义操作,但是原则上在backward中我们只能进行Variable的操作, 这显然就要求我们在backward中的操作都是可自动求导的。因为默认情况下, By default, the backward function should always work with Variable and create a proper graph (similarly to the forward function of an nn.Module).
。所以如果我们的涉及到不可导的操作,那么我们就不能在bacjward函数中创建一个正确的图,值得注意的是,我们是
自动求导是根据每个op的backward创建的graph来进行的!
我去,知道真相的我眼泪掉下来!很奇怪吧!我的最初的想法就是forward记录input如何被操作,然后backward就会自动反向,根据forward创建的图进行!然而,当你print(type(input))
时你竟然发现类型是Tensor,根本不是Variable!那怎么记录graph?然而真实情况竟然是自动求导竟然是在backward的操作中创建的图!
这也就是为什么我们需要在backward中用全部用variable来操作,而forward就没必要,forward只需要用tensor操作就可以。
non-differential操作的backward怎么写?
当然你想个近似算法,弄出来,一般直接backward中第一句就是grad_output = grad_output.data
,这样我们就无法进行创建正确的graph了。
#加一个这个from torch.autograd.function import once_differentiable@staticmethod@once_differentiabledef backward(ctx, grad_output): print(type(grad_output))# 此时你会惊奇的发现,竟然是Tensor了! # 做点其他的操作得到grad_output_changed grad_input = grad_output_changed return grad_input
因为我们在backward中已经是直接拿出data进行操作的了,所以我们直接得到Tensor类型返回就行!!
- Pytorch入门学习(八)-----自定义层的实现(甚至不可导operation的backward写法)
- pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解
- pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解
- pytorch学习笔记(十三):backward过程的底层实现解析
- Pytorch的backward()相关理解
- Pytorch入门学习(三)---- NN包的使用
- Pytorch入门学习(四)---- 多GPU的使用
- PyTorch入门学习(一)
- Pytorch学习笔记(一):pytorch的安装-Ubuntu14.04
- 基于PyTorch的深度学习入门教程(八)——图像风格迁移
- 学习指针不可少的好文章 ( 八 )
- Pytorch学习入门(一)--- 从torch7跳坑至pytorch --- Tensor
- pytorch学习笔记(六):自定义Datasets
- pytorch学习笔记(六):自定义Datasets
- Pytorch学习入门(二)--- Autograd
- Pytorch入门学习(七)---- 数据加载和预处理的通用方法
- Pytorch入门学习(九)---detach()的作用(从GAN代码分析)
- 深度学习实践操作—从小白到大白(八):安装Pytorch到特定的Anaconda环境
- atom快捷键
- 创建类Student和对象
- Java基础知识汇总
- HDU 6172 Array Challenge 【线性递推式模板】
- 有记忆的电路——时序逻辑电路
- Pytorch入门学习(八)-----自定义层的实现(甚至不可导operation的backward写法)
- Apache Tomcat RCE if readonly set to false (CVE-2017-12617)
- 创建对象中的一些
- 剑指offer---和为S的两个数
- MySQL之Got fatal error 1236 from master when reading data from binary log
- python列表使用
- C/C++程序训练6---歌德巴赫猜想的证明
- 进度条
- Subscript of sum