!26437 Optimize constant folding for pynative

Merge pull request !26437 from JoyLvliang/optimize_constant_folding_for_pynative
This commit is contained in:
i-robot 2021-12-15 11:36:14 +00:00 committed by Gitee
commit 8d05a3870b
5 changed files with 97 additions and 42 deletions

View File

@ -304,9 +304,29 @@ py::object BuildValue(const ValuePtr &value_ptr) {
} }
} }
py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) { py::object AbstractTupleValueToPython(const AbstractTuplePtr &tuple_abs) {
MS_EXCEPTION_IF_NULL(tuple_abs);
auto value = tuple_abs->BuildValue();
if (value->isa<AnyValue>()) {
return py::none();
}
const auto &elements = tuple_abs->elements();
size_t len = elements.size();
py::tuple value_tuple(len);
for (size_t i = 0; i < len; ++i) {
value_tuple[i] = ConvertAbstractToPython(elements[i], true)[ATTR_VALUE];
}
return std::move(value_tuple);
}
py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base, bool only_convert_value) {
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base); auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
MS_EXCEPTION_IF_NULL(arg_tuple); MS_EXCEPTION_IF_NULL(arg_tuple);
auto dic = py::dict();
if (only_convert_value) {
dic[ATTR_VALUE] = AbstractTupleValueToPython(arg_tuple);
return dic;
}
size_t len = arg_tuple->size(); size_t len = arg_tuple->size();
py::tuple shape_tuple(len); py::tuple shape_tuple(len);
py::tuple dtype_tuple(len); py::tuple dtype_tuple(len);
@ -319,8 +339,7 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
bool dyn_value = false; bool dyn_value = false;
for (size_t i = 0; i < len; i++) { for (size_t i = 0; i < len; i++) {
auto arg = arg_tuple->elements()[i]; py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]);
py::dict out = ConvertAbstractToPython(arg);
shape_tuple[i] = out[ATTR_SHAPE]; shape_tuple[i] = out[ATTR_SHAPE];
dtype_tuple[i] = out[ATTR_DTYPE]; dtype_tuple[i] = out[ATTR_DTYPE];
value_tuple[i] = out[ATTR_VALUE]; value_tuple[i] = out[ATTR_VALUE];
@ -339,7 +358,6 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
dyn_shape = true; dyn_shape = true;
} }
} }
auto dic = py::dict();
dic[ATTR_SHAPE] = shape_tuple; dic[ATTR_SHAPE] = shape_tuple;
dic[ATTR_DTYPE] = dtype_tuple; dic[ATTR_DTYPE] = dtype_tuple;
MS_EXCEPTION_IF_NULL(arg_tuple->BuildValue()); MS_EXCEPTION_IF_NULL(arg_tuple->BuildValue());
@ -361,9 +379,29 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
return dic; return dic;
} }
py::dict AbstractListToPython(const AbstractBasePtr &abs_base) { py::object AbstractListValueToPython(const AbstractListPtr &list_abs) {
MS_EXCEPTION_IF_NULL(list_abs);
auto value = list_abs->BuildValue();
if (value->isa<AnyValue>()) {
return py::none();
}
const auto &elements = list_abs->elements();
size_t len = elements.size();
py::list value_list(len);
for (size_t i = 0; i < len; ++i) {
value_list[i] = ConvertAbstractToPython(elements[i], true)[ATTR_VALUE];
}
return std::move(value_list);
}
py::dict AbstractListToPython(const AbstractBasePtr &abs_base, bool only_convert_value) {
auto arg_list = dyn_cast<AbstractList>(abs_base); auto arg_list = dyn_cast<AbstractList>(abs_base);
MS_EXCEPTION_IF_NULL(arg_list); MS_EXCEPTION_IF_NULL(arg_list);
auto dic = py::dict();
if (only_convert_value) {
dic[ATTR_VALUE] = AbstractListValueToPython(arg_list);
return dic;
}
size_t len = arg_list->size(); size_t len = arg_list->size();
py::list shape_list(len); py::list shape_list(len);
py::list dtype_list(len); py::list dtype_list(len);
@ -385,7 +423,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
dyn_shape = true; dyn_shape = true;
} }
} }
auto dic = py::dict();
dic[ATTR_SHAPE] = shape_list; dic[ATTR_SHAPE] = shape_list;
dic[ATTR_DTYPE] = dtype_list; dic[ATTR_DTYPE] = dtype_list;
MS_EXCEPTION_IF_NULL(arg_list->BuildValue()); MS_EXCEPTION_IF_NULL(arg_list->BuildValue());
@ -403,10 +441,14 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
return dic; return dic;
} }
void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, py::dict *dic) { void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, bool only_convert_value, py::dict *dic) {
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base); auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
MS_EXCEPTION_IF_NULL(dic); MS_EXCEPTION_IF_NULL(dic);
MS_EXCEPTION_IF_NULL(arg_tensor); MS_EXCEPTION_IF_NULL(arg_tensor);
if (only_convert_value) {
(*dic)[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue());
return;
}
MS_EXCEPTION_IF_NULL(arg_tensor->shape()); MS_EXCEPTION_IF_NULL(arg_tensor->shape());
(*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape(); (*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape();
const auto &min_shape = arg_tensor->shape()->min_shape(); const auto &min_shape = arg_tensor->shape()->min_shape();
@ -477,11 +519,26 @@ TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_lis
} }
} // end anonymous namespace } // end anonymous namespace
py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_convert_value) {
MS_EXCEPTION_IF_NULL(abs_base); MS_EXCEPTION_IF_NULL(abs_base);
auto dic = py::dict(); auto dic = py::dict();
if (abs_base->isa<AbstractTensor>()) { if (abs_base->isa<AbstractTensor>()) {
ConvertAbstractTensorToPython(abs_base, &dic); ConvertAbstractTensorToPython(abs_base, only_convert_value, &dic);
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
ShapeVector shape;
dic[ATTR_SHAPE] = shape;
dic[ATTR_DTYPE] = abs_base->BuildType();
dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue());
} else if (abs_base->isa<AbstractTuple>()) {
return AbstractTupleToPython(abs_base, only_convert_value);
} else if (abs_base->isa<AbstractList>()) {
return AbstractListToPython(abs_base, only_convert_value);
} else if (abs_base->isa<AbstractSlice>()) {
auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
ShapeVector shape;
dic[ATTR_SHAPE] = shape;
dic[ATTR_DTYPE] = arg_slice->BuildType();
dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
} else if (abs_base->isa<AbstractRowTensor>()) { } else if (abs_base->isa<AbstractRowTensor>()) {
auto arg = dyn_cast<AbstractRowTensor>(abs_base); auto arg = dyn_cast<AbstractRowTensor>(abs_base);
dic[ATTR_SHAPE] = arg->shape()->shape(); dic[ATTR_SHAPE] = arg->shape()->shape();
@ -497,25 +554,10 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic[ATTR_SHAPE] = arg->shape()->shape(); dic[ATTR_SHAPE] = arg->shape()->shape();
dic[ATTR_DTYPE] = arg->BuildType(); dic[ATTR_DTYPE] = arg->BuildType();
dic[ATTR_VALUE] = BuildValue(arg->BuildValue()); dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
ShapeVector shape;
dic[ATTR_SHAPE] = shape;
dic[ATTR_DTYPE] = abs_base->BuildType();
dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue());
} else if (abs_base->isa<AbstractSlice>()) {
auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
ShapeVector shape;
dic[ATTR_SHAPE] = shape;
dic[ATTR_DTYPE] = arg_slice->BuildType();
dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
} else if (abs_base->isa<AbstractEllipsis>()) { } else if (abs_base->isa<AbstractEllipsis>()) {
dic[ATTR_SHAPE] = py::none(); dic[ATTR_SHAPE] = py::none();
dic[ATTR_DTYPE] = py::ellipsis(); dic[ATTR_DTYPE] = py::ellipsis();
dic[ATTR_VALUE] = py::ellipsis(); dic[ATTR_VALUE] = py::ellipsis();
} else if (abs_base->isa<AbstractTuple>()) {
return AbstractTupleToPython(abs_base);
} else if (abs_base->isa<AbstractList>()) {
return AbstractListToPython(abs_base);
} else if (abs_base->isa<AbstractNone>()) { } else if (abs_base->isa<AbstractNone>()) {
dic[ATTR_SHAPE] = py::none(); dic[ATTR_SHAPE] = py::none();
dic[ATTR_DTYPE] = py::none(); dic[ATTR_DTYPE] = py::none();

View File

@ -174,7 +174,7 @@ bool IsSubtype(const AbstractBasePtr x, const TypePtr model);
void ClearPrimEvaluatorMap(); void ClearPrimEvaluatorMap();
py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base); py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_convert_value = false);
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -1242,19 +1242,18 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
const abstract::AbstractBasePtrList &args_spec_list, const CNodePtr &cnode, const abstract::AbstractBasePtrList &args_spec_list, const CNodePtr &cnode,
bool prim_cache_hit, py::object *ret) { bool prim_cache_hit, py::object *ret) {
MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(op_exec_info);
auto prim = op_exec_info->py_primitive; const auto &prim = op_exec_info->py_primitive;
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
// Infer output value by constant folding // Infer output value by constant folding
MS_EXCEPTION_IF_NULL(ret); MS_EXCEPTION_IF_NULL(ret);
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract, true);
if (!output["value"].is_none()) { if (!output[ATTR_VALUE].is_none()) {
*ret = output["value"]; *ret = output[ATTR_VALUE];
grad()->RecordGradOpInfo(op_exec_info, PyObjToValue(*ret)); grad()->RecordGradOpInfo(op_exec_info);
return; return;
} } else if (prim->is_const_prim()) {
if (prim->is_const_prim()) {
*ret = py::cast(""); *ret = py::cast("");
grad()->RecordGradOpInfo(op_exec_info, PyObjToValue(*ret)); grad()->RecordGradOpInfo(op_exec_info);
return; return;
} }
@ -1296,7 +1295,7 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
node_abs_map_.clear(); node_abs_map_.clear();
} }
// Record op info for judge whether the construct of cell has been changed // Record op info for judge whether the construct of cell has been changed
grad()->RecordGradOpInfo(op_exec_info, out_real_value); grad()->RecordGradOpInfo(op_exec_info);
grad()->UpdateForwardTensorInfoInBpropGraph(op_exec_info, out_real_value); grad()->UpdateForwardTensorInfoInBpropGraph(op_exec_info, out_real_value);
} }
@ -1740,13 +1739,12 @@ void GradExecutor::EnableOpGraphCache(bool is_enable) {
inst->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, is_enable); inst->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, is_enable);
} }
void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out) { void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info) {
if (!grad_flag_) { if (!grad_flag_) {
MS_LOG(DEBUG) << "Grad flag is set to false, no need to record op info"; MS_LOG(DEBUG) << "Grad flag is set to false, no need to record op info";
return; return;
} }
MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(op_exec_info);
MS_EXCEPTION_IF_NULL(op_out);
std::string input_args_info; std::string input_args_info;
// Record input args info (weight or data) // Record input args info (weight or data)
for (const auto mask : op_exec_info->inputs_mask) { for (const auto mask : op_exec_info->inputs_mask) {
@ -1761,11 +1759,12 @@ void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const Val
const auto &curr_op_num = top_cell()->op_num(); const auto &curr_op_num = top_cell()->op_num();
op_exec_info->op_info += op_exec_info->op_name + "-" + std::to_string(curr_op_num) + "-" + input_args_info; op_exec_info->op_info += op_exec_info->op_name + "-" + std::to_string(curr_op_num) + "-" + input_args_info;
// The out shape is added to determine those ops that change the shape // The out shape is added to determine those ops that change the shape
auto out_abs = op_out->ToAbstract(); const auto &out_abs = op_exec_info->abstract;
if (out_abs != nullptr) { if (out_abs != nullptr) {
auto out_shape = out_abs->BuildShape()->ToString(); auto shape = out_abs->BuildShape();
if (out_shape.find("()") == std::string::npos && out_shape.find("NoShape") == std::string::npos) { MS_EXCEPTION_IF_NULL(shape);
op_exec_info->op_info += "-" + out_shape; if (!shape->isa<abstract::NoShape>() && !shape->IsDimZero()) {
op_exec_info->op_info += "-" + shape->ToString();
} }
} }
top_cell()->all_op_info() += "-" + op_exec_info->op_info; top_cell()->all_op_info() += "-" + op_exec_info->op_info;
@ -3230,7 +3229,8 @@ void GradExecutor::GradMsFunctionInner(const std::string &phase, const py::objec
// Identity op info for current running ms_func graph. // Identity op info for current running ms_func graph.
OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>(); OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>();
op_exec_info->op_name = phase; op_exec_info->op_name = phase;
RecordGradOpInfo(op_exec_info, actual_out_v); op_exec_info->abstract = actual_out_v->ToAbstract();
RecordGradOpInfo(op_exec_info);
MS_LOG(DEBUG) << "ms_function cnode op info: " << op_exec_info->op_info; MS_LOG(DEBUG) << "ms_function cnode op info: " << op_exec_info->op_info;
// Step 1: Update actual output tensors used in grad graph. // Step 1: Update actual output tensors used in grad graph.

View File

@ -201,7 +201,7 @@ class GradExecutor {
std::set<std::string> &forward_outputs_id() const { return top_cell()->outputs_id(); } std::set<std::string> &forward_outputs_id() const { return top_cell()->outputs_id(); }
AnfNodePtr GetInput(const py::object &obj, bool op_mask); AnfNodePtr GetInput(const py::object &obj, bool op_mask);
std::string GetCellId(const py::object &obj, const py::args &args); std::string GetCellId(const py::object &obj, const py::args &args);
void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out); void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info);
bool need_construct_graph() const { return !cell_stack_.empty() && grad_flag_; } bool need_construct_graph() const { return !cell_stack_.empty() && grad_flag_; }
// Construct grad graph for ms_function // Construct grad graph for ms_function
bool eliminate_forward() const { return eliminate_forward_; } bool eliminate_forward() const { return eliminate_forward_; }

View File

@ -69,6 +69,11 @@ class MS_CORE_API BaseShape : public Base {
/// \return True if the object's dimensions are dynamic, otherwise false. /// \return True if the object's dimensions are dynamic, otherwise false.
virtual bool IsDynamic() const = 0; virtual bool IsDynamic() const = 0;
/// \brief Whether the object's dimension is zero.
///
/// \return True if the object's dimension is zero, otherwise false.
virtual bool IsDimZero() const = 0;
/// \brief Whether the object's dimensions are unknown. /// \brief Whether the object's dimensions are unknown.
/// ///
/// \return True if the object's dimensions are unknown, otherwise false. /// \return True if the object's dimensions are unknown, otherwise false.
@ -97,6 +102,8 @@ class MS_CORE_API NoShape final : public BaseShape {
bool IsDynamic() const override { return false; } bool IsDynamic() const override { return false; }
bool IsDimZero() const override { return true; };
bool IsDimUnknown() const override { return false; } bool IsDimUnknown() const override { return false; }
}; };
@ -172,6 +179,8 @@ class MS_CORE_API Shape final : public BaseShape {
return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < 0; }); return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < 0; });
} }
bool IsDimZero() const override { return shape_.empty(); };
bool IsDimUnknown() const override { bool IsDimUnknown() const override {
return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < -1; }); return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < -1; });
} }
@ -236,6 +245,10 @@ class MS_CORE_API SequenceShape : public BaseShape {
return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDynamic(); }); return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDynamic(); });
} }
bool IsDimZero() const override {
return std::all_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDimZero(); });
};
bool IsDimUnknown() const override { bool IsDimUnknown() const override {
return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDimUnknown(); }); return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDimUnknown(); });
} }