diff --git a/mindspore/lite/tools/anf_exporter/fetch_content.cc b/mindspore/lite/tools/anf_exporter/fetch_content.cc index ce2d9239a27..b60bcdebec4 100644 --- a/mindspore/lite/tools/anf_exporter/fetch_content.cc +++ b/mindspore/lite/tools/anf_exporter/fetch_content.cc @@ -360,38 +360,6 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy return ret; } -int SetFormatForCnode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, - DataInfo *data_info) { - data_info->format_ = mindspore::NHWC; - MS_CHECK_TRUE_MSG(cnode->input(index) != nullptr, RET_ERROR, "input is nullptr"); - auto input_node_prim = GetValueNode((cnode->input(index)->cast()->input(0))); - MS_CHECK_TRUE_MSG(input_node_prim != nullptr, RET_ERROR, "GetValueNode failed"); - if (input_node_prim->GetAttr(mindspore::ops::kFormat) != nullptr) { - auto value = input_node_prim->GetAttr(mindspore::ops::kFormat); - if (value->isa()) { - data_info->format_ = GetValue(value); - } - } - if (opt::CheckPrimitiveType(cnode->input(index), prim::kPrimTranspose)) { - std::vector perm; - if (opt::GetTransposePerm(cnode->input(index)->cast(), &perm) != RET_OK) { - return RET_ERROR; - } - if (perm.size() < kNumTransposePermSize) { - return RET_OK; - } - // NHWC to NCHW: perm is {0, 3, 1, 2} - // NCHW to NHWC: perm is {0, 2, 3, 1} - if (perm[0] == 0 && perm[1] == 3 && perm[2] == 1 && perm[3] == 2 && - (data_info->format_ == NHWC || data_info->format_ == KHWC)) { - data_info->format_ = NCHW; - } else if (perm[0] == 0 && perm[1] == 2 && perm[2] == 3 && perm[3] == 1 && data_info->format_ == NCHW) { - data_info->format_ = NHWC; - } - } - return RET_OK; -} - int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, DataInfo *data_info) { MS_ASSERT(cnode != nullptr && data_info != nullptr); @@ -414,11 +382,13 @@ int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType f } auto shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); std::vector dims(shape_vector.begin(), shape_vector.end()); - auto ret = SetFormatForCnode(cnode, index, fmk_type, train_flag, data_info); + Format format{mindspore::NHWC}; + auto ret = opt::DetermineCertainVarInputFormat(cnode, index, &format); if (ret != RET_OK) { MS_LOG(ERROR) << "set format for cnode failed"; return RET_ERROR; } + data_info->format_ = format; data_info->data_type_ = type_ptr->type_id(); data_info->shape_ = dims; data_info->node_type_ = NodeType_CNode; diff --git a/mindspore/lite/tools/optimizer/common/format_utils.cc b/mindspore/lite/tools/optimizer/common/format_utils.cc index b4ba512e031..8b50ed1d299 100644 --- a/mindspore/lite/tools/optimizer/common/format_utils.cc +++ b/mindspore/lite/tools/optimizer/common/format_utils.cc @@ -223,5 +223,52 @@ bool IsSpecialType(const CNodePtr &cnode) { CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, kPrimMakeTupleV2) || CheckPrimitiveType(cnode, prim::kPrimReturn); } + +int DetermineCertainVarInputFormat(const CNodePtr &cnode, size_t index, Format *format) { + MS_CHECK_TRUE_MSG(cnode != nullptr && format != nullptr, RET_ERROR, "function's parameter is nullptr."); + auto var_input_info = GetRealCertainVarInput(cnode, index); + if (var_input_info.first == nullptr) { + MS_LOG(ERROR) << "cannot get the real var input."; + return RET_ERROR; + } + *format = mindspore::NHWC; + auto real_input_cnode = var_input_info.first; + auto item_index = var_input_info.second; + auto input_node_prim = GetValueNode((real_input_cnode->input(0))); + MS_CHECK_TRUE_MSG(input_node_prim != nullptr, RET_ERROR, "get primitive failed"); + auto value_ptr = input_node_prim->GetAttr(ops::kFormat); + if (value_ptr != nullptr) { + MS_CHECK_TRUE_MSG(value_ptr->isa(), RET_ERROR, "format attr must be an int64_t val."); + auto value = GetValue(value_ptr); + MS_CHECK_TRUE_MSG(value >= NCHW && value <= NCW, RET_ERROR, "format val is out of enum's range."); + *format = static_cast(value); + } + value_ptr = input_node_prim->GetAttr(kOutputsFormat); + if (value_ptr != nullptr) { + MS_CHECK_TRUE_MSG(value_ptr->isa(), RET_ERROR, "outputs_format attr should be sequence."); + auto formats = CastToInt(value_ptr); + if (item_index >= 0 && static_cast(item_index) < formats.size()) { + MS_CHECK_TRUE_MSG(formats[item_index] >= NCHW && formats[item_index] <= NCW, RET_ERROR, + "format val is out of enum's range."); + *format = static_cast(formats[item_index]); + } + } + if (CheckPrimitiveType(real_input_cnode, prim::kPrimTranspose)) { + std::vector perm; + if (GetTransposePerm(real_input_cnode, &perm) != RET_OK) { + MS_LOG(ERROR) << "fetch transpose's perm failed."; + return RET_ERROR; + } + if (perm.size() != kNC2NH.size()) { + return RET_OK; + } + if (perm == kNH2NC && (*format == NHWC || *format == KHWC)) { + *format = NCHW; + } else if (perm == opt::kNC2NH && *format == NCHW) { + *format = NHWC; + } + } + return RET_OK; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/common/format_utils.h b/mindspore/lite/tools/optimizer/common/format_utils.h index c0b826199ba..e16ec85439a 100644 --- a/mindspore/lite/tools/optimizer/common/format_utils.h +++ b/mindspore/lite/tools/optimizer/common/format_utils.h @@ -25,7 +25,7 @@ namespace mindspore { namespace opt { -constexpr auto kInferDone = "infer_done"; +constexpr auto kOutputsFormat = "outputs_format"; enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE }; struct TransTypePair { FormatTransNodeType pre_; @@ -42,6 +42,7 @@ STATUS GetTransposePerm(const CNodePtr &cnode, std::vector *perm); void RemoveIfMonad(const CNodePtr &cnode); bool IsMonadNode(const AnfNodePtr &node); bool IsSpecialType(const CNodePtr &cnode); +int DetermineCertainVarInputFormat(const CNodePtr &cnode, size_t index, Format *format); } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 79842fed04b..273585031fd 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -39,7 +39,7 @@ namespace opt { namespace { constexpr auto kAnfPrimitiveIndex = 0; constexpr auto kDeviceTypeNone = -1; -int DeduceDimConvertion(schema::Format src_format, schema::Format dst_format, std::vector *perm) { +int DeduceDimConvertion(schema::Format src_format, schema::Format dst_format, std::vector *const perm) { MS_ASSERT(perm != nullptr); auto src_format_str = std::string(schema::EnumNameFormat(src_format)); auto dst_format_str = std::string(schema::EnumNameFormat(dst_format)); @@ -66,10 +66,16 @@ int DeduceDimConvertion(schema::Format src_format, schema::Format dst_format, st template void TransposeData(const ShapeVector &origin_shape, const ShapeVector &cur_shape, const std::vector &perm, - T *weight_data, std::vector *buf) { + T *const weight_data, std::vector *buf) { MS_ASSERT(weight_data != nullptr && buf != nullptr); MS_ASSERT(origin_shape.size() == cur_shape.size() && cur_shape.size() == perm.size()); - int count = std::accumulate(origin_shape.begin(), origin_shape.end(), 1, std::multiplies()); + int count = 1; + for (const auto &dat : origin_shape) { + if (INT_MUL_OVERFLOW(count, static_cast(dat))) { + return; + } + count *= static_cast(dat); + } ShapeVector post_multiply(cur_shape.size()); std::unordered_map dim_map; for (int i = cur_shape.size() - 1; i >= 0; --i) { @@ -88,10 +94,17 @@ void TransposeData(const ShapeVector &origin_shape, const ShapeVector &cur_shape position_map[j] = temp % origin_shape[j]; temp /= origin_shape[j]; } - int64_t new_pos = std::accumulate(position_map.begin(), position_map.end(), 0, - [&post_multiply, &dim_map](int64_t res, const std::pair &pair_y) { - return res + post_multiply[dim_map[pair_y.first]] * pair_y.second; - }); + int64_t new_pos = 0; + for (const auto &pair_y : position_map) { + if (INT_MUL_OVERFLOW(post_multiply[dim_map[pair_y.first]], pair_y.second)) { + return; + } + if (INT_ADD_OVERFLOW(new_pos, post_multiply[dim_map[pair_y.first]] * pair_y.second)) { + return; + } + new_pos += post_multiply[dim_map[pair_y.first]] * pair_y.second; + } + buf->at(new_pos) = weight_data[i]; } } @@ -121,7 +134,14 @@ STATUS DoTransposeData(const tensor::TensorPtr &tensor, schema::Format src_forma } new_shape.push_back(origin_shape[val]); } - auto count = std::accumulate(origin_shape.begin(), origin_shape.end(), 1LL, std::multiplies()); + int64_t count = 1; + for (const auto &dat : origin_shape) { + if (INT_MUL_OVERFLOW(count, dat)) { + MS_LOG(ERROR) << "Int mul overflow"; + return RET_ERROR; + } + count *= dat; + } if (count <= 0 || count > static_cast(INT32_MAX)) { MS_LOG(ERROR) << "tensor element num is too big, which should be smaller than int32_max."; return RET_ERROR; @@ -174,6 +194,38 @@ bool IsRealKernel(const AnfNodePtr &node) { #endif return !is_virtual_node; } + +int CopyTensorDataFromTensorInfo(const tensor::TensorPtr &tensor_info, + const std::shared_ptr &tensor_info_dst, size_t data_count) { + if (tensor_info->data_type() == kNumberTypeInt64) { + auto *tensor_data = reinterpret_cast(tensor_info_dst->data_c()); + if (tensor_data == nullptr) { + MS_LOG(ERROR) << "new data failed"; + return RET_ERROR; + } + auto *origin_data = reinterpret_cast(tensor_info->data_c()); + for (size_t i = 0; i < data_count; ++i) { + if (origin_data[i] > static_cast(INT32_MAX) || origin_data[i] < static_cast(INT32_MIN)) { + MS_LOG(WARNING) << "int64 data " << origin_data[i] << "too big to fit into int32"; + tensor_data[i] = origin_data[i] > 0 ? INT32_MAX : INT32_MIN; + } else { + tensor_data[i] = static_cast(origin_data[i]); + } + } + } else { + tensor_info_dst->set_data_type(tensor_info->data_type()); + auto *tensor_data = reinterpret_cast(tensor_info_dst->data_c()); + if (tensor_data == nullptr) { + MS_LOG(ERROR) << "new data failed"; + return RET_ERROR; + } + if (memcpy_s(tensor_data, tensor_info_dst->Size(), tensor_info->data_c(), tensor_info->Size()) != lite::RET_OK) { + MS_LOG(ERROR) << "memcpy data failed."; + return RET_ERROR; + } + } + return RET_OK; +} } // namespace bool CheckInputs(const CNodePtr &cnode) { @@ -351,7 +403,7 @@ bool IsGraphKernel(const AnfNodePtr &node) { return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); } -ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id) { +ParameterPtr AddNewBiasNode(const float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id) { if (bias_data == nullptr || func_graph == nullptr) { MS_LOG(ERROR) << "input parameter is nullptr."; return nullptr; @@ -461,7 +513,7 @@ bool IsParamNode(const BaseRef &n) { return tensor->data_c() != nullptr; } -STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *tensor_info, const CNodePtr &cnode, size_t index) { +STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *const tensor_info, const CNodePtr &cnode, size_t index) { CHECK_NULL_RETURN(tensor_info); CHECK_NULL_RETURN(cnode); AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, index); @@ -673,7 +725,18 @@ ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr MS_LOG(ERROR) << "new tensor::Tensor failed."; return nullptr; } - size_t data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + int data_count = 1; + for (const auto &dat : shape) { + if (INT_MUL_OVERFLOW(data_count, static_cast(dat))) { + MS_LOG(ERROR) << "Int mul overflow."; + return nullptr; + } + data_count *= static_cast(dat); + } + if (data_count < 0) { + MS_LOG(ERROR) << "Invalid shape."; + return nullptr; + } if (tensor_info->Size() == 0) { auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info_new); if (status != RET_OK) { @@ -682,33 +745,12 @@ ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr } return param_node; } - if (tensor_info->data_type() == kNumberTypeInt64) { - auto *tensor_data = reinterpret_cast(tensor_info_new->data_c()); - if (tensor_data == nullptr) { - MS_LOG(ERROR) << "new data failed"; - return nullptr; - } - auto *origin_data = reinterpret_cast(tensor_info->data_c()); - for (size_t i = 0; i < data_count; ++i) { - if (origin_data[i] > static_cast(INT32_MAX) || origin_data[i] < static_cast(INT32_MIN)) { - MS_LOG(WARNING) << "int64 data " << origin_data[i] << "too big to fit into int32"; - tensor_data[i] = origin_data[i] > 0 ? INT32_MAX : INT32_MIN; - } else { - tensor_data[i] = static_cast(origin_data[i]); - } - } - } else { - tensor_info_new->set_data_type(tensor_info->data_type()); - auto *tensor_data = reinterpret_cast(tensor_info_new->data_c()); - if (tensor_data == nullptr) { - MS_LOG(ERROR) << "new data failed"; - return nullptr; - } - if (memcpy_s(tensor_data, tensor_info_new->Size(), tensor_info->data_c(), tensor_info->Size()) != lite::RET_OK) { - MS_LOG(ERROR) << "memcpy data failed."; - return nullptr; - } + + if (CopyTensorDataFromTensorInfo(tensor_info, tensor_info_new, static_cast(data_count)) != RET_OK) { + MS_LOG(ERROR) << "copy tensor data failed"; + return nullptr; } + auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info_new); if (status != RET_OK) { MS_LOG(ERROR) << "init parameter from tensor info failed"; @@ -1012,6 +1054,7 @@ bool IsQuantParameterNode(const PrimitiveCPtr &prim) { } return false; } + void UpdateManager(const FuncGraphPtr &func_graph) { auto manager = func_graph->manager(); if (manager == nullptr) { @@ -1026,5 +1069,62 @@ void UpdateManager(const FuncGraphPtr &func_graph) { manager->AddFuncGraph(one_func_graph); } } + +std::pair GetRealCertainVarInput(const CNodePtr &cnode, size_t index) { + MS_CHECK_TRUE_MSG(cnode != nullptr, {}, "function's parameter is nullptr."); + MS_CHECK_TRUE_MSG(cnode->input(index) != nullptr, {}, "required input is nullptr"); + auto real_input_cnode = cnode->input(index)->cast(); + MS_CHECK_TRUE_MSG(real_input_cnode != nullptr, {}, "input node is not a cnode."); + int item_index = 0; + if (opt::CheckPrimitiveType(real_input_cnode, prim::kPrimTupleGetItem)) { + auto index_node = real_input_cnode->input(opt::kInputIndexTwo); + MS_CHECK_TRUE_MSG(index_node != nullptr, {}, "tuple_get_item's second input is nullptr."); + MS_CHECK_TRUE_MSG(index_node->isa(), {}, "tuple_get_item's second input should be valuenode."); + auto index_ptr = index_node->cast()->value(); + MS_CHECK_TRUE_MSG(index_ptr != nullptr, {}, "tuple_get_item's second input val is nullptr."); + auto value = CastToInt(index_ptr); + MS_CHECK_TRUE_MSG(value.size() == 1, {}, "tuple_get_item's second input is invalid."); + item_index = value.front(); + MS_CHECK_TRUE_MSG(real_input_cnode->input(1) != nullptr, {}, "tuple_get_item's first input is nullptr"); + real_input_cnode = real_input_cnode->input(1)->cast(); + MS_CHECK_TRUE_MSG(real_input_cnode != nullptr, {}, "tuple_get_item first input is not cnode."); + } + return {real_input_cnode, item_index}; +} + +int DetermineCertainVarInputHasInferred(const CNodePtr &cnode, size_t index, bool *infer_succ) { + MS_CHECK_TRUE_MSG(cnode != nullptr && infer_succ != nullptr, RET_ERROR, "function's parameter is nullptr."); + auto var_input_info = GetRealCertainVarInput(cnode, index); + if (var_input_info.first == nullptr) { + MS_LOG(ERROR) << "cannot get the real var input."; + return RET_ERROR; + } + auto real_input_cnode = var_input_info.first; + auto item_index = var_input_info.second; + auto input_node_prim = GetValueNode((real_input_cnode->input(0))); + MS_CHECK_TRUE_MSG(input_node_prim != nullptr, RET_ERROR, "get primitive failed."); + *infer_succ = false; + auto value_ptr = input_node_prim->GetAttr(kInferDone); + if (value_ptr != nullptr) { + MS_CHECK_TRUE_MSG(value_ptr->isa(), RET_ERROR, "value is not a boolean."); + *infer_succ = GetValue(value_ptr); + } + value_ptr = input_node_prim->GetAttr(kInferFlags); + if (value_ptr == nullptr) { + return RET_OK; + } + MS_CHECK_TRUE_MSG(value_ptr->isa(), RET_ERROR, "infer flag should be a vector."); + auto value_sequence = value_ptr->cast(); + auto elements = value_sequence->value(); + MS_CHECK_TRUE_MSG(!elements.empty(), RET_ERROR, "infer_info has no content."); + auto first_element = elements.front(); + MS_CHECK_TRUE_MSG(first_element != nullptr, RET_ERROR, "element is a nullptr."); + MS_CHECK_TRUE_MSG(first_element->isa(), RET_ERROR, "each element is not a boolean."); + auto infer_infos = GetValue>(value_ptr); + MS_CHECK_TRUE_MSG(item_index >= 0 && static_cast(item_index) < infer_infos.size(), RET_ERROR, + "item index is out of range."); + *infer_succ = infer_infos[item_index]; + return RET_OK; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 003c001d836..3ea19b19ca5 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -19,6 +19,7 @@ #include #include +#include #include #include "ops/primitive_c.h" #include "ir/anf.h" @@ -35,6 +36,10 @@ using mindspore::lite::RET_OK; using mindspore::lite::STATUS; namespace mindspore { namespace opt { +// used for common op, which corresponding value is a boolean. +constexpr auto kInferDone = "infer_done"; +// used for control_flow op(while and if), which corresponding value is a boolean vec. +constexpr auto kInferFlags = "infer_flags"; inline constexpr int kInputIndexTwo = 2; inline constexpr int kInputIndexThree = 3; inline constexpr int kInputIndexFour = 4; @@ -64,7 +69,7 @@ bool IsGraphKernel(const AnfNodePtr &node); bool CheckInputs(const CNodePtr &cnode); -ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id); +ParameterPtr AddNewBiasNode(const float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id); bool IsParamNode(const BaseRef &n); @@ -130,6 +135,10 @@ bool IsQuantParameterNode(const PrimitiveCPtr &prim); void UpdateManager(const FuncGraphPtr &func_graph); +std::pair GetRealCertainVarInput(const CNodePtr &cnode, size_t index); + +int DetermineCertainVarInputHasInferred(const CNodePtr &cnode, size_t index, bool *infer_succ); + template inline bool IsSpecifiedNode(const BaseRef &n) { if (utils::isa(n)) { diff --git a/mindspore/lite/tools/optimizer/format/delete_redundant_transpose.cc b/mindspore/lite/tools/optimizer/format/delete_redundant_transpose.cc index 985d92112cf..e01d5c8d9a4 100644 --- a/mindspore/lite/tools/optimizer/format/delete_redundant_transpose.cc +++ b/mindspore/lite/tools/optimizer/format/delete_redundant_transpose.cc @@ -148,7 +148,27 @@ STATUS DeleteRedundantTranspose::UpdateNodeFormat(const CNodePtr &cnode) { if (prim->GetAttr(ops::kFormat) == nullptr) { return lite::RET_OK; } - auto format = GetValue(prim->GetAttr(ops::kFormat)); + auto forward_format = GetValue(prim->GetAttr(ops::kFormat)); + const int max_search_depth{3}; + int loop{0}; + auto search_node = cnode->input(1); + while (loop < max_search_depth) { + MS_CHECK_TRUE_RET(search_node != nullptr, lite::RET_ERROR); + auto search_cnode = search_node->cast(); + if (search_cnode == nullptr) { + break; + } + auto primitive = GetCNodePrimitive(search_cnode); + if (primitive == nullptr) { + break; + } + if (primitive->GetAttr(ops::kFormat) != nullptr) { + forward_format = GetValue(primitive->GetAttr(ops::kFormat)); + break; + } + search_node = search_cnode->input(1); + ++loop; + } auto node_users = manager_->node_users()[cnode]; for (auto &node_user : node_users) { if (node_user.second != 1) { @@ -161,7 +181,7 @@ STATUS DeleteRedundantTranspose::UpdateNodeFormat(const CNodePtr &cnode) { auto post_cnode = node_user.first->cast(); auto post_prim = GetValueNode(post_cnode->input(0)); MS_ASSERT(post_prim != nullptr); - post_prim->AddAttr(ops::kFormat, MakeValue(format)); + post_prim->AddAttr(ops::kFormat, MakeValue(forward_format)); } return lite::RET_OK; } diff --git a/mindspore/lite/tools/optimizer/graph/decrease_transpose_algo.cc b/mindspore/lite/tools/optimizer/graph/decrease_transpose_algo.cc index 4a6dd35fd6f..173ed6156d8 100644 --- a/mindspore/lite/tools/optimizer/graph/decrease_transpose_algo.cc +++ b/mindspore/lite/tools/optimizer/graph/decrease_transpose_algo.cc @@ -506,36 +506,34 @@ int DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGra MS_LOG(ERROR) << "Get index failed: " << e.what(); return lite::RET_ERROR; } - param_node->set_abstract(GetCNodeInputAbstract(cnode, index)->Clone()); + auto abstract = GetCNodeInputAbstract(cnode, index); + MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is a nullptr."); + param_node->set_abstract(abstract->Clone()); if (utils::isa(cnode->input(index))) { ShapeVector shape_vec = {-1}; - auto out_cnode = cnode->input(index)->cast(); - MS_ASSERT(out_cnode != nullptr); - MS_ASSERT(trans_cnode != nullptr); - auto out_prim = GetValueNode(out_cnode->input(0)); - MS_CHECK_TRUE_MSG(out_prim != nullptr, lite::RET_ERROR, "GetValueNode failed"); - if (out_prim->GetAttr(kInferDone) == nullptr || !GetValue(out_prim->GetAttr(kInferDone))) { - param_node->abstract()->set_shape(std::make_shared(shape_vec)); + bool has_inferred{false}; + auto ret = DetermineCertainVarInputHasInferred(cnode, index, &has_inferred); + MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "determine infer flag failed."); + if (!has_inferred) { + auto abstract_shape = std::make_shared(shape_vec); + MS_CHECK_TRUE_MSG(abstract_shape != nullptr, RET_ERROR, "create shape failed."); + param_node->abstract()->set_shape(abstract_shape); } - } else { + } + if (utils::isa(cnode->input(index))) { + param_node->set_default_param(cnode->input(index)->cast()->default_param()); + } + if (utils::isa(cnode->input(index))) { lite::DataInfo data_info; - if (utils::isa(cnode->input(index))) { - if (cnode->input(index)->cast()->has_default()) { - param_node->set_default_param(cnode->input(index)->cast()->default_param()); - } - continue; - } auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info, true); if (status != lite::RET_OK) { continue; } ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end()); - if (data_info.data_.empty()) { - param_node->set_default_param(std::make_shared((TypeId)data_info.data_type_, shape_vec)); - } else { - param_node->set_default_param(std::make_shared((TypeId)data_info.data_type_, shape_vec, - data_info.data_.data(), data_info.data_.size())); - } + auto tensor_info = lite::CreateTensorInfo(data_info.data_.data(), data_info.data_.size(), shape_vec, + static_cast(data_info.data_type_)); + MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_ERROR, "create a tensor failed."); + param_node->set_default_param(tensor_info); } } return lite::RET_OK; diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc index a3ef10a293f..ed84fa98eb4 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc @@ -38,37 +38,11 @@ int GetCNodeCertainInputFormat(const CNodePtr cnode, int index, mindspore::Forma cnode->set_inputs(origin_inputs); return lite::RET_NO_CHANGE; } - auto real_cnode = cnode->input(index)->cast(); - MS_ASSERT(real_cnode != nullptr); - if (CheckPrimitiveType(real_cnode, prim::kPrimTupleGetItem)) { - real_cnode = real_cnode->input(1)->cast(); + if (DetermineCertainVarInputFormat(cnode, index, format) != RET_OK) { + MS_LOG(ERROR) << "determine certain var-input's format failed."; + return RET_ERROR; } cnode->set_inputs(origin_inputs); - MS_ASSERT(real_cnode != nullptr); - auto primitive = GetValueNode(real_cnode->input(0)); - MS_CHECK_TRUE_MSG(primitive != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed"); - if (primitive->GetAttr(ops::kFormat) == nullptr) { - MS_LOG(DEBUG) << "cnode has no format attr. " << real_cnode->fullname_with_scope(); - return lite::RET_NO_CHANGE; - } - auto format_attr = primitive->GetAttr(ops::kFormat); - MS_CHECK_TRUE_MSG(format_attr != nullptr, lite::RET_NULL_PTR, "GetAttr Failed"); - *format = static_cast(GetValue(format_attr)); - if (CheckPrimitiveType(real_cnode, prim::kPrimTranspose)) { - std::vector perm; - if (GetTransposePerm(real_cnode, &perm) != lite::RET_OK) { - MS_LOG(ERROR) << "get transpose perm failed."; - return lite::RET_ERROR; - } - if (perm.size() != DIMENSION_4D) { - return RET_OK; - } - if (perm == kNH2NC && *format == mindspore::NHWC) { - *format = mindspore::NCHW; - } else if (perm == kNC2NH && *format == mindspore::NCHW) { - *format = mindspore::NHWC; - } - } return lite::RET_OK; } @@ -93,6 +67,94 @@ int ModifySubGraphInputCNodeFormat(const FuncGraphPtr &sub_graph, const Paramete } return lite::RET_OK; } + +int JudgeControlFlowCertainOutputHasInferred(const CNodePtr &return_cnode, size_t index, bool *infer_info) { + MS_ASSERT(return_cnode != nullptr && infer_info != nullptr); + MS_CHECK_TRUE_MSG(index < return_cnode->size(), RET_ERROR, "input index is out of range."); + *infer_info = true; + auto abstract_base = GetCNodeInputAbstract(return_cnode, index); + MS_CHECK_TRUE_MSG(abstract_base != nullptr, RET_ERROR, "anfnode has no abstract."); + ShapeVector shape; + auto ret = FetchShapeFromAbstract(abstract_base, &shape); + MS_CHECK_TRUE_MSG(ret == lite::RET_OK, RET_ERROR, "fetch shape from abstract failed."); + if (std::find(shape.begin(), shape.end(), -1) != shape.end()) { + *infer_info = false; + return RET_OK; + } + if (utils::isa(return_cnode->input(index))) { + ret = DetermineCertainVarInputHasInferred(return_cnode, index, infer_info); + MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "determine infer flag failed."); + } + return RET_OK; +} + +int ModifyWhileBodyGraphInputs(const CNodePtr &cnode, const FuncGraphPtr &sub_graph, const ParameterPtr &graph_input, + size_t input_index) { + MS_ASSERT(cnode != nullptr && sub_graph != nullptr && graph_input != nullptr); + if (!CheckPrimitiveType(cnode, prim::kPrimWhile)) { + return RET_OK; + } + auto body_graph = GetValueNode(cnode->input(kInputIndexTwo)); + MS_ASSERT(body_graph != nullptr); + if (body_graph.get() != sub_graph.get()) { + MS_LOG(DEBUG) << "sub_graph is not body graph."; + return RET_OK; + } + auto return_cnode = sub_graph->get_return(); + MS_CHECK_TRUE_MSG(return_cnode != nullptr, RET_ERROR, "return node is a nullptr."); + auto origin_outputs = return_cnode->inputs(); + auto ret = lite::RemoveIfDepend(return_cnode); + if (ret != RET_OK) { + return_cnode->set_inputs(origin_outputs); + MS_LOG(ERROR) << "remove depend node failed."; + return RET_ERROR; + } + ret = lite::RemoveIfMakeTuple(return_cnode); + if (ret != RET_OK) { + return_cnode->set_inputs(origin_outputs); + MS_LOG(ERROR) << "remove make_tuple node failed."; + return RET_ERROR; + } + RemoveIfMonad(return_cnode); + if (return_cnode->size() == 0 || input_index >= return_cnode->size() - 1) { + return_cnode->set_inputs(origin_outputs); + MS_LOG(ERROR) << "input index is out of range."; + return RET_ERROR; + } + auto output = return_cnode->input(input_index + 1); + return_cnode->set_inputs(origin_outputs); + MS_CHECK_TRUE_MSG(output != nullptr, RET_ERROR, "output node is a nullptr."); + if (output->isa()) { + graph_input->set_default_param(nullptr); + } + return RET_OK; +} + +int MergeTwoBranchOfIfOp(const CNodePtr &cnode, const CNodePtr &return_cnode, size_t index, bool *true_branch) { + MS_ASSERT(cnode != nullptr && return_cnode != nullptr && true_branch != nullptr); + *true_branch = true; + if (!CheckPrimitiveType(cnode, prim::kPrimIf)) { + return RET_OK; + } + bool infer_info{false}; + // judge true branch. + if (JudgeControlFlowCertainOutputHasInferred(return_cnode, index, &infer_info) != RET_OK) { + MS_LOG(ERROR) << "determine certain output has inferred failed."; + return RET_ERROR; + } + if (infer_info) { + return RET_OK; + } + // judge false branch. + if (JudgeControlFlowCertainOutputHasInferred(cnode, index + kInputSizeThree, &infer_info) != RET_OK) { + MS_LOG(ERROR) << "determine certain output has inferred failed."; + return RET_ERROR; + } + if (infer_info) { + *true_branch = false; + } + return RET_OK; +} } // namespace bool InferShapePass::Run(const FuncGraphPtr &func_graph) { @@ -190,10 +252,10 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) { auto sub_func_graph = GetValueNode(cnode->input(1)); if (sub_func_graph == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); - return lite::RET_ERROR; + return RET_ERROR; } auto ret = SetSubGraphInput(cnode, sub_func_graph); - if (ret != lite::RET_OK) { + if (ret != RET_OK) { MS_LOG(ERROR) << "SetSubGraphInput failed: " << ret; return RET_ERROR; } @@ -255,21 +317,22 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt auto last_underline = node_name.find_last_of("_"); node_name = node_name.substr(0, last_underline); last_underline = node_name.find_last_of("_"); - auto index = 0; + size_t index = 0; try { - index = std::stoi(node_name.substr(last_underline + 1)) + 3; + index = static_cast(std::stoi(node_name.substr(last_underline + 1))) + kInputSizeThree; } catch (const std::exception &e) { MS_LOG(ERROR) << "Get index failed: " << e.what(); return RET_ERROR; } - param_node->set_abstract(opt::GetCNodeInputAbstract(cnode, index)->Clone()); - if (utils::isa(cnode->input(index))) { + auto abstract = GetCNodeInputAbstract(cnode, index); + MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is a nullptr."); + param_node->set_abstract(abstract->Clone()); + if (utils::isa(cnode->input(index))) { ShapeVector shape_vec = {-1}; - auto out_cnode = cnode->input(index)->cast(); - MS_ASSERT(trans_cnode != nullptr); - auto out_prim = GetValueNode(out_cnode->input(0)); - MS_CHECK_TRUE_MSG(out_prim != nullptr, lite::RET_ERROR, "GetValueNode Failed"); - if (out_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue(out_prim->GetAttr(opt::kInferDone))) { + bool has_inferred{false}; + auto ret = DetermineCertainVarInputHasInferred(cnode, index, &has_inferred); + MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "determine infer flag failed."); + if (!has_inferred) { auto abstract_shape = std::make_shared(shape_vec); CHECK_NULL_RETURN(abstract_shape); param_node->abstract()->set_shape(abstract_shape); @@ -282,6 +345,7 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt if (ModifySubGraphInputCNodeFormat(sub_graph, param_node, format) != lite::RET_OK) { MS_LOG(DEBUG) << "modify subgraph input cnode format failed." << cnode->func_graph_as_var(); } + continue; } if (utils::isa(cnode->input(index))) { param_node->set_default_param(cnode->input(index)->cast()->default_param()); @@ -298,6 +362,11 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_ERROR, "created tensor is a nullptr."); param_node->set_default_param(tensor_info); } + // while's body graph:if the corresponding output is a variable, the corresponding input's data will be set to NULL. + if (ModifyWhileBodyGraphInputs(cnode, sub_graph, param_node, index - kInputSizeThree) != RET_OK) { + MS_LOG(ERROR) << "modify while body graph's certain input failed."; + return RET_ERROR; + } } return RET_OK; } @@ -342,47 +411,39 @@ STATUS InferShapePass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGrap lite::RemoveIfDepend(return_node); lite::RemoveIfMakeTuple(return_node); AbstractBasePtrList abstract_list; - bool infer_done = true; + std::vector infer_infos; for (size_t i = 1; i < return_node->size(); ++i) { - auto abstract_base = opt::GetCNodeInputAbstract(return_node, i); - MS_CHECK_TRUE_RET(abstract_base != nullptr, lite::RET_ERROR); - abstract_list.emplace_back(abstract_base->Clone()); - auto abstract_tensor = abstract_base->cast(); - MS_ASSERT(abstract_tensor != nullptr); - auto shape_ptr = utils::cast(abstract_tensor->BuildShape()); - MS_ASSERT(shape_ptr != nullptr); - auto shape = shape_ptr->shape(); - if (std::find(shape.begin(), shape.end(), -1) != shape.end()) { - infer_done = false; - } - if (utils::isa(return_node->input(i))) { - auto input_cnode = return_node->input(i)->cast(); - MS_ASSERT(input_cnode != nullptr); - if (opt::CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) { - input_cnode = input_cnode->input(1)->cast(); - } - auto input_prim = GetValueNode(input_cnode->input(0)); - CHECK_NULL_RETURN(input_prim); - if (input_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue(input_prim->GetAttr(opt::kInferDone))) { - infer_done = false; + bool true_branch{false}; + auto ret = MergeTwoBranchOfIfOp(cnode, return_node, i, &true_branch); + MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "decide to fetch which branch failed."); + AbstractBasePtr abstract; + bool infer_info; + if (true_branch) { + abstract = GetCNodeInputAbstract(return_node, i); + if (JudgeControlFlowCertainOutputHasInferred(return_node, i, &infer_info) != lite::RET_OK) { + MS_LOG(ERROR) << "determine certain output has inferred failed."; + return lite::RET_ERROR; } + } else { + abstract = GetCNodeInputAbstract(cnode, i + kInputSizeThree); + infer_info = true; } + MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "get a nullptr abstract."); + abstract_list.emplace_back(abstract->Clone()); + infer_infos.push_back(infer_info); } return_node->set_inputs(origin_inputs); if (utils::isa(cnode->abstract())) { auto abstract_tuple = std::make_shared(abstract_list); - CHECK_NULL_RETURN(abstract_tuple); + MS_CHECK_TRUE_MSG(abstract_tuple != nullptr, RET_ERROR, "created AbstractTuple is a nullptr."); cnode->set_abstract(abstract_tuple); } else { - if (abstract_list.size() != 1) { - MS_LOG(ERROR) << "cnode output is invalid."; - } + MS_CHECK_TRUE_MSG(abstract_list.size() == 1, RET_ERROR, "cnode output is invalid."); cnode->set_abstract(abstract_list.front()); } auto prim = GetValueNode(cnode->input(0)); - CHECK_NULL_RETURN(prim); - infer_done = CheckPrimitiveType(cnode, prim::kPrimWhile) ? false : infer_done; - prim->AddAttr(opt::kInferDone, MakeValue(infer_done)); + MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "cnode's input0 is not a primitive."); + prim->AddAttr(kInferFlags, MakeValue(infer_infos)); return RET_OK; } diff --git a/mindspore/lite/tools/optimizer/graph/lite_tensor_extractor.cc b/mindspore/lite/tools/optimizer/graph/lite_tensor_extractor.cc index 7f1694be11e..0d205a7bf10 100644 --- a/mindspore/lite/tools/optimizer/graph/lite_tensor_extractor.cc +++ b/mindspore/lite/tools/optimizer/graph/lite_tensor_extractor.cc @@ -197,18 +197,10 @@ int GetCNodeVarInput(const CNodePtr &cnode, std::vector *var_ms_input return lite::RET_ERROR; } tensor->set_format((Format)(data_info.format_)); - MS_ASSERT(cnode->input(i) != nullptr); - auto input_cnode = cnode->input(i)->cast(); - MS_ASSERT(input_cnode != nullptr); - auto input_prim = GetValueNode(input_cnode->input(0)); - if (CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) { - MS_CHECK_TRUE_RET(input_cnode->input(1) != nullptr, lite::RET_NULL_PTR); - auto item_input_cnode = input_cnode->input(1)->cast(); - MS_CHECK_TRUE_RET(item_input_cnode != nullptr, lite::RET_NULL_PTR); - input_prim = GetValueNode(item_input_cnode->input(0)); - } - MS_CHECK_TRUE_RET(input_prim != nullptr, lite::RET_NULL_PTR); - if (input_prim->GetAttr(kInferDone) == nullptr || !GetValue(input_prim->GetAttr(kInferDone))) { + bool has_inferred{false}; + auto ret = DetermineCertainVarInputHasInferred(cnode, i, &has_inferred); + MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "determine infer flag failed."); + if (!has_inferred) { tensor->set_shape({-1}); } var_ms_inputs->emplace_back(tensor); diff --git a/mindspore/lite/tools/optimizer/graph/node_infershape.cc b/mindspore/lite/tools/optimizer/graph/node_infershape.cc index a8a42f77213..df48e85de06 100644 --- a/mindspore/lite/tools/optimizer/graph/node_infershape.cc +++ b/mindspore/lite/tools/optimizer/graph/node_infershape.cc @@ -142,9 +142,7 @@ STATUS NodeInferShape::InferShape(const CNodePtr &cnode) { } if (ret == lite::RET_OK || ret == lite::RET_INFER_INVALID) { auto set_status = SetCNodeAbstract(cnode, outputs, ret); - auto cnode_prim = GetValueNode(cnode->input(0)); - MS_CHECK_TRUE_MSG(cnode_prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed"); - cnode_prim->AddAttr(ops::kFormat, MakeValue(inputs[0]->format())); + anf_prim->AddAttr(ops::kFormat, MakeValue(inputs[0]->format())); if (set_status != lite::RET_OK) { MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope(); return set_status; @@ -152,6 +150,12 @@ STATUS NodeInferShape::InferShape(const CNodePtr &cnode) { } else { MS_LOG(WARNING) << "infer shape failed."; } + if (CheckPrimitiveType(cnode, prim::kPrimCustom)) { + std::vector outputs_format; + std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_format), + [](const lite::Tensor *output) { return output->format(); }); + anf_prim->AddAttr(kOutputsFormat, MakeValue(outputs_format)); + } return ret; }