From 09f1a4bbafddb39fd293edd7c716620327e94a2a Mon Sep 17 00:00:00 2001 From: buxue Date: Mon, 17 Aug 2020 17:53:02 +0800 Subject: [PATCH] support axis is None for all and any interface in graph mode. --- mindspore/_extends/parse/standard_method.py | 5 ++++- .../pipeline/infer/test_interface_all_and_any_of_tensor.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index 763a4da780d..c99024dcceb 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -44,6 +44,8 @@ def all_(x, axis=(), keep_dims=False): Tensor, has the same data type as x. """ + if axis is None: + axis = () reduce_all = P.ReduceAll(keep_dims) return reduce_all(x, axis) @@ -60,7 +62,8 @@ def any_(x, axis=(), keep_dims=False): Returns: Tensor, has the same data type as x. """ - + if axis is None: + axis = () reduce_any = P.ReduceAny(keep_dims) return reduce_any(x, axis) diff --git a/tests/ut/python/pipeline/infer/test_interface_all_and_any_of_tensor.py b/tests/ut/python/pipeline/infer/test_interface_all_and_any_of_tensor.py index 9781164038f..97bcd598bb4 100644 --- a/tests/ut/python/pipeline/infer/test_interface_all_and_any_of_tensor.py +++ b/tests/ut/python/pipeline/infer/test_interface_all_and_any_of_tensor.py @@ -28,8 +28,8 @@ def test_all_and_any_of_tensor_in_graph(): def construct(self, x): all_ = x.all() any_ = x.any() - all_0 = x.all(0, True) - any_0 = x.any(0, True) + all_0 = x.all(None, True) + any_0 = x.any(None, True) return all_, any_, all_0, any_0 net = Net()