pytorch使用(二)自定义网络
来源:互联网 发布:苹果手机设置2g网络 编辑:程序博客网 时间:2024/06/06 01:55
pytorch使用:目录
pytorch使用(二)自定义网络
首先参考pytorch的官方手册中关于torch.nn的说明。
1. 定义网络结构
搭建自己的网路使用class torch.nn.Module
,在官方手册中有一个非常简单的例子:
import torch.nn as nnimport torch.nn.functional as Fclass Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
这个例子定义了一个只有两层的网络Model。其中两个函数:
- 初始化函数 __init__(self)
定义了具体网络有什么层,这里实际上没有决定网络的结构,也就是说将上面的例子中的self.conv1
和self.conv2
定义的前后顺序调换是完全没有影响的。
- forward
函数定义了网络的前向传播的顺序。
pytorch中具体支持的不同的层请参考官方手册torch.nn
2. 网络参数初始化
pytorch官方提供了多种初始化函数,具体参考官方手册:
- torch.nn.init.uniform(tensor, a=0, b=1)
- torch.nn.init.normal(tensor, mean=0, std=1)
- torch.nn.init.constant(tensor, val)
- torch.nn.init.xavier_uniform(tensor, gain=1)
初始化函数可以直接作用于神经网络参数
- 对网络的某一层参数进行初始化
import torch.nn as nnimport torch.nn.init as initconv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)init.xavier_uniform(conv1.weight)#卷积参数init.constant(conv1.bias, 0.1)#偏重
- 对整个网络的参数进行初始化
def weights_init(m): if isinstance(m, nn.Conv2d): xavier(m.weight.data) xavier(m.bias.data)
下面举一个例子,定义一个网络MyNet
,网络由6层的卷积构成:
import torchimport torch.nn as nnimport torch.nn.init as initclass MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.conv1 = nn.Conv2d(22, 64, 7, padding=3) self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(64, 128, 7, padding=3) self.relu2 = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(128, 256, 5, padding=2) self.relu3 = nn.ReLU(inplace=True) self.conv4 = nn.Conv2d(256, 128, 5, padding=2) self.relu4 = nn.ReLU(inplace=True) self.conv5= nn.Conv2d(128, 64, 3, padding=1) self.relu5 = nn.ReLU(inplace=True) self.conv6 = nn.Conv2d(64, 6, 3, padding=1) self.relu6 = nn.ReLU(inplace=True) for m in self.modules(): if isinstance(m, nn.Conv2d): init.xavier_uniform(m.weight.data) init.constant(m.bias.data,0.1) def forward(self, x): x = self.relu1(self.conv1(x)) x = self.relu2(self.conv2(x)) x = self.relu3(self.conv3(x)) x = self.relu4(self.conv4(x)) x = self.relu5(self.conv5(x)) f = self.relu6(self.conv6(x)) return f
阅读全文
0 0
- pytorch使用(二)自定义网络
- PyTorch(二)——搭建和自定义网络
- pytorch使用(三)网络结构可视化
- pytorch使用(四)训练网络
- Pytorch学习笔记(二)
- pytorch学习笔记(六):自定义Datasets
- pytorch学习笔记(六):自定义Datasets
- pytorch 使用
- pytorch学习笔记(二):gradient
- pytorch学习笔记(二):gradient
- pytorch学习笔记(二):gradient
- Pytorch学习入门(二)--- Autograd
- pytorch学习笔记(二):gradient
- 我在读pyTorch文档(二)
- pytorch自定义Dataset并使用torchvision的Transform
- 网络请求框架(二):volley使用之自定义请求
- pytorch入门(二)——自动求导函数
- PyTorch学习总结(二)——基于torch.utils.ffi的自定义C扩展
- jstl常用标签
- CentOS安装MySQL-5.6.23
- bss,data,text段
- Spring框架搭建
- 环境变量配置(Java、Python、Tomcat、Maven)
- pytorch使用(二)自定义网络
- java 集合HashSet
- 线性分类器定义和局限性
- intent.putExtra()数组 传入数组名称!不含[]
- C++ IOCP2
- 最大子段和
- 设计模式遵循的七大原则
- C语言小游戏入门之三子棋
- 关于Maven项目build时出现No compiler is provided in this environment的处理