diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index fb616709b46..4ea4a24eea9 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -352,7 +352,7 @@ class _Reduce(PrimitiveWithInfer): if np_reduce_func is not None: value = input_x['value'].asnumpy() - if not axis_v: + if not axis_v and axis_v != 0: axis_v = [i for i in range(len(input_x['shape']))] axis_v = tuple(axis_v) value = np_reduce_func(value, axis_v, keepdims=self.keep_dims)