From cf40e4f55f78735df170e82959c9cee3d9683c05 Mon Sep 17 00:00:00 2001 From: "7347157+joylvliang@user.noreply.gitee.com" Date: Wed, 17 Nov 2021 15:49:37 +0800 Subject: [PATCH] optimize_constant_folding_for_pynative --- .../pipeline/jit/static_analysis/prim.cc | 90 ++++++++++++++----- .../ccsrc/pipeline/jit/static_analysis/prim.h | 2 +- .../pipeline/pynative/pynative_execute.cc | 32 +++---- .../pipeline/pynative/pynative_execute.h | 2 +- mindspore/core/abstract/dshape.h | 13 +++ 5 files changed, 97 insertions(+), 42 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index cc9dd0a2a90..4d60a4c5bcc 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -304,9 +304,29 @@ py::object BuildValue(const ValuePtr &value_ptr) { } } -py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) { +py::object AbstractTupleValueToPython(const AbstractTuplePtr &tuple_abs) { + MS_EXCEPTION_IF_NULL(tuple_abs); + auto value = tuple_abs->BuildValue(); + if (value->isa()) { + return py::none(); + } + const auto &elements = tuple_abs->elements(); + size_t len = elements.size(); + py::tuple value_tuple(len); + for (size_t i = 0; i < len; ++i) { + value_tuple[i] = ConvertAbstractToPython(elements[i], true)[ATTR_VALUE]; + } + return std::move(value_tuple); +} + +py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base, bool only_convert_value) { auto arg_tuple = dyn_cast(abs_base); MS_EXCEPTION_IF_NULL(arg_tuple); + auto dic = py::dict(); + if (only_convert_value) { + dic[ATTR_VALUE] = AbstractTupleValueToPython(arg_tuple); + return dic; + } size_t len = arg_tuple->size(); py::tuple shape_tuple(len); py::tuple dtype_tuple(len); @@ -319,8 +339,7 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) { bool dyn_value = false; for (size_t i = 0; i < len; i++) { - auto arg = arg_tuple->elements()[i]; - py::dict out = ConvertAbstractToPython(arg); + py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]); shape_tuple[i] = out[ATTR_SHAPE]; dtype_tuple[i] = out[ATTR_DTYPE]; value_tuple[i] = out[ATTR_VALUE]; @@ -339,7 +358,6 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) { dyn_shape = true; } } - auto dic = py::dict(); dic[ATTR_SHAPE] = shape_tuple; dic[ATTR_DTYPE] = dtype_tuple; MS_EXCEPTION_IF_NULL(arg_tuple->BuildValue()); @@ -361,9 +379,29 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) { return dic; } -py::dict AbstractListToPython(const AbstractBasePtr &abs_base) { +py::object AbstractListValueToPython(const AbstractListPtr &list_abs) { + MS_EXCEPTION_IF_NULL(list_abs); + auto value = list_abs->BuildValue(); + if (value->isa()) { + return py::none(); + } + const auto &elements = list_abs->elements(); + size_t len = elements.size(); + py::list value_list(len); + for (size_t i = 0; i < len; ++i) { + value_list[i] = ConvertAbstractToPython(elements[i], true)[ATTR_VALUE]; + } + return std::move(value_list); +} + +py::dict AbstractListToPython(const AbstractBasePtr &abs_base, bool only_convert_value) { auto arg_list = dyn_cast(abs_base); MS_EXCEPTION_IF_NULL(arg_list); + auto dic = py::dict(); + if (only_convert_value) { + dic[ATTR_VALUE] = AbstractListValueToPython(arg_list); + return dic; + } size_t len = arg_list->size(); py::list shape_list(len); py::list dtype_list(len); @@ -385,7 +423,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) { dyn_shape = true; } } - auto dic = py::dict(); + dic[ATTR_SHAPE] = shape_list; dic[ATTR_DTYPE] = dtype_list; MS_EXCEPTION_IF_NULL(arg_list->BuildValue()); @@ -403,10 +441,14 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) { return dic; } -void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, py::dict *dic) { +void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, bool only_convert_value, py::dict *dic) { auto arg_tensor = dyn_cast(abs_base); MS_EXCEPTION_IF_NULL(dic); MS_EXCEPTION_IF_NULL(arg_tensor); + if (only_convert_value) { + (*dic)[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue()); + return; + } MS_EXCEPTION_IF_NULL(arg_tensor->shape()); (*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape(); const auto &min_shape = arg_tensor->shape()->min_shape(); @@ -477,11 +519,26 @@ TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_lis } } // end anonymous namespace -py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { +py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_convert_value) { MS_EXCEPTION_IF_NULL(abs_base); auto dic = py::dict(); if (abs_base->isa()) { - ConvertAbstractTensorToPython(abs_base, &dic); + ConvertAbstractTensorToPython(abs_base, only_convert_value, &dic); + } else if (abs_base->isa() || abs_base->isa() || abs_base->isa()) { + ShapeVector shape; + dic[ATTR_SHAPE] = shape; + dic[ATTR_DTYPE] = abs_base->BuildType(); + dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue()); + } else if (abs_base->isa()) { + return AbstractTupleToPython(abs_base, only_convert_value); + } else if (abs_base->isa()) { + return AbstractListToPython(abs_base, only_convert_value); + } else if (abs_base->isa()) { + auto arg_slice = dyn_cast(abs_base); + ShapeVector shape; + dic[ATTR_SHAPE] = shape; + dic[ATTR_DTYPE] = arg_slice->BuildType(); + dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue()); } else if (abs_base->isa()) { auto arg = dyn_cast(abs_base); dic[ATTR_SHAPE] = arg->shape()->shape(); @@ -497,25 +554,10 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { dic[ATTR_SHAPE] = arg->shape()->shape(); dic[ATTR_DTYPE] = arg->BuildType(); dic[ATTR_VALUE] = BuildValue(arg->BuildValue()); - } else if (abs_base->isa() || abs_base->isa() || abs_base->isa()) { - ShapeVector shape; - dic[ATTR_SHAPE] = shape; - dic[ATTR_DTYPE] = abs_base->BuildType(); - dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue()); - } else if (abs_base->isa()) { - auto arg_slice = dyn_cast(abs_base); - ShapeVector shape; - dic[ATTR_SHAPE] = shape; - dic[ATTR_DTYPE] = arg_slice->BuildType(); - dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue()); } else if (abs_base->isa()) { dic[ATTR_SHAPE] = py::none(); dic[ATTR_DTYPE] = py::ellipsis(); dic[ATTR_VALUE] = py::ellipsis(); - } else if (abs_base->isa()) { - return AbstractTupleToPython(abs_base); - } else if (abs_base->isa()) { - return AbstractListToPython(abs_base); } else if (abs_base->isa()) { dic[ATTR_SHAPE] = py::none(); dic[ATTR_DTYPE] = py::none(); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h index 2ac1977e55a..fbff7a3f480 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h @@ -174,7 +174,7 @@ bool IsSubtype(const AbstractBasePtr x, const TypePtr model); void ClearPrimEvaluatorMap(); -py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base); +py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_convert_value = false); } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 2214207cc93..d454d2b0788 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1225,19 +1225,18 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list, const CNodePtr &cnode, bool prim_cache_hit, py::object *ret) { MS_EXCEPTION_IF_NULL(op_exec_info); - auto prim = op_exec_info->py_primitive; + const auto &prim = op_exec_info->py_primitive; MS_EXCEPTION_IF_NULL(prim); // Infer output value by constant folding MS_EXCEPTION_IF_NULL(ret); - py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); - if (!output["value"].is_none()) { - *ret = output["value"]; - grad()->RecordGradOpInfo(op_exec_info, PyObjToValue(*ret)); + py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract, true); + if (!output[ATTR_VALUE].is_none()) { + *ret = output[ATTR_VALUE]; + grad()->RecordGradOpInfo(op_exec_info); return; - } - if (prim->is_const_prim()) { + } else if (prim->is_const_prim()) { *ret = py::cast(""); - grad()->RecordGradOpInfo(op_exec_info, PyObjToValue(*ret)); + grad()->RecordGradOpInfo(op_exec_info); return; } @@ -1279,7 +1278,7 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info, node_abs_map_.clear(); } // Record op info for judge whether the construct of cell has been changed - grad()->RecordGradOpInfo(op_exec_info, out_real_value); + grad()->RecordGradOpInfo(op_exec_info); grad()->UpdateForwardTensorInfoInBpropGraph(op_exec_info, out_real_value); } @@ -1723,13 +1722,12 @@ void GradExecutor::EnableOpGraphCache(bool is_enable) { inst->set_param(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, is_enable); } -void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out) { +void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info) { if (!grad_flag_) { MS_LOG(DEBUG) << "Grad flag is set to false, no need to record op info"; return; } MS_EXCEPTION_IF_NULL(op_exec_info); - MS_EXCEPTION_IF_NULL(op_out); std::string input_args_info; // Record input args info (weight or data) for (const auto mask : op_exec_info->inputs_mask) { @@ -1744,11 +1742,12 @@ void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const Val const auto &curr_op_num = top_cell()->op_num(); op_exec_info->op_info += op_exec_info->op_name + "-" + std::to_string(curr_op_num) + "-" + input_args_info; // The out shape is added to determine those ops that change the shape - auto out_abs = op_out->ToAbstract(); + const auto &out_abs = op_exec_info->abstract; if (out_abs != nullptr) { - auto out_shape = out_abs->BuildShape()->ToString(); - if (out_shape.find("()") == std::string::npos && out_shape.find("NoShape") == std::string::npos) { - op_exec_info->op_info += "-" + out_shape; + auto shape = out_abs->BuildShape(); + MS_EXCEPTION_IF_NULL(shape); + if (!shape->isa() && !shape->IsDimZero()) { + op_exec_info->op_info += "-" + shape->ToString(); } } top_cell()->all_op_info() += "-" + op_exec_info->op_info; @@ -3213,7 +3212,8 @@ void GradExecutor::GradMsFunctionInner(const std::string &phase, const py::objec // Identity op info for current running ms_func graph. OpExecInfoPtr op_exec_info = std::make_shared(); op_exec_info->op_name = phase; - RecordGradOpInfo(op_exec_info, actual_out_v); + op_exec_info->abstract = actual_out_v->ToAbstract(); + RecordGradOpInfo(op_exec_info); MS_LOG(DEBUG) << "ms_function cnode op info: " << op_exec_info->op_info; // Step 1: Update actual output tensors used in grad graph. diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 5a43cf5fd60..659904689d0 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -201,7 +201,7 @@ class GradExecutor { std::set &forward_outputs_id() const { return top_cell()->outputs_id(); } AnfNodePtr GetInput(const py::object &obj, bool op_mask); std::string GetCellId(const py::object &obj, const py::args &args); - void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out); + void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info); bool need_construct_graph() const { return !cell_stack_.empty() && grad_flag_; } // Construct grad graph for ms_function bool eliminate_forward() const { return eliminate_forward_; } diff --git a/mindspore/core/abstract/dshape.h b/mindspore/core/abstract/dshape.h index 497d060aaa4..498864e4ec4 100644 --- a/mindspore/core/abstract/dshape.h +++ b/mindspore/core/abstract/dshape.h @@ -69,6 +69,11 @@ class MS_CORE_API BaseShape : public Base { /// \return True if the object's dimensions are dynamic, otherwise false. virtual bool IsDynamic() const = 0; + /// \brief Whether the object's dimension is zero. + /// + /// \return True if the object's dimension is zero, otherwise false. + virtual bool IsDimZero() const = 0; + /// \brief Whether the object's dimensions are unknown. /// /// \return True if the object's dimensions are unknown, otherwise false. @@ -97,6 +102,8 @@ class MS_CORE_API NoShape final : public BaseShape { bool IsDynamic() const override { return false; } + bool IsDimZero() const override { return true; }; + bool IsDimUnknown() const override { return false; } }; @@ -172,6 +179,8 @@ class MS_CORE_API Shape final : public BaseShape { return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < 0; }); } + bool IsDimZero() const override { return shape_.empty(); }; + bool IsDimUnknown() const override { return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < -1; }); } @@ -236,6 +245,10 @@ class MS_CORE_API SequenceShape : public BaseShape { return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDynamic(); }); } + bool IsDimZero() const override { + return std::all_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDimZero(); }); + }; + bool IsDimUnknown() const override { return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDimUnknown(); }); }