DARLA 源码解析
来源:互联网 发布:单片机地址寄存器 编辑:程序博客网 时间:2024/06/05 19:14
DARLA 源码解析
标签(空格分隔): 增强学习算法 源码
'''Implementation of DARLA preprocessing, as found in DARLA: Improving Zero-Shot Transfer in Reinforcement Learningby Higgins and Pal et al (https://arxiv.org/pdf/1707.08475.pdf):DAE:X_noisy --J--> Z ----> X_hatminimizing (X_noisy-X_hat)^2Beta VAE:X ----> Z ----> X_hatminimizing (J(X) - J(X_hat))^2 + beta*KL(Q(Z|X) || P(Z))Right now this just trains the model using MNIST datasetpytorch version'''import torchimport torch.nn as nnimport torchvision.datasets as dsetsfrom torchvision import datasets, transformsfrom torch.autograd import Variablefrom utils import *from torch.nn import functional as Fclass DAE(nn.Module): def __init__(self): super(DAE, self).__init__() self.image_dim = 28 # a 28x28 image corresponds to 4 on the FC layer, a 64x64 image corresponds to 13 # can calculate this using output_after_conv() in utils.py self.latent_dim = 100 #论文中的特征表示 self.noise_scale = 0.001 self.batch_size = 50 #########编码器################# self.encoder = nn.Sequential( nn.Conv2d(1, 32, kernel_size=4, stride=1), nn.ReLU(), nn.Conv2d(32, 32, kernel_size=4, stride=1), nn.ReLU(), nn.Conv2d(32, 32, kernel_size=4, stride=2), nn.ReLU(), nn.Conv2d(32, 32, kernel_size=4, stride=2), nn.ReLU()) self.fc1 = nn.Linear(32*4*4, self.latent_dim) ###########解码器################ self.fc2 = nn.Linear(self.latent_dim, 32*4*4) self.decoder = nn.Sequential( nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2), nn.ReLU(), nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2), nn.ReLU(), nn.ConvTranspose2d(32, 32, kernel_size=4, stride=1), nn.ReLU(), nn.ConvTranspose2d(32, 1, kernel_size=4, stride=1), nn.ReLU()) def forward(self, x): #encode-decode 前向传播 返回特征表示和解码输出值 x = torch.add(x, Variable(self.noise_scale*torch.randn(self.batch_size, 1, self.image_dim, self.image_dim)))# z = self.encoder(x) z = z.view(-1, 32*4*4) z = self.fc1(z) x_hat = self.fc2(z) x_hat = x_hat.view(-1, 32, 4, 4) x_hat = self.decoder(x_hat) return z, x_hat def encode(self, x): #返回特征表示 #x = x.unsqueeze(0) z, x_hat = self.forward(x) return zdef train_dae(num_epochs = 1, batch_size = 128, learning_rate = 1e-3): train_dataset = dsets.MNIST(root='./data/', #### testing that it works with MNIST data train=True, transform=transforms.ToTensor(), download=True) train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transforms.ToTensor()), batch_size=batch_size, shuffle=True) dae = DAE() dae.batch_size = batch_size criterion = nn.MSELoss() optimizer = torch.optim.Adam(dae.parameters(), lr=learning_rate) for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_loader): x = Variable(images) # Forward + Backward + Optimize optimizer.zero_grad() z, x_hat = dae(x) loss = criterion(x_hat, x)# 解码后与原图像的差别 loss.backward() optimizer.step() if (i + 1) % 1 == 0: print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f' % (epoch + 1, num_epochs, i + 1, len(train_dataset) // batch_size, loss.data[0])) torch.save(dae.state_dict(), 'dae-test-model.pkl')#在编码过程中输出两个值,一个是特征表示的均值,一个是特征表示的方差class BetaVAE(nn.Module): def __init__(self): super(BetaVAE, self).__init__() self.image_dim = 28 # a 28x28 image corresponds to 4 on the FC layer, a 64x64 image corresponds to 13 # can calculate this using output_after_conv() in utils.py self.latent_dim = 100 self.batch_size = 50 self.encoder = nn.Sequential( nn.Conv2d(1, 32, kernel_size=4, stride=1), nn.ReLU(), nn.Conv2d(32, 32, kernel_size=4, stride=1), nn.ReLU(), nn.Conv2d(32, 32, kernel_size=4, stride=2), nn.ReLU(), nn.Conv2d(32, 32, kernel_size=4, stride=2), nn.ReLU()) self.fc_mu = nn.Linear(32*4*4, self.latent_dim) self.fc_sigma = nn.Linear(32 * 4 * 4, self.latent_dim) self.fc_up = nn.Linear(self.latent_dim, 32*4*4) self.decoder = nn.Sequential( nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2), nn.ReLU(), nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2), nn.ReLU(), nn.ConvTranspose2d(32, 32, kernel_size=4, stride=1), nn.ReLU(), nn.ConvTranspose2d(32, 1, kernel_size=4, stride=1), nn.ReLU()) def forward(self, x): z = self.encoder(x) z = z.view(-1, 32*4*4) mu_z = self.fc_mu(z) log_sigma_z = self.fc_sigma(z) #通过方差和均值加上随机噪声进行采样,当作特征表示 sample_z = mu_z + log_sigma_z.exp()*Variable(torch.randn(self.batch_size, self.latent_dim)) x_hat = self.fc_up(sample_z) x_hat = x_hat.view(-1, 32, 4, 4) x_hat = self.decoder(x_hat) return mu_z, log_sigma_z, x_hatdef bvae_loss_function(z_hat, z, mu, logvar, beta=1, batch_size=128): RCL = F.mse_loss(z, z_hat) #reconstruction loss dae 和bave特征表示之间的loss KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) #KL divergence # Normalise by same number of elements as in reconstruction KLD /= batch_size return RCL + beta*KLDdef train_bvae(num_epochs = 100, batch_size = 128, learning_rate = 1e-4): train_dataset = dsets.MNIST(root='./data/', #### testing that it works with MNIST data train=True, transform=transforms.ToTensor(), download=True) train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transforms.ToTensor()), batch_size=batch_size, shuffle=True) bvae = BetaVAE() bvae.batch_size = batch_size dae = DAE() dae.load_state_dict(torch.load('dae-test-model.pkl')) dae.batch_size = batch_size dae.eval() optimizer = torch.optim.Adam(bvae.parameters(), lr=learning_rate) for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_loader): x = Variable(images) # Forward + Backward + Optimize optimizer.zero_grad() mu_z, log_sigma_z, x_hat = bvae(x) #bvae 的损失函数 loss = bvae_loss_function(dae.encode(x_hat), dae.encode(x), mu_z, 2*log_sigma_z, batch_size=batch_size) loss.backward() optimizer.step() if (i + 1) % 1 == 0: print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f' % (epoch + 1, num_epochs, i + 1, len(train_dataset) // batch_size, loss.data[0])) torch.save(bvae.state_dict(), 'bvae-test-model.pkl')if __name__ == '__main__': train_dae() #预训练 DAE(普通自编码器) train_bvae() #训练bave网络 dae = DAE() dae.load_state_dict(torch.load('dae-test-model.pkl')) x = Variable(torch.randn(1, 1, 28, 28)) bvae = BetaVAE() m,s,x_hat = bvae(x) print(m.size()) z = dae.encode(x) print(z.size()) print(x_hat.size()) dae = DAE() x = Variable(torch.randn(1,1,28,28)) z, x_hat = dae(x)
阅读全文
0 0
- DARLA 源码解析
- 源码解析
- 源码解析
- DARLA: Improving Zero-Shot Transfer in Reinforcement Learning 阅读笔记
- 【JDk源码解析之一】ArrayList源码解析
- 【源码解析】-- ArrayList的源码解析
- EventBus源码解析(史上最全的源码解析)
- 【源码】Vector、Stack源码解析
- Sping源码解析-源码下载
- <Android源码>IntentService源码解析
- JAVA源码解析-String源码
- JAVA源码解析-ArrayList源码
- JAVA源码解析-LinkedList源码
- Spark源码-SparkContext源码解析
- Jboss源码解析
- 网页病毒源码解析
- strlen源码解析
- chrome源码解析系列
- 第11周 第三项 数据结构例程——图的遍历
- 自定义三级联动地址选择器
- WebSocket发送文字图片功能
- android wifi 的链接
- 路由音乐切换
- DARLA 源码解析
- imooc的疯狂的蚂蚁的课程《Python操作MySQL数据库》 python3+pymysql模块来操作mysql数据库
- 第13周项目1- 验证算法(3)
- 无限轮播
- 云星数据---Scala实战系列(精品版)】:Scala入门教程007-Scala数组详解006
- 第十二周项目一C/C++验证算法
- RN入门-新建rn项目
- Tp day 2
- 正则表达式