From da060da2b98b8429cbb5d71d10440ae7ef5235af Mon Sep 17 00:00:00 2001 From: liangzhibo Date: Thu, 15 Dec 2022 20:34:48 +0800 Subject: [PATCH] enable constexpr to run as graph --- .../pipeline/jit/static_analysis/evaluator.cc | 2 +- .../pipeline/jit/static_analysis/prim.cc | 62 +++++++++++++++++++ .../ccsrc/pipeline/jit/static_analysis/prim.h | 13 ++++ .../jit/static_analysis/static_analysis.cc | 40 +++++++----- mindspore/core/utils/flags.h | 1 + .../python/mindspore/numpy/utils_const.py | 2 +- .../composite/multitype_ops/_compile_utils.py | 6 +- .../multitype_ops/_constexpr_utils.py | 8 +-- .../mindspore/ops/function/sparse_func.py | 6 +- mindspore/python/mindspore/ops/primitive.py | 2 + .../python/mindspore/scipy/utils_const.py | 7 ++- .../test_mutable_variable_length.py | 28 --------- 12 files changed, 119 insertions(+), 58 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index a2cd57f31d6..9ed652cf52f 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -517,7 +517,7 @@ EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, const AnfNodeConfigPtr &out_conf) { if (args_conf_list.empty() && identifier_ != "MakeTupleEvaluator" && identifier_ != "MakeListEvaluator" && - identifier_ != "RaiseEvaluator") { + identifier_ != "RaiseEvaluator" && identifier_ != "ConstexprEvaluator") { MS_LOG(EXCEPTION) << "Size should be greater than 0, during running " << identifier_; } AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 84740885f00..dde98645acb 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -1737,6 +1737,68 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt } } // namespace +EvalResultPtr ConstexprEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, + const ConfigPtr &, const AnfNodeConfigPtr &out_conf) { + // Consider all primitive implemented python infer() real use the tuple/list arguments. + CheckSequenceArgumentForPythonPrimitive(prim_py_, args_spec_list); + MS_EXCEPTION_IF_NULL(prim_py_); + auto py_args = PreparePyInputs(prim_py_, args_spec_list); + prim_py_->BeginRecordAddAttr(); + py::dict output = prim_py_->RunInfer(py_args); + prim_py_->EndRecordAddAttr(); + if (output.contains("fn")) { + // The inputs contain variable, the constexpr will run as graph. + py::tuple values = output["fn"]; + if (values.empty()) { + MS_LOG(EXCEPTION) << "Can not get origin function from constexpr."; + } + auto inner_val = parse::ParsePythonCode(values[0]); + MS_EXCEPTION_IF_NULL(inner_val); + auto inner_fg = dyn_cast(inner_val); + MS_EXCEPTION_IF_NULL(inner_fg); + auto mng = Manage(inner_fg, false); + inner_fg->set_manager(mng); + MS_EXCEPTION_IF_NULL(out_conf); + auto out_node = out_conf->node(); + MS_EXCEPTION_IF_NULL(out_node); + auto out_cnode = dyn_cast(out_node); + MS_EXCEPTION_IF_NULL(out_cnode); + FuncGraphPtr func_graph = out_node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + std::vector new_cnode_inputs = {NewValueNode(inner_fg)}; + const auto &out_cnode_inputs = out_cnode->inputs(); + (void)std::copy(out_cnode_inputs.begin() + 1, out_cnode_inputs.end(), std::back_inserter(new_cnode_inputs)); + auto new_node = func_graph->NewCNodeInOrder(new_cnode_inputs); + func_graph->ReplaceInOrder(out_node, new_node); + AnalysisEnginePtr eng = out_conf->engine(); + MS_EXCEPTION_IF_NULL(eng); + AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph()); + return eng->ForwardConfig(out_conf, fn_conf); + } + // If all inputs are constant value, use python prim evaluator. + // Ensure input arguments are evaluated. + auto ret_abstract = EvalUndeterminedArgs(args_spec_list); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined"; + return ret_abstract; + } + auto forbid_reuse = prim_py_->HasAttr(GRAPH_FLAG_FORBID_REUSE_RESULT); + if (!forbid_reuse) { + // Try to get infer result from evaluator cache. + EvalResultPtr eval_result = evaluator_cache_mgr_->GetValue(args_spec_list); + if (eval_result != nullptr) { + return std::make_shared(eval_result->abstract()->Clone(), eval_result->attribute()); + } + } + const auto &added_attrs = prim_py_->evaluate_added_attrs(); + MS_LOG(DEBUG) << "Output type is " << py::str(output); + auto res_abs = PyInferRes2Abstract(prim_py_, output); + MS_LOG(DEBUG) << "Python InferTensor result abstract: " << res_abs->ToString(); + EvalResultPtr eval_result = std::make_shared(res_abs, std::make_shared(added_attrs)); + evaluator_cache_mgr_->SetValue(args_spec_list, eval_result); + return eval_result; +} + EvalResultPtr MakeTupleEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &, const AnfNodeConfigPtr &out_conf) { std::shared_ptr sequence_nodes = std::make_shared(); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h index 53a92371480..af99b824394 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h @@ -160,6 +160,19 @@ class MixedPrecisionCastEvaluator final : public Evaluator { PrimitivePtr prim_; }; +class ConstexprEvaluator : public TransitionPrimEvaluator { + public: + explicit ConstexprEvaluator(const PrimitivePyPtr primitive) + : TransitionPrimEvaluator("ConstexprEvaluator"), prim_py_(primitive) {} + ~ConstexprEvaluator() override = default; + MS_DECLARE_PARENT(ConstexprEvaluator, TransitionPrimEvaluator) + EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list, const ConfigPtr &, + const AnfNodeConfigPtr &out_conf) override; + + private: + PrimitivePyPtr prim_py_; +}; + class MakeTupleEvaluator : public TransitionPrimEvaluator { public: MakeTupleEvaluator() : TransitionPrimEvaluator("MakeTupleEvaluator") {} diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index 5d544462928..81eeb93dbe9 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -503,6 +503,29 @@ void AnalysisEngine::Clear() { root_context_ = nullptr; } +EvaluatorPtr GetPyEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) { + auto prim_py = dyn_cast(prim); + if (prim_py != nullptr) { + auto is_constexpr = prim_py->HasAttr(GRAPH_FLAG_CONSTEXPR_PRIM); + if (is_constexpr) { + return std::make_shared(prim_py); + } + if (engine == nullptr) { + return std::make_shared(prim_py); + } + + const auto &iter = engine->prim_py_evaluators_.find(prim_py); + if (iter != engine->prim_py_evaluators_.end()) { + return iter->second; + } + auto evaluator = std::make_shared(prim_py); + engine->prim_py_evaluators_[prim_py] = evaluator; + return evaluator; + } + MS_LOG(ERROR) << "The primitive with python evaluator should be a python primitive."; + return nullptr; +} + EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) { // Custom Primitive with python infer_shape, infer_type MS_EXCEPTION_IF_NULL(prim); @@ -533,22 +556,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr // Use python infer function if the infer function not founded in the map return a python evaluator EvaluatorPtr evaluator = nullptr; if (prim->HasPyEvaluator()) { - auto prim_py = dyn_cast(prim); - if (prim_py != nullptr) { - if (engine == nullptr) { - return std::make_shared(prim_py); - } - - const auto &iter = engine->prim_py_evaluators_.find(prim_py); - if (iter != engine->prim_py_evaluators_.end()) { - return iter->second; - } - evaluator = std::make_shared(prim_py); - engine->prim_py_evaluators_[prim_py] = evaluator; - return evaluator; - } - MS_LOG(ERROR) << "The primitive with python evaluator should be a python primitive."; - return nullptr; + return GetPyEvaluator(prim, engine); } // Return a default evaluator diff --git a/mindspore/core/utils/flags.h b/mindspore/core/utils/flags.h index dd7fc5e2a81..05cb9bffbcb 100644 --- a/mindspore/core/utils/flags.h +++ b/mindspore/core/utils/flags.h @@ -30,6 +30,7 @@ inline const char GRAPH_FLAG_SIDE_EFFECT_PROPAGATE[] = "side_effect_propagate"; inline const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP[] = "side_effect_backprop"; inline const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP_MEM[] = "side_effect_backprop_mem"; inline const char GRAPH_FLAG_FORBID_REUSE_RESULT[] = "forbid_reuse_result"; +inline const char GRAPH_FLAG_CONSTEXPR_PRIM[] = "constexpr_prim"; inline const char GRAPH_FLAG_IS_WHILE_HEADER[] = "is_while_header"; inline const char GRAPH_FLAG_ORDER_ENFORCE_SKIP[] = "order_enforce_skip"; inline const char GRAPH_FLAG_BPROP_RETURN_SPARSE[] = "bprop_return_sparse"; diff --git a/mindspore/python/mindspore/numpy/utils_const.py b/mindspore/python/mindspore/numpy/utils_const.py index ea5d508096d..1ccf2caf157 100644 --- a/mindspore/python/mindspore/numpy/utils_const.py +++ b/mindspore/python/mindspore/numpy/utils_const.py @@ -77,7 +77,7 @@ def _check_dtype(dtype): @constexpr def _is_shape_empty(shp): """Check whether shape contains zero""" - if shp is None: + if F.is_sequence_shape_unknown(shp): return False if isinstance(shp, int): return shp == 0 diff --git a/mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py index d0d45b9796c..ac5d7c97cf2 100644 --- a/mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/python/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -432,7 +432,9 @@ def tensor_index_by_list(data, list_index): if all(isinstance(i, bool) for i in list_index): const_utils.raise_unimplemented_error( "Not supported to the dynamic shape tensor slice by using list of Boolean type") - tensor_index = const_utils.sequence_to_index(list_index, data_shape[0]) + tensor_index = const_utils.sequence_to_index(list_index, None) + else: + tensor_index = const_utils.sequence_to_index(list_index, data_shape[0]) if tensor_index is False: const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.") return F.gather(data, tensor_index, 0) @@ -1109,6 +1111,8 @@ def format_list_indices(list_indices, length): # If eyery element in list is bool, it's treated as 1-D bool tensor. # If every element in list is int(not all bool), it's treated as int tensor. if const_utils.judge_indexes_types(indices_types, mstype.int_type + (mstype.bool_,)): + if not F.isconstant(length): + return const_utils.sequence_to_index(list_indices, None) return const_utils.sequence_to_index(list_indices, length) # If list contains other types(.../list/tuple/None), it's treated as a tuple return const_utils.deep_tuple(list_indices) diff --git a/mindspore/python/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/python/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index ef36d389c51..484b23333ff 100644 --- a/mindspore/python/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/python/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -25,7 +25,6 @@ from mindspore.common._register_for_tensor import tensor_operator_registry from mindspore.common.tensor import Tensor from mindspore.ops import operations as P from mindspore.ops.primitive import constexpr - from mindspore import log as logger ALL_TENSOR = 0 @@ -166,6 +165,7 @@ def _deep_tensor_to_nparray(array_like): #TODO: remove comment +#@constexpr(run_graph=False) @constexpr def check_range(x, dim_size): if dim_size is None: @@ -634,12 +634,6 @@ def _judge_order_continuous(order_sequence): @constexpr def scalar_in_sequence(x, y): """Determine whether the scalar in the sequence.""" - if x is None: - raise ValueError("Judge scalar in tuple or list require scalar and sequence must be constant, " - "but the scalar is not.") - if y is None: - raise ValueError("Judge scalar in tuple or list require scalar and sequence must be constant, " - "but the sequence is not.") return x in y diff --git a/mindspore/python/mindspore/ops/function/sparse_func.py b/mindspore/python/mindspore/ops/function/sparse_func.py index a663c7aac88..cf4a886004c 100644 --- a/mindspore/python/mindspore/ops/function/sparse_func.py +++ b/mindspore/python/mindspore/ops/function/sparse_func.py @@ -32,6 +32,7 @@ from mindspore.common import dtype as mstype from mindspore.ops.primitive import constexpr, Primitive from mindspore.ops.operations.array_ops import GatherNd, Coalesce from mindspore.ops.operations import _csr_ops +from mindspore.ops import functional as F from mindspore.common import CSRTensor, COOTensor, Tensor from mindspore.ops.composite.multitype_ops._constexpr_utils import raise_value_error, raise_type_error, make_tensor,\ promote_binary_dtype @@ -63,12 +64,11 @@ def _make_tensor_with_dtype(data, dtype): return Tensor(data, dtype=dtype) -@constexpr def _convert_shape(shape): """Temporary solution to get shape value, will be removed when shape op is supported.""" - if shape is None: + if F.is_sequence_shape_unknown(shape): return (-2,) - shape = [-1 if i is None else i for i in shape] + shape = [-1 if not F.isconstant(i) else i for i in shape] return tuple(shape) diff --git a/mindspore/python/mindspore/ops/primitive.py b/mindspore/python/mindspore/ops/primitive.py index 81072bb6e9a..b18274ccd77 100644 --- a/mindspore/python/mindspore/ops/primitive.py +++ b/mindspore/python/mindspore/ops/primitive.py @@ -777,6 +777,7 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr PrimitiveWithInfer.__init__(self, op_name) self.set_const_prim(True) self.fn = fn + self.add_prim_attr('constexpr_prim', True) if not reuse_result: self.add_prim_attr('forbid_reuse_result', True) @@ -786,6 +787,7 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr if (item["dtype"] is not None and item["value"] is None and check): logger.warning("The \"" + self.name + "\" is a constexpr function." \ " The input arguments must be all constant value.") + return {'dtype': None, 'shape': None, 'value': None, 'fn': (fn,)} value_args.append(item["value"]) return {'dtype': None, 'shape': None, 'value': fn(*value_args)} diff --git a/mindspore/python/mindspore/scipy/utils_const.py b/mindspore/python/mindspore/scipy/utils_const.py index c974f5b1366..d29604f1068 100644 --- a/mindspore/python/mindspore/scipy/utils_const.py +++ b/mindspore/python/mindspore/scipy/utils_const.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from types import FunctionType from collections.abc import Iterable +from mindspore.ops import functional as F from .. import context from ..ops.primitive import constexpr from ..common import Tensor, CSRTensor @@ -29,13 +30,17 @@ def _callable_const(x): @constexpr +def is_pynative(): + return context.get_context("mode") == context.PYNATIVE_MODE + + def is_within_graph(x): """ Returns true if x is None. It's aim to check whether the call is within MindSpore graph. Because in graph mode, x should be None in constexpr when x is a variable of MindSpore. Note that always return true if the call is in pynative mode. """ - return context.get_context("mode") == context.PYNATIVE_MODE or x is None + return is_pynative() or not F.isconstant(x) or x is None @constexpr diff --git a/tests/ut/python/dynamic_sequence/test_mutable_variable_length.py b/tests/ut/python/dynamic_sequence/test_mutable_variable_length.py index c0d30ecf070..0d436f88abb 100644 --- a/tests/ut/python/dynamic_sequence/test_mutable_variable_length.py +++ b/tests/ut/python/dynamic_sequence/test_mutable_variable_length.py @@ -19,7 +19,6 @@ from mindspore.ops import functional as F from mindspore import Tensor from mindspore import jit from mindspore import context -from mindspore.ops.primitive import constexpr def test_generate_mutable_sequence_with_dynamic_length_with_jit(): @@ -137,33 +136,6 @@ def test_dynamic_length_sequence_length_sequence_value_shape_unknown_2(): assert not ret2 -def test_dynamic_length_sequence_length_sequence_with_constexpr(): - """ - Feature: Mutable with dynamic length. - Description: Dynamic length sequence should be convert to None with passing to constexpr. - Expectation: No exception. - """ - context.set_context(mode=context.GRAPH_MODE) - - @constexpr - def test(x): - return x - - @jit - def foo(x): - return test(x) - - x = mutable([Tensor([1]), Tensor([2])], True) - ret = foo(x) - assert ret is None - - y = mutable([Tensor([1]), Tensor([2])], False) - ret = foo(y) - assert len(ret) == 2 - assert ret[0] is None - assert ret[1] is None - - def test_dynamic_length_sequence_getitem(): """ Feature: Mutable with dynamic length.