DARLA 源码解析

来源:互联网 发布:单片机地址寄存器 编辑:程序博客网 时间:2024/06/05 19:14

DARLA 源码解析

标签(空格分隔): 增强学习算法 源码


'''Implementation of DARLA preprocessing, as found in DARLA: Improving Zero-Shot Transfer in Reinforcement Learningby Higgins and Pal et al (https://arxiv.org/pdf/1707.08475.pdf):DAE:X_noisy --J--> Z ----> X_hatminimizing (X_noisy-X_hat)^2Beta VAE:X ----> Z ----> X_hatminimizing (J(X) - J(X_hat))^2 + beta*KL(Q(Z|X) || P(Z))Right now this just trains the model using MNIST datasetpytorch version'''import torchimport torch.nn as nnimport torchvision.datasets as dsetsfrom torchvision import datasets, transformsfrom torch.autograd import Variablefrom utils import *from torch.nn import functional as Fclass DAE(nn.Module):    def __init__(self):        super(DAE, self).__init__()        self.image_dim = 28 # a 28x28 image corresponds to 4 on the FC layer, a 64x64 image corresponds to 13                            # can calculate this using output_after_conv() in utils.py        self.latent_dim = 100  #论文中的特征表示        self.noise_scale = 0.001        self.batch_size = 50           #########编码器#################        self.encoder = nn.Sequential(            nn.Conv2d(1, 32, kernel_size=4, stride=1),            nn.ReLU(),            nn.Conv2d(32, 32, kernel_size=4, stride=1),            nn.ReLU(),            nn.Conv2d(32, 32, kernel_size=4, stride=2),            nn.ReLU(),            nn.Conv2d(32, 32, kernel_size=4, stride=2),            nn.ReLU())        self.fc1 = nn.Linear(32*4*4, self.latent_dim)        ###########解码器################        self.fc2 = nn.Linear(self.latent_dim, 32*4*4)        self.decoder = nn.Sequential(            nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2),            nn.ReLU(),            nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2),            nn.ReLU(),            nn.ConvTranspose2d(32, 32, kernel_size=4, stride=1),            nn.ReLU(),            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=1),            nn.ReLU())    def forward(self, x):    #encode-decode 前向传播 返回特征表示和解码输出值        x = torch.add(x, Variable(self.noise_scale*torch.randn(self.batch_size, 1, self.image_dim, self.image_dim)))#        z = self.encoder(x)        z = z.view(-1, 32*4*4)        z = self.fc1(z)        x_hat = self.fc2(z)        x_hat = x_hat.view(-1, 32, 4, 4)        x_hat = self.decoder(x_hat)        return z, x_hat    def encode(self, x):    #返回特征表示        #x = x.unsqueeze(0)        z, x_hat = self.forward(x)        return zdef train_dae(num_epochs = 1, batch_size = 128, learning_rate = 1e-3):    train_dataset = dsets.MNIST(root='./data/',         #### testing that it works with MNIST data                                train=True,                                transform=transforms.ToTensor(),                                download=True)    train_loader = torch.utils.data.DataLoader(        datasets.MNIST('../data', train=True, download=True,                       transform=transforms.ToTensor()), batch_size=batch_size, shuffle=True)    dae = DAE()    dae.batch_size = batch_size    criterion = nn.MSELoss()    optimizer = torch.optim.Adam(dae.parameters(), lr=learning_rate)    for epoch in range(num_epochs):        for i, (images, labels) in enumerate(train_loader):            x = Variable(images)            # Forward + Backward + Optimize            optimizer.zero_grad()            z, x_hat = dae(x)            loss = criterion(x_hat, x)# 解码后与原图像的差别            loss.backward()            optimizer.step()            if (i + 1) % 1 == 0:                print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'                      % (epoch + 1, num_epochs, i + 1, len(train_dataset) // batch_size, loss.data[0]))        torch.save(dae.state_dict(), 'dae-test-model.pkl')#在编码过程中输出两个值,一个是特征表示的均值,一个是特征表示的方差class BetaVAE(nn.Module):    def __init__(self):        super(BetaVAE, self).__init__()        self.image_dim = 28 # a 28x28 image corresponds to 4 on the FC layer, a 64x64 image corresponds to 13                            # can calculate this using output_after_conv() in utils.py        self.latent_dim = 100        self.batch_size = 50        self.encoder = nn.Sequential(            nn.Conv2d(1, 32, kernel_size=4, stride=1),            nn.ReLU(),            nn.Conv2d(32, 32, kernel_size=4, stride=1),            nn.ReLU(),            nn.Conv2d(32, 32, kernel_size=4, stride=2),            nn.ReLU(),            nn.Conv2d(32, 32, kernel_size=4, stride=2),            nn.ReLU())        self.fc_mu = nn.Linear(32*4*4, self.latent_dim)        self.fc_sigma = nn.Linear(32 * 4 * 4, self.latent_dim)        self.fc_up = nn.Linear(self.latent_dim, 32*4*4)        self.decoder = nn.Sequential(            nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2),            nn.ReLU(),            nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2),            nn.ReLU(),            nn.ConvTranspose2d(32, 32, kernel_size=4, stride=1),            nn.ReLU(),            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=1),            nn.ReLU())    def forward(self, x):        z = self.encoder(x)        z = z.view(-1, 32*4*4)        mu_z = self.fc_mu(z)        log_sigma_z = self.fc_sigma(z)        #通过方差和均值加上随机噪声进行采样,当作特征表示        sample_z = mu_z + log_sigma_z.exp()*Variable(torch.randn(self.batch_size, self.latent_dim))        x_hat = self.fc_up(sample_z)        x_hat = x_hat.view(-1, 32, 4, 4)        x_hat = self.decoder(x_hat)        return mu_z, log_sigma_z, x_hatdef bvae_loss_function(z_hat, z, mu, logvar, beta=1, batch_size=128):    RCL = F.mse_loss(z, z_hat) #reconstruction loss dae 和bave特征表示之间的loss    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) #KL divergence    # Normalise by same number of elements as in reconstruction    KLD /= batch_size    return RCL + beta*KLDdef train_bvae(num_epochs = 100, batch_size = 128, learning_rate = 1e-4):    train_dataset = dsets.MNIST(root='./data/',  #### testing that it works with MNIST data                                train=True,                                transform=transforms.ToTensor(),                                download=True)    train_loader = torch.utils.data.DataLoader(        datasets.MNIST('../data', train=True, download=True,                       transform=transforms.ToTensor()), batch_size=batch_size, shuffle=True)    bvae = BetaVAE()    bvae.batch_size = batch_size    dae = DAE()    dae.load_state_dict(torch.load('dae-test-model.pkl'))    dae.batch_size = batch_size    dae.eval()    optimizer = torch.optim.Adam(bvae.parameters(), lr=learning_rate)    for epoch in range(num_epochs):        for i, (images, labels) in enumerate(train_loader):            x = Variable(images)            # Forward + Backward + Optimize            optimizer.zero_grad()            mu_z, log_sigma_z, x_hat = bvae(x)            #bvae 的损失函数            loss = bvae_loss_function(dae.encode(x_hat), dae.encode(x), mu_z, 2*log_sigma_z, batch_size=batch_size)            loss.backward()            optimizer.step()            if (i + 1) % 1 == 0:                print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'                      % (epoch + 1, num_epochs, i + 1, len(train_dataset) // batch_size, loss.data[0]))        torch.save(bvae.state_dict(), 'bvae-test-model.pkl')if __name__ == '__main__':    train_dae()   #预训练 DAE(普通自编码器)    train_bvae()  #训练bave网络    dae = DAE()    dae.load_state_dict(torch.load('dae-test-model.pkl'))    x = Variable(torch.randn(1, 1, 28, 28))    bvae = BetaVAE()    m,s,x_hat = bvae(x)    print(m.size())    z = dae.encode(x)    print(z.size())    print(x_hat.size())    dae = DAE()    x = Variable(torch.randn(1,1,28,28))    z, x_hat = dae(x)
原创粉丝点击