forked from mindspore-Ecosystem/mindspore
Change comments about LossBase.get_loss()
This commit is contained in:
parent
871f6f0a88
commit
554ceb2492
|
@ -72,10 +72,12 @@ class LossBase(Cell):
|
|||
|
||||
def get_loss(self, x, weights=1.0):
|
||||
"""
|
||||
Computes the weighted loss
|
||||
Computes the weighted loss.
|
||||
|
||||
Args:
|
||||
weights: Optional `Tensor` whose rank is either 0, or the same rank as inputs, and must be broadcastable to
|
||||
inputs (i.e., all dimensions must be either `1`, or the same as the corresponding inputs dimension).
|
||||
weights (Union[float, Tensor]): Optional `Tensor` whose rank is either 0, or the same rank as inputs,
|
||||
and must be broadcastable to inputs (i.e., all dimensions must be either `1`,
|
||||
or the same as the corresponding inputs dimension).
|
||||
"""
|
||||
input_dtype = x.dtype
|
||||
x = self.cast(x, mstype.float32)
|
||||
|
|
|
@ -72,7 +72,7 @@ class Loss(nn.Cell):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class CrossEntropy(LossBase):
|
||||
class CrossEntropy(Loss):
|
||||
"""CrossEntropy"""
|
||||
def __init__(self, smooth_factor=0., num_classes=1000):
|
||||
super(CrossEntropy, self).__init__()
|
||||
|
|
Loading…
Reference in New Issue