生成对抗网络(GAN)初探
来源:互联网 发布:玛丽王后知乎 编辑:程序博客网 时间:2024/05/22 14:24
参考网址:https://github.com/yunjey/pytorch-tutorial
数据参考网址: 后期补上
一、网络上的介绍
二、DCGAN,有两个神经网络
Generator:生成器。用来生成伪造的图像。生成器的输入是gauss随机噪声,输出是一张图像,随着训练epochs的增加,伪图像会越来越像真正的图像。
Discriminator:鉴别器。是鉴别一张图像的真假。输入是一张真实的图像,输出则是{0,1} 输出为0时 鉴别器认为该图像是假图像,否则反之。
三、训练办法
四、训练网络核心代码
def train(self): """Train generator and discriminator.""" fixed_noise = self.to_variable(torch.randn(self.batch_size, self.z_dim)) total_step = len(self.data_loader) for epoch in range(self.num_epochs): for i, images in enumerate(self.data_loader): #===================== Train D =====================# images = self.to_variable(images) # 生成Variable对象 batch_size = images.size(0) noise = self.to_variable(torch.randn(batch_size, self.z_dim)) # Train D to recognize real images as real. outputs = self.discriminator(images) real_loss = torch.mean((outputs - 1) ** 2) # L2 loss instead of Binary cross entropy loss (this is optional for stable training) # Train D to recognize fake images as fake. fake_images = self.generator(noise) outputs = self.discriminator(fake_images) fake_loss = torch.mean(outputs ** 2) # Backprop + optimize d_loss = real_loss + fake_loss self.reset_grad() d_loss.backward() self.d_optimizer.step() #===================== Train G =====================# noise = self.to_variable(torch.randn(batch_size, self.z_dim)) # Train G so that D recognizes G(z) as real. fake_images = self.generator(noise) outputs = self.discriminator(fake_images) g_loss = torch.mean((outputs - 1) ** 2) # Backprop + optimize self.reset_grad() g_loss.backward() self.g_optimizer.step() # print the log info if (i+1) % self.log_step == 0: print('Epoch [%d/%d], Step[%d/%d], d_real_loss: %.4f, ' 'd_fake_loss: %.4f, g_loss: %.4f' %(epoch+1, self.num_epochs, i+1, total_step, real_loss.data[0], fake_loss.data[0], g_loss.data[0])) # save the sampled images if (i+1) % self.sample_step == 0: fake_images = self.generator(fixed_noise) torchvision.utils.save_image(self.denorm(fake_images.data), os.path.join(self.sample_path, 'fake_samples-%d-%d.png' %(epoch+1, i+1))) # save the model parameters for each epoch g_path = os.path.join(self.model_path, 'generator-%d.pkl' %(epoch+1)) d_path = os.path.join(self.model_path, 'discriminator-%d.pkl' %(epoch+1)) torch.save(self.generator.state_dict(), g_path) torch.save(self.discriminator.state_dict(), d_path)
五、程序整体特性
给定gauss特征向量,就可以生成一张训练集类似的头像人脸。最终生成效果如下
阅读全文
0 0
- 生成对抗网络(GAN)初探
- 生成对抗网络(GAN)
- 生成对抗网络GAN
- GAN生成对抗网络
- 生成对抗网络-GAN
- Gan 生成对抗网络
- [生成对抗网络] GAN
- 生成对抗网络(GAN)
- 生成式对抗网络(GAN)资源
- 贝叶斯生成对抗网络(GAN)
- 7 什么是GAN(生成对抗网络)?
- 了解生成对抗网络GAN
- 浅谈GAN生成对抗网络
- pytorch GAN生成对抗网络
- GAN 生成式对抗网络
- 生成对抗网络GAN学习
- GAN—生成对抗网络
- 生成式对抗网络GAN研究进展(一)
- iOS-js互相调用
- Nginx学习(一)
- Odoo 8.0深入浅出开发教程
- Android Handler 、 Looper 、Message
- poj 3292 (前缀和)
- 生成对抗网络(GAN)初探
- WebView中给url添加cookie的值
- JEECMS源代码基本结构及相关技术简介
- Math.prototype.concat详解及二维数组扁平化方法
- ISO普及
- Spring AOP 完成日志记录
- gcc中-pthread和-lpthread的区别
- 设计模式之——原型设计模式
- 20:计算2的幂