Perplexity

来源:互联网 发布:java 回文字符串 编辑:程序博客网 时间:2024/05/21 15:27

其实就是求一个batch里,平均每个词的p(x)

class Perplexity(nn.Module):    def __init__(self):        super(Perplexity, self).__init__()    def forward(self, logits, target):        """        :param logits: tensor with shape of [batch_size, seq_len, input_size]        :param target: tensor with shape of [batch_size, seq_len] of Long type filled with indexes to gather from logits        :return: tensor with shape of [batch_size] with perplexity evaluation        """        #将一个句子里每个step真正词对应的log_prob相加取平均,再在前面加上一个负号        [batch_size, seq_len, input_size] = logits.size()        logits = logits.view(-1, input_size)        log_probs = F.log_softmax(logits)        del logits        log_probs = log_probs.view(batch_size, seq_len, input_size)        #在第二维加一个维度        target = target.unsqueeze(2)        #从log_probs里根据index挑选出指定的element        out = t.gather(log_probs, dim=2, index=target).squeeze(2).neg()        #在sequence方向取平均值,也就是每个词的平均exp(log_likely_h)        ppl = out.mean(1).exp()        return ppl

basic VAE里 KL的计算

#t.pow(mu, 2)求mu的2次方kld = (-0.5 * t.sum(logvar - t.pow(mu, 2) - t.exp(logvar) + 1, 1)).mean().squeeze()

另一种方式

            nll_per_word = self.mle_loss(output, targets)            avg_lengths = tf.cast(tf.reduce_mean(self.lengths), tf.float32)            #negtive loglik. 注意这里除以的是batch的长度,没有包含所有句子的长度            self.nll = tf.reduce_sum(nll_per_word) / cfg.batch_size            self.perplexity = tf.exp(self.nll/avg_lengths)
原创粉丝点击