pytorch创建神经网络用于分类

来源:互联网 发布:健身软件哪个好 编辑:程序博客网 时间:2024/06/05 06:21

感谢周莫烦同学的教程https://morvanzhou.github.io/tutorials/machine-learning/torch/3-03-fast-nn/

方法一:

"""View more, visit my tutorial page: https://morvanzhou.github.io/tutorials/My Youtube Channel: https://www.youtube.com/user/MorvanZhouDependencies:torch: 0.1.11matplotlib"""import torchfrom torch.autograd import Variableimport torch.nn.functional as Fimport matplotlib.pyplot as plttorch.manual_seed(1)    # reproducible# make fake datan_data = torch.ones(100, 2)# 100行 * 2列  全是 1的矩阵 [torch.FloatTensor of size 100x2]# 2*n_data 是 100行,2列的数据x0 = torch.normal(2*n_data, 1)      # class0 x data (tensor), 100行*2列 随机抽取离散数,包含横纵坐标y0 = torch.zeros(100)               # class0 y data (tensor), 100行*1列 全是0 #标签x1 = torch.normal(-2*n_data, 1)     # class1 x data (tensor), shape=(100, 2)y1 = torch.ones(100)                # class1 y data (tensor), 100行*1列,全是1x = torch.cat((x0, x1), 0).type(torch.FloatTensor)  # shape (200, 2) FloatTensor = 32-bit floating 把x0,x1竖着拼接y = torch.cat((y0, y1), ).type(torch.LongTensor)    # shape (200,) LongTensor = 64-bit integer 把y0和y1竖着拼接# torch can only train on Variable, so convert them to Variablex, y = Variable(x), Variable(y)# plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn')# plt.show()class Net(torch.nn.Module):    def __init__(self, n_feature, n_hidden, n_output):        super(Net, self).__init__()        self.hidden = torch.nn.Linear(n_feature, n_hidden)   # hidden layer        self.out = torch.nn.Linear(n_hidden, n_output)   # output layer    def forward(self, x):        x = F.relu(self.hidden(x))      # activation function for hidden layer        x = self.out(x)        return xnet = Net(n_feature=2, n_hidden=10, n_output=2)     # define the networkn_feature =2# x axis, yaxis ,2 types input#10个神经元的隐藏层#n_output 二分类 输出2种类标签print(net)  # net architectureoptimizer = torch.optim.SGD(net.parameters(), lr=0.02)loss_func = torch.nn.CrossEntropyLoss()  # the target label is NOT an one-hotted#CrossEntropyLoss适用于分类问题的损失函数plt.ion()   # something about plottingfor t in range(100):    out = net(x)                 # input x and predict based on x    loss = loss_func(out, y)     # must be (1. nn output, 2. target), the target label is NOT one-hotted    optimizer.zero_grad()   # clear gradients for next train    loss.backward()         # backpropagation, compute gradients    optimizer.step()        # apply gradients    if t % 2 == 0:        # plot and show learning process 出图的代码        plt.cla()        prediction = torch.max(F.softmax(out), 1)[1]        pred_y = prediction.data.numpy().squeeze()        target_y = y.data.numpy()        plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')        accuracy = sum(pred_y == target_y)/200.        plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color':  'red'})        plt.pause(0.1)plt.ioff()plt.show()


方法二:

"""View more, visit my tutorial page: https://morvanzhou.github.io/tutorials/My Youtube Channel: https://www.youtube.com/user/MorvanZhouDependencies:torch: 0.1.11"""import torchimport torch.nn.functional as F# replace following class code with an easy sequential networkclass Net(torch.nn.Module):    def __init__(self, n_feature, n_hidden, n_output):        super(Net, self).__init__()        self.hidden = torch.nn.Linear(n_feature, n_hidden)   # hidden layer        self.predict = torch.nn.Linear(n_hidden, n_output)   # output layer    def forward(self, x):        x = F.relu(self.hidden(x))      # activation function for hidden layer , relu是一个function        x = self.predict(x)             # linear output        return xnet1 = Net(1, 10, 1)# easy and fast way to build your networknet2 = torch.nn.Sequential(    torch.nn.Linear(1, 10),    torch.nn.ReLU(),#一个class,就是一个功能    torch.nn.Linear(10, 1))print(net1)     # net1 architecture"""Net (  (hidden): Linear (1 -> 10)  (predict): Linear (10 -> 1))"""print(net2)     # net2 architecture"""Sequential (  (0): Linear (1 -> 10)  (1): ReLU ()  (2): Linear (10 -> 1))"""

Squeantial就是表示可以在里面堆积罗列层