您的位置:首页 > 运维架构

torch.nn.CrossEntropyLoss的相关

2018-10-13 20:22 603 查看

 参数可以看下面

[code]class CrossEntropyLoss(_WeightedLoss):

def __init__(self, weight=None, size_average=True, ignore_index=-100, reduce=True):
pass
def forward(self, input, target):
pass

解释:网上的一些解释,该损失的计算公式:https://blog.csdn.net/tmk_01/article/details/80839810

这里面的公式的2个例子举的很好,该例子是判断每一个实例的分类,比如图片是属于哪一个类的。所以每一个实例的真实标签是一维的。做不一样的任务,实例的标签也是不一样的,有1维或者是多维的。举例:如果输入的是图像,任务是分类每个图像对应一个类别即数字是一维的,总共有class个类别,那么input=【n,c,h,w】,target=【n】是一维的,元素个数为n;如果是做语义分割任务是得到一张【h,w】的语义图,那么input=【n,c,h,w】,target=【n,h,w】。

[code]r"""This criterion combines :func:`nn.LogSoftmax` and :func:`nn.NLLLoss` in one single class.

It is useful when training a classification problem with `C` classes.
If provided, the optional argument :attr:`weight` should be a 1D `Tensor`
assigning weight to each of the classes.
This is particularly useful when you have an unbalanced training set.

The `input` is expected to contain scores for each class.

`input` has to be a Tensor of size either :math:`(minibatch, C)` or
:math:`(minibatch, C, d_1, d_2, ..., d_K)`
with :math:`K \geq 2` for the `K`-dimensional case (described later).

This criterion expects a class index (0 to `C-1`) as the
`target` for each value of a 1D tensor of size `minibatch`

The loss can be described as:

.. math::
loss(x, class) = -x[class] + log(sum_j (exp(x[j])))

or in the case of the `weight` argument being specified:

.. math::
loss(x, class) = weight[class] (-x[class] + log(sum_j (exp(x[j]))))

The losses are averaged across observations for each minibatch.

Can also be used for higher dimension inputs, such as 2D images, by providing
an input of size :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 2`,
where :math:`K` is the number of dimensions, and a target of appropriate shape
(see below).
"""

参数:

[code]"""

Args:
weight (Tensor, optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of size `C`
size_average (bool, optional): By default, the losses are averaged over observations for each minibatch.
However, if the field `size_average` is set to ``False``, the losses are
instead summed for each minibatch. Ignored if reduce is ``False``.
ignore_index (int, optional): Specifies a target value that is ignored
and does not contribute to the input gradient. When `size_average` is
``True``, the loss is averaged over non-ignored targets.
reduce (bool, optional): By default, the losses are averaged or summed over
observations for each minibatch depending on `size_average`. When reduce
is ``False``, returns a loss per batch instead and ignores
size_average. Default: ``True``

Shape:
- Input: :math:`(N, C)` where `C = number of classes`, or
:math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 2`
in the case of `K`-dimensional loss.
- Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or
:math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 2` in the case of
K-dimensional loss.
- Output: scalar. If reduce is ``False``, then the same size
as the target: :math:`(N)`, or
:math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 2` in the case
of K-dimensional loss.

Examples::

>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
"""

注意上面的target的类型一定是torch.long类型.还有情况要注意:如下图

[code]import torch.nn as nn
import torch.nn.functional as F
#可以的格式
#首先要定义这个函数,也就是实例化才能使用
loss = nn.CrossEntropyLoss()(input,target)

#这个函数已经进行了定义为criterion
criterion = nn.CrossEntropyLoss()
loss = criterion(input,target)

#F函数里面有这个函数的相关定义,直接调用就可以
loss = F.cross_entropy(input,target)

#上述三个的输出结果是一样的,loss 类型为torch.Tensor是一个可求导的tensor; loss.item类型为float是python类型的常数值.

#不可以的格式
loss = nn.CrossEntropyLoss(input,target)
#得到的loss类型为torch.nn.modules.loss.CrossEntropyLoss,相当于干函数的实例化,而不是上面的tensor类型

 也就是说loss是一个可求导的tensor常数,而loss.item是一个float类型的常数,不能进行求导

讲解完了这个函数,接下来看一下具体的应用https://mp.csdn.net/postedit/82885569

阅读更多
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: