forked from mindspore-Ecosystem/mindspore
!10307 add type check support to sampledsoftmaxloss
From: @TFbunny Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghui
This commit is contained in:
commit
7a873d10a1
|
@ -281,6 +281,10 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
|
|||
x = self.softmax_cross_entropy(logits, labels)[0]
|
||||
return self.get_loss(x)
|
||||
|
||||
@constexpr
|
||||
def _check_label_dtype(labels_dtype, cls_name):
|
||||
validator.check_type_name("labels", labels_dtype, [mstype.int32, mstype.int64], cls_name)
|
||||
|
||||
|
||||
class SampledSoftmaxLoss(_Loss):
|
||||
r"""
|
||||
|
@ -373,8 +377,11 @@ class SampledSoftmaxLoss(_Loss):
|
|||
self.zeros_like = P.ZerosLike()
|
||||
self.mul = P.Mul()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.dtype = P.DType()
|
||||
|
||||
def construct(self, weights, biases, labels, inputs):
|
||||
_check_label_dtype(self.dtype(labels), self.cls_name)
|
||||
|
||||
logits, labels = self._compute_sampled_logits(
|
||||
weights=weights,
|
||||
biases=biases,
|
||||
|
@ -424,6 +431,7 @@ class SampledSoftmaxLoss(_Loss):
|
|||
`[batch_size, num_true + num_sampled]`
|
||||
out_labels: A Tensor object with the same shape as `out_logits`.
|
||||
"""
|
||||
|
||||
if not labels.dtype == mstype.int32:
|
||||
labels = self.cast(labels, mstype.int32)
|
||||
labels = self.reshape(labels, (-1, num_true))
|
||||
|
|
Loading…
Reference in New Issue