pytorch学习笔记(十二):详解 Module 类
来源:互联网 发布:linux编程中exit() 编辑:程序博客网 时间:2024/06/05 12:44
Module 是 pytorch
提供的一个基类,每次我们要 搭建 自己的神经网络的时候都要继承这个类,继承这个类会使得我们 搭建网络的过程变得异常简单。
本文主要关注 Module
类的内部是怎么样的。
初始化方法中做了什么
def __init__(self): self._backend = thnn_backend self._parameters = OrderedDict() self._buffers = OrderedDict() self._backward_hooks = OrderedDict() self._forward_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict() self._modules = OrderedDict() self.training = True
这是 Module
的初始化方法:
self._parameters
用来存放注册的Parameter
对象self._buffers
用来存放注册的Buffer
对象。(pytorch 中 buffer 的概念就是 不需要反向传导更新的值)self._modules
用来保存注册的Module
对象。self.training
标志位,用来表示是不是在training
状态下...hooks
用来保存 注册的hook
__setattr__
与 __getattr__
__setattr__
每次给属性赋值的时候,都会调用这个方法。
__setattr__
的代码比较多,我们一点一点看。
remove_from
:工具函数, 用来从self.__dict__, self._buffers, self._modules
中删除对象。
第一种情况: value
的类型是 Paramter
- 从 三大 字典中将 同名的 对象删掉
- 然后,注册
paramter
第二种情况: value
不是 Parameter
对象, name
在 self._parameter
中
self._parameters[name] = None
已经考虑了 value
是 Parameter
对象,剩下的就是考虑 value
为 buffer
或 Module
了
第三种情况:value
不是 Parameter
对象, value
为 Module
对象
- 从三大字典里面移除同名 对象
- 然后直接向
self._modules
字典里添加value
第四种情况:value
不是Parameter
对象, value
不为 Module
对象, 但是 name
在 self._modules
里
self._modules[name]=None
第五种情况:value
不是Parameter
对象, value
不为 Module
对象, name
存在 self._buffers
里
self._buffers[name]=None
最后一种情况: 就是 普通的属性了。
def __setattr__(self, name, value): def remove_from(*dicts): for d in dicts: if name in d: del d[name] params = self.__dict__.get('_parameters') if isinstance(value, Parameter): if params is None: raise AttributeError( "cannot assign parameters before Module.__init__() call") remove_from(self.__dict__, self._buffers, self._modules) self.register_parameter(name, value) elif params is not None and name in params: if value is not None: raise TypeError("cannot assign '{}' as parameter '{}' " "(torch.nn.Parameter or None expected)" .format(torch.typename(value), name)) self.register_parameter(name, value) else: modules = self.__dict__.get('_modules') if isinstance(value, Module): if modules is None: raise AttributeError( "cannot assign module before Module.__init__() call") remove_from(self.__dict__, self._parameters, self._buffers) modules[name] = value elif modules is not None and name in modules: if value is not None: raise TypeError("cannot assign '{}' as child module '{}' " "(torch.nn.Module or None expected)" .format(torch.typename(value), name)) modules[name] = value else: buffers = self.__dict__.get('_buffers') if buffers is not None and name in buffers: if value is not None and not torch.is_tensor(value): raise TypeError("cannot assign '{}' as buffer '{}' " "(torch.Tensor or None expected)" .format(torch.typename(value), name)) buffers[name] = value else: object.__setattr__(self, name, value)
__getattr__
: 当获取self.__dict__
中没有的键所对应的值的时候,就会调用这个方法因为
parameter, module, buffer
的键值对存在与self._parameters, self._modules, self.buffer
中,所以,当想获取这些 值时, 就会调用这个方法。
def __getattr__(self, name): if '_parameters' in self.__dict__: _parameters = self.__dict__['_parameters'] if name in _parameters: return _parameters[name] if '_buffers' in self.__dict__: _buffers = self.__dict__['_buffers'] if name in _buffers: return _buffers[name] if '_modules' in self.__dict__: modules = self.__dict__['_modules'] if name in modules: return modules[name] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, name))
register_parameter
向模型中注册 Parameter
def register_parameter(self, name, param): """Adds a parameter to the module. The parameter can be accessed as an attribute using given name. """ if '_parameters' not in self.__dict__: raise AttributeError( "cannot assign parameter before Module.__init__() call") if param is None: self._parameters[name] = None elif not isinstance(param, Parameter): raise TypeError("cannot assign '{}' object to parameter '{}' " "(torch.nn.Parameter or None required)" .format(torch.typename(param), name)) elif param.grad_fn: raise ValueError( "Cannot assign non-leaf Variable to parameter '{0}'. Model " "parameters must be created explicitly. To express '{0}' " "as a function of another variable, compute the value in " "the forward() method.".format(name)) else: self._parameters[name] = param
Module.training 标志 如何影响 前向过程
从nn.Dropout
来看 Module.training
class Dropout(Module): def __init__(self, p=0.5, inplace=False): super(Dropout, self).__init__() if p < 0 or p > 1: raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) self.p = p self.inplace = inplace def forward(self, input): return F.dropout(input, self.p, self.training, self.inplace)
可以看出,在forward
过程中,直接获取,父类的training
的值。
我们 通常通过 module.train()
和 module.eval()
来切换模型的 训练测试阶段。
def train(self, mode=True): """Sets the module in training mode. This has any effect only on modules such as Dropout or BatchNorm. """ self.training = mode for module in self.children(): # 递归调用子模块 train 函数, 来设定所有 module 的 training 值。 module.train(mode) return self
需要注意的是:module.eval()
仅仅设置 module
的 training
属性,如果我们想获得最快的推断速度, 还需要 设置 输入 Variable
的volatile
属性为 True
。
参考资料
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
- pytorch学习笔记(十二):详解 Module 类
- pytorch 学习笔记之自定义 Module
- Pytorch 学习笔记之自定义 Module
- pytorch学习笔记(1)--pytorch张量
- Pytorch学习笔记(一)
- Pytorch学习笔记(二)
- Pytorch学习笔记(三)
- Pytorch学习笔记(四)
- Pytorch学习笔记(五)
- Pytorch学习笔记(六)
- C学习笔记(十二)函数详解
- pytorch学习笔记(九):PyTorch结构介绍
- pytorch学习笔记(九):PyTorch结构介绍
- Pytorch学习笔记(一):pytorch的安装-Ubuntu14.04
- pytorch学习笔记(十七):python 端扩展 pytorch
- pytorch学习笔记(十八):C 语言扩展 pytorch
- pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解
- pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解
- Solve the Equation问题及解法
- 第958期机器学习日报(2017-05-03)
- 第七周 负数把正数移出队列
- 第957期机器学习日报(2017-05-02)
- 第959期机器学习日报(2017-05-04)
- pytorch学习笔记(十二):详解 Module 类
- 前端三板斧,后端两条路
- Doubango_SIP协议栈参考资料
- jupyter 常用快捷键
- 秒杀的架构设计和解决方案
- Mapstruct简介和简单例子
- 第八周 项目2.1 试编写算法实现将字符串S中所有值为c1的字符换成值为c2的字符: void Trans(SqString *&s, char c1, char c2);
- C++学习之路(14)---vector报错:vector iterator not incrementable"或"vector iterator + offset out of range"
- 我与网易云音乐 之 【零】沙暴送赞