pytorch-parameter initialization
来源:互联网 发布:酷q机器人php源码 编辑:程序博客网 时间:2024/06/10 00:14
torch.nn.init
weight.data.fill_(1)
bias.data.fill_(0)
weight.data.uniform_(-stdv, stdv)
1.
params = list(net.parameters())
2.
conv2params = list(net.conv2.parameters())
kernels conv2params[0]
bias conv2params[1]
3.
for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_()
4.
def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0)
5.
def weight_init(m): if isinstance(m, nn.Linear):size = m.weight.size()fan_out = size[0] # number of rowsfan_in = size[1] # number of columnsvariance = np.sqrt(2.0/(fan_in + fan_out))m.weight.data.normal_(0.0, variance)
net = Residual() # generate an instance network from the Net classnet.apply(weights_init) # apply weight init
The apply
function will search recursively for all the modules inside your network, and will call the function on each of them. So allLinear
layers you have in your model will be initialized using this one call.
---------------------------------------------------reference--------------------------------
1. https://discuss.pytorch.org/t/weight-initilzation/157/2
阅读全文
0 0
- pytorch-parameter initialization
- FILESYSTEMIO_OPTIONS Initialization Parameter
- OCP-043 STATISTICS_LEVEL initialization parameter
- Initialization Parameter files: PFILEs vs. SPFILEs
- Initialization Parameter files: PFILEs vs. SPFILEs
- A trap of parameter ‘size_average’ in pytorch 详解
- ORA-02095: specified initialization parameter cannot be modified
- 117.View the Exhibit and examine the initialization parameter settings. Which three initialization p
- PyTorch
- PyTorch
- PyTorch
- pytorch
- pytorch
- Pytorch
- initialization
- Initialization
- Initialization
- initialization
- Netty私有栈协议
- ZipException
- js中常用的遍历函数
- 网卡驱动
- 用Google的gflags优雅的解析命令行参数(一)
- pytorch-parameter initialization
- 远程登录 Linux 服务器报错:Permission denied, please try again.
- vmware fusion安装centOS7出现:Network is unreachable解决办法
- 山东省第八届ACM省赛I.Parity check
- 01-复杂度1 最大子列和问题 (20分)
- 试用rest-assured
- bzoj3126[Usaco2013 Open]Photo 单调队列+dp
- 06-4-用户及组管理初步
- 地理配准后,结果为何无法和底图重合