forked from mindspore-Ecosystem/mindspore
enable constexpr to run as graph
This commit is contained in:
parent
942a488d58
commit
da060da2b9
|
@ -517,7 +517,7 @@ EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
if (args_conf_list.empty() && identifier_ != "MakeTupleEvaluator" && identifier_ != "MakeListEvaluator" &&
|
||||
identifier_ != "RaiseEvaluator") {
|
||||
identifier_ != "RaiseEvaluator" && identifier_ != "ConstexprEvaluator") {
|
||||
MS_LOG(EXCEPTION) << "Size should be greater than 0, during running " << identifier_;
|
||||
}
|
||||
AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
|
||||
|
|
|
@ -1737,6 +1737,68 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
}
|
||||
} // namespace
|
||||
|
||||
EvalResultPtr ConstexprEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
|
||||
// Consider all primitive implemented python infer() real use the tuple/list arguments.
|
||||
CheckSequenceArgumentForPythonPrimitive(prim_py_, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(prim_py_);
|
||||
auto py_args = PreparePyInputs(prim_py_, args_spec_list);
|
||||
prim_py_->BeginRecordAddAttr();
|
||||
py::dict output = prim_py_->RunInfer(py_args);
|
||||
prim_py_->EndRecordAddAttr();
|
||||
if (output.contains("fn")) {
|
||||
// The inputs contain variable, the constexpr will run as graph.
|
||||
py::tuple values = output["fn"];
|
||||
if (values.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Can not get origin function from constexpr.";
|
||||
}
|
||||
auto inner_val = parse::ParsePythonCode(values[0]);
|
||||
MS_EXCEPTION_IF_NULL(inner_val);
|
||||
auto inner_fg = dyn_cast<FuncGraph>(inner_val);
|
||||
MS_EXCEPTION_IF_NULL(inner_fg);
|
||||
auto mng = Manage(inner_fg, false);
|
||||
inner_fg->set_manager(mng);
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
auto out_node = out_conf->node();
|
||||
MS_EXCEPTION_IF_NULL(out_node);
|
||||
auto out_cnode = dyn_cast<CNode>(out_node);
|
||||
MS_EXCEPTION_IF_NULL(out_cnode);
|
||||
FuncGraphPtr func_graph = out_node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> new_cnode_inputs = {NewValueNode(inner_fg)};
|
||||
const auto &out_cnode_inputs = out_cnode->inputs();
|
||||
(void)std::copy(out_cnode_inputs.begin() + 1, out_cnode_inputs.end(), std::back_inserter(new_cnode_inputs));
|
||||
auto new_node = func_graph->NewCNodeInOrder(new_cnode_inputs);
|
||||
func_graph->ReplaceInOrder(out_node, new_node);
|
||||
AnalysisEnginePtr eng = out_conf->engine();
|
||||
MS_EXCEPTION_IF_NULL(eng);
|
||||
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
|
||||
return eng->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
// If all inputs are constant value, use python prim evaluator.
|
||||
// Ensure input arguments are evaluated.
|
||||
auto ret_abstract = EvalUndeterminedArgs(args_spec_list);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
}
|
||||
auto forbid_reuse = prim_py_->HasAttr(GRAPH_FLAG_FORBID_REUSE_RESULT);
|
||||
if (!forbid_reuse) {
|
||||
// Try to get infer result from evaluator cache.
|
||||
EvalResultPtr eval_result = evaluator_cache_mgr_->GetValue(args_spec_list);
|
||||
if (eval_result != nullptr) {
|
||||
return std::make_shared<EvalResult>(eval_result->abstract()->Clone(), eval_result->attribute());
|
||||
}
|
||||
}
|
||||
const auto &added_attrs = prim_py_->evaluate_added_attrs();
|
||||
MS_LOG(DEBUG) << "Output type is " << py::str(output);
|
||||
auto res_abs = PyInferRes2Abstract(prim_py_, output);
|
||||
MS_LOG(DEBUG) << "Python InferTensor result abstract: " << res_abs->ToString();
|
||||
EvalResultPtr eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>(added_attrs));
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
|
||||
return eval_result;
|
||||
}
|
||||
|
||||
EvalResultPtr MakeTupleEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list,
|
||||
const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
|
||||
std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
|
||||
|
|
|
@ -160,6 +160,19 @@ class MixedPrecisionCastEvaluator final : public Evaluator {
|
|||
PrimitivePtr prim_;
|
||||
};
|
||||
|
||||
class ConstexprEvaluator : public TransitionPrimEvaluator {
|
||||
public:
|
||||
explicit ConstexprEvaluator(const PrimitivePyPtr primitive)
|
||||
: TransitionPrimEvaluator("ConstexprEvaluator"), prim_py_(primitive) {}
|
||||
~ConstexprEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(ConstexprEvaluator, TransitionPrimEvaluator)
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
|
||||
private:
|
||||
PrimitivePyPtr prim_py_;
|
||||
};
|
||||
|
||||
class MakeTupleEvaluator : public TransitionPrimEvaluator {
|
||||
public:
|
||||
MakeTupleEvaluator() : TransitionPrimEvaluator("MakeTupleEvaluator") {}
|
||||
|
|
|
@ -503,6 +503,29 @@ void AnalysisEngine::Clear() {
|
|||
root_context_ = nullptr;
|
||||
}
|
||||
|
||||
EvaluatorPtr GetPyEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
|
||||
auto prim_py = dyn_cast<PrimitivePy>(prim);
|
||||
if (prim_py != nullptr) {
|
||||
auto is_constexpr = prim_py->HasAttr(GRAPH_FLAG_CONSTEXPR_PRIM);
|
||||
if (is_constexpr) {
|
||||
return std::make_shared<ConstexprEvaluator>(prim_py);
|
||||
}
|
||||
if (engine == nullptr) {
|
||||
return std::make_shared<PythonPrimEvaluator>(prim_py);
|
||||
}
|
||||
|
||||
const auto &iter = engine->prim_py_evaluators_.find(prim_py);
|
||||
if (iter != engine->prim_py_evaluators_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
auto evaluator = std::make_shared<PythonPrimEvaluator>(prim_py);
|
||||
engine->prim_py_evaluators_[prim_py] = evaluator;
|
||||
return evaluator;
|
||||
}
|
||||
MS_LOG(ERROR) << "The primitive with python evaluator should be a python primitive.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) {
|
||||
// Custom Primitive with python infer_shape, infer_type
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
@ -533,22 +556,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
|
|||
// Use python infer function if the infer function not founded in the map return a python evaluator
|
||||
EvaluatorPtr evaluator = nullptr;
|
||||
if (prim->HasPyEvaluator()) {
|
||||
auto prim_py = dyn_cast<PrimitivePy>(prim);
|
||||
if (prim_py != nullptr) {
|
||||
if (engine == nullptr) {
|
||||
return std::make_shared<PythonPrimEvaluator>(prim_py);
|
||||
}
|
||||
|
||||
const auto &iter = engine->prim_py_evaluators_.find(prim_py);
|
||||
if (iter != engine->prim_py_evaluators_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
evaluator = std::make_shared<PythonPrimEvaluator>(prim_py);
|
||||
engine->prim_py_evaluators_[prim_py] = evaluator;
|
||||
return evaluator;
|
||||
}
|
||||
MS_LOG(ERROR) << "The primitive with python evaluator should be a python primitive.";
|
||||
return nullptr;
|
||||
return GetPyEvaluator(prim, engine);
|
||||
}
|
||||
|
||||
// Return a default evaluator
|
||||
|
|
|
@ -30,6 +30,7 @@ inline const char GRAPH_FLAG_SIDE_EFFECT_PROPAGATE[] = "side_effect_propagate";
|
|||
inline const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP[] = "side_effect_backprop";
|
||||
inline const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP_MEM[] = "side_effect_backprop_mem";
|
||||
inline const char GRAPH_FLAG_FORBID_REUSE_RESULT[] = "forbid_reuse_result";
|
||||
inline const char GRAPH_FLAG_CONSTEXPR_PRIM[] = "constexpr_prim";
|
||||
inline const char GRAPH_FLAG_IS_WHILE_HEADER[] = "is_while_header";
|
||||
inline const char GRAPH_FLAG_ORDER_ENFORCE_SKIP[] = "order_enforce_skip";
|
||||
inline const char GRAPH_FLAG_BPROP_RETURN_SPARSE[] = "bprop_return_sparse";
|
||||
|
|
|
@ -77,7 +77,7 @@ def _check_dtype(dtype):
|
|||
@constexpr
|
||||
def _is_shape_empty(shp):
|
||||
"""Check whether shape contains zero"""
|
||||
if shp is None:
|
||||
if F.is_sequence_shape_unknown(shp):
|
||||
return False
|
||||
if isinstance(shp, int):
|
||||
return shp == 0
|
||||
|
|
|
@ -432,7 +432,9 @@ def tensor_index_by_list(data, list_index):
|
|||
if all(isinstance(i, bool) for i in list_index):
|
||||
const_utils.raise_unimplemented_error(
|
||||
"Not supported to the dynamic shape tensor slice by using list of Boolean type")
|
||||
tensor_index = const_utils.sequence_to_index(list_index, data_shape[0])
|
||||
tensor_index = const_utils.sequence_to_index(list_index, None)
|
||||
else:
|
||||
tensor_index = const_utils.sequence_to_index(list_index, data_shape[0])
|
||||
if tensor_index is False:
|
||||
const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.")
|
||||
return F.gather(data, tensor_index, 0)
|
||||
|
@ -1109,6 +1111,8 @@ def format_list_indices(list_indices, length):
|
|||
# If eyery element in list is bool, it's treated as 1-D bool tensor.
|
||||
# If every element in list is int(not all bool), it's treated as int tensor.
|
||||
if const_utils.judge_indexes_types(indices_types, mstype.int_type + (mstype.bool_,)):
|
||||
if not F.isconstant(length):
|
||||
return const_utils.sequence_to_index(list_indices, None)
|
||||
return const_utils.sequence_to_index(list_indices, length)
|
||||
# If list contains other types(.../list/tuple/None), it's treated as a tuple
|
||||
return const_utils.deep_tuple(list_indices)
|
||||
|
|
|
@ -25,7 +25,6 @@ from mindspore.common._register_for_tensor import tensor_operator_registry
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
||||
from mindspore import log as logger
|
||||
|
||||
ALL_TENSOR = 0
|
||||
|
@ -166,6 +165,7 @@ def _deep_tensor_to_nparray(array_like):
|
|||
|
||||
|
||||
#TODO: remove comment
|
||||
#@constexpr(run_graph=False)
|
||||
@constexpr
|
||||
def check_range(x, dim_size):
|
||||
if dim_size is None:
|
||||
|
@ -634,12 +634,6 @@ def _judge_order_continuous(order_sequence):
|
|||
@constexpr
|
||||
def scalar_in_sequence(x, y):
|
||||
"""Determine whether the scalar in the sequence."""
|
||||
if x is None:
|
||||
raise ValueError("Judge scalar in tuple or list require scalar and sequence must be constant, "
|
||||
"but the scalar is not.")
|
||||
if y is None:
|
||||
raise ValueError("Judge scalar in tuple or list require scalar and sequence must be constant, "
|
||||
"but the sequence is not.")
|
||||
return x in y
|
||||
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.ops.primitive import constexpr, Primitive
|
||||
from mindspore.ops.operations.array_ops import GatherNd, Coalesce
|
||||
from mindspore.ops.operations import _csr_ops
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import CSRTensor, COOTensor, Tensor
|
||||
from mindspore.ops.composite.multitype_ops._constexpr_utils import raise_value_error, raise_type_error, make_tensor,\
|
||||
promote_binary_dtype
|
||||
|
@ -63,12 +64,11 @@ def _make_tensor_with_dtype(data, dtype):
|
|||
return Tensor(data, dtype=dtype)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _convert_shape(shape):
|
||||
"""Temporary solution to get shape value, will be removed when shape op is supported."""
|
||||
if shape is None:
|
||||
if F.is_sequence_shape_unknown(shape):
|
||||
return (-2,)
|
||||
shape = [-1 if i is None else i for i in shape]
|
||||
shape = [-1 if not F.isconstant(i) else i for i in shape]
|
||||
return tuple(shape)
|
||||
|
||||
|
||||
|
|
|
@ -777,6 +777,7 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr
|
|||
PrimitiveWithInfer.__init__(self, op_name)
|
||||
self.set_const_prim(True)
|
||||
self.fn = fn
|
||||
self.add_prim_attr('constexpr_prim', True)
|
||||
if not reuse_result:
|
||||
self.add_prim_attr('forbid_reuse_result', True)
|
||||
|
||||
|
@ -786,6 +787,7 @@ def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=Tr
|
|||
if (item["dtype"] is not None and item["value"] is None and check):
|
||||
logger.warning("The \"" + self.name + "\" is a constexpr function." \
|
||||
" The input arguments must be all constant value.")
|
||||
return {'dtype': None, 'shape': None, 'value': None, 'fn': (fn,)}
|
||||
value_args.append(item["value"])
|
||||
return {'dtype': None, 'shape': None, 'value': fn(*value_args)}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from __future__ import absolute_import
|
||||
from types import FunctionType
|
||||
from collections.abc import Iterable
|
||||
from mindspore.ops import functional as F
|
||||
from .. import context
|
||||
from ..ops.primitive import constexpr
|
||||
from ..common import Tensor, CSRTensor
|
||||
|
@ -29,13 +30,17 @@ def _callable_const(x):
|
|||
|
||||
|
||||
@constexpr
|
||||
def is_pynative():
|
||||
return context.get_context("mode") == context.PYNATIVE_MODE
|
||||
|
||||
|
||||
def is_within_graph(x):
|
||||
"""
|
||||
Returns true if x is None. It's aim to check whether the call is within MindSpore graph.
|
||||
Because in graph mode, x should be None in constexpr when x is a variable of MindSpore.
|
||||
Note that always return true if the call is in pynative mode.
|
||||
"""
|
||||
return context.get_context("mode") == context.PYNATIVE_MODE or x is None
|
||||
return is_pynative() or not F.isconstant(x) or x is None
|
||||
|
||||
|
||||
@constexpr
|
||||
|
|
|
@ -19,7 +19,6 @@ from mindspore.ops import functional as F
|
|||
from mindspore import Tensor
|
||||
from mindspore import jit
|
||||
from mindspore import context
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
||||
|
||||
def test_generate_mutable_sequence_with_dynamic_length_with_jit():
|
||||
|
@ -137,33 +136,6 @@ def test_dynamic_length_sequence_length_sequence_value_shape_unknown_2():
|
|||
assert not ret2
|
||||
|
||||
|
||||
def test_dynamic_length_sequence_length_sequence_with_constexpr():
|
||||
"""
|
||||
Feature: Mutable with dynamic length.
|
||||
Description: Dynamic length sequence should be convert to None with passing to constexpr.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
@constexpr
|
||||
def test(x):
|
||||
return x
|
||||
|
||||
@jit
|
||||
def foo(x):
|
||||
return test(x)
|
||||
|
||||
x = mutable([Tensor([1]), Tensor([2])], True)
|
||||
ret = foo(x)
|
||||
assert ret is None
|
||||
|
||||
y = mutable([Tensor([1]), Tensor([2])], False)
|
||||
ret = foo(y)
|
||||
assert len(ret) == 2
|
||||
assert ret[0] is None
|
||||
assert ret[1] is None
|
||||
|
||||
|
||||
def test_dynamic_length_sequence_getitem():
|
||||
"""
|
||||
Feature: Mutable with dynamic length.
|
||||
|
|
Loading…
Reference in New Issue