optimize the infershape of control_flow model and the acquisition of format

This commit is contained in:
xuanyue 2021-12-18 15:39:14 +08:00
parent 7a1c25c08d
commit ecd207422c
10 changed files with 381 additions and 179 deletions

View File

@ -360,38 +360,6 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy
return ret;
}
int SetFormatForCnode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info) {
data_info->format_ = mindspore::NHWC;
MS_CHECK_TRUE_MSG(cnode->input(index) != nullptr, RET_ERROR, "input is nullptr");
auto input_node_prim = GetValueNode<PrimitivePtr>((cnode->input(index)->cast<CNodePtr>()->input(0)));
MS_CHECK_TRUE_MSG(input_node_prim != nullptr, RET_ERROR, "GetValueNode failed");
if (input_node_prim->GetAttr(mindspore::ops::kFormat) != nullptr) {
auto value = input_node_prim->GetAttr(mindspore::ops::kFormat);
if (value->isa<mindspore::Int64Imm>()) {
data_info->format_ = GetValue<int64_t>(value);
}
}
if (opt::CheckPrimitiveType(cnode->input(index), prim::kPrimTranspose)) {
std::vector<int> perm;
if (opt::GetTransposePerm(cnode->input(index)->cast<CNodePtr>(), &perm) != RET_OK) {
return RET_ERROR;
}
if (perm.size() < kNumTransposePermSize) {
return RET_OK;
}
// NHWC to NCHW: perm is {0, 3, 1, 2}
// NCHW to NHWC: perm is {0, 2, 3, 1}
if (perm[0] == 0 && perm[1] == 3 && perm[2] == 1 && perm[3] == 2 &&
(data_info->format_ == NHWC || data_info->format_ == KHWC)) {
data_info->format_ = NCHW;
} else if (perm[0] == 0 && perm[1] == 2 && perm[2] == 3 && perm[3] == 1 && data_info->format_ == NCHW) {
data_info->format_ = NHWC;
}
}
return RET_OK;
}
int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info) {
MS_ASSERT(cnode != nullptr && data_info != nullptr);
@ -414,11 +382,13 @@ int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType f
}
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
auto ret = SetFormatForCnode(cnode, index, fmk_type, train_flag, data_info);
Format format{mindspore::NHWC};
auto ret = opt::DetermineCertainVarInputFormat(cnode, index, &format);
if (ret != RET_OK) {
MS_LOG(ERROR) << "set format for cnode failed";
return RET_ERROR;
}
data_info->format_ = format;
data_info->data_type_ = type_ptr->type_id();
data_info->shape_ = dims;
data_info->node_type_ = NodeType_CNode;

View File

@ -223,5 +223,52 @@ bool IsSpecialType(const CNodePtr &cnode) {
CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, kPrimMakeTupleV2) ||
CheckPrimitiveType(cnode, prim::kPrimReturn);
}
int DetermineCertainVarInputFormat(const CNodePtr &cnode, size_t index, Format *format) {
MS_CHECK_TRUE_MSG(cnode != nullptr && format != nullptr, RET_ERROR, "function's parameter is nullptr.");
auto var_input_info = GetRealCertainVarInput(cnode, index);
if (var_input_info.first == nullptr) {
MS_LOG(ERROR) << "cannot get the real var input.";
return RET_ERROR;
}
*format = mindspore::NHWC;
auto real_input_cnode = var_input_info.first;
auto item_index = var_input_info.second;
auto input_node_prim = GetValueNode<PrimitivePtr>((real_input_cnode->input(0)));
MS_CHECK_TRUE_MSG(input_node_prim != nullptr, RET_ERROR, "get primitive failed");
auto value_ptr = input_node_prim->GetAttr(ops::kFormat);
if (value_ptr != nullptr) {
MS_CHECK_TRUE_MSG(value_ptr->isa<mindspore::Int64Imm>(), RET_ERROR, "format attr must be an int64_t val.");
auto value = GetValue<int64_t>(value_ptr);
MS_CHECK_TRUE_MSG(value >= NCHW && value <= NCW, RET_ERROR, "format val is out of enum's range.");
*format = static_cast<Format>(value);
}
value_ptr = input_node_prim->GetAttr(kOutputsFormat);
if (value_ptr != nullptr) {
MS_CHECK_TRUE_MSG(value_ptr->isa<ValueSequeue>(), RET_ERROR, "outputs_format attr should be sequence.");
auto formats = CastToInt(value_ptr);
if (item_index >= 0 && static_cast<size_t>(item_index) < formats.size()) {
MS_CHECK_TRUE_MSG(formats[item_index] >= NCHW && formats[item_index] <= NCW, RET_ERROR,
"format val is out of enum's range.");
*format = static_cast<Format>(formats[item_index]);
}
}
if (CheckPrimitiveType(real_input_cnode, prim::kPrimTranspose)) {
std::vector<int> perm;
if (GetTransposePerm(real_input_cnode, &perm) != RET_OK) {
MS_LOG(ERROR) << "fetch transpose's perm failed.";
return RET_ERROR;
}
if (perm.size() != kNC2NH.size()) {
return RET_OK;
}
if (perm == kNH2NC && (*format == NHWC || *format == KHWC)) {
*format = NCHW;
} else if (perm == opt::kNC2NH && *format == NCHW) {
*format = NHWC;
}
}
return RET_OK;
}
} // namespace opt
} // namespace mindspore

View File

@ -25,7 +25,7 @@
namespace mindspore {
namespace opt {
constexpr auto kInferDone = "infer_done";
constexpr auto kOutputsFormat = "outputs_format";
enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE };
struct TransTypePair {
FormatTransNodeType pre_;
@ -42,6 +42,7 @@ STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm);
void RemoveIfMonad(const CNodePtr &cnode);
bool IsMonadNode(const AnfNodePtr &node);
bool IsSpecialType(const CNodePtr &cnode);
int DetermineCertainVarInputFormat(const CNodePtr &cnode, size_t index, Format *format);
} // namespace opt
} // namespace mindspore

View File

@ -39,7 +39,7 @@ namespace opt {
namespace {
constexpr auto kAnfPrimitiveIndex = 0;
constexpr auto kDeviceTypeNone = -1;
int DeduceDimConvertion(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm) {
int DeduceDimConvertion(schema::Format src_format, schema::Format dst_format, std::vector<int> *const perm) {
MS_ASSERT(perm != nullptr);
auto src_format_str = std::string(schema::EnumNameFormat(src_format));
auto dst_format_str = std::string(schema::EnumNameFormat(dst_format));
@ -66,10 +66,16 @@ int DeduceDimConvertion(schema::Format src_format, schema::Format dst_format, st
template <typename T>
void TransposeData(const ShapeVector &origin_shape, const ShapeVector &cur_shape, const std::vector<int> &perm,
T *weight_data, std::vector<T> *buf) {
T *const weight_data, std::vector<T> *buf) {
MS_ASSERT(weight_data != nullptr && buf != nullptr);
MS_ASSERT(origin_shape.size() == cur_shape.size() && cur_shape.size() == perm.size());
int count = std::accumulate(origin_shape.begin(), origin_shape.end(), 1, std::multiplies<int>());
int count = 1;
for (const auto &dat : origin_shape) {
if (INT_MUL_OVERFLOW(count, static_cast<int>(dat))) {
return;
}
count *= static_cast<int>(dat);
}
ShapeVector post_multiply(cur_shape.size());
std::unordered_map<int, int> dim_map;
for (int i = cur_shape.size() - 1; i >= 0; --i) {
@ -88,10 +94,17 @@ void TransposeData(const ShapeVector &origin_shape, const ShapeVector &cur_shape
position_map[j] = temp % origin_shape[j];
temp /= origin_shape[j];
}
int64_t new_pos = std::accumulate(position_map.begin(), position_map.end(), 0,
[&post_multiply, &dim_map](int64_t res, const std::pair<int, int> &pair_y) {
return res + post_multiply[dim_map[pair_y.first]] * pair_y.second;
});
int64_t new_pos = 0;
for (const auto &pair_y : position_map) {
if (INT_MUL_OVERFLOW(post_multiply[dim_map[pair_y.first]], pair_y.second)) {
return;
}
if (INT_ADD_OVERFLOW(new_pos, post_multiply[dim_map[pair_y.first]] * pair_y.second)) {
return;
}
new_pos += post_multiply[dim_map[pair_y.first]] * pair_y.second;
}
buf->at(new_pos) = weight_data[i];
}
}
@ -121,7 +134,14 @@ STATUS DoTransposeData(const tensor::TensorPtr &tensor, schema::Format src_forma
}
new_shape.push_back(origin_shape[val]);
}
auto count = std::accumulate(origin_shape.begin(), origin_shape.end(), 1LL, std::multiplies<int64_t>());
int64_t count = 1;
for (const auto &dat : origin_shape) {
if (INT_MUL_OVERFLOW(count, dat)) {
MS_LOG(ERROR) << "Int mul overflow";
return RET_ERROR;
}
count *= dat;
}
if (count <= 0 || count > static_cast<int64_t>(INT32_MAX)) {
MS_LOG(ERROR) << "tensor element num is too big, which should be smaller than int32_max.";
return RET_ERROR;
@ -174,6 +194,38 @@ bool IsRealKernel(const AnfNodePtr &node) {
#endif
return !is_virtual_node;
}
int CopyTensorDataFromTensorInfo(const tensor::TensorPtr &tensor_info,
const std::shared_ptr<tensor::Tensor> &tensor_info_dst, size_t data_count) {
if (tensor_info->data_type() == kNumberTypeInt64) {
auto *tensor_data = reinterpret_cast<int *>(tensor_info_dst->data_c());
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "new data failed";
return RET_ERROR;
}
auto *origin_data = reinterpret_cast<int64_t *>(tensor_info->data_c());
for (size_t i = 0; i < data_count; ++i) {
if (origin_data[i] > static_cast<int64_t>(INT32_MAX) || origin_data[i] < static_cast<int64_t>(INT32_MIN)) {
MS_LOG(WARNING) << "int64 data " << origin_data[i] << "too big to fit into int32";
tensor_data[i] = origin_data[i] > 0 ? INT32_MAX : INT32_MIN;
} else {
tensor_data[i] = static_cast<int>(origin_data[i]);
}
}
} else {
tensor_info_dst->set_data_type(tensor_info->data_type());
auto *tensor_data = reinterpret_cast<int8_t *>(tensor_info_dst->data_c());
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "new data failed";
return RET_ERROR;
}
if (memcpy_s(tensor_data, tensor_info_dst->Size(), tensor_info->data_c(), tensor_info->Size()) != lite::RET_OK) {
MS_LOG(ERROR) << "memcpy data failed.";
return RET_ERROR;
}
}
return RET_OK;
}
} // namespace
bool CheckInputs(const CNodePtr &cnode) {
@ -351,7 +403,7 @@ bool IsGraphKernel(const AnfNodePtr &node) {
return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
}
ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id) {
ParameterPtr AddNewBiasNode(const float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id) {
if (bias_data == nullptr || func_graph == nullptr) {
MS_LOG(ERROR) << "input parameter is nullptr.";
return nullptr;
@ -461,7 +513,7 @@ bool IsParamNode(const BaseRef &n) {
return tensor->data_c() != nullptr;
}
STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *tensor_info, const CNodePtr &cnode, size_t index) {
STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *const tensor_info, const CNodePtr &cnode, size_t index) {
CHECK_NULL_RETURN(tensor_info);
CHECK_NULL_RETURN(cnode);
AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, index);
@ -673,7 +725,18 @@ ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr
MS_LOG(ERROR) << "new tensor::Tensor failed.";
return nullptr;
}
size_t data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
int data_count = 1;
for (const auto &dat : shape) {
if (INT_MUL_OVERFLOW(data_count, static_cast<int>(dat))) {
MS_LOG(ERROR) << "Int mul overflow.";
return nullptr;
}
data_count *= static_cast<int>(dat);
}
if (data_count < 0) {
MS_LOG(ERROR) << "Invalid shape.";
return nullptr;
}
if (tensor_info->Size() == 0) {
auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info_new);
if (status != RET_OK) {
@ -682,33 +745,12 @@ ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr
}
return param_node;
}
if (tensor_info->data_type() == kNumberTypeInt64) {
auto *tensor_data = reinterpret_cast<int *>(tensor_info_new->data_c());
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "new data failed";
return nullptr;
}
auto *origin_data = reinterpret_cast<int64_t *>(tensor_info->data_c());
for (size_t i = 0; i < data_count; ++i) {
if (origin_data[i] > static_cast<int64_t>(INT32_MAX) || origin_data[i] < static_cast<int64_t>(INT32_MIN)) {
MS_LOG(WARNING) << "int64 data " << origin_data[i] << "too big to fit into int32";
tensor_data[i] = origin_data[i] > 0 ? INT32_MAX : INT32_MIN;
} else {
tensor_data[i] = static_cast<int>(origin_data[i]);
}
}
} else {
tensor_info_new->set_data_type(tensor_info->data_type());
auto *tensor_data = reinterpret_cast<int8_t *>(tensor_info_new->data_c());
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "new data failed";
return nullptr;
}
if (memcpy_s(tensor_data, tensor_info_new->Size(), tensor_info->data_c(), tensor_info->Size()) != lite::RET_OK) {
MS_LOG(ERROR) << "memcpy data failed.";
return nullptr;
}
if (CopyTensorDataFromTensorInfo(tensor_info, tensor_info_new, static_cast<size_t>(data_count)) != RET_OK) {
MS_LOG(ERROR) << "copy tensor data failed";
return nullptr;
}
auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info_new);
if (status != RET_OK) {
MS_LOG(ERROR) << "init parameter from tensor info failed";
@ -1012,6 +1054,7 @@ bool IsQuantParameterNode(const PrimitiveCPtr &prim) {
}
return false;
}
void UpdateManager(const FuncGraphPtr &func_graph) {
auto manager = func_graph->manager();
if (manager == nullptr) {
@ -1026,5 +1069,62 @@ void UpdateManager(const FuncGraphPtr &func_graph) {
manager->AddFuncGraph(one_func_graph);
}
}
std::pair<CNodePtr, int> GetRealCertainVarInput(const CNodePtr &cnode, size_t index) {
MS_CHECK_TRUE_MSG(cnode != nullptr, {}, "function's parameter is nullptr.");
MS_CHECK_TRUE_MSG(cnode->input(index) != nullptr, {}, "required input is nullptr");
auto real_input_cnode = cnode->input(index)->cast<CNodePtr>();
MS_CHECK_TRUE_MSG(real_input_cnode != nullptr, {}, "input node is not a cnode.");
int item_index = 0;
if (opt::CheckPrimitiveType(real_input_cnode, prim::kPrimTupleGetItem)) {
auto index_node = real_input_cnode->input(opt::kInputIndexTwo);
MS_CHECK_TRUE_MSG(index_node != nullptr, {}, "tuple_get_item's second input is nullptr.");
MS_CHECK_TRUE_MSG(index_node->isa<ValueNode>(), {}, "tuple_get_item's second input should be valuenode.");
auto index_ptr = index_node->cast<ValueNodePtr>()->value();
MS_CHECK_TRUE_MSG(index_ptr != nullptr, {}, "tuple_get_item's second input val is nullptr.");
auto value = CastToInt(index_ptr);
MS_CHECK_TRUE_MSG(value.size() == 1, {}, "tuple_get_item's second input is invalid.");
item_index = value.front();
MS_CHECK_TRUE_MSG(real_input_cnode->input(1) != nullptr, {}, "tuple_get_item's first input is nullptr");
real_input_cnode = real_input_cnode->input(1)->cast<CNodePtr>();
MS_CHECK_TRUE_MSG(real_input_cnode != nullptr, {}, "tuple_get_item first input is not cnode.");
}
return {real_input_cnode, item_index};
}
int DetermineCertainVarInputHasInferred(const CNodePtr &cnode, size_t index, bool *infer_succ) {
MS_CHECK_TRUE_MSG(cnode != nullptr && infer_succ != nullptr, RET_ERROR, "function's parameter is nullptr.");
auto var_input_info = GetRealCertainVarInput(cnode, index);
if (var_input_info.first == nullptr) {
MS_LOG(ERROR) << "cannot get the real var input.";
return RET_ERROR;
}
auto real_input_cnode = var_input_info.first;
auto item_index = var_input_info.second;
auto input_node_prim = GetValueNode<PrimitivePtr>((real_input_cnode->input(0)));
MS_CHECK_TRUE_MSG(input_node_prim != nullptr, RET_ERROR, "get primitive failed.");
*infer_succ = false;
auto value_ptr = input_node_prim->GetAttr(kInferDone);
if (value_ptr != nullptr) {
MS_CHECK_TRUE_MSG(value_ptr->isa<BoolImm>(), RET_ERROR, "value is not a boolean.");
*infer_succ = GetValue<bool>(value_ptr);
}
value_ptr = input_node_prim->GetAttr(kInferFlags);
if (value_ptr == nullptr) {
return RET_OK;
}
MS_CHECK_TRUE_MSG(value_ptr->isa<ValueSequeue>(), RET_ERROR, "infer flag should be a vector.");
auto value_sequence = value_ptr->cast<ValueSequeuePtr>();
auto elements = value_sequence->value();
MS_CHECK_TRUE_MSG(!elements.empty(), RET_ERROR, "infer_info has no content.");
auto first_element = elements.front();
MS_CHECK_TRUE_MSG(first_element != nullptr, RET_ERROR, "element is a nullptr.");
MS_CHECK_TRUE_MSG(first_element->isa<BoolImm>(), RET_ERROR, "each element is not a boolean.");
auto infer_infos = GetValue<std::vector<bool>>(value_ptr);
MS_CHECK_TRUE_MSG(item_index >= 0 && static_cast<size_t>(item_index) < infer_infos.size(), RET_ERROR,
"item index is out of range.");
*infer_succ = infer_infos[item_index];
return RET_OK;
}
} // namespace opt
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "ops/primitive_c.h"
#include "ir/anf.h"
@ -35,6 +36,10 @@ using mindspore::lite::RET_OK;
using mindspore::lite::STATUS;
namespace mindspore {
namespace opt {
// used for common op, which corresponding value is a boolean.
constexpr auto kInferDone = "infer_done";
// used for control_flow op(while and if), which corresponding value is a boolean vec.
constexpr auto kInferFlags = "infer_flags";
inline constexpr int kInputIndexTwo = 2;
inline constexpr int kInputIndexThree = 3;
inline constexpr int kInputIndexFour = 4;
@ -64,7 +69,7 @@ bool IsGraphKernel(const AnfNodePtr &node);
bool CheckInputs(const CNodePtr &cnode);
ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id);
ParameterPtr AddNewBiasNode(const float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id);
bool IsParamNode(const BaseRef &n);
@ -130,6 +135,10 @@ bool IsQuantParameterNode(const PrimitiveCPtr &prim);
void UpdateManager(const FuncGraphPtr &func_graph);
std::pair<CNodePtr, int> GetRealCertainVarInput(const CNodePtr &cnode, size_t index);
int DetermineCertainVarInputHasInferred(const CNodePtr &cnode, size_t index, bool *infer_succ);
template <const PrimitivePtr *prim = nullptr>
inline bool IsSpecifiedNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {

View File

@ -148,7 +148,27 @@ STATUS DeleteRedundantTranspose::UpdateNodeFormat(const CNodePtr &cnode) {
if (prim->GetAttr(ops::kFormat) == nullptr) {
return lite::RET_OK;
}
auto format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
auto forward_format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
const int max_search_depth{3};
int loop{0};
auto search_node = cnode->input(1);
while (loop < max_search_depth) {
MS_CHECK_TRUE_RET(search_node != nullptr, lite::RET_ERROR);
auto search_cnode = search_node->cast<CNodePtr>();
if (search_cnode == nullptr) {
break;
}
auto primitive = GetCNodePrimitive(search_cnode);
if (primitive == nullptr) {
break;
}
if (primitive->GetAttr(ops::kFormat) != nullptr) {
forward_format = GetValue<int64_t>(primitive->GetAttr(ops::kFormat));
break;
}
search_node = search_cnode->input(1);
++loop;
}
auto node_users = manager_->node_users()[cnode];
for (auto &node_user : node_users) {
if (node_user.second != 1) {
@ -161,7 +181,7 @@ STATUS DeleteRedundantTranspose::UpdateNodeFormat(const CNodePtr &cnode) {
auto post_cnode = node_user.first->cast<CNodePtr>();
auto post_prim = GetValueNode<PrimitivePtr>(post_cnode->input(0));
MS_ASSERT(post_prim != nullptr);
post_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(format));
post_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(forward_format));
}
return lite::RET_OK;
}

View File

@ -506,36 +506,34 @@ int DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGra
MS_LOG(ERROR) << "Get index failed: " << e.what();
return lite::RET_ERROR;
}
param_node->set_abstract(GetCNodeInputAbstract(cnode, index)->Clone());
auto abstract = GetCNodeInputAbstract(cnode, index);
MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is a nullptr.");
param_node->set_abstract(abstract->Clone());
if (utils::isa<CNodePtr>(cnode->input(index))) {
ShapeVector shape_vec = {-1};
auto out_cnode = cnode->input(index)->cast<CNodePtr>();
MS_ASSERT(out_cnode != nullptr);
MS_ASSERT(trans_cnode != nullptr);
auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
MS_CHECK_TRUE_MSG(out_prim != nullptr, lite::RET_ERROR, "GetValueNode failed");
if (out_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(kInferDone))) {
param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec));
bool has_inferred{false};
auto ret = DetermineCertainVarInputHasInferred(cnode, index, &has_inferred);
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "determine infer flag failed.");
if (!has_inferred) {
auto abstract_shape = std::make_shared<abstract::Shape>(shape_vec);
MS_CHECK_TRUE_MSG(abstract_shape != nullptr, RET_ERROR, "create shape failed.");
param_node->abstract()->set_shape(abstract_shape);
}
} else {
}
if (utils::isa<Parameter>(cnode->input(index))) {
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
}
if (utils::isa<ValueNode>(cnode->input(index))) {
lite::DataInfo data_info;
if (utils::isa<ParameterPtr>(cnode->input(index))) {
if (cnode->input(index)->cast<ParameterPtr>()->has_default()) {
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
}
continue;
}
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info, true);
if (status != lite::RET_OK) {
continue;
}
ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
if (data_info.data_.empty()) {
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec));
} else {
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec,
data_info.data_.data(), data_info.data_.size()));
}
auto tensor_info = lite::CreateTensorInfo(data_info.data_.data(), data_info.data_.size(), shape_vec,
static_cast<TypeId>(data_info.data_type_));
MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_ERROR, "create a tensor failed.");
param_node->set_default_param(tensor_info);
}
}
return lite::RET_OK;

View File

@ -38,37 +38,11 @@ int GetCNodeCertainInputFormat(const CNodePtr cnode, int index, mindspore::Forma
cnode->set_inputs(origin_inputs);
return lite::RET_NO_CHANGE;
}
auto real_cnode = cnode->input(index)->cast<CNodePtr>();
MS_ASSERT(real_cnode != nullptr);
if (CheckPrimitiveType(real_cnode, prim::kPrimTupleGetItem)) {
real_cnode = real_cnode->input(1)->cast<CNodePtr>();
if (DetermineCertainVarInputFormat(cnode, index, format) != RET_OK) {
MS_LOG(ERROR) << "determine certain var-input's format failed.";
return RET_ERROR;
}
cnode->set_inputs(origin_inputs);
MS_ASSERT(real_cnode != nullptr);
auto primitive = GetValueNode<PrimitivePtr>(real_cnode->input(0));
MS_CHECK_TRUE_MSG(primitive != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
if (primitive->GetAttr(ops::kFormat) == nullptr) {
MS_LOG(DEBUG) << "cnode has no format attr. " << real_cnode->fullname_with_scope();
return lite::RET_NO_CHANGE;
}
auto format_attr = primitive->GetAttr(ops::kFormat);
MS_CHECK_TRUE_MSG(format_attr != nullptr, lite::RET_NULL_PTR, "GetAttr Failed");
*format = static_cast<mindspore::Format>(GetValue<int64_t>(format_attr));
if (CheckPrimitiveType(real_cnode, prim::kPrimTranspose)) {
std::vector<int> perm;
if (GetTransposePerm(real_cnode, &perm) != lite::RET_OK) {
MS_LOG(ERROR) << "get transpose perm failed.";
return lite::RET_ERROR;
}
if (perm.size() != DIMENSION_4D) {
return RET_OK;
}
if (perm == kNH2NC && *format == mindspore::NHWC) {
*format = mindspore::NCHW;
} else if (perm == kNC2NH && *format == mindspore::NCHW) {
*format = mindspore::NHWC;
}
}
return lite::RET_OK;
}
@ -93,6 +67,94 @@ int ModifySubGraphInputCNodeFormat(const FuncGraphPtr &sub_graph, const Paramete
}
return lite::RET_OK;
}
int JudgeControlFlowCertainOutputHasInferred(const CNodePtr &return_cnode, size_t index, bool *infer_info) {
MS_ASSERT(return_cnode != nullptr && infer_info != nullptr);
MS_CHECK_TRUE_MSG(index < return_cnode->size(), RET_ERROR, "input index is out of range.");
*infer_info = true;
auto abstract_base = GetCNodeInputAbstract(return_cnode, index);
MS_CHECK_TRUE_MSG(abstract_base != nullptr, RET_ERROR, "anfnode has no abstract.");
ShapeVector shape;
auto ret = FetchShapeFromAbstract(abstract_base, &shape);
MS_CHECK_TRUE_MSG(ret == lite::RET_OK, RET_ERROR, "fetch shape from abstract failed.");
if (std::find(shape.begin(), shape.end(), -1) != shape.end()) {
*infer_info = false;
return RET_OK;
}
if (utils::isa<CNodePtr>(return_cnode->input(index))) {
ret = DetermineCertainVarInputHasInferred(return_cnode, index, infer_info);
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "determine infer flag failed.");
}
return RET_OK;
}
int ModifyWhileBodyGraphInputs(const CNodePtr &cnode, const FuncGraphPtr &sub_graph, const ParameterPtr &graph_input,
size_t input_index) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr && graph_input != nullptr);
if (!CheckPrimitiveType(cnode, prim::kPrimWhile)) {
return RET_OK;
}
auto body_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
MS_ASSERT(body_graph != nullptr);
if (body_graph.get() != sub_graph.get()) {
MS_LOG(DEBUG) << "sub_graph is not body graph.";
return RET_OK;
}
auto return_cnode = sub_graph->get_return();
MS_CHECK_TRUE_MSG(return_cnode != nullptr, RET_ERROR, "return node is a nullptr.");
auto origin_outputs = return_cnode->inputs();
auto ret = lite::RemoveIfDepend(return_cnode);
if (ret != RET_OK) {
return_cnode->set_inputs(origin_outputs);
MS_LOG(ERROR) << "remove depend node failed.";
return RET_ERROR;
}
ret = lite::RemoveIfMakeTuple(return_cnode);
if (ret != RET_OK) {
return_cnode->set_inputs(origin_outputs);
MS_LOG(ERROR) << "remove make_tuple node failed.";
return RET_ERROR;
}
RemoveIfMonad(return_cnode);
if (return_cnode->size() == 0 || input_index >= return_cnode->size() - 1) {
return_cnode->set_inputs(origin_outputs);
MS_LOG(ERROR) << "input index is out of range.";
return RET_ERROR;
}
auto output = return_cnode->input(input_index + 1);
return_cnode->set_inputs(origin_outputs);
MS_CHECK_TRUE_MSG(output != nullptr, RET_ERROR, "output node is a nullptr.");
if (output->isa<CNode>()) {
graph_input->set_default_param(nullptr);
}
return RET_OK;
}
int MergeTwoBranchOfIfOp(const CNodePtr &cnode, const CNodePtr &return_cnode, size_t index, bool *true_branch) {
MS_ASSERT(cnode != nullptr && return_cnode != nullptr && true_branch != nullptr);
*true_branch = true;
if (!CheckPrimitiveType(cnode, prim::kPrimIf)) {
return RET_OK;
}
bool infer_info{false};
// judge true branch.
if (JudgeControlFlowCertainOutputHasInferred(return_cnode, index, &infer_info) != RET_OK) {
MS_LOG(ERROR) << "determine certain output has inferred failed.";
return RET_ERROR;
}
if (infer_info) {
return RET_OK;
}
// judge false branch.
if (JudgeControlFlowCertainOutputHasInferred(cnode, index + kInputSizeThree, &infer_info) != RET_OK) {
MS_LOG(ERROR) << "determine certain output has inferred failed.";
return RET_ERROR;
}
if (infer_info) {
*true_branch = false;
}
return RET_OK;
}
} // namespace
bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
@ -190,10 +252,10 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return lite::RET_ERROR;
return RET_ERROR;
}
auto ret = SetSubGraphInput(cnode, sub_func_graph);
if (ret != lite::RET_OK) {
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetSubGraphInput failed: " << ret;
return RET_ERROR;
}
@ -255,21 +317,22 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt
auto last_underline = node_name.find_last_of("_");
node_name = node_name.substr(0, last_underline);
last_underline = node_name.find_last_of("_");
auto index = 0;
size_t index = 0;
try {
index = std::stoi(node_name.substr(last_underline + 1)) + 3;
index = static_cast<size_t>(std::stoi(node_name.substr(last_underline + 1))) + kInputSizeThree;
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Get index failed: " << e.what();
return RET_ERROR;
}
param_node->set_abstract(opt::GetCNodeInputAbstract(cnode, index)->Clone());
if (utils::isa<CNodePtr>(cnode->input(index))) {
auto abstract = GetCNodeInputAbstract(cnode, index);
MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is a nullptr.");
param_node->set_abstract(abstract->Clone());
if (utils::isa<CNode>(cnode->input(index))) {
ShapeVector shape_vec = {-1};
auto out_cnode = cnode->input(index)->cast<CNodePtr>();
MS_ASSERT(trans_cnode != nullptr);
auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
MS_CHECK_TRUE_MSG(out_prim != nullptr, lite::RET_ERROR, "GetValueNode Failed");
if (out_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(opt::kInferDone))) {
bool has_inferred{false};
auto ret = DetermineCertainVarInputHasInferred(cnode, index, &has_inferred);
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "determine infer flag failed.");
if (!has_inferred) {
auto abstract_shape = std::make_shared<abstract::Shape>(shape_vec);
CHECK_NULL_RETURN(abstract_shape);
param_node->abstract()->set_shape(abstract_shape);
@ -282,6 +345,7 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt
if (ModifySubGraphInputCNodeFormat(sub_graph, param_node, format) != lite::RET_OK) {
MS_LOG(DEBUG) << "modify subgraph input cnode format failed." << cnode->func_graph_as_var();
}
continue;
}
if (utils::isa<Parameter>(cnode->input(index))) {
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
@ -298,6 +362,11 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt
MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_ERROR, "created tensor is a nullptr.");
param_node->set_default_param(tensor_info);
}
// while's body graph:if the corresponding output is a variable, the corresponding input's data will be set to NULL.
if (ModifyWhileBodyGraphInputs(cnode, sub_graph, param_node, index - kInputSizeThree) != RET_OK) {
MS_LOG(ERROR) << "modify while body graph's certain input failed.";
return RET_ERROR;
}
}
return RET_OK;
}
@ -342,47 +411,39 @@ STATUS InferShapePass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGrap
lite::RemoveIfDepend(return_node);
lite::RemoveIfMakeTuple(return_node);
AbstractBasePtrList abstract_list;
bool infer_done = true;
std::vector<bool> infer_infos;
for (size_t i = 1; i < return_node->size(); ++i) {
auto abstract_base = opt::GetCNodeInputAbstract(return_node, i);
MS_CHECK_TRUE_RET(abstract_base != nullptr, lite::RET_ERROR);
abstract_list.emplace_back(abstract_base->Clone());
auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
MS_ASSERT(abstract_tensor != nullptr);
auto shape_ptr = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape());
MS_ASSERT(shape_ptr != nullptr);
auto shape = shape_ptr->shape();
if (std::find(shape.begin(), shape.end(), -1) != shape.end()) {
infer_done = false;
}
if (utils::isa<CNodePtr>(return_node->input(i))) {
auto input_cnode = return_node->input(i)->cast<CNodePtr>();
MS_ASSERT(input_cnode != nullptr);
if (opt::CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
input_cnode = input_cnode->input(1)->cast<CNodePtr>();
}
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
CHECK_NULL_RETURN(input_prim);
if (input_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(opt::kInferDone))) {
infer_done = false;
bool true_branch{false};
auto ret = MergeTwoBranchOfIfOp(cnode, return_node, i, &true_branch);
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "decide to fetch which branch failed.");
AbstractBasePtr abstract;
bool infer_info;
if (true_branch) {
abstract = GetCNodeInputAbstract(return_node, i);
if (JudgeControlFlowCertainOutputHasInferred(return_node, i, &infer_info) != lite::RET_OK) {
MS_LOG(ERROR) << "determine certain output has inferred failed.";
return lite::RET_ERROR;
}
} else {
abstract = GetCNodeInputAbstract(cnode, i + kInputSizeThree);
infer_info = true;
}
MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "get a nullptr abstract.");
abstract_list.emplace_back(abstract->Clone());
infer_infos.push_back(infer_info);
}
return_node->set_inputs(origin_inputs);
if (utils::isa<abstract::AbstractTuplePtr>(cnode->abstract())) {
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
CHECK_NULL_RETURN(abstract_tuple);
MS_CHECK_TRUE_MSG(abstract_tuple != nullptr, RET_ERROR, "created AbstractTuple is a nullptr.");
cnode->set_abstract(abstract_tuple);
} else {
if (abstract_list.size() != 1) {
MS_LOG(ERROR) << "cnode output is invalid.";
}
MS_CHECK_TRUE_MSG(abstract_list.size() == 1, RET_ERROR, "cnode output is invalid.");
cnode->set_abstract(abstract_list.front());
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
CHECK_NULL_RETURN(prim);
infer_done = CheckPrimitiveType(cnode, prim::kPrimWhile) ? false : infer_done;
prim->AddAttr(opt::kInferDone, MakeValue<bool>(infer_done));
MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "cnode's input0 is not a primitive.");
prim->AddAttr(kInferFlags, MakeValue(infer_infos));
return RET_OK;
}

View File

@ -197,18 +197,10 @@ int GetCNodeVarInput(const CNodePtr &cnode, std::vector<TensorPtr> *var_ms_input
return lite::RET_ERROR;
}
tensor->set_format((Format)(data_info.format_));
MS_ASSERT(cnode->input(i) != nullptr);
auto input_cnode = cnode->input(i)->cast<CNodePtr>();
MS_ASSERT(input_cnode != nullptr);
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
if (CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
MS_CHECK_TRUE_RET(input_cnode->input(1) != nullptr, lite::RET_NULL_PTR);
auto item_input_cnode = input_cnode->input(1)->cast<CNodePtr>();
MS_CHECK_TRUE_RET(item_input_cnode != nullptr, lite::RET_NULL_PTR);
input_prim = GetValueNode<PrimitivePtr>(item_input_cnode->input(0));
}
MS_CHECK_TRUE_RET(input_prim != nullptr, lite::RET_NULL_PTR);
if (input_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(kInferDone))) {
bool has_inferred{false};
auto ret = DetermineCertainVarInputHasInferred(cnode, i, &has_inferred);
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "determine infer flag failed.");
if (!has_inferred) {
tensor->set_shape({-1});
}
var_ms_inputs->emplace_back(tensor);

View File

@ -142,9 +142,7 @@ STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
}
if (ret == lite::RET_OK || ret == lite::RET_INFER_INVALID) {
auto set_status = SetCNodeAbstract(cnode, outputs, ret);
auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_CHECK_TRUE_MSG(cnode_prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
cnode_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(inputs[0]->format()));
anf_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(inputs[0]->format()));
if (set_status != lite::RET_OK) {
MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope();
return set_status;
@ -152,6 +150,12 @@ STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
} else {
MS_LOG(WARNING) << "infer shape failed.";
}
if (CheckPrimitiveType(cnode, prim::kPrimCustom)) {
std::vector<int64_t> outputs_format;
std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_format),
[](const lite::Tensor *output) { return output->format(); });
anf_prim->AddAttr(kOutputsFormat, MakeValue(outputs_format));
}
return ret;
}