Perplexity
来源:互联网 发布:java 回文字符串 编辑:程序博客网 时间:2024/05/21 15:27
其实就是求一个batch里,平均每个词的p(x)
class Perplexity(nn.Module): def __init__(self): super(Perplexity, self).__init__() def forward(self, logits, target): """ :param logits: tensor with shape of [batch_size, seq_len, input_size] :param target: tensor with shape of [batch_size, seq_len] of Long type filled with indexes to gather from logits :return: tensor with shape of [batch_size] with perplexity evaluation """ #将一个句子里每个step真正词对应的log_prob相加取平均,再在前面加上一个负号 [batch_size, seq_len, input_size] = logits.size() logits = logits.view(-1, input_size) log_probs = F.log_softmax(logits) del logits log_probs = log_probs.view(batch_size, seq_len, input_size) #在第二维加一个维度 target = target.unsqueeze(2) #从log_probs里根据index挑选出指定的element out = t.gather(log_probs, dim=2, index=target).squeeze(2).neg() #在sequence方向取平均值,也就是每个词的平均exp(log_likely_h) ppl = out.mean(1).exp() return ppl
basic VAE里 KL的计算
#t.pow(mu, 2)求mu的2次方kld = (-0.5 * t.sum(logvar - t.pow(mu, 2) - t.exp(logvar) + 1, 1)).mean().squeeze()
另一种方式
nll_per_word = self.mle_loss(output, targets) avg_lengths = tf.cast(tf.reduce_mean(self.lengths), tf.float32) #negtive loglik. 注意这里除以的是batch的长度,没有包含所有句子的长度 self.nll = tf.reduce_sum(nll_per_word) / cfg.batch_size self.perplexity = tf.exp(self.nll/avg_lengths)
阅读全文
0 0
- perplexity
- Perplexity
- Perplexity详解
- Perplexity定义
- language model perplexity计算
- LDA perplexity计算
- 概率分布的 perplexity
- Perplexity(困惑度)
- 语言模型的评估-Perplexity
- 语言模型评价指标Perplexity
- Cross Entropy and Perplexity in NLP
- LDA主题模型评估方法--Perplexity
- LDA主题模型评估方法--Perplexity
- LDA主题模型评估方法--Perplexity
- LDA主题模型评估方法--Perplexity
- linux crontab
- angular内置服务interval和timeout
- Exception Handling in Java
- 谷歌又新增了两个意想不到的功能
- php基础
- Perplexity
- BZOJ 2006: [NOI2010]超级钢琴
- CentOS 6.5 下Nginx的配置
- Emacs
- 用户空间和内核空间,进程上下文
- html标签
- LinkedBlockingDeque
- 使用c3p0链接数据库
- 【JavaScript】随笔2