生成对抗网络(GAN)初探

来源:互联网 发布:玛丽王后知乎 编辑:程序博客网 时间:2024/05/22 14:24

参考网址:https://github.com/yunjey/pytorch-tutorial

数据参考网址: 后期补上

一、网络上的介绍

二、DCGAN,有两个神经网络

Generator:生成器。用来生成伪造的图像。生成器的输入是gauss随机噪声,输出是一张图像,随着训练epochs的增加,伪图像会越来越像真正的图像。

Discriminator:鉴别器。是鉴别一张图像的真假。输入是一张真实的图像,输出则是{0,1} 输出为0时 鉴别器认为该图像是假图像,否则反之。


三、训练办法


四、训练网络核心代码

def train(self):    """Train generator and discriminator."""    fixed_noise = self.to_variable(torch.randn(self.batch_size, self.z_dim))    total_step = len(self.data_loader)    for epoch in range(self.num_epochs):        for i, images in enumerate(self.data_loader):                        #===================== Train D =====================#            images = self.to_variable(images) # 生成Variable对象            batch_size = images.size(0)            noise = self.to_variable(torch.randn(batch_size, self.z_dim))                        # Train D to recognize real images as real.            outputs = self.discriminator(images)            real_loss = torch.mean((outputs - 1) ** 2)      # L2 loss instead of Binary cross entropy loss (this is optional for stable training)            # Train D to recognize fake images as fake.            fake_images = self.generator(noise)            outputs = self.discriminator(fake_images)            fake_loss = torch.mean(outputs ** 2)            # Backprop + optimize            d_loss = real_loss + fake_loss            self.reset_grad()            d_loss.backward()            self.d_optimizer.step()                        #===================== Train G =====================#            noise = self.to_variable(torch.randn(batch_size, self.z_dim))                        # Train G so that D recognizes G(z) as real.            fake_images = self.generator(noise)            outputs = self.discriminator(fake_images)            g_loss = torch.mean((outputs - 1) ** 2)            # Backprop + optimize            self.reset_grad()            g_loss.backward()            self.g_optimizer.step()            # print the log info            if (i+1) % self.log_step == 0:                print('Epoch [%d/%d], Step[%d/%d], d_real_loss: %.4f, '                       'd_fake_loss: %.4f, g_loss: %.4f'                       %(epoch+1, self.num_epochs, i+1, total_step,                         real_loss.data[0], fake_loss.data[0], g_loss.data[0]))            # save the sampled images            if (i+1) % self.sample_step == 0:                fake_images = self.generator(fixed_noise)                torchvision.utils.save_image(self.denorm(fake_images.data),                     os.path.join(self.sample_path,                                 'fake_samples-%d-%d.png' %(epoch+1, i+1)))                # save the model parameters for each epoch        g_path = os.path.join(self.model_path, 'generator-%d.pkl' %(epoch+1))        d_path = os.path.join(self.model_path, 'discriminator-%d.pkl' %(epoch+1))        torch.save(self.generator.state_dict(), g_path)        torch.save(self.discriminator.state_dict(), d_path)

五、程序整体特性

给定gauss特征向量,就可以生成一张训练集类似的头像人脸。最终生成效果如下


原创粉丝点击