From e6296d21aeda50e269274bfe9d458fce865cf927 Mon Sep 17 00:00:00 2001 From: liangzhibo Date: Fri, 28 Oct 2022 12:24:01 +0800 Subject: [PATCH] Change value to python for tuple/list --- .../pipeline/jit/static_analysis/prim.cc | 28 +---- .../mindspore/ops/operations/array_ops.py | 4 +- .../python/mindspore/scipy/sparse/linalg.py | 4 +- .../ut/python/graph_syntax/test_constexpr.py | 107 ++++++++++++++++++ 4 files changed, 114 insertions(+), 29 deletions(-) create mode 100644 tests/ut/python/graph_syntax/test_constexpr.py diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 0c334deb4b5..7f8531d7ec2 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -361,10 +361,6 @@ py::object BuildValue(const ValuePtr &value_ptr) { py::object AbstractTupleValueToPython(const AbstractTuple *tuple_abs) { MS_EXCEPTION_IF_NULL(tuple_abs); - auto value = tuple_abs->BuildValue(); - if (value->isa()) { - return py::none(); - } const auto &elements = tuple_abs->elements(); size_t len = elements.size(); py::tuple value_tuple(len); @@ -446,12 +442,7 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base, bool only_conver } dic[ATTR_SHAPE] = shape_tuple; dic[ATTR_DTYPE] = dtype_tuple; - MS_EXCEPTION_IF_NULL(arg_tuple->BuildValue()); - if (arg_tuple->BuildValue()->isa()) { - dic[ATTR_VALUE] = py::none(); - } else { - dic[ATTR_VALUE] = value_tuple; - } + dic[ATTR_VALUE] = value_tuple; if (dyn_value) { dic[ATTR_MIN_VALUE] = min_value_tuple; @@ -505,20 +496,12 @@ py::dict AbstractDictionaryToPython(const AbstractBasePtr &abs_base) { dic[ATTR_SHAPE] = shape_list; dic[ATTR_DTYPE] = dtype_list; MS_EXCEPTION_IF_NULL(arg_dict->BuildValue()); - if (arg_dict->BuildValue()->isa()) { - dic[ATTR_VALUE] = py::none(); - } else { - dic[ATTR_VALUE] = value_dict; - } + dic[ATTR_VALUE] = value_dict; return dic; } py::object AbstractListValueToPython(const AbstractList *list_abs) { MS_EXCEPTION_IF_NULL(list_abs); - auto value = list_abs->BuildValue(); - if (value->isa()) { - return py::none(); - } const auto &elements = list_abs->elements(); size_t len = elements.size(); py::list value_list(len); @@ -585,12 +568,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base, bool only_convert dic[ATTR_SHAPE] = shape_list; dic[ATTR_DTYPE] = dtype_list; - MS_EXCEPTION_IF_NULL(arg_list->BuildValue()); - if (arg_list->BuildValue()->isa()) { - dic[ATTR_VALUE] = py::none(); - } else { - dic[ATTR_VALUE] = value_list; - } + dic[ATTR_VALUE] = value_list; if (dyn_value) { dic[ATTR_MIN_VALUE] = min_value_list; diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 2fd12bb2886..b2936702145 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -2566,7 +2566,7 @@ class Concat(PrimitiveWithCheck): def infer_value(self, input_x): """Implement Concat infer value""" value = None - if input_x is not None: + if input_x is not None and None not in input_x: value = Tensor(np.concatenate([x.asnumpy() for x in input_x], axis=self.axis)) return value @@ -2782,7 +2782,7 @@ class Stack(PrimitiveWithInfer): tuple_value = value['value'] input_array = [] infered_value = None - if tuple_value is not None: + if tuple_value is not None and None not in tuple_value: for item in tuple_value: npy_item = item.asnumpy() input_array.append(npy_item) diff --git a/mindspore/python/mindspore/scipy/sparse/linalg.py b/mindspore/python/mindspore/scipy/sparse/linalg.py index 26dbd55f8ac..7cc2a6e8a37 100644 --- a/mindspore/python/mindspore/scipy/sparse/linalg.py +++ b/mindspore/python/mindspore/scipy/sparse/linalg.py @@ -347,7 +347,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, restart=20, maxiter=None, _value_check(func_name, callback_type, None, 'callback_type', op='is', fmt='todo') if restart > size: restart = size - if not is_within_graph(A): + if not is_within_graph(b): x, info = GMRES(A, M, solve_method)(b, x0, tol, restart, maxiter, atol) else: x, info = GMRESV2(solve_method)(A, b, x0, tol, restart, maxiter, M, atol) @@ -516,7 +516,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, callback=None _type_check(func_name, maxiter, int, 'maxiter') _value_check(func_name, callback, None, 'callback', op='is', fmt='todo') - if not is_within_graph(A): + if not is_within_graph(b): x, info = CG(A, M)(b, x0, tol, atol, maxiter) else: x, info = CGv2()(A, b, x0, tol, atol, maxiter, M) diff --git a/tests/ut/python/graph_syntax/test_constexpr.py b/tests/ut/python/graph_syntax/test_constexpr.py new file mode 100644 index 00000000000..99c125a5538 --- /dev/null +++ b/tests/ut/python/graph_syntax/test_constexpr.py @@ -0,0 +1,107 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_dict_get """ +from mindspore import Tensor, jit, context, mutable +from mindspore.ops.primitive import constexpr + +context.set_context(mode=context.GRAPH_MODE) + + +@constexpr +def count_none(arg): + if arg is None: + raise ValueError("The arg is None") + a = 0 + for e in arg: + if e is None: + a += 1 + elif isinstance(e, (tuple, list)) and None in e: + a += 1 + return a + + +def test_constexpr_input_with_variable_element_tuple(): + """ + Feature: constexpr with variable element tuple input. + Description: If tuple is used as constexpr input, the variable element will be converted to None. + Expectation: No exception. + """ + @jit + def foo(x): + arg = (1, 2, x, x+1) + return count_none(arg) + + out = foo(Tensor([1])) + assert out == 2 + + +def test_constexpr_input_with_variable_element_tuple_2(): + """ + Feature: constexpr with variable element tuple input. + Description: If tuple is used as constexpr input, the variable element will be converted to None. + Expectation: No exception. + """ + @jit + def foo(x): + arg = (1, 2, x, (x, 1, 2)) + return count_none(arg) + + out = foo(Tensor([1])) + assert out == 2 + + +def test_constexpr_input_with_variable_element_list(): + """ + Feature: constexpr with variable element list input. + Description: If list is used as constexpr input, the variable element will be converted to None. + Expectation: No exception. + """ + @jit + def foo(x): + arg = [1, 2, x, x+1] + return count_none(arg) + + out = foo(Tensor([1])) + assert out == 2 + + +def test_constexpr_input_with_variable_element_list_2(): + """ + Feature: constexpr with variable element list input. + Description: If list is used as constexpr input, the variable element will be converted to None. + Expectation: No exception. + """ + @jit + def foo(x): + arg = [1, 2, x, [x, 1, 2]] + return count_none(arg) + + out = foo(Tensor([1])) + assert out == 2 + + +def test_constexpr_input_with_mutable_list(): + """ + Feature: constexpr with mutable list. + Description: If mutable list is used as constexpr input, all elements will be converted to None. + Expectation: No exception. + """ + @jit + def foo(x): + arg = mutable([Tensor([1]), Tensor([2]), x]) + return count_none(arg) + + out = foo(Tensor([1])) + assert out == 3