forked from mindspore-Ecosystem/mindspore
optimize the infershape of control_flow model and the acquisition of format
This commit is contained in:
parent
7a1c25c08d
commit
ecd207422c
|
@ -360,38 +360,6 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy
|
|||
return ret;
|
||||
}
|
||||
|
||||
int SetFormatForCnode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
|
||||
DataInfo *data_info) {
|
||||
data_info->format_ = mindspore::NHWC;
|
||||
MS_CHECK_TRUE_MSG(cnode->input(index) != nullptr, RET_ERROR, "input is nullptr");
|
||||
auto input_node_prim = GetValueNode<PrimitivePtr>((cnode->input(index)->cast<CNodePtr>()->input(0)));
|
||||
MS_CHECK_TRUE_MSG(input_node_prim != nullptr, RET_ERROR, "GetValueNode failed");
|
||||
if (input_node_prim->GetAttr(mindspore::ops::kFormat) != nullptr) {
|
||||
auto value = input_node_prim->GetAttr(mindspore::ops::kFormat);
|
||||
if (value->isa<mindspore::Int64Imm>()) {
|
||||
data_info->format_ = GetValue<int64_t>(value);
|
||||
}
|
||||
}
|
||||
if (opt::CheckPrimitiveType(cnode->input(index), prim::kPrimTranspose)) {
|
||||
std::vector<int> perm;
|
||||
if (opt::GetTransposePerm(cnode->input(index)->cast<CNodePtr>(), &perm) != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (perm.size() < kNumTransposePermSize) {
|
||||
return RET_OK;
|
||||
}
|
||||
// NHWC to NCHW: perm is {0, 3, 1, 2}
|
||||
// NCHW to NHWC: perm is {0, 2, 3, 1}
|
||||
if (perm[0] == 0 && perm[1] == 3 && perm[2] == 1 && perm[3] == 2 &&
|
||||
(data_info->format_ == NHWC || data_info->format_ == KHWC)) {
|
||||
data_info->format_ = NCHW;
|
||||
} else if (perm[0] == 0 && perm[1] == 2 && perm[2] == 3 && perm[3] == 1 && data_info->format_ == NCHW) {
|
||||
data_info->format_ = NHWC;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
|
||||
DataInfo *data_info) {
|
||||
MS_ASSERT(cnode != nullptr && data_info != nullptr);
|
||||
|
@ -414,11 +382,13 @@ int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType f
|
|||
}
|
||||
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
|
||||
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
|
||||
auto ret = SetFormatForCnode(cnode, index, fmk_type, train_flag, data_info);
|
||||
Format format{mindspore::NHWC};
|
||||
auto ret = opt::DetermineCertainVarInputFormat(cnode, index, &format);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "set format for cnode failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
data_info->format_ = format;
|
||||
data_info->data_type_ = type_ptr->type_id();
|
||||
data_info->shape_ = dims;
|
||||
data_info->node_type_ = NodeType_CNode;
|
||||
|
|
|
@ -223,5 +223,52 @@ bool IsSpecialType(const CNodePtr &cnode) {
|
|||
CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, kPrimMakeTupleV2) ||
|
||||
CheckPrimitiveType(cnode, prim::kPrimReturn);
|
||||
}
|
||||
|
||||
int DetermineCertainVarInputFormat(const CNodePtr &cnode, size_t index, Format *format) {
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr && format != nullptr, RET_ERROR, "function's parameter is nullptr.");
|
||||
auto var_input_info = GetRealCertainVarInput(cnode, index);
|
||||
if (var_input_info.first == nullptr) {
|
||||
MS_LOG(ERROR) << "cannot get the real var input.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
*format = mindspore::NHWC;
|
||||
auto real_input_cnode = var_input_info.first;
|
||||
auto item_index = var_input_info.second;
|
||||
auto input_node_prim = GetValueNode<PrimitivePtr>((real_input_cnode->input(0)));
|
||||
MS_CHECK_TRUE_MSG(input_node_prim != nullptr, RET_ERROR, "get primitive failed");
|
||||
auto value_ptr = input_node_prim->GetAttr(ops::kFormat);
|
||||
if (value_ptr != nullptr) {
|
||||
MS_CHECK_TRUE_MSG(value_ptr->isa<mindspore::Int64Imm>(), RET_ERROR, "format attr must be an int64_t val.");
|
||||
auto value = GetValue<int64_t>(value_ptr);
|
||||
MS_CHECK_TRUE_MSG(value >= NCHW && value <= NCW, RET_ERROR, "format val is out of enum's range.");
|
||||
*format = static_cast<Format>(value);
|
||||
}
|
||||
value_ptr = input_node_prim->GetAttr(kOutputsFormat);
|
||||
if (value_ptr != nullptr) {
|
||||
MS_CHECK_TRUE_MSG(value_ptr->isa<ValueSequeue>(), RET_ERROR, "outputs_format attr should be sequence.");
|
||||
auto formats = CastToInt(value_ptr);
|
||||
if (item_index >= 0 && static_cast<size_t>(item_index) < formats.size()) {
|
||||
MS_CHECK_TRUE_MSG(formats[item_index] >= NCHW && formats[item_index] <= NCW, RET_ERROR,
|
||||
"format val is out of enum's range.");
|
||||
*format = static_cast<Format>(formats[item_index]);
|
||||
}
|
||||
}
|
||||
if (CheckPrimitiveType(real_input_cnode, prim::kPrimTranspose)) {
|
||||
std::vector<int> perm;
|
||||
if (GetTransposePerm(real_input_cnode, &perm) != RET_OK) {
|
||||
MS_LOG(ERROR) << "fetch transpose's perm failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (perm.size() != kNC2NH.size()) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (perm == kNH2NC && (*format == NHWC || *format == KHWC)) {
|
||||
*format = NCHW;
|
||||
} else if (perm == opt::kNC2NH && *format == NCHW) {
|
||||
*format = NHWC;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr auto kInferDone = "infer_done";
|
||||
constexpr auto kOutputsFormat = "outputs_format";
|
||||
enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE };
|
||||
struct TransTypePair {
|
||||
FormatTransNodeType pre_;
|
||||
|
@ -42,6 +42,7 @@ STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm);
|
|||
void RemoveIfMonad(const CNodePtr &cnode);
|
||||
bool IsMonadNode(const AnfNodePtr &node);
|
||||
bool IsSpecialType(const CNodePtr &cnode);
|
||||
int DetermineCertainVarInputFormat(const CNodePtr &cnode, size_t index, Format *format);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ namespace opt {
|
|||
namespace {
|
||||
constexpr auto kAnfPrimitiveIndex = 0;
|
||||
constexpr auto kDeviceTypeNone = -1;
|
||||
int DeduceDimConvertion(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm) {
|
||||
int DeduceDimConvertion(schema::Format src_format, schema::Format dst_format, std::vector<int> *const perm) {
|
||||
MS_ASSERT(perm != nullptr);
|
||||
auto src_format_str = std::string(schema::EnumNameFormat(src_format));
|
||||
auto dst_format_str = std::string(schema::EnumNameFormat(dst_format));
|
||||
|
@ -66,10 +66,16 @@ int DeduceDimConvertion(schema::Format src_format, schema::Format dst_format, st
|
|||
|
||||
template <typename T>
|
||||
void TransposeData(const ShapeVector &origin_shape, const ShapeVector &cur_shape, const std::vector<int> &perm,
|
||||
T *weight_data, std::vector<T> *buf) {
|
||||
T *const weight_data, std::vector<T> *buf) {
|
||||
MS_ASSERT(weight_data != nullptr && buf != nullptr);
|
||||
MS_ASSERT(origin_shape.size() == cur_shape.size() && cur_shape.size() == perm.size());
|
||||
int count = std::accumulate(origin_shape.begin(), origin_shape.end(), 1, std::multiplies<int>());
|
||||
int count = 1;
|
||||
for (const auto &dat : origin_shape) {
|
||||
if (INT_MUL_OVERFLOW(count, static_cast<int>(dat))) {
|
||||
return;
|
||||
}
|
||||
count *= static_cast<int>(dat);
|
||||
}
|
||||
ShapeVector post_multiply(cur_shape.size());
|
||||
std::unordered_map<int, int> dim_map;
|
||||
for (int i = cur_shape.size() - 1; i >= 0; --i) {
|
||||
|
@ -88,10 +94,17 @@ void TransposeData(const ShapeVector &origin_shape, const ShapeVector &cur_shape
|
|||
position_map[j] = temp % origin_shape[j];
|
||||
temp /= origin_shape[j];
|
||||
}
|
||||
int64_t new_pos = std::accumulate(position_map.begin(), position_map.end(), 0,
|
||||
[&post_multiply, &dim_map](int64_t res, const std::pair<int, int> &pair_y) {
|
||||
return res + post_multiply[dim_map[pair_y.first]] * pair_y.second;
|
||||
});
|
||||
int64_t new_pos = 0;
|
||||
for (const auto &pair_y : position_map) {
|
||||
if (INT_MUL_OVERFLOW(post_multiply[dim_map[pair_y.first]], pair_y.second)) {
|
||||
return;
|
||||
}
|
||||
if (INT_ADD_OVERFLOW(new_pos, post_multiply[dim_map[pair_y.first]] * pair_y.second)) {
|
||||
return;
|
||||
}
|
||||
new_pos += post_multiply[dim_map[pair_y.first]] * pair_y.second;
|
||||
}
|
||||
|
||||
buf->at(new_pos) = weight_data[i];
|
||||
}
|
||||
}
|
||||
|
@ -121,7 +134,14 @@ STATUS DoTransposeData(const tensor::TensorPtr &tensor, schema::Format src_forma
|
|||
}
|
||||
new_shape.push_back(origin_shape[val]);
|
||||
}
|
||||
auto count = std::accumulate(origin_shape.begin(), origin_shape.end(), 1LL, std::multiplies<int64_t>());
|
||||
int64_t count = 1;
|
||||
for (const auto &dat : origin_shape) {
|
||||
if (INT_MUL_OVERFLOW(count, dat)) {
|
||||
MS_LOG(ERROR) << "Int mul overflow";
|
||||
return RET_ERROR;
|
||||
}
|
||||
count *= dat;
|
||||
}
|
||||
if (count <= 0 || count > static_cast<int64_t>(INT32_MAX)) {
|
||||
MS_LOG(ERROR) << "tensor element num is too big, which should be smaller than int32_max.";
|
||||
return RET_ERROR;
|
||||
|
@ -174,6 +194,38 @@ bool IsRealKernel(const AnfNodePtr &node) {
|
|||
#endif
|
||||
return !is_virtual_node;
|
||||
}
|
||||
|
||||
int CopyTensorDataFromTensorInfo(const tensor::TensorPtr &tensor_info,
|
||||
const std::shared_ptr<tensor::Tensor> &tensor_info_dst, size_t data_count) {
|
||||
if (tensor_info->data_type() == kNumberTypeInt64) {
|
||||
auto *tensor_data = reinterpret_cast<int *>(tensor_info_dst->data_c());
|
||||
if (tensor_data == nullptr) {
|
||||
MS_LOG(ERROR) << "new data failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto *origin_data = reinterpret_cast<int64_t *>(tensor_info->data_c());
|
||||
for (size_t i = 0; i < data_count; ++i) {
|
||||
if (origin_data[i] > static_cast<int64_t>(INT32_MAX) || origin_data[i] < static_cast<int64_t>(INT32_MIN)) {
|
||||
MS_LOG(WARNING) << "int64 data " << origin_data[i] << "too big to fit into int32";
|
||||
tensor_data[i] = origin_data[i] > 0 ? INT32_MAX : INT32_MIN;
|
||||
} else {
|
||||
tensor_data[i] = static_cast<int>(origin_data[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tensor_info_dst->set_data_type(tensor_info->data_type());
|
||||
auto *tensor_data = reinterpret_cast<int8_t *>(tensor_info_dst->data_c());
|
||||
if (tensor_data == nullptr) {
|
||||
MS_LOG(ERROR) << "new data failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (memcpy_s(tensor_data, tensor_info_dst->Size(), tensor_info->data_c(), tensor_info->Size()) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "memcpy data failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool CheckInputs(const CNodePtr &cnode) {
|
||||
|
@ -351,7 +403,7 @@ bool IsGraphKernel(const AnfNodePtr &node) {
|
|||
return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
}
|
||||
|
||||
ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id) {
|
||||
ParameterPtr AddNewBiasNode(const float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id) {
|
||||
if (bias_data == nullptr || func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "input parameter is nullptr.";
|
||||
return nullptr;
|
||||
|
@ -461,7 +513,7 @@ bool IsParamNode(const BaseRef &n) {
|
|||
return tensor->data_c() != nullptr;
|
||||
}
|
||||
|
||||
STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *tensor_info, const CNodePtr &cnode, size_t index) {
|
||||
STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *const tensor_info, const CNodePtr &cnode, size_t index) {
|
||||
CHECK_NULL_RETURN(tensor_info);
|
||||
CHECK_NULL_RETURN(cnode);
|
||||
AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, index);
|
||||
|
@ -673,7 +725,18 @@ ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
MS_LOG(ERROR) << "new tensor::Tensor failed.";
|
||||
return nullptr;
|
||||
}
|
||||
size_t data_count = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
|
||||
int data_count = 1;
|
||||
for (const auto &dat : shape) {
|
||||
if (INT_MUL_OVERFLOW(data_count, static_cast<int>(dat))) {
|
||||
MS_LOG(ERROR) << "Int mul overflow.";
|
||||
return nullptr;
|
||||
}
|
||||
data_count *= static_cast<int>(dat);
|
||||
}
|
||||
if (data_count < 0) {
|
||||
MS_LOG(ERROR) << "Invalid shape.";
|
||||
return nullptr;
|
||||
}
|
||||
if (tensor_info->Size() == 0) {
|
||||
auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info_new);
|
||||
if (status != RET_OK) {
|
||||
|
@ -682,33 +745,12 @@ ParameterPtr BuildParameterNode(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
}
|
||||
return param_node;
|
||||
}
|
||||
if (tensor_info->data_type() == kNumberTypeInt64) {
|
||||
auto *tensor_data = reinterpret_cast<int *>(tensor_info_new->data_c());
|
||||
if (tensor_data == nullptr) {
|
||||
MS_LOG(ERROR) << "new data failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto *origin_data = reinterpret_cast<int64_t *>(tensor_info->data_c());
|
||||
for (size_t i = 0; i < data_count; ++i) {
|
||||
if (origin_data[i] > static_cast<int64_t>(INT32_MAX) || origin_data[i] < static_cast<int64_t>(INT32_MIN)) {
|
||||
MS_LOG(WARNING) << "int64 data " << origin_data[i] << "too big to fit into int32";
|
||||
tensor_data[i] = origin_data[i] > 0 ? INT32_MAX : INT32_MIN;
|
||||
} else {
|
||||
tensor_data[i] = static_cast<int>(origin_data[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tensor_info_new->set_data_type(tensor_info->data_type());
|
||||
auto *tensor_data = reinterpret_cast<int8_t *>(tensor_info_new->data_c());
|
||||
if (tensor_data == nullptr) {
|
||||
MS_LOG(ERROR) << "new data failed";
|
||||
return nullptr;
|
||||
}
|
||||
if (memcpy_s(tensor_data, tensor_info_new->Size(), tensor_info->data_c(), tensor_info->Size()) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "memcpy data failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (CopyTensorDataFromTensorInfo(tensor_info, tensor_info_new, static_cast<size_t>(data_count)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "copy tensor data failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info_new);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "init parameter from tensor info failed";
|
||||
|
@ -1012,6 +1054,7 @@ bool IsQuantParameterNode(const PrimitiveCPtr &prim) {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void UpdateManager(const FuncGraphPtr &func_graph) {
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
|
@ -1026,5 +1069,62 @@ void UpdateManager(const FuncGraphPtr &func_graph) {
|
|||
manager->AddFuncGraph(one_func_graph);
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<CNodePtr, int> GetRealCertainVarInput(const CNodePtr &cnode, size_t index) {
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr, {}, "function's parameter is nullptr.");
|
||||
MS_CHECK_TRUE_MSG(cnode->input(index) != nullptr, {}, "required input is nullptr");
|
||||
auto real_input_cnode = cnode->input(index)->cast<CNodePtr>();
|
||||
MS_CHECK_TRUE_MSG(real_input_cnode != nullptr, {}, "input node is not a cnode.");
|
||||
int item_index = 0;
|
||||
if (opt::CheckPrimitiveType(real_input_cnode, prim::kPrimTupleGetItem)) {
|
||||
auto index_node = real_input_cnode->input(opt::kInputIndexTwo);
|
||||
MS_CHECK_TRUE_MSG(index_node != nullptr, {}, "tuple_get_item's second input is nullptr.");
|
||||
MS_CHECK_TRUE_MSG(index_node->isa<ValueNode>(), {}, "tuple_get_item's second input should be valuenode.");
|
||||
auto index_ptr = index_node->cast<ValueNodePtr>()->value();
|
||||
MS_CHECK_TRUE_MSG(index_ptr != nullptr, {}, "tuple_get_item's second input val is nullptr.");
|
||||
auto value = CastToInt(index_ptr);
|
||||
MS_CHECK_TRUE_MSG(value.size() == 1, {}, "tuple_get_item's second input is invalid.");
|
||||
item_index = value.front();
|
||||
MS_CHECK_TRUE_MSG(real_input_cnode->input(1) != nullptr, {}, "tuple_get_item's first input is nullptr");
|
||||
real_input_cnode = real_input_cnode->input(1)->cast<CNodePtr>();
|
||||
MS_CHECK_TRUE_MSG(real_input_cnode != nullptr, {}, "tuple_get_item first input is not cnode.");
|
||||
}
|
||||
return {real_input_cnode, item_index};
|
||||
}
|
||||
|
||||
int DetermineCertainVarInputHasInferred(const CNodePtr &cnode, size_t index, bool *infer_succ) {
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr && infer_succ != nullptr, RET_ERROR, "function's parameter is nullptr.");
|
||||
auto var_input_info = GetRealCertainVarInput(cnode, index);
|
||||
if (var_input_info.first == nullptr) {
|
||||
MS_LOG(ERROR) << "cannot get the real var input.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto real_input_cnode = var_input_info.first;
|
||||
auto item_index = var_input_info.second;
|
||||
auto input_node_prim = GetValueNode<PrimitivePtr>((real_input_cnode->input(0)));
|
||||
MS_CHECK_TRUE_MSG(input_node_prim != nullptr, RET_ERROR, "get primitive failed.");
|
||||
*infer_succ = false;
|
||||
auto value_ptr = input_node_prim->GetAttr(kInferDone);
|
||||
if (value_ptr != nullptr) {
|
||||
MS_CHECK_TRUE_MSG(value_ptr->isa<BoolImm>(), RET_ERROR, "value is not a boolean.");
|
||||
*infer_succ = GetValue<bool>(value_ptr);
|
||||
}
|
||||
value_ptr = input_node_prim->GetAttr(kInferFlags);
|
||||
if (value_ptr == nullptr) {
|
||||
return RET_OK;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(value_ptr->isa<ValueSequeue>(), RET_ERROR, "infer flag should be a vector.");
|
||||
auto value_sequence = value_ptr->cast<ValueSequeuePtr>();
|
||||
auto elements = value_sequence->value();
|
||||
MS_CHECK_TRUE_MSG(!elements.empty(), RET_ERROR, "infer_info has no content.");
|
||||
auto first_element = elements.front();
|
||||
MS_CHECK_TRUE_MSG(first_element != nullptr, RET_ERROR, "element is a nullptr.");
|
||||
MS_CHECK_TRUE_MSG(first_element->isa<BoolImm>(), RET_ERROR, "each element is not a boolean.");
|
||||
auto infer_infos = GetValue<std::vector<bool>>(value_ptr);
|
||||
MS_CHECK_TRUE_MSG(item_index >= 0 && static_cast<size_t>(item_index) < infer_infos.size(), RET_ERROR,
|
||||
"item index is out of range.");
|
||||
*infer_succ = infer_infos[item_index];
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ir/anf.h"
|
||||
|
@ -35,6 +36,10 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::lite::STATUS;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
// used for common op, which corresponding value is a boolean.
|
||||
constexpr auto kInferDone = "infer_done";
|
||||
// used for control_flow op(while and if), which corresponding value is a boolean vec.
|
||||
constexpr auto kInferFlags = "infer_flags";
|
||||
inline constexpr int kInputIndexTwo = 2;
|
||||
inline constexpr int kInputIndexThree = 3;
|
||||
inline constexpr int kInputIndexFour = 4;
|
||||
|
@ -64,7 +69,7 @@ bool IsGraphKernel(const AnfNodePtr &node);
|
|||
|
||||
bool CheckInputs(const CNodePtr &cnode);
|
||||
|
||||
ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id);
|
||||
ParameterPtr AddNewBiasNode(const float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, TypeId type_id);
|
||||
|
||||
bool IsParamNode(const BaseRef &n);
|
||||
|
||||
|
@ -130,6 +135,10 @@ bool IsQuantParameterNode(const PrimitiveCPtr &prim);
|
|||
|
||||
void UpdateManager(const FuncGraphPtr &func_graph);
|
||||
|
||||
std::pair<CNodePtr, int> GetRealCertainVarInput(const CNodePtr &cnode, size_t index);
|
||||
|
||||
int DetermineCertainVarInputHasInferred(const CNodePtr &cnode, size_t index, bool *infer_succ);
|
||||
|
||||
template <const PrimitivePtr *prim = nullptr>
|
||||
inline bool IsSpecifiedNode(const BaseRef &n) {
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
|
|
|
@ -148,7 +148,27 @@ STATUS DeleteRedundantTranspose::UpdateNodeFormat(const CNodePtr &cnode) {
|
|||
if (prim->GetAttr(ops::kFormat) == nullptr) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
|
||||
auto forward_format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
|
||||
const int max_search_depth{3};
|
||||
int loop{0};
|
||||
auto search_node = cnode->input(1);
|
||||
while (loop < max_search_depth) {
|
||||
MS_CHECK_TRUE_RET(search_node != nullptr, lite::RET_ERROR);
|
||||
auto search_cnode = search_node->cast<CNodePtr>();
|
||||
if (search_cnode == nullptr) {
|
||||
break;
|
||||
}
|
||||
auto primitive = GetCNodePrimitive(search_cnode);
|
||||
if (primitive == nullptr) {
|
||||
break;
|
||||
}
|
||||
if (primitive->GetAttr(ops::kFormat) != nullptr) {
|
||||
forward_format = GetValue<int64_t>(primitive->GetAttr(ops::kFormat));
|
||||
break;
|
||||
}
|
||||
search_node = search_cnode->input(1);
|
||||
++loop;
|
||||
}
|
||||
auto node_users = manager_->node_users()[cnode];
|
||||
for (auto &node_user : node_users) {
|
||||
if (node_user.second != 1) {
|
||||
|
@ -161,7 +181,7 @@ STATUS DeleteRedundantTranspose::UpdateNodeFormat(const CNodePtr &cnode) {
|
|||
auto post_cnode = node_user.first->cast<CNodePtr>();
|
||||
auto post_prim = GetValueNode<PrimitivePtr>(post_cnode->input(0));
|
||||
MS_ASSERT(post_prim != nullptr);
|
||||
post_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(format));
|
||||
post_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(forward_format));
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
|
|
@ -506,36 +506,34 @@ int DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGra
|
|||
MS_LOG(ERROR) << "Get index failed: " << e.what();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
param_node->set_abstract(GetCNodeInputAbstract(cnode, index)->Clone());
|
||||
auto abstract = GetCNodeInputAbstract(cnode, index);
|
||||
MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is a nullptr.");
|
||||
param_node->set_abstract(abstract->Clone());
|
||||
if (utils::isa<CNodePtr>(cnode->input(index))) {
|
||||
ShapeVector shape_vec = {-1};
|
||||
auto out_cnode = cnode->input(index)->cast<CNodePtr>();
|
||||
MS_ASSERT(out_cnode != nullptr);
|
||||
MS_ASSERT(trans_cnode != nullptr);
|
||||
auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(out_prim != nullptr, lite::RET_ERROR, "GetValueNode failed");
|
||||
if (out_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(kInferDone))) {
|
||||
param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec));
|
||||
bool has_inferred{false};
|
||||
auto ret = DetermineCertainVarInputHasInferred(cnode, index, &has_inferred);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "determine infer flag failed.");
|
||||
if (!has_inferred) {
|
||||
auto abstract_shape = std::make_shared<abstract::Shape>(shape_vec);
|
||||
MS_CHECK_TRUE_MSG(abstract_shape != nullptr, RET_ERROR, "create shape failed.");
|
||||
param_node->abstract()->set_shape(abstract_shape);
|
||||
}
|
||||
} else {
|
||||
}
|
||||
if (utils::isa<Parameter>(cnode->input(index))) {
|
||||
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
|
||||
}
|
||||
if (utils::isa<ValueNode>(cnode->input(index))) {
|
||||
lite::DataInfo data_info;
|
||||
if (utils::isa<ParameterPtr>(cnode->input(index))) {
|
||||
if (cnode->input(index)->cast<ParameterPtr>()->has_default()) {
|
||||
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info, true);
|
||||
if (status != lite::RET_OK) {
|
||||
continue;
|
||||
}
|
||||
ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
|
||||
if (data_info.data_.empty()) {
|
||||
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec));
|
||||
} else {
|
||||
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec,
|
||||
data_info.data_.data(), data_info.data_.size()));
|
||||
}
|
||||
auto tensor_info = lite::CreateTensorInfo(data_info.data_.data(), data_info.data_.size(), shape_vec,
|
||||
static_cast<TypeId>(data_info.data_type_));
|
||||
MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_ERROR, "create a tensor failed.");
|
||||
param_node->set_default_param(tensor_info);
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
|
|
|
@ -38,37 +38,11 @@ int GetCNodeCertainInputFormat(const CNodePtr cnode, int index, mindspore::Forma
|
|||
cnode->set_inputs(origin_inputs);
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto real_cnode = cnode->input(index)->cast<CNodePtr>();
|
||||
MS_ASSERT(real_cnode != nullptr);
|
||||
if (CheckPrimitiveType(real_cnode, prim::kPrimTupleGetItem)) {
|
||||
real_cnode = real_cnode->input(1)->cast<CNodePtr>();
|
||||
if (DetermineCertainVarInputFormat(cnode, index, format) != RET_OK) {
|
||||
MS_LOG(ERROR) << "determine certain var-input's format failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
cnode->set_inputs(origin_inputs);
|
||||
MS_ASSERT(real_cnode != nullptr);
|
||||
auto primitive = GetValueNode<PrimitivePtr>(real_cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(primitive != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
|
||||
if (primitive->GetAttr(ops::kFormat) == nullptr) {
|
||||
MS_LOG(DEBUG) << "cnode has no format attr. " << real_cnode->fullname_with_scope();
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto format_attr = primitive->GetAttr(ops::kFormat);
|
||||
MS_CHECK_TRUE_MSG(format_attr != nullptr, lite::RET_NULL_PTR, "GetAttr Failed");
|
||||
*format = static_cast<mindspore::Format>(GetValue<int64_t>(format_attr));
|
||||
if (CheckPrimitiveType(real_cnode, prim::kPrimTranspose)) {
|
||||
std::vector<int> perm;
|
||||
if (GetTransposePerm(real_cnode, &perm) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get transpose perm failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (perm.size() != DIMENSION_4D) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (perm == kNH2NC && *format == mindspore::NHWC) {
|
||||
*format = mindspore::NCHW;
|
||||
} else if (perm == kNC2NH && *format == mindspore::NCHW) {
|
||||
*format = mindspore::NHWC;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
|
@ -93,6 +67,94 @@ int ModifySubGraphInputCNodeFormat(const FuncGraphPtr &sub_graph, const Paramete
|
|||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int JudgeControlFlowCertainOutputHasInferred(const CNodePtr &return_cnode, size_t index, bool *infer_info) {
|
||||
MS_ASSERT(return_cnode != nullptr && infer_info != nullptr);
|
||||
MS_CHECK_TRUE_MSG(index < return_cnode->size(), RET_ERROR, "input index is out of range.");
|
||||
*infer_info = true;
|
||||
auto abstract_base = GetCNodeInputAbstract(return_cnode, index);
|
||||
MS_CHECK_TRUE_MSG(abstract_base != nullptr, RET_ERROR, "anfnode has no abstract.");
|
||||
ShapeVector shape;
|
||||
auto ret = FetchShapeFromAbstract(abstract_base, &shape);
|
||||
MS_CHECK_TRUE_MSG(ret == lite::RET_OK, RET_ERROR, "fetch shape from abstract failed.");
|
||||
if (std::find(shape.begin(), shape.end(), -1) != shape.end()) {
|
||||
*infer_info = false;
|
||||
return RET_OK;
|
||||
}
|
||||
if (utils::isa<CNodePtr>(return_cnode->input(index))) {
|
||||
ret = DetermineCertainVarInputHasInferred(return_cnode, index, infer_info);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "determine infer flag failed.");
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ModifyWhileBodyGraphInputs(const CNodePtr &cnode, const FuncGraphPtr &sub_graph, const ParameterPtr &graph_input,
|
||||
size_t input_index) {
|
||||
MS_ASSERT(cnode != nullptr && sub_graph != nullptr && graph_input != nullptr);
|
||||
if (!CheckPrimitiveType(cnode, prim::kPrimWhile)) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto body_graph = GetValueNode<FuncGraphPtr>(cnode->input(kInputIndexTwo));
|
||||
MS_ASSERT(body_graph != nullptr);
|
||||
if (body_graph.get() != sub_graph.get()) {
|
||||
MS_LOG(DEBUG) << "sub_graph is not body graph.";
|
||||
return RET_OK;
|
||||
}
|
||||
auto return_cnode = sub_graph->get_return();
|
||||
MS_CHECK_TRUE_MSG(return_cnode != nullptr, RET_ERROR, "return node is a nullptr.");
|
||||
auto origin_outputs = return_cnode->inputs();
|
||||
auto ret = lite::RemoveIfDepend(return_cnode);
|
||||
if (ret != RET_OK) {
|
||||
return_cnode->set_inputs(origin_outputs);
|
||||
MS_LOG(ERROR) << "remove depend node failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = lite::RemoveIfMakeTuple(return_cnode);
|
||||
if (ret != RET_OK) {
|
||||
return_cnode->set_inputs(origin_outputs);
|
||||
MS_LOG(ERROR) << "remove make_tuple node failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
RemoveIfMonad(return_cnode);
|
||||
if (return_cnode->size() == 0 || input_index >= return_cnode->size() - 1) {
|
||||
return_cnode->set_inputs(origin_outputs);
|
||||
MS_LOG(ERROR) << "input index is out of range.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto output = return_cnode->input(input_index + 1);
|
||||
return_cnode->set_inputs(origin_outputs);
|
||||
MS_CHECK_TRUE_MSG(output != nullptr, RET_ERROR, "output node is a nullptr.");
|
||||
if (output->isa<CNode>()) {
|
||||
graph_input->set_default_param(nullptr);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MergeTwoBranchOfIfOp(const CNodePtr &cnode, const CNodePtr &return_cnode, size_t index, bool *true_branch) {
|
||||
MS_ASSERT(cnode != nullptr && return_cnode != nullptr && true_branch != nullptr);
|
||||
*true_branch = true;
|
||||
if (!CheckPrimitiveType(cnode, prim::kPrimIf)) {
|
||||
return RET_OK;
|
||||
}
|
||||
bool infer_info{false};
|
||||
// judge true branch.
|
||||
if (JudgeControlFlowCertainOutputHasInferred(return_cnode, index, &infer_info) != RET_OK) {
|
||||
MS_LOG(ERROR) << "determine certain output has inferred failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (infer_info) {
|
||||
return RET_OK;
|
||||
}
|
||||
// judge false branch.
|
||||
if (JudgeControlFlowCertainOutputHasInferred(cnode, index + kInputSizeThree, &infer_info) != RET_OK) {
|
||||
MS_LOG(ERROR) << "determine certain output has inferred failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (infer_info) {
|
||||
*true_branch = false;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
|
||||
|
@ -190,10 +252,10 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) {
|
|||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
if (sub_func_graph == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return lite::RET_ERROR;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = SetSubGraphInput(cnode, sub_func_graph);
|
||||
if (ret != lite::RET_OK) {
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubGraphInput failed: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -255,21 +317,22 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt
|
|||
auto last_underline = node_name.find_last_of("_");
|
||||
node_name = node_name.substr(0, last_underline);
|
||||
last_underline = node_name.find_last_of("_");
|
||||
auto index = 0;
|
||||
size_t index = 0;
|
||||
try {
|
||||
index = std::stoi(node_name.substr(last_underline + 1)) + 3;
|
||||
index = static_cast<size_t>(std::stoi(node_name.substr(last_underline + 1))) + kInputSizeThree;
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "Get index failed: " << e.what();
|
||||
return RET_ERROR;
|
||||
}
|
||||
param_node->set_abstract(opt::GetCNodeInputAbstract(cnode, index)->Clone());
|
||||
if (utils::isa<CNodePtr>(cnode->input(index))) {
|
||||
auto abstract = GetCNodeInputAbstract(cnode, index);
|
||||
MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is a nullptr.");
|
||||
param_node->set_abstract(abstract->Clone());
|
||||
if (utils::isa<CNode>(cnode->input(index))) {
|
||||
ShapeVector shape_vec = {-1};
|
||||
auto out_cnode = cnode->input(index)->cast<CNodePtr>();
|
||||
MS_ASSERT(trans_cnode != nullptr);
|
||||
auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(out_prim != nullptr, lite::RET_ERROR, "GetValueNode Failed");
|
||||
if (out_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(opt::kInferDone))) {
|
||||
bool has_inferred{false};
|
||||
auto ret = DetermineCertainVarInputHasInferred(cnode, index, &has_inferred);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "determine infer flag failed.");
|
||||
if (!has_inferred) {
|
||||
auto abstract_shape = std::make_shared<abstract::Shape>(shape_vec);
|
||||
CHECK_NULL_RETURN(abstract_shape);
|
||||
param_node->abstract()->set_shape(abstract_shape);
|
||||
|
@ -282,6 +345,7 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt
|
|||
if (ModifySubGraphInputCNodeFormat(sub_graph, param_node, format) != lite::RET_OK) {
|
||||
MS_LOG(DEBUG) << "modify subgraph input cnode format failed." << cnode->func_graph_as_var();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (utils::isa<Parameter>(cnode->input(index))) {
|
||||
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
|
||||
|
@ -298,6 +362,11 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt
|
|||
MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_ERROR, "created tensor is a nullptr.");
|
||||
param_node->set_default_param(tensor_info);
|
||||
}
|
||||
// while's body graph:if the corresponding output is a variable, the corresponding input's data will be set to NULL.
|
||||
if (ModifyWhileBodyGraphInputs(cnode, sub_graph, param_node, index - kInputSizeThree) != RET_OK) {
|
||||
MS_LOG(ERROR) << "modify while body graph's certain input failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -342,47 +411,39 @@ STATUS InferShapePass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGrap
|
|||
lite::RemoveIfDepend(return_node);
|
||||
lite::RemoveIfMakeTuple(return_node);
|
||||
AbstractBasePtrList abstract_list;
|
||||
bool infer_done = true;
|
||||
std::vector<bool> infer_infos;
|
||||
for (size_t i = 1; i < return_node->size(); ++i) {
|
||||
auto abstract_base = opt::GetCNodeInputAbstract(return_node, i);
|
||||
MS_CHECK_TRUE_RET(abstract_base != nullptr, lite::RET_ERROR);
|
||||
abstract_list.emplace_back(abstract_base->Clone());
|
||||
auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
|
||||
MS_ASSERT(abstract_tensor != nullptr);
|
||||
auto shape_ptr = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape());
|
||||
MS_ASSERT(shape_ptr != nullptr);
|
||||
auto shape = shape_ptr->shape();
|
||||
if (std::find(shape.begin(), shape.end(), -1) != shape.end()) {
|
||||
infer_done = false;
|
||||
}
|
||||
if (utils::isa<CNodePtr>(return_node->input(i))) {
|
||||
auto input_cnode = return_node->input(i)->cast<CNodePtr>();
|
||||
MS_ASSERT(input_cnode != nullptr);
|
||||
if (opt::CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
|
||||
input_cnode = input_cnode->input(1)->cast<CNodePtr>();
|
||||
}
|
||||
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
|
||||
CHECK_NULL_RETURN(input_prim);
|
||||
if (input_prim->GetAttr(opt::kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(opt::kInferDone))) {
|
||||
infer_done = false;
|
||||
bool true_branch{false};
|
||||
auto ret = MergeTwoBranchOfIfOp(cnode, return_node, i, &true_branch);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "decide to fetch which branch failed.");
|
||||
AbstractBasePtr abstract;
|
||||
bool infer_info;
|
||||
if (true_branch) {
|
||||
abstract = GetCNodeInputAbstract(return_node, i);
|
||||
if (JudgeControlFlowCertainOutputHasInferred(return_node, i, &infer_info) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "determine certain output has inferred failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
abstract = GetCNodeInputAbstract(cnode, i + kInputSizeThree);
|
||||
infer_info = true;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "get a nullptr abstract.");
|
||||
abstract_list.emplace_back(abstract->Clone());
|
||||
infer_infos.push_back(infer_info);
|
||||
}
|
||||
return_node->set_inputs(origin_inputs);
|
||||
if (utils::isa<abstract::AbstractTuplePtr>(cnode->abstract())) {
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
CHECK_NULL_RETURN(abstract_tuple);
|
||||
MS_CHECK_TRUE_MSG(abstract_tuple != nullptr, RET_ERROR, "created AbstractTuple is a nullptr.");
|
||||
cnode->set_abstract(abstract_tuple);
|
||||
} else {
|
||||
if (abstract_list.size() != 1) {
|
||||
MS_LOG(ERROR) << "cnode output is invalid.";
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(abstract_list.size() == 1, RET_ERROR, "cnode output is invalid.");
|
||||
cnode->set_abstract(abstract_list.front());
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
CHECK_NULL_RETURN(prim);
|
||||
infer_done = CheckPrimitiveType(cnode, prim::kPrimWhile) ? false : infer_done;
|
||||
prim->AddAttr(opt::kInferDone, MakeValue<bool>(infer_done));
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "cnode's input0 is not a primitive.");
|
||||
prim->AddAttr(kInferFlags, MakeValue(infer_infos));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -197,18 +197,10 @@ int GetCNodeVarInput(const CNodePtr &cnode, std::vector<TensorPtr> *var_ms_input
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
tensor->set_format((Format)(data_info.format_));
|
||||
MS_ASSERT(cnode->input(i) != nullptr);
|
||||
auto input_cnode = cnode->input(i)->cast<CNodePtr>();
|
||||
MS_ASSERT(input_cnode != nullptr);
|
||||
auto input_prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
|
||||
if (CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
|
||||
MS_CHECK_TRUE_RET(input_cnode->input(1) != nullptr, lite::RET_NULL_PTR);
|
||||
auto item_input_cnode = input_cnode->input(1)->cast<CNodePtr>();
|
||||
MS_CHECK_TRUE_RET(item_input_cnode != nullptr, lite::RET_NULL_PTR);
|
||||
input_prim = GetValueNode<PrimitivePtr>(item_input_cnode->input(0));
|
||||
}
|
||||
MS_CHECK_TRUE_RET(input_prim != nullptr, lite::RET_NULL_PTR);
|
||||
if (input_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(input_prim->GetAttr(kInferDone))) {
|
||||
bool has_inferred{false};
|
||||
auto ret = DetermineCertainVarInputHasInferred(cnode, i, &has_inferred);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "determine infer flag failed.");
|
||||
if (!has_inferred) {
|
||||
tensor->set_shape({-1});
|
||||
}
|
||||
var_ms_inputs->emplace_back(tensor);
|
||||
|
|
|
@ -142,9 +142,7 @@ STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
|
|||
}
|
||||
if (ret == lite::RET_OK || ret == lite::RET_INFER_INVALID) {
|
||||
auto set_status = SetCNodeAbstract(cnode, outputs, ret);
|
||||
auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(cnode_prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
|
||||
cnode_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(inputs[0]->format()));
|
||||
anf_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(inputs[0]->format()));
|
||||
if (set_status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope();
|
||||
return set_status;
|
||||
|
@ -152,6 +150,12 @@ STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
|
|||
} else {
|
||||
MS_LOG(WARNING) << "infer shape failed.";
|
||||
}
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimCustom)) {
|
||||
std::vector<int64_t> outputs_format;
|
||||
std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_format),
|
||||
[](const lite::Tensor *output) { return output->format(); });
|
||||
anf_prim->AddAttr(kOutputsFormat, MakeValue(outputs_format));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue