pytorch-parameter initialization

来源:互联网 发布:酷q机器人php源码 编辑:程序博客网 时间:2024/06/10 00:14


torch.nn.init

weight.data.fill_(1)
bias.data.fill_(0)

weight.data.uniform_(-stdv, stdv)

1. 

params = list(net.parameters())


2. 

conv2params = list(net.conv2.parameters())

kernels  conv2params[0]

bias conv2params[1]


3.

        for m in self.modules():            if isinstance(m, nn.Conv2d):                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels                m.weight.data.normal_(0, math.sqrt(2. / n))            elif isinstance(m, nn.BatchNorm2d):                m.weight.data.fill_(1)                m.bias.data.zero_()


4.

def weights_init(m):    classname = m.__class__.__name__    if classname.find('Conv') != -1:        m.weight.data.normal_(0.0, 0.02)    elif classname.find('BatchNorm') != -1:        m.weight.data.normal_(1.0, 0.02)        m.bias.data.fill_(0)

5. 

def weight_init(m): if isinstance(m, nn.Linear):size = m.weight.size()fan_out = size[0] # number of rowsfan_in = size[1] # number of columnsvariance = np.sqrt(2.0/(fan_in + fan_out))m.weight.data.normal_(0.0, variance)

net = Residual() # generate an instance network from the Net classnet.apply(weights_init) # apply weight init

The apply function will search recursively for all the modules inside your network, and will call the function on each of them. So allLinear layers you have in your model will be initialized using this one call.






---------------------------------------------------reference--------------------------------

1. https://discuss.pytorch.org/t/weight-initilzation/157/2