Change comments about LossBase.get_loss()

This commit is contained in:
chenhaozhe 2021-07-06 11:16:28 +08:00
parent 871f6f0a88
commit 554ceb2492
2 changed files with 6 additions and 4 deletions

View File

@ -72,10 +72,12 @@ class LossBase(Cell):
def get_loss(self, x, weights=1.0):
"""
Computes the weighted loss
Computes the weighted loss.
Args:
weights: Optional `Tensor` whose rank is either 0, or the same rank as inputs, and must be broadcastable to
inputs (i.e., all dimensions must be either `1`, or the same as the corresponding inputs dimension).
weights (Union[float, Tensor]): Optional `Tensor` whose rank is either 0, or the same rank as inputs,
and must be broadcastable to inputs (i.e., all dimensions must be either `1`,
or the same as the corresponding inputs dimension).
"""
input_dtype = x.dtype
x = self.cast(x, mstype.float32)

View File

@ -72,7 +72,7 @@ class Loss(nn.Cell):
raise NotImplementedError
class CrossEntropy(LossBase):
class CrossEntropy(Loss):
"""CrossEntropy"""
def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropy, self).__init__()