fix_moments_clipbyvalue

This commit is contained in:
jiangzhenguang 2021-01-18 20:42:32 +08:00
parent 78d51aa323
commit c3ab66afd9
2 changed files with 8 additions and 1 deletions

View File

@ -936,7 +936,7 @@ class Moments(Cell):
self.squeeze = P.Squeeze(self.axis)
def construct(self, x):
tensor_dtype = x.dtype
tensor_dtype = F.dtype(x)
_check_input_dtype("input x", tensor_dtype, [mstype.float16, mstype.float32], self.cls_name)
if tensor_dtype == mstype.float16:
x = self.cast(x, mstype.float32)

View File

@ -25,6 +25,12 @@ from mindspore._checkparam import Validator as validator
from mindspore.ops.primitive import constexpr
@constexpr
def _check_shape(input_shape, out_shape):
if input_shape != out_shape:
raise ValueError("Cannot broadcast the shape of x to the shape of clip_value_min or clip_value_max.")
def clip_by_value(x, clip_value_min, clip_value_max):
"""
Clips tensor values to a specified min and max.
@ -63,6 +69,7 @@ def clip_by_value(x, clip_value_min, clip_value_max):
max_op = P.Maximum()
x_min = min_op(x, clip_value_max)
x_max = max_op(x_min, clip_value_min)
_check_shape(F.shape(x), F.shape(x_max))
return x_max