pytorch使用(二)自定义网络

来源:互联网 发布:苹果手机设置2g网络 编辑:程序博客网 时间:2024/06/06 01:55

pytorch使用:目录


pytorch使用(二)自定义网络

首先参考pytorch的官方手册中关于torch.nn的说明。

1. 定义网络结构

搭建自己的网路使用class torch.nn.Module,在官方手册中有一个非常简单的例子:

import torch.nn as nnimport torch.nn.functional as Fclass Model(nn.Module):    def __init__(self):        super(Model, self).__init__()        self.conv1 = nn.Conv2d(1, 20, 5)        self.conv2 = nn.Conv2d(20, 20, 5)    def forward(self, x):       x = F.relu(self.conv1(x))       return F.relu(self.conv2(x))

这个例子定义了一个只有两层的网络Model。其中两个函数:
- 初始化函数 __init__(self)定义了具体网络有什么层,这里实际上没有决定网络的结构,也就是说将上面的例子中的self.conv1self.conv2定义的前后顺序调换是完全没有影响的。
- forward函数定义了网络的前向传播的顺序。

pytorch中具体支持的不同的层请参考官方手册torch.nn

2. 网络参数初始化

pytorch官方提供了多种初始化函数,具体参考官方手册:
- torch.nn.init.uniform(tensor, a=0, b=1)
- torch.nn.init.normal(tensor, mean=0, std=1)
- torch.nn.init.constant(tensor, val)
- torch.nn.init.xavier_uniform(tensor, gain=1)

初始化函数可以直接作用于神经网络参数

  1. 对网络的某一层参数进行初始化
import torch.nn as nnimport torch.nn.init as initconv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)init.xavier_uniform(conv1.weight)#卷积参数init.constant(conv1.bias, 0.1)#偏重
  1. 对整个网络的参数进行初始化
def weights_init(m):    if isinstance(m, nn.Conv2d):        xavier(m.weight.data)        xavier(m.bias.data)  

下面举一个例子,定义一个网络MyNet,网络由6层的卷积构成:

import torchimport torch.nn as nnimport torch.nn.init as initclass MyNet(nn.Module):    def __init__(self):        super(MyNet, self).__init__()        self.conv1 = nn.Conv2d(22, 64, 7, padding=3)        self.relu1 = nn.ReLU(inplace=True)        self.conv2 = nn.Conv2d(64, 128, 7, padding=3)        self.relu2 = nn.ReLU(inplace=True)        self.conv3 = nn.Conv2d(128, 256, 5, padding=2)        self.relu3 = nn.ReLU(inplace=True)        self.conv4 = nn.Conv2d(256, 128, 5, padding=2)        self.relu4 = nn.ReLU(inplace=True)        self.conv5= nn.Conv2d(128, 64, 3, padding=1)        self.relu5 = nn.ReLU(inplace=True)        self.conv6 = nn.Conv2d(64, 6, 3, padding=1)        self.relu6 = nn.ReLU(inplace=True)        for m in self.modules():            if isinstance(m, nn.Conv2d):                init.xavier_uniform(m.weight.data)                init.constant(m.bias.data,0.1)    def forward(self, x):        x = self.relu1(self.conv1(x))        x = self.relu2(self.conv2(x))        x = self.relu3(self.conv3(x))        x = self.relu4(self.conv4(x))        x = self.relu5(self.conv5(x))        f = self.relu6(self.conv6(x))        return f
原创粉丝点击