forked from mindspore-Ecosystem/mindspore
!26437 Optimize constant folding for pynative
Merge pull request !26437 from JoyLvliang/optimize_constant_folding_for_pynative
This commit is contained in:
commit
8d05a3870b
|
@ -304,9 +304,29 @@ py::object BuildValue(const ValuePtr &value_ptr) {
|
|||
}
|
||||
}
|
||||
|
||||
py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
|
||||
py::object AbstractTupleValueToPython(const AbstractTuplePtr &tuple_abs) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_abs);
|
||||
auto value = tuple_abs->BuildValue();
|
||||
if (value->isa<AnyValue>()) {
|
||||
return py::none();
|
||||
}
|
||||
const auto &elements = tuple_abs->elements();
|
||||
size_t len = elements.size();
|
||||
py::tuple value_tuple(len);
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
value_tuple[i] = ConvertAbstractToPython(elements[i], true)[ATTR_VALUE];
|
||||
}
|
||||
return std::move(value_tuple);
|
||||
}
|
||||
|
||||
py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base, bool only_convert_value) {
|
||||
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
|
||||
MS_EXCEPTION_IF_NULL(arg_tuple);
|
||||
auto dic = py::dict();
|
||||
if (only_convert_value) {
|
||||
dic[ATTR_VALUE] = AbstractTupleValueToPython(arg_tuple);
|
||||
return dic;
|
||||
}
|
||||
size_t len = arg_tuple->size();
|
||||
py::tuple shape_tuple(len);
|
||||
py::tuple dtype_tuple(len);
|
||||
|
@ -319,8 +339,7 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
|
|||
bool dyn_value = false;
|
||||
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
auto arg = arg_tuple->elements()[i];
|
||||
py::dict out = ConvertAbstractToPython(arg);
|
||||
py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]);
|
||||
shape_tuple[i] = out[ATTR_SHAPE];
|
||||
dtype_tuple[i] = out[ATTR_DTYPE];
|
||||
value_tuple[i] = out[ATTR_VALUE];
|
||||
|
@ -339,7 +358,6 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
|
|||
dyn_shape = true;
|
||||
}
|
||||
}
|
||||
auto dic = py::dict();
|
||||
dic[ATTR_SHAPE] = shape_tuple;
|
||||
dic[ATTR_DTYPE] = dtype_tuple;
|
||||
MS_EXCEPTION_IF_NULL(arg_tuple->BuildValue());
|
||||
|
@ -361,9 +379,29 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
|
|||
return dic;
|
||||
}
|
||||
|
||||
py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
|
||||
py::object AbstractListValueToPython(const AbstractListPtr &list_abs) {
|
||||
MS_EXCEPTION_IF_NULL(list_abs);
|
||||
auto value = list_abs->BuildValue();
|
||||
if (value->isa<AnyValue>()) {
|
||||
return py::none();
|
||||
}
|
||||
const auto &elements = list_abs->elements();
|
||||
size_t len = elements.size();
|
||||
py::list value_list(len);
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
value_list[i] = ConvertAbstractToPython(elements[i], true)[ATTR_VALUE];
|
||||
}
|
||||
return std::move(value_list);
|
||||
}
|
||||
|
||||
py::dict AbstractListToPython(const AbstractBasePtr &abs_base, bool only_convert_value) {
|
||||
auto arg_list = dyn_cast<AbstractList>(abs_base);
|
||||
MS_EXCEPTION_IF_NULL(arg_list);
|
||||
auto dic = py::dict();
|
||||
if (only_convert_value) {
|
||||
dic[ATTR_VALUE] = AbstractListValueToPython(arg_list);
|
||||
return dic;
|
||||
}
|
||||
size_t len = arg_list->size();
|
||||
py::list shape_list(len);
|
||||
py::list dtype_list(len);
|
||||
|
@ -385,7 +423,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
|
|||
dyn_shape = true;
|
||||
}
|
||||
}
|
||||
auto dic = py::dict();
|
||||
|
||||
dic[ATTR_SHAPE] = shape_list;
|
||||
dic[ATTR_DTYPE] = dtype_list;
|
||||
MS_EXCEPTION_IF_NULL(arg_list->BuildValue());
|
||||
|
@ -403,10 +441,14 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
|
|||
return dic;
|
||||
}
|
||||
|
||||
void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
|
||||
void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, bool only_convert_value, py::dict *dic) {
|
||||
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
|
||||
MS_EXCEPTION_IF_NULL(dic);
|
||||
MS_EXCEPTION_IF_NULL(arg_tensor);
|
||||
if (only_convert_value) {
|
||||
(*dic)[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue());
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(arg_tensor->shape());
|
||||
(*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape();
|
||||
const auto &min_shape = arg_tensor->shape()->min_shape();
|
||||
|
@ -477,11 +519,26 @@ TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_lis
|
|||
}
|
||||
} // end anonymous namespace
|
||||
|
||||
py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
||||
py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_convert_value) {
|
||||
MS_EXCEPTION_IF_NULL(abs_base);
|
||||
auto dic = py::dict();
|
||||
if (abs_base->isa<AbstractTensor>()) {
|
||||
ConvertAbstractTensorToPython(abs_base, &dic);
|
||||
ConvertAbstractTensorToPython(abs_base, only_convert_value, &dic);
|
||||
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
|
||||
ShapeVector shape;
|
||||
dic[ATTR_SHAPE] = shape;
|
||||
dic[ATTR_DTYPE] = abs_base->BuildType();
|
||||
dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue());
|
||||
} else if (abs_base->isa<AbstractTuple>()) {
|
||||
return AbstractTupleToPython(abs_base, only_convert_value);
|
||||
} else if (abs_base->isa<AbstractList>()) {
|
||||
return AbstractListToPython(abs_base, only_convert_value);
|
||||
} else if (abs_base->isa<AbstractSlice>()) {
|
||||
auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
|
||||
ShapeVector shape;
|
||||
dic[ATTR_SHAPE] = shape;
|
||||
dic[ATTR_DTYPE] = arg_slice->BuildType();
|
||||
dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
|
||||
} else if (abs_base->isa<AbstractRowTensor>()) {
|
||||
auto arg = dyn_cast<AbstractRowTensor>(abs_base);
|
||||
dic[ATTR_SHAPE] = arg->shape()->shape();
|
||||
|
@ -497,25 +554,10 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
|||
dic[ATTR_SHAPE] = arg->shape()->shape();
|
||||
dic[ATTR_DTYPE] = arg->BuildType();
|
||||
dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
|
||||
} else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
|
||||
ShapeVector shape;
|
||||
dic[ATTR_SHAPE] = shape;
|
||||
dic[ATTR_DTYPE] = abs_base->BuildType();
|
||||
dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue());
|
||||
} else if (abs_base->isa<AbstractSlice>()) {
|
||||
auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
|
||||
ShapeVector shape;
|
||||
dic[ATTR_SHAPE] = shape;
|
||||
dic[ATTR_DTYPE] = arg_slice->BuildType();
|
||||
dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
|
||||
} else if (abs_base->isa<AbstractEllipsis>()) {
|
||||
dic[ATTR_SHAPE] = py::none();
|
||||
dic[ATTR_DTYPE] = py::ellipsis();
|
||||
dic[ATTR_VALUE] = py::ellipsis();
|
||||
} else if (abs_base->isa<AbstractTuple>()) {
|
||||
return AbstractTupleToPython(abs_base);
|
||||
} else if (abs_base->isa<AbstractList>()) {
|
||||
return AbstractListToPython(abs_base);
|
||||
} else if (abs_base->isa<AbstractNone>()) {
|
||||
dic[ATTR_SHAPE] = py::none();
|
||||
dic[ATTR_DTYPE] = py::none();
|
||||
|
|
|
@ -174,7 +174,7 @@ bool IsSubtype(const AbstractBasePtr x, const TypePtr model);
|
|||
|
||||
void ClearPrimEvaluatorMap();
|
||||
|
||||
py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base);
|
||||
py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_convert_value = false);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -1242,19 +1242,18 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
|
|||
const abstract::AbstractBasePtrList &args_spec_list, const CNodePtr &cnode,
|
||||
bool prim_cache_hit, py::object *ret) {
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
auto prim = op_exec_info->py_primitive;
|
||||
const auto &prim = op_exec_info->py_primitive;
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
// Infer output value by constant folding
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
|
||||
if (!output["value"].is_none()) {
|
||||
*ret = output["value"];
|
||||
grad()->RecordGradOpInfo(op_exec_info, PyObjToValue(*ret));
|
||||
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract, true);
|
||||
if (!output[ATTR_VALUE].is_none()) {
|
||||
*ret = output[ATTR_VALUE];
|
||||
grad()->RecordGradOpInfo(op_exec_info);
|
||||
return;
|
||||
}
|
||||
if (prim->is_const_prim()) {
|
||||
} else if (prim->is_const_prim()) {
|
||||
*ret = py::cast("");
|
||||
grad()->RecordGradOpInfo(op_exec_info, PyObjToValue(*ret));
|
||||
grad()->RecordGradOpInfo(op_exec_info);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1296,7 +1295,7 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
|
|||
node_abs_map_.clear();
|
||||
}
|
||||
// Record op info for judge whether the construct of cell has been changed
|
||||
grad()->RecordGradOpInfo(op_exec_info, out_real_value);
|
||||
grad()->RecordGradOpInfo(op_exec_info);
|
||||
grad()->UpdateForwardTensorInfoInBpropGraph(op_exec_info, out_real_value);
|
||||
}
|
||||
|
||||
|
@ -1740,13 +1739,12 @@ void GradExecutor::EnableOpGraphCache(bool is_enable) {
|
|||
inst->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, is_enable);
|
||||
}
|
||||
|
||||
void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out) {
|
||||
void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info) {
|
||||
if (!grad_flag_) {
|
||||
MS_LOG(DEBUG) << "Grad flag is set to false, no need to record op info";
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
MS_EXCEPTION_IF_NULL(op_out);
|
||||
std::string input_args_info;
|
||||
// Record input args info (weight or data)
|
||||
for (const auto mask : op_exec_info->inputs_mask) {
|
||||
|
@ -1761,11 +1759,12 @@ void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const Val
|
|||
const auto &curr_op_num = top_cell()->op_num();
|
||||
op_exec_info->op_info += op_exec_info->op_name + "-" + std::to_string(curr_op_num) + "-" + input_args_info;
|
||||
// The out shape is added to determine those ops that change the shape
|
||||
auto out_abs = op_out->ToAbstract();
|
||||
const auto &out_abs = op_exec_info->abstract;
|
||||
if (out_abs != nullptr) {
|
||||
auto out_shape = out_abs->BuildShape()->ToString();
|
||||
if (out_shape.find("()") == std::string::npos && out_shape.find("NoShape") == std::string::npos) {
|
||||
op_exec_info->op_info += "-" + out_shape;
|
||||
auto shape = out_abs->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
if (!shape->isa<abstract::NoShape>() && !shape->IsDimZero()) {
|
||||
op_exec_info->op_info += "-" + shape->ToString();
|
||||
}
|
||||
}
|
||||
top_cell()->all_op_info() += "-" + op_exec_info->op_info;
|
||||
|
@ -3230,7 +3229,8 @@ void GradExecutor::GradMsFunctionInner(const std::string &phase, const py::objec
|
|||
// Identity op info for current running ms_func graph.
|
||||
OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>();
|
||||
op_exec_info->op_name = phase;
|
||||
RecordGradOpInfo(op_exec_info, actual_out_v);
|
||||
op_exec_info->abstract = actual_out_v->ToAbstract();
|
||||
RecordGradOpInfo(op_exec_info);
|
||||
MS_LOG(DEBUG) << "ms_function cnode op info: " << op_exec_info->op_info;
|
||||
|
||||
// Step 1: Update actual output tensors used in grad graph.
|
||||
|
|
|
@ -201,7 +201,7 @@ class GradExecutor {
|
|||
std::set<std::string> &forward_outputs_id() const { return top_cell()->outputs_id(); }
|
||||
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
|
||||
std::string GetCellId(const py::object &obj, const py::args &args);
|
||||
void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out);
|
||||
void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info);
|
||||
bool need_construct_graph() const { return !cell_stack_.empty() && grad_flag_; }
|
||||
// Construct grad graph for ms_function
|
||||
bool eliminate_forward() const { return eliminate_forward_; }
|
||||
|
|
|
@ -69,6 +69,11 @@ class MS_CORE_API BaseShape : public Base {
|
|||
/// \return True if the object's dimensions are dynamic, otherwise false.
|
||||
virtual bool IsDynamic() const = 0;
|
||||
|
||||
/// \brief Whether the object's dimension is zero.
|
||||
///
|
||||
/// \return True if the object's dimension is zero, otherwise false.
|
||||
virtual bool IsDimZero() const = 0;
|
||||
|
||||
/// \brief Whether the object's dimensions are unknown.
|
||||
///
|
||||
/// \return True if the object's dimensions are unknown, otherwise false.
|
||||
|
@ -97,6 +102,8 @@ class MS_CORE_API NoShape final : public BaseShape {
|
|||
|
||||
bool IsDynamic() const override { return false; }
|
||||
|
||||
bool IsDimZero() const override { return true; };
|
||||
|
||||
bool IsDimUnknown() const override { return false; }
|
||||
};
|
||||
|
||||
|
@ -172,6 +179,8 @@ class MS_CORE_API Shape final : public BaseShape {
|
|||
return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < 0; });
|
||||
}
|
||||
|
||||
bool IsDimZero() const override { return shape_.empty(); };
|
||||
|
||||
bool IsDimUnknown() const override {
|
||||
return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < -1; });
|
||||
}
|
||||
|
@ -236,6 +245,10 @@ class MS_CORE_API SequenceShape : public BaseShape {
|
|||
return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDynamic(); });
|
||||
}
|
||||
|
||||
bool IsDimZero() const override {
|
||||
return std::all_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDimZero(); });
|
||||
};
|
||||
|
||||
bool IsDimUnknown() const override {
|
||||
return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDimUnknown(); });
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue