enable constexpr to run as graph

This commit is contained in:
liangzhibo 2022-12-15 20:34:48 +08:00
parent 942a488d58
commit da060da2b9
12 changed files with 119 additions and 58 deletions

View File

@ -517,7 +517,7 @@ EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {
if (args_conf_list.empty() && identifier_ != "MakeTupleEvaluator" && identifier_ != "MakeListEvaluator" &&
identifier_ != "RaiseEvaluator") {
identifier_ != "RaiseEvaluator" && identifier_ != "ConstexprEvaluator") {
MS_LOG(EXCEPTION) << "Size should be greater than 0, during running " << identifier_;
}
AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);

View File

@ -1737,6 +1737,68 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
}
} // namespace
EvalResultPtr ConstexprEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
// Consider all primitive implemented python infer() real use the tuple/list arguments.
CheckSequenceArgumentForPythonPrimitive(prim_py_, args_spec_list);
MS_EXCEPTION_IF_NULL(prim_py_);
auto py_args = PreparePyInputs(prim_py_, args_spec_list);
prim_py_->BeginRecordAddAttr();
py::dict output = prim_py_->RunInfer(py_args);
prim_py_->EndRecordAddAttr();
if (output.contains("fn")) {
// The inputs contain variable, the constexpr will run as graph.
py::tuple values = output["fn"];
if (values.empty()) {
MS_LOG(EXCEPTION) << "Can not get origin function from constexpr.";
}
auto inner_val = parse::ParsePythonCode(values[0]);
MS_EXCEPTION_IF_NULL(inner_val);
auto inner_fg = dyn_cast<FuncGraph>(inner_val);
MS_EXCEPTION_IF_NULL(inner_fg);
auto mng = Manage(inner_fg, false);
inner_fg->set_manager(mng);
MS_EXCEPTION_IF_NULL(out_conf);
auto out_node = out_conf->node();
MS_EXCEPTION_IF_NULL(out_node);
auto out_cnode = dyn_cast<CNode>(out_node);
MS_EXCEPTION_IF_NULL(out_cnode);
FuncGraphPtr func_graph = out_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> new_cnode_inputs = {NewValueNode(inner_fg)};
const auto &out_cnode_inputs = out_cnode->inputs();
(void)std::copy(out_cnode_inputs.begin() + 1, out_cnode_inputs.end(), std::back_inserter(new_cnode_inputs));
auto new_node = func_graph->NewCNodeInOrder(new_cnode_inputs);
func_graph->ReplaceInOrder(out_node, new_node);
AnalysisEnginePtr eng = out_conf->engine();
MS_EXCEPTION_IF_NULL(eng);
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
return eng->ForwardConfig(out_conf, fn_conf);
}
// If all inputs are constant value, use python prim evaluator.
// Ensure input arguments are evaluated.
auto ret_abstract = EvalUndeterminedArgs(args_spec_list);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
return ret_abstract;
}
auto forbid_reuse = prim_py_->HasAttr(GRAPH_FLAG_FORBID_REUSE_RESULT);
if (!forbid_reuse) {
// Try to get infer result from evaluator cache.
EvalResultPtr eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
if (eval_result != nullptr) {
return std::make_shared<EvalResult>(eval_result->abstract()->Clone(), eval_result->attribute());
}
}
const auto &added_attrs = prim_py_->evaluate_added_attrs();
MS_LOG(DEBUG) << "Output type is " << py::str(output);
auto res_abs = PyInferRes2Abstract(prim_py_, output);
MS_LOG(DEBUG) << "Python InferTensor result abstract: " << res_abs->ToString();
EvalResultPtr eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>(added_attrs));
evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
return eval_result;
}
EvalResultPtr MakeTupleEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list,
const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();

View File

@ -160,6 +160,19 @@ class MixedPrecisionCastEvaluator final : public Evaluator {
PrimitivePtr prim_;
};
class ConstexprEvaluator : public TransitionPrimEvaluator {
public:
explicit ConstexprEvaluator(const PrimitivePyPtr primitive)
: TransitionPrimEvaluator("ConstexprEvaluator"), prim_py_(primitive) {}
~ConstexprEvaluator() override = default;
MS_DECLARE_PARENT(ConstexprEvaluator, TransitionPrimEvaluator)
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override;
private:
PrimitivePyPtr prim_py_;
};
class MakeTupleEvaluator : public TransitionPrimEvaluator {
public:
MakeTupleEvaluator() : TransitionPrimEvaluator("MakeTupleEvaluator") {}

View File

@ -503,6 +503,29 @@ void AnalysisEngine::Clear() {
root_context_ = nullptr;
}
EvaluatorPtr GetPyEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
auto prim_py = dyn_cast<PrimitivePy>(prim);
if (prim_py != nullptr) {
auto is_constexpr = prim_py->HasAttr(GRAPH_FLAG_CONSTEXPR_PRIM);
if (is_constexpr) {
return std::make_shared<ConstexprEvaluator>(prim_py);
}
if (engine == nullptr) {
return std::make_shared<PythonPrimEvaluator>(prim_py);
}
const auto &iter = engine->prim_py_evaluators_.find(prim_py);
if (iter != engine->prim_py_evaluators_.end()) {
return iter->second;
}
auto evaluator = std::make_shared<PythonPrimEvaluator>(prim_py);
engine->prim_py_evaluators_[prim_py] = evaluator;
return evaluator;
}
MS_LOG(ERROR) << "The primitive with python evaluator should be a python primitive.";
return nullptr;
}
EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
// Custom Primitive with python infer_shape, infer_type
MS_EXCEPTION_IF_NULL(prim);
@ -533,22 +556,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
// Use python infer function if the infer function not founded in the map return a python evaluator
EvaluatorPtr evaluator = nullptr;
if (prim->HasPyEvaluator()) {
auto prim_py = dyn_cast<PrimitivePy>(prim);
if (prim_py != nullptr) {
if (engine == nullptr) {
return std::make_shared<PythonPrimEvaluator>(prim_py);
}
const auto &iter = engine->prim_py_evaluators_.find(prim_py);
if (iter != engine->prim_py_evaluators_.end()) {
return iter->second;
}
evaluator = std::make_shared<PythonPrimEvaluator>(prim_py);
engine->prim_py_evaluators_[prim_py] = evaluator;
return evaluator;
}
MS_LOG(ERROR) << "The primitive with python evaluator should be a python primitive.";
return nullptr;
return GetPyEvaluator(prim, engine);
}
// Return a default evaluator

View File

@ -30,6 +30,7 @@ inline const char GRAPH_FLAG_SIDE_EFFECT_PROPAGATE[] = "side_effect_propagate";
inline const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP[] = "side_effect_backprop";
inline const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP_MEM[] = "side_effect_backprop_mem";
inline const char GRAPH_FLAG_FORBID_REUSE_RESULT[] = "forbid_reuse_result";
inline const char GRAPH_FLAG_CONSTEXPR_PRIM[] = "constexpr_prim";
inline const char GRAPH_FLAG_IS_WHILE_HEADER[] = "is_while_header";
inline const char GRAPH_FLAG_ORDER_ENFORCE_SKIP[] = "order_enforce_skip";
inline const char GRAPH_FLAG_BPROP_RETURN_SPARSE[] = "bprop_return_sparse";

View File

@ -77,7 +77,7 @@ def _check_dtype(dtype):
@constexpr
def _is_shape_empty(shp):
"""Check whether shape contains zero"""
if shp is None:
if F.is_sequence_shape_unknown(shp):
return False
if isinstance(shp, int):
return shp == 0

View File

@ -432,6 +432,8 @@ def tensor_index_by_list(data, list_index):
if all(isinstance(i, bool) for i in list_index):
const_utils.raise_unimplemented_error(
"Not supported to the dynamic shape tensor slice by using list of Boolean type")
tensor_index = const_utils.sequence_to_index(list_index, None)
else:
tensor_index = const_utils.sequence_to_index(list_index, data_shape[0])
if tensor_index is False:
const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.")
@ -1109,6 +1111,8 @@ def format_list_indices(list_indices, length):
# If eyery element in list is bool, it's treated as 1-D bool tensor.
# If every element in list is int(not all bool), it's treated as int tensor.
if const_utils.judge_indexes_types(indices_types, mstype.int_type + (mstype.bool_,)):
if not F.isconstant(length):
return const_utils.sequence_to_index(list_indices, None)
return const_utils.sequence_to_index(list_indices, length)
# If list contains other types(.../list/tuple/None), it's treated as a tuple
return const_utils.deep_tuple(list_indices)

View File

@ -25,7 +25,6 @@ from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.ops.primitive import constexpr
from mindspore import log as logger
ALL_TENSOR = 0
@ -166,6 +165,7 @@ def _deep_tensor_to_nparray(array_like):
#TODO: remove comment
#@constexpr(run_graph=False)
@constexpr
def check_range(x, dim_size):
if dim_size is None:
@ -634,12 +634,6 @@ def _judge_order_continuous(order_sequence):
@constexpr
def scalar_in_sequence(x, y):
"""Determine whether the scalar in the sequence."""
if x is None:
raise ValueError("Judge scalar in tuple or list require scalar and sequence must be constant, "
"but the scalar is not.")
if y is None:
raise ValueError("Judge scalar in tuple or list require scalar and sequence must be constant, "
"but the sequence is not.")
return x in y

View File

@ -32,6 +32,7 @@ from mindspore.common import dtype as mstype
from mindspore.ops.primitive import constexpr, Primitive
from mindspore.ops.operations.array_ops import GatherNd, Coalesce
from mindspore.ops.operations import _csr_ops
from mindspore.ops import functional as F
from mindspore.common import CSRTensor, COOTensor, Tensor
from mindspore.ops.composite.multitype_ops._constexpr_utils import raise_value_error, raise_type_error, make_tensor,\
promote_binary_dtype
@ -63,12 +64,11 @@ def _make_tensor_with_dtype(data, dtype):
return Tensor(data, dtype=dtype)
@constexpr
def _convert_shape(shape):
"""Temporary solution to get shape value, will be removed when shape op is supported."""
if shape is None:
if F.is_sequence_shape_unknown(shape):
return (-2,)
shape = [-1 if i is None else i for i in shape]
shape = [-1 if not F.isconstant(i) else i for i in shape]
return tuple(shape)

View File

@ -777,6 +777,7 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr
PrimitiveWithInfer.__init__(self, op_name)
self.set_const_prim(True)
self.fn = fn
self.add_prim_attr('constexpr_prim', True)
if not reuse_result:
self.add_prim_attr('forbid_reuse_result', True)
@ -786,6 +787,7 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr
if (item["dtype"] is not None and item["value"] is None and check):
logger.warning("The \"" + self.name + "\" is a constexpr function." \
" The input arguments must be all constant value.")
return {'dtype': None, 'shape': None, 'value': None, 'fn': (fn,)}
value_args.append(item["value"])
return {'dtype': None, 'shape': None, 'value': fn(*value_args)}

View File

@ -16,6 +16,7 @@
from __future__ import absolute_import
from types import FunctionType
from collections.abc import Iterable
from mindspore.ops import functional as F
from .. import context
from ..ops.primitive import constexpr
from ..common import Tensor, CSRTensor
@ -29,13 +30,17 @@ def _callable_const(x):
@constexpr
def is_pynative():
return context.get_context("mode") == context.PYNATIVE_MODE
def is_within_graph(x):
"""
Returns true if x is None. It's aim to check whether the call is within MindSpore graph.
Because in graph mode, x should be None in constexpr when x is a variable of MindSpore.
Note that always return true if the call is in pynative mode.
"""
return context.get_context("mode") == context.PYNATIVE_MODE or x is None
return is_pynative() or not F.isconstant(x) or x is None
@constexpr

View File

@ -19,7 +19,6 @@ from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore import jit
from mindspore import context
from mindspore.ops.primitive import constexpr
def test_generate_mutable_sequence_with_dynamic_length_with_jit():
@ -137,33 +136,6 @@ def test_dynamic_length_sequence_length_sequence_value_shape_unknown_2():
assert not ret2
def test_dynamic_length_sequence_length_sequence_with_constexpr():
"""
Feature: Mutable with dynamic length.
Description: Dynamic length sequence should be convert to None with passing to constexpr.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
@constexpr
def test(x):
return x
@jit
def foo(x):
return test(x)
x = mutable([Tensor([1]), Tensor([2])], True)
ret = foo(x)
assert ret is None
y = mutable([Tensor([1]), Tensor([2])], False)
ret = foo(y)
assert len(ret) == 2
assert ret[0] is None
assert ret[1] is None
def test_dynamic_length_sequence_getitem():
"""
Feature: Mutable with dynamic length.