enable constexpr to run as graph

This commit is contained in:
liangzhibo 2022-12-15 20:34:48 +08:00
parent 942a488d58
commit da060da2b9
12 changed files with 119 additions and 58 deletions

View File

@ -517,7 +517,7 @@ EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) { const AnfNodeConfigPtr &out_conf) {
if (args_conf_list.empty() && identifier_ != "MakeTupleEvaluator" && identifier_ != "MakeListEvaluator" && if (args_conf_list.empty() && identifier_ != "MakeTupleEvaluator" && identifier_ != "MakeListEvaluator" &&
identifier_ != "RaiseEvaluator") { identifier_ != "RaiseEvaluator" && identifier_ != "ConstexprEvaluator") {
MS_LOG(EXCEPTION) << "Size should be greater than 0, during running " << identifier_; MS_LOG(EXCEPTION) << "Size should be greater than 0, during running " << identifier_;
} }
AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list); AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);

View File

@ -1737,6 +1737,68 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
} }
} // namespace } // namespace
EvalResultPtr ConstexprEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
// Consider all primitive implemented python infer() real use the tuple/list arguments.
CheckSequenceArgumentForPythonPrimitive(prim_py_, args_spec_list);
MS_EXCEPTION_IF_NULL(prim_py_);
auto py_args = PreparePyInputs(prim_py_, args_spec_list);
prim_py_->BeginRecordAddAttr();
py::dict output = prim_py_->RunInfer(py_args);
prim_py_->EndRecordAddAttr();
if (output.contains("fn")) {
// The inputs contain variable, the constexpr will run as graph.
py::tuple values = output["fn"];
if (values.empty()) {
MS_LOG(EXCEPTION) << "Can not get origin function from constexpr.";
}
auto inner_val = parse::ParsePythonCode(values[0]);
MS_EXCEPTION_IF_NULL(inner_val);
auto inner_fg = dyn_cast<FuncGraph>(inner_val);
MS_EXCEPTION_IF_NULL(inner_fg);
auto mng = Manage(inner_fg, false);
inner_fg->set_manager(mng);
MS_EXCEPTION_IF_NULL(out_conf);
auto out_node = out_conf->node();
MS_EXCEPTION_IF_NULL(out_node);
auto out_cnode = dyn_cast<CNode>(out_node);
MS_EXCEPTION_IF_NULL(out_cnode);
FuncGraphPtr func_graph = out_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> new_cnode_inputs = {NewValueNode(inner_fg)};
const auto &out_cnode_inputs = out_cnode->inputs();
(void)std::copy(out_cnode_inputs.begin() + 1, out_cnode_inputs.end(), std::back_inserter(new_cnode_inputs));
auto new_node = func_graph->NewCNodeInOrder(new_cnode_inputs);
func_graph->ReplaceInOrder(out_node, new_node);
AnalysisEnginePtr eng = out_conf->engine();
MS_EXCEPTION_IF_NULL(eng);
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
return eng->ForwardConfig(out_conf, fn_conf);
}
// If all inputs are constant value, use python prim evaluator.
// Ensure input arguments are evaluated.
auto ret_abstract = EvalUndeterminedArgs(args_spec_list);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
return ret_abstract;
}
auto forbid_reuse = prim_py_->HasAttr(GRAPH_FLAG_FORBID_REUSE_RESULT);
if (!forbid_reuse) {
// Try to get infer result from evaluator cache.
EvalResultPtr eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
if (eval_result != nullptr) {
return std::make_shared<EvalResult>(eval_result->abstract()->Clone(), eval_result->attribute());
}
}
const auto &added_attrs = prim_py_->evaluate_added_attrs();
MS_LOG(DEBUG) << "Output type is " << py::str(output);
auto res_abs = PyInferRes2Abstract(prim_py_, output);
MS_LOG(DEBUG) << "Python InferTensor result abstract: " << res_abs->ToString();
EvalResultPtr eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>(added_attrs));
evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
return eval_result;
}
EvalResultPtr MakeTupleEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, EvalResultPtr MakeTupleEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list,
const ConfigPtr &, const AnfNodeConfigPtr &out_conf) { const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>(); std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();

View File

@ -160,6 +160,19 @@ class MixedPrecisionCastEvaluator final : public Evaluator {
PrimitivePtr prim_; PrimitivePtr prim_;
}; };
class ConstexprEvaluator : public TransitionPrimEvaluator {
public:
explicit ConstexprEvaluator(const PrimitivePyPtr primitive)
: TransitionPrimEvaluator("ConstexprEvaluator"), prim_py_(primitive) {}
~ConstexprEvaluator() override = default;
MS_DECLARE_PARENT(ConstexprEvaluator, TransitionPrimEvaluator)
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override;
private:
PrimitivePyPtr prim_py_;
};
class MakeTupleEvaluator : public TransitionPrimEvaluator { class MakeTupleEvaluator : public TransitionPrimEvaluator {
public: public:
MakeTupleEvaluator() : TransitionPrimEvaluator("MakeTupleEvaluator") {} MakeTupleEvaluator() : TransitionPrimEvaluator("MakeTupleEvaluator") {}

View File

@ -503,6 +503,29 @@ void AnalysisEngine::Clear() {
root_context_ = nullptr; root_context_ = nullptr;
} }
EvaluatorPtr GetPyEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
auto prim_py = dyn_cast<PrimitivePy>(prim);
if (prim_py != nullptr) {
auto is_constexpr = prim_py->HasAttr(GRAPH_FLAG_CONSTEXPR_PRIM);
if (is_constexpr) {
return std::make_shared<ConstexprEvaluator>(prim_py);
}
if (engine == nullptr) {
return std::make_shared<PythonPrimEvaluator>(prim_py);
}
const auto &iter = engine->prim_py_evaluators_.find(prim_py);
if (iter != engine->prim_py_evaluators_.end()) {
return iter->second;
}
auto evaluator = std::make_shared<PythonPrimEvaluator>(prim_py);
engine->prim_py_evaluators_[prim_py] = evaluator;
return evaluator;
}
MS_LOG(ERROR) << "The primitive with python evaluator should be a python primitive.";
return nullptr;
}
EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) { EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
// Custom Primitive with python infer_shape, infer_type // Custom Primitive with python infer_shape, infer_type
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
@ -533,22 +556,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
// Use python infer function if the infer function not founded in the map return a python evaluator // Use python infer function if the infer function not founded in the map return a python evaluator
EvaluatorPtr evaluator = nullptr; EvaluatorPtr evaluator = nullptr;
if (prim->HasPyEvaluator()) { if (prim->HasPyEvaluator()) {
auto prim_py = dyn_cast<PrimitivePy>(prim); return GetPyEvaluator(prim, engine);
if (prim_py != nullptr) {
if (engine == nullptr) {
return std::make_shared<PythonPrimEvaluator>(prim_py);
}
const auto &iter = engine->prim_py_evaluators_.find(prim_py);
if (iter != engine->prim_py_evaluators_.end()) {
return iter->second;
}
evaluator = std::make_shared<PythonPrimEvaluator>(prim_py);
engine->prim_py_evaluators_[prim_py] = evaluator;
return evaluator;
}
MS_LOG(ERROR) << "The primitive with python evaluator should be a python primitive.";
return nullptr;
} }
// Return a default evaluator // Return a default evaluator

View File

@ -30,6 +30,7 @@ inline const char GRAPH_FLAG_SIDE_EFFECT_PROPAGATE[] = "side_effect_propagate";
inline const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP[] = "side_effect_backprop"; inline const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP[] = "side_effect_backprop";
inline const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP_MEM[] = "side_effect_backprop_mem"; inline const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP_MEM[] = "side_effect_backprop_mem";
inline const char GRAPH_FLAG_FORBID_REUSE_RESULT[] = "forbid_reuse_result"; inline const char GRAPH_FLAG_FORBID_REUSE_RESULT[] = "forbid_reuse_result";
inline const char GRAPH_FLAG_CONSTEXPR_PRIM[] = "constexpr_prim";
inline const char GRAPH_FLAG_IS_WHILE_HEADER[] = "is_while_header"; inline const char GRAPH_FLAG_IS_WHILE_HEADER[] = "is_while_header";
inline const char GRAPH_FLAG_ORDER_ENFORCE_SKIP[] = "order_enforce_skip"; inline const char GRAPH_FLAG_ORDER_ENFORCE_SKIP[] = "order_enforce_skip";
inline const char GRAPH_FLAG_BPROP_RETURN_SPARSE[] = "bprop_return_sparse"; inline const char GRAPH_FLAG_BPROP_RETURN_SPARSE[] = "bprop_return_sparse";

View File

@ -77,7 +77,7 @@ def _check_dtype(dtype):
@constexpr @constexpr
def _is_shape_empty(shp): def _is_shape_empty(shp):
"""Check whether shape contains zero""" """Check whether shape contains zero"""
if shp is None: if F.is_sequence_shape_unknown(shp):
return False return False
if isinstance(shp, int): if isinstance(shp, int):
return shp == 0 return shp == 0

View File

@ -432,7 +432,9 @@ def tensor_index_by_list(data, list_index):
if all(isinstance(i, bool) for i in list_index): if all(isinstance(i, bool) for i in list_index):
const_utils.raise_unimplemented_error( const_utils.raise_unimplemented_error(
"Not supported to the dynamic shape tensor slice by using list of Boolean type") "Not supported to the dynamic shape tensor slice by using list of Boolean type")
tensor_index = const_utils.sequence_to_index(list_index, data_shape[0]) tensor_index = const_utils.sequence_to_index(list_index, None)
else:
tensor_index = const_utils.sequence_to_index(list_index, data_shape[0])
if tensor_index is False: if tensor_index is False:
const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.") const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.")
return F.gather(data, tensor_index, 0) return F.gather(data, tensor_index, 0)
@ -1109,6 +1111,8 @@ def format_list_indices(list_indices, length):
# If eyery element in list is bool, it's treated as 1-D bool tensor. # If eyery element in list is bool, it's treated as 1-D bool tensor.
# If every element in list is int(not all bool), it's treated as int tensor. # If every element in list is int(not all bool), it's treated as int tensor.
if const_utils.judge_indexes_types(indices_types, mstype.int_type + (mstype.bool_,)): if const_utils.judge_indexes_types(indices_types, mstype.int_type + (mstype.bool_,)):
if not F.isconstant(length):
return const_utils.sequence_to_index(list_indices, None)
return const_utils.sequence_to_index(list_indices, length) return const_utils.sequence_to_index(list_indices, length)
# If list contains other types(.../list/tuple/None), it's treated as a tuple # If list contains other types(.../list/tuple/None), it's treated as a tuple
return const_utils.deep_tuple(list_indices) return const_utils.deep_tuple(list_indices)

View File

@ -25,7 +25,6 @@ from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
from mindspore import log as logger from mindspore import log as logger
ALL_TENSOR = 0 ALL_TENSOR = 0
@ -166,6 +165,7 @@ def _deep_tensor_to_nparray(array_like):
#TODO: remove comment #TODO: remove comment
#@constexpr(run_graph=False)
@constexpr @constexpr
def check_range(x, dim_size): def check_range(x, dim_size):
if dim_size is None: if dim_size is None:
@ -634,12 +634,6 @@ def _judge_order_continuous(order_sequence):
@constexpr @constexpr
def scalar_in_sequence(x, y): def scalar_in_sequence(x, y):
"""Determine whether the scalar in the sequence.""" """Determine whether the scalar in the sequence."""
if x is None:
raise ValueError("Judge scalar in tuple or list require scalar and sequence must be constant, "
"but the scalar is not.")
if y is None:
raise ValueError("Judge scalar in tuple or list require scalar and sequence must be constant, "
"but the sequence is not.")
return x in y return x in y

View File

@ -32,6 +32,7 @@ from mindspore.common import dtype as mstype
from mindspore.ops.primitive import constexpr, Primitive from mindspore.ops.primitive import constexpr, Primitive
from mindspore.ops.operations.array_ops import GatherNd, Coalesce from mindspore.ops.operations.array_ops import GatherNd, Coalesce
from mindspore.ops.operations import _csr_ops from mindspore.ops.operations import _csr_ops
from mindspore.ops import functional as F
from mindspore.common import CSRTensor, COOTensor, Tensor from mindspore.common import CSRTensor, COOTensor, Tensor
from mindspore.ops.composite.multitype_ops._constexpr_utils import raise_value_error, raise_type_error, make_tensor,\ from mindspore.ops.composite.multitype_ops._constexpr_utils import raise_value_error, raise_type_error, make_tensor,\
promote_binary_dtype promote_binary_dtype
@ -63,12 +64,11 @@ def _make_tensor_with_dtype(data, dtype):
return Tensor(data, dtype=dtype) return Tensor(data, dtype=dtype)
@constexpr
def _convert_shape(shape): def _convert_shape(shape):
"""Temporary solution to get shape value, will be removed when shape op is supported.""" """Temporary solution to get shape value, will be removed when shape op is supported."""
if shape is None: if F.is_sequence_shape_unknown(shape):
return (-2,) return (-2,)
shape = [-1 if i is None else i for i in shape] shape = [-1 if not F.isconstant(i) else i for i in shape]
return tuple(shape) return tuple(shape)

View File

@ -777,6 +777,7 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr
PrimitiveWithInfer.__init__(self, op_name) PrimitiveWithInfer.__init__(self, op_name)
self.set_const_prim(True) self.set_const_prim(True)
self.fn = fn self.fn = fn
self.add_prim_attr('constexpr_prim', True)
if not reuse_result: if not reuse_result:
self.add_prim_attr('forbid_reuse_result', True) self.add_prim_attr('forbid_reuse_result', True)
@ -786,6 +787,7 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr
if (item["dtype"] is not None and item["value"] is None and check): if (item["dtype"] is not None and item["value"] is None and check):
logger.warning("The \"" + self.name + "\" is a constexpr function." \ logger.warning("The \"" + self.name + "\" is a constexpr function." \
" The input arguments must be all constant value.") " The input arguments must be all constant value.")
return {'dtype': None, 'shape': None, 'value': None, 'fn': (fn,)}
value_args.append(item["value"]) value_args.append(item["value"])
return {'dtype': None, 'shape': None, 'value': fn(*value_args)} return {'dtype': None, 'shape': None, 'value': fn(*value_args)}

View File

@ -16,6 +16,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from types import FunctionType from types import FunctionType
from collections.abc import Iterable from collections.abc import Iterable
from mindspore.ops import functional as F
from .. import context from .. import context
from ..ops.primitive import constexpr from ..ops.primitive import constexpr
from ..common import Tensor, CSRTensor from ..common import Tensor, CSRTensor
@ -29,13 +30,17 @@ def _callable_const(x):
@constexpr @constexpr
def is_pynative():
return context.get_context("mode") == context.PYNATIVE_MODE
def is_within_graph(x): def is_within_graph(x):
""" """
Returns true if x is None. It's aim to check whether the call is within MindSpore graph. Returns true if x is None. It's aim to check whether the call is within MindSpore graph.
Because in graph mode, x should be None in constexpr when x is a variable of MindSpore. Because in graph mode, x should be None in constexpr when x is a variable of MindSpore.
Note that always return true if the call is in pynative mode. Note that always return true if the call is in pynative mode.
""" """
return context.get_context("mode") == context.PYNATIVE_MODE or x is None return is_pynative() or not F.isconstant(x) or x is None
@constexpr @constexpr

View File

@ -19,7 +19,6 @@ from mindspore.ops import functional as F
from mindspore import Tensor from mindspore import Tensor
from mindspore import jit from mindspore import jit
from mindspore import context from mindspore import context
from mindspore.ops.primitive import constexpr
def test_generate_mutable_sequence_with_dynamic_length_with_jit(): def test_generate_mutable_sequence_with_dynamic_length_with_jit():
@ -137,33 +136,6 @@ def test_dynamic_length_sequence_length_sequence_value_shape_unknown_2():
assert not ret2 assert not ret2
def test_dynamic_length_sequence_length_sequence_with_constexpr():
"""
Feature: Mutable with dynamic length.
Description: Dynamic length sequence should be convert to None with passing to constexpr.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
@constexpr
def test(x):
return x
@jit
def foo(x):
return test(x)
x = mutable([Tensor([1]), Tensor([2])], True)
ret = foo(x)
assert ret is None
y = mutable([Tensor([1]), Tensor([2])], False)
ret = foo(y)
assert len(ret) == 2
assert ret[0] is None
assert ret[1] is None
def test_dynamic_length_sequence_getitem(): def test_dynamic_length_sequence_getitem():
""" """
Feature: Mutable with dynamic length. Feature: Mutable with dynamic length.