Pytorch 中triplet loss的写法

来源:互联网 发布:高中化学软件下载 编辑:程序博客网 时间:2024/05/22 13:56

triplet loss

在Pytorch中有一个类,已经定义好了triplet loss的criterion, class TripletMarginLoss(Module):

class TripletMarginLoss(Module):    r"""Creates a criterion that measures the triplet loss given an input    tensors x1, x2, x3 and a margin with a value greater than 0.    This is used for measuring a relative similarity between samples. A triplet    is composed by `a`, `p` and `n`: anchor, positive examples and negative    example respectively. The shape of all input variables should be    :math:`(N, D)`.    The distance swap is described in detail in the paper `Learning shallow    convolutional feature descriptors with triplet losses`_ by    V. Balntas, E. Riba et al.    Args:        anchor: anchor input tensor        positive: positive input tensor        negative: negative input tensor        p: the norm degree. Default: 2    Shape:        - Input: :math:`(N, D)` where `D = vector dimension`        - Output: :math:`(N, 1)`

使用示例:

  >>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)    >>> input1 = autograd.Variable(torch.randn(100, 128))    >>> input2 = autograd.Variable(torch.randn(100, 128))    >>> input3 = autograd.Variable(torch.randn(100, 128))    >>> output = triplet_loss(input1, input2, input3)    >>> output.backward()

参考网址

https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/loss.py

阅读全文
0 0
原创粉丝点击