forked from mindspore-Ecosystem/mindspore
fix_moments_clipbyvalue
This commit is contained in:
parent
78d51aa323
commit
c3ab66afd9
|
@ -936,7 +936,7 @@ class Moments(Cell):
|
|||
self.squeeze = P.Squeeze(self.axis)
|
||||
|
||||
def construct(self, x):
|
||||
tensor_dtype = x.dtype
|
||||
tensor_dtype = F.dtype(x)
|
||||
_check_input_dtype("input x", tensor_dtype, [mstype.float16, mstype.float32], self.cls_name)
|
||||
if tensor_dtype == mstype.float16:
|
||||
x = self.cast(x, mstype.float32)
|
||||
|
|
|
@ -25,6 +25,12 @@ from mindspore._checkparam import Validator as validator
|
|||
from mindspore.ops.primitive import constexpr
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_shape(input_shape, out_shape):
|
||||
if input_shape != out_shape:
|
||||
raise ValueError("Cannot broadcast the shape of x to the shape of clip_value_min or clip_value_max.")
|
||||
|
||||
|
||||
def clip_by_value(x, clip_value_min, clip_value_max):
|
||||
"""
|
||||
Clips tensor values to a specified min and max.
|
||||
|
@ -63,6 +69,7 @@ def clip_by_value(x, clip_value_min, clip_value_max):
|
|||
max_op = P.Maximum()
|
||||
x_min = min_op(x, clip_value_max)
|
||||
x_max = max_op(x_min, clip_value_min)
|
||||
_check_shape(F.shape(x), F.shape(x_max))
|
||||
return x_max
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue