310 codex and online infer bugfix

This commit is contained in:
chenping 2022-02-18 18:01:37 +08:00
parent 25c3d507cd
commit 8e20d1e4c3
18 changed files with 123 additions and 152 deletions

View File

@ -576,9 +576,11 @@ int LiteSession::IsolateOutputTensor() {
new_tensor->set_init_ref_count(src_tensor->init_ref_count());
/* src tensor set for graph calculate */
#ifdef ENABLE_FP16
if (src_tensor->data_type() == kNumberTypeFloat16) {
src_tensor->set_data_type(kNumberTypeFloat32);
}
#endif
src_tensor->set_ref_count(1);
isolate_graph_output_map_.insert(std::make_pair(new_tensor, src_tensor));

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -32,6 +32,7 @@ STATUS AclOptionsParser::ParseAclOptions(const mindspore::Context *ctx, AclModel
CHECK_NULL_RETURN(acl_options);
auto context = const_cast<mindspore::Context *>(ctx);
CHECK_NULL_RETURN(context);
auto device_infos = context->MutableDeviceInfo();
if (device_infos.size() < 1) {
MS_LOG(WARNING) << "Context is not set device info, please check.";

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -24,6 +24,7 @@
#include "utils/utils.h"
#include "src/common/log_util.h"
#include "ir/func_graph.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace lite {
@ -35,24 +36,23 @@ constexpr size_t kInvalidSize = SIZE_MAX;
} // namespace
static size_t GetTupleGetItemOutIndex(const mindspore::CNodePtr &tuple_get_item) {
MS_ASSERT(tuple_get_item != nullptr);
if (tuple_get_item->size() != mindspore::kTupleGetItemInputSize) {
MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!";
return kInvalidSize;
}
MS_CHECK_TRUE_MSG(tuple_get_item != nullptr, kInvalidSize, "tuple_get_item is nullptr.");
MS_CHECK_TRUE_MSG(tuple_get_item->size() == mindspore::kTupleGetItemInputSize, kInvalidSize,
"The node tuple_get_item must have 3 inputs!");
auto output_index_value_node = tuple_get_item->input(mindspore::kInputNodeOutputIndexInTupleGetItem);
MS_ASSERT(output_index_value_node != nullptr);
MS_CHECK_TRUE_MSG(output_index_value_node != nullptr, kInvalidSize, "output_index_value_node is nullptr.");
auto value_node = output_index_value_node->cast<mindspore::ValueNodePtr>();
MS_ASSERT(value_node != nullptr);
return IntToSize(opt::CastToInt(value_node->value()).front());
MS_CHECK_TRUE_MSG(value_node != nullptr, kInvalidSize, "value_node is nullptr.");
auto values = opt::CastToInt(value_node->value());
MS_CHECK_TRUE_MSG(values.size() > 0, kInvalidSize, "value_node has no value.");
return IntToSize(values.front());
}
static bool CheckPrimitiveType(const mindspore::AnfNodePtr &node, const mindspore::PrimitivePtr &primitive_type) {
if (node == nullptr) {
return false;
}
MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
if (node->isa<mindspore::CNode>()) {
auto cnode = node->cast<mindspore::CNodePtr>();
MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cnode is nullptr.");
return IsPrimitive(cnode->input(0), primitive_type);
} else if (node->isa<mindspore::ValueNode>()) {
return IsPrimitive(node, primitive_type);
@ -64,9 +64,9 @@ STATUS GetShapeVectorFromCNode(const mindspore::CNodePtr &cnode, std::vector<int
mindspore::AbstractBasePtr cnode_abstract;
if (CheckPrimitiveType(cnode, mindspore::prim::kPrimTupleGetItem)) {
auto tuple_inputs = cnode->inputs();
MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize);
MS_CHECK_TRUE_MSG(tuple_inputs.size() == kTupleGetItemInputSize, lite::RET_ERROR, "The node must has 3 inputs.");
auto get_item_input_cnode = tuple_inputs.at(kSecondIndex);
MS_ASSERT(get_item_input_cnode != nullptr);
MS_CHECK_TRUE_MSG(get_item_input_cnode != nullptr, lite::RET_ERROR, "input cnode is nullptr.");
auto idx = GetTupleGetItemOutIndex(cnode);
if (!mindspore::utils::isa<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
@ -106,8 +106,10 @@ STATUS GetShapeVectorFromCNode(const mindspore::CNodePtr &cnode, std::vector<int
TypeId GetTypeFromNode(const AnfNodePtr &node) {
TypeId type = kNumberTypeFloat32;
MS_CHECK_TRUE_MSG(node != nullptr, type, "node is nullptr.");
if (utils::isa<CNodePtr>(node)) {
auto cnode = node->cast<CNodePtr>();
MS_CHECK_TRUE_MSG(cnode != nullptr, type, "cnode is nullptr.");
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
@ -115,6 +117,7 @@ TypeId GetTypeFromNode(const AnfNodePtr &node) {
return type;
}
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
MS_CHECK_TRUE_MSG(type_ptr != nullptr, type, "type_ptr is nullptr.");
type = type_ptr->type_id();
}
MS_LOG(INFO) << "node type id is " << type;
@ -124,31 +127,27 @@ TypeId GetTypeFromNode(const AnfNodePtr &node) {
std::vector<int> GetIntParameterData(const ParameterPtr &param_ptr) {
std::vector<int> result;
if (param_ptr == nullptr) {
MS_LOG(DEBUG) << "Param is nullptr.";
return result;
}
MS_CHECK_TRUE_MSG(param_ptr != nullptr, result, "Param is nullptr.");
if (!param_ptr->has_default()) {
MS_LOG(DEBUG) << "Param has not default.";
return result;
}
auto default_param = param_ptr->default_param();
MS_CHECK_TRUE_MSG(default_param != nullptr, result, "default_param is nullptr.");
if (!utils::isa<tensor::TensorPtr>(default_param)) {
MS_LOG(DEBUG) << "Tensor info is not tensor::TensorPtr.";
return result;
}
auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
if (default_param_ptr == nullptr) {
MS_LOG(DEBUG) << "Default param ptr is nullptr.";
return result;
}
MS_CHECK_TRUE_MSG(default_param_ptr != nullptr, result, "default_param_ptr is nullptr.");
if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) {
MS_LOG(DEBUG) << "Default param is not int.";
return result;
}
auto ptr = reinterpret_cast<int *>(default_param_ptr->data_c());
MS_CHECK_TRUE_MSG(ptr != nullptr, result, "ptr is nullptr.");
int shape_size =
std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<int>());
for (int i = 0; i < shape_size; i++) {
@ -158,6 +157,7 @@ std::vector<int> GetIntParameterData(const ParameterPtr &param_ptr) {
}
bool IsCaseNode(const CNodePtr node) {
MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
if (node->input(0) == nullptr) {
MS_LOG(WARNING) << "The input of node is nullptr.";
return false;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,6 +25,7 @@
#include "ops/relu6.h"
#include "ops/sigmoid.h"
#include "ops/tanh.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace lite {
@ -44,10 +45,7 @@ STATUS ActivationMapper::Mapper(const CNodePtr &cnode) {
return lite::RET_ERROR;
}
auto activate_prim = dynamic_cast<ops::Activation *>(src_prim.get());
if (activate_prim == nullptr) {
MS_LOG(ERROR) << "Dynamic cast activation failed.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(activate_prim != nullptr, lite::RET_ERROR, "Dynamic cast activation failed.");
PrimitivePtr dst_prim = nullptr;
ActivationType type = activate_prim->get_activation_type();
if (activation_type_map.find(type) != activation_type_map.end()) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,6 +17,7 @@
#include "tools/converter/adapter/acl/mapper/cast_mapper.h"
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
#include "tools/converter/adapter/acl/common/utils.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace lite {
@ -25,33 +26,26 @@ constexpr size_t kNameCastInputNum = 3;
} // namespace
STATUS CastMapper::Mapper(const CNodePtr &cnode) {
if (cnode == nullptr) {
MS_LOG(ERROR) << "Cnode is nullptr.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(cnode != nullptr, lite::RET_ERROR, "Cnode is nullptr.");
if (cnode->size() != kNameCastInputNum) {
MS_LOG(ERROR) << "Input size of cast must be " << kNameCastInputNum << ", real size: " << cnode->size();
return lite::RET_ERROR;
}
// convert last parameter to const value node
auto to_input = cnode->input(kNameCastInputNum - 1);
MS_CHECK_TRUE_MSG(to_input != nullptr, lite::RET_ERROR, "to_input is nullptr.");
if (!utils::isa<ParameterPtr>(to_input)) {
MS_LOG(ERROR) << "The to node is not parameter.";
return lite::RET_ERROR;
}
ParameterPtr to_param = to_input->cast<ParameterPtr>();
MS_CHECK_TRUE_MSG(to_param != nullptr, lite::RET_ERROR, "to_param is nullptr.");
auto data = acl::GetIntParameterData(to_param);
int dst_type = data.empty() ? kNumberTypeInt32 : data.front();
TypePtr type_ptr = TypeIdToType(TypeId(dst_type));
if (type_ptr == nullptr) {
MS_LOG(ERROR) << "New type ptr failed.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(type_ptr != nullptr, lite::RET_ERROR, "New type ptr failed.");
ValueNodePtr value_node = NewValueNode(type_ptr);
if (value_node == nullptr) {
MS_LOG(ERROR) << "New value node failed.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(value_node != nullptr, lite::RET_ERROR, "New value node failed.");
cnode->set_input(kNameCastInputNum - 1, value_node);
return lite::RET_OK;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,6 +21,7 @@
#include "tools/converter/adapter/acl/common/utils.h"
#include "include/registry/converter_context.h"
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace lite {
@ -44,7 +45,7 @@ STATUS Conv2dTransposeMapper::Mapper(const CNodePtr &cnode) {
} else {
dst_prim = std::make_shared<acl::Conv2DTransposeD>();
}
MS_ASSERT(dst_prim != nullptr);
MS_CHECK_TRUE_MSG(dst_prim != nullptr, RET_ERROR, "dst_prim is nullptr.");
dst_prim->SetAttrs(src_prim->attrs());
if (fmk_type != converter::kFmkTypeCaffe) {
if (AdjustGeAttr(cnode, dst_prim) != RET_OK) {
@ -62,6 +63,7 @@ STATUS Conv2dTransposeMapper::Mapper(const CNodePtr &cnode) {
}
STATUS Conv2dTransposeMapper::AdjustGeAttr(const CNodePtr &cnode, const PrimitivePtr &dst_prim) {
MS_CHECK_TRUE_MSG(dst_prim != nullptr, RET_ERROR, "dst_prim is nullptr.");
std::vector<int64_t> shape = {0, 0, 0, 0};
dst_prim->AddAttr("input_size", MakeValue(shape));

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,6 +18,7 @@
#include <vector>
#include <map>
#include "ops/op_utils.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace lite {
@ -32,6 +33,7 @@ constexpr auto kNamePaddingMode = "padding";
STATUS ConvBaseMapper::AdjustAttrPad(const PrimitivePtr &prim) {
// attr pad val
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "prim is nullptr.");
auto pad_ptr = prim->GetAttr(ops::kPadList);
if (pad_ptr == nullptr) {
std::vector<int64_t> pad_list = {0, 0, 0, 0};

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,6 +17,7 @@
#include "tools/converter/adapter/acl/mapper/gather_fusion_mapper.h"
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
#include "tools/converter/adapter/acl/common/utils.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace lite {
@ -25,28 +26,24 @@ constexpr size_t kNameGatherInputNum = 4;
}
STATUS GatherMapper::Mapper(const CNodePtr &cnode) {
if (cnode == nullptr) {
MS_LOG(ERROR) << "Cnode is nullptr.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(cnode != nullptr, lite::RET_ERROR, "Cnode is nullptr.");
if (cnode->size() != kNameGatherInputNum) {
MS_LOG(ERROR) << "Input size of gather must be " << kNameGatherInputNum << ", real size: " << cnode->size();
return lite::RET_ERROR;
}
// convert last parameter to const value node
auto axis_input = cnode->input(kNameGatherInputNum - 1);
MS_CHECK_TRUE_MSG(axis_input != nullptr, lite::RET_ERROR, "axis_input is nullptr.");
if (!utils::isa<ParameterPtr>(axis_input)) {
MS_LOG(ERROR) << "The axis node is not parameter.";
return lite::RET_ERROR;
}
ParameterPtr axis_param = axis_input->cast<ParameterPtr>();
MS_CHECK_TRUE_MSG(axis_param != nullptr, lite::RET_ERROR, "axis_param is nullptr.");
auto data = acl::GetIntParameterData(axis_param);
int64_t axis = data.empty() ? 0 : static_cast<int64_t>(data.front());
ValueNodePtr value_node = NewValueNode<int64_t>(axis);
if (value_node == nullptr) {
MS_LOG(ERROR) << "New value node failed.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(value_node != nullptr, lite::RET_ERROR, "New value node failed.");
cnode->set_input(kNameGatherInputNum - 1, value_node);
return lite::RET_OK;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
#include "include/registry/converter_context.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace lite {
@ -40,10 +41,7 @@ STATUS MaxPoolFusionMapper::Mapper(const CNodePtr &cnode) {
} else {
dst_prim = std::make_shared<ops::MaxPool>();
}
if (dst_prim == nullptr) {
MS_LOG(ERROR) << "Get primitive by fmk type failed.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "Get primitive by fmk type failed.");
dst_prim->SetAttrs(src_prim->attrs());
if (AdjustPoolAttr(fmk_type, kNameMaxPoolFusion, dst_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Adjust pool attr failed.";

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -61,6 +61,7 @@ STATUS PrimitiveMapper::GetValueNodeAndPrimFromCnode(const CNodePtr &cnode, Valu
}
STATUS PrimitiveMapper::AttrAdjust(const PrimitivePtr &prim, const std::string &name) const {
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "prim is nullptr.");
auto value_ptr = prim->GetAttr(name);
if (value_ptr == nullptr) {
MS_LOG(WARNING) << prim->name() << " has no attr " << name;
@ -143,6 +144,7 @@ STATUS PrimitiveMapper::AdjustPoolAttr(int fmk_type, const std::string &src_prim
AdjustOnnxPoolAttr(dst_prim);
}
// adjust common attr
MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
auto status = AttrAdjust(dst_prim, ops::kKernelSize);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Adjust kernel size failed.";
@ -163,10 +165,7 @@ STATUS PrimitiveMapper::MoveAttrMap(const CNodePtr &cnode, const PrimitivePtr &d
MS_LOG(ERROR) << "Get primitive from cnode failed.";
return lite::RET_ERROR;
}
if (dst_prim == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
dst_prim->SetAttrs(src_prim->attrs());
value_node->set_value(dst_prim);
return lite::RET_OK;
@ -174,6 +173,7 @@ STATUS PrimitiveMapper::MoveAttrMap(const CNodePtr &cnode, const PrimitivePtr &d
STATUS PrimitiveMapper::AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const PrimitivePtr &dst_prim, const std::string &attr_name, size_t flag) const {
MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
auto attr_val = dst_prim->GetAttr(attr_name);
if (attr_val == nullptr) {
MS_LOG(INFO) << "There is no attr: " << attr_name;
@ -232,6 +232,7 @@ STATUS PrimitiveMapper::AddAttrForDynInputPrimitive(const CNodePtr &cnode, const
STATUS PrimitiveMapper::AdjustAttrFormat(const PrimitivePtr &prim, const std::string &name) const {
int64_t format = Format::NCHW;
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "prim is nullptr.");
if (prim->GetAttr(ops::kFormat) != nullptr) {
format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -45,18 +45,14 @@ const std::set<std::string> kCNodeWithDynamicInput = {kNamewiEltwise, ops::kName
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode) {
CNodePtr get_item_cnode = nullptr;
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "New TupleGetItem failed";
return nullptr;
}
MS_CHECK_TRUE_MSG(tuple_get_item_prim_ptr != nullptr, nullptr, "New TupleGetItem failed.");
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
MS_CHECK_TRUE_MSG(tuple_get_item_prim != nullptr, nullptr, "tuple_prim is nullptr.");
auto get_item_value = NewValueNode(MakeValue<int64_t>(0));
MS_CHECK_TRUE_MSG(get_item_value != nullptr, nullptr, "item_value is nullptr.");
AnfNodePtrList inputs{tuple_get_item_prim, input_cnode, get_item_value};
get_item_cnode = func_graph->NewCNode(inputs);
if (get_item_cnode == nullptr) {
MS_LOG(ERROR) << "New get item cnode failed.";
return nullptr;
}
MS_CHECK_TRUE_MSG(get_item_cnode != nullptr, nullptr, "New get item cnode failed.");
std::vector<int64_t> shape;
if (acl::GetShapeVectorFromCNode(input_cnode, &shape) != lite::RET_OK) {
@ -65,10 +61,7 @@ CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &
}
TypeId type = acl::GetTypeFromNode(input_cnode);
auto get_item_abstract = CreateTensorAbstract(shape, type);
if (get_item_abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstract failed.";
return nullptr;
}
MS_CHECK_TRUE_MSG(get_item_abstract != nullptr, nullptr, "Create tensor abstract failed.");
get_item_cnode->set_abstract(get_item_abstract);
get_item_cnode->set_fullname_with_scope(input_cnode->fullname_with_scope() + "_getitem");
return get_item_cnode;
@ -83,10 +76,12 @@ static STATUS AdapteNodeWithMultiOutputs(const FuncGraphPtr &func_graph, const C
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
auto input = cnode->input(i);
MS_CHECK_TRUE_MSG(input != nullptr, lite::RET_ERROR, "input is nullptr.");
if (!utils::isa<CNode>(input)) {
continue;
}
auto input_cnode = input->cast<CNodePtr>();
MS_CHECK_TRUE_MSG(input_cnode != nullptr, lite::RET_ERROR, "input_cnode is nullptr.");
std::string input_func_name = GetCNodeFuncName(input_cnode);
if (kCNodeWithMultiOutputs.find(input_func_name) != kCNodeWithMultiOutputs.end()) {
MS_LOG(DEBUG) << "Input " << input_func_name << " of cnode " << cnode_func_name << " has multioutputs";
@ -111,20 +106,14 @@ static STATUS AdapteNodeWithDynamicInput(const FuncGraphPtr &func_graph, const C
}
MS_LOG(DEBUG) << "Adapter cnode with dynamic input: " << cnode_func_name;
auto make_tuple_val_node = NewValueNode(prim::kPrimMakeTuple);
if (make_tuple_val_node == nullptr) {
MS_LOG(ERROR) << "New make tuple val node failed.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(make_tuple_val_node != nullptr, lite::RET_ERROR, "New make tuple val node failed.");
AnfNodePtrList new_inputs = {make_tuple_val_node};
auto cnode_inputs = cnode->inputs();
if (cnode_inputs.size() >= kCnodeInputMinNum) {
new_inputs.insert(new_inputs.end(), cnode_inputs.begin() + 1, cnode_inputs.end());
}
auto make_tuple_cnode = func_graph->NewCNode(new_inputs);
if (make_tuple_cnode == nullptr) {
MS_LOG(ERROR) << "New make tuple cnode failed.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(make_tuple_cnode != nullptr, lite::RET_ERROR, "New make tuple cnode failed.");
const std::vector<AnfNodePtr> replace_node = {cnode_inputs[0], make_tuple_cnode};
cnode->set_inputs(replace_node);
@ -134,10 +123,7 @@ static STATUS AdapteNodeWithDynamicInput(const FuncGraphPtr &func_graph, const C
STATUS AdapteSpatialNode(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
auto cnodes = func_graph->GetOrderedCnodes();
for (const auto &cnode : cnodes) {
if (cnode == nullptr) {
MS_LOG(ERROR) << "Cnode is nullptr.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(cnode != nullptr, lite::RET_ERROR, "Cnode is nullptr.");
if (AdapteNodeWithMultiOutputs(func_graph, cnode, manager) != lite::RET_OK) {
MS_LOG(ERROR) << "Adapter node with multioutput failed.";
return lite::RET_ERROR;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
#include "include/registry/converter_context.h"
#include "ops/op_utils.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace lite {
@ -35,7 +36,7 @@ STATUS StridedSliceMapper::Mapper(const CNodePtr &cnode) {
int fmk_type = attr_val != nullptr ? GetValue<int>(attr_val) : converter::kFmkTypeTf;
if (fmk_type == converter::kFmkTypeOnnx) {
auto dst_prim = std::make_shared<acl::StridedSliceV2>();
MS_ASSERT(dst_prim != nullptr);
MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
dst_prim->SetAttrs(src_prim->attrs());
value_node->set_value(dst_prim);
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@
#include <vector>
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
#include "tools/converter/adapter/acl/common/utils.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace lite {
@ -27,30 +28,26 @@ constexpr size_t kCommonInputNum = 3;
}
STATUS TransposeMapper::Mapper(const CNodePtr &cnode) {
if (cnode == nullptr) {
MS_LOG(ERROR) << "Cnode is nullptr.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(cnode != nullptr, lite::RET_ERROR, "Cnode is nullptr.");
if (cnode->size() != kCommonInputNum) {
MS_LOG(ERROR) << "Input size of transpose must be " << kCommonInputNum << ", real size: " << cnode->size();
return lite::RET_ERROR;
}
// convert last parameter to const value node
auto perm_input = cnode->input(kCommonInputNum - 1);
MS_CHECK_TRUE_MSG(perm_input != nullptr, lite::RET_ERROR, "perm_input is nullptr.");
if (!utils::isa<ParameterPtr>(perm_input)) {
MS_LOG(ERROR) << "The perm node is not parameter.";
return lite::RET_ERROR;
}
ParameterPtr perm_param = perm_input->cast<ParameterPtr>();
MS_CHECK_TRUE_MSG(perm_param != nullptr, lite::RET_ERROR, "perm_param is nullptr.");
auto data = acl::GetIntParameterData(perm_param);
std::vector<int64_t> perm;
std::transform(data.begin(), data.end(), std::back_inserter(perm),
[](int32_t n) -> int64_t { return static_cast<int64_t>(n); });
ValueNodePtr value_node = NewValueNode<std::vector<int64_t>>(perm);
if (value_node == nullptr) {
MS_LOG(ERROR) << "New value node failed.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(value_node != nullptr, lite::RET_ERROR, "New value node failed.");
cnode->set_input(kCommonInputNum - 1, value_node);
return lite::RET_OK;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,6 +21,7 @@
#include "tools/optimizer/common/gllo_utils.h"
#include "src/common/log_util.h"
#include "ops/op_utils.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace lite {
@ -60,6 +61,8 @@ STATUS UpsampleMapper::Mapper(const CNodePtr &cnode) {
}
STATUS UpsampleMapper::AttrAdjust(const PrimitivePtr &src_prim, const ValueNodePtr &val_node) {
MS_CHECK_TRUE_MSG(src_prim != nullptr, RET_ERROR, "src_prim is nullptr.");
MS_CHECK_TRUE_MSG(val_node != nullptr, RET_ERROR, "val_node is nullptr.");
auto attr_val = src_prim->GetAttr("scale");
CHECK_NULL_RETURN(attr_val);
std::vector<float> scale = opt::CastToFloat(attr_val);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -166,6 +166,7 @@ STATUS AclPassImpl::PostProcGraph(const FuncGraphPtr &func_graph) {
}
std::string AclPassImpl::AdjustCnodeName(const PrimitivePtr &prim) {
MS_CHECK_TRUE_MSG(prim != nullptr, "", "prim is nullptr.");
std::string name = prim->name();
if (kAdjustCnodeName.find(name) != kAdjustCnodeName.end()) {
auto val_ptr = prim->GetAttr(ops::kOriginalOpName);
@ -180,7 +181,7 @@ std::string AclPassImpl::AdjustCnodeName(const PrimitivePtr &prim) {
STATUS AclPassImpl::RunPrimitiveMapper(const FuncGraphPtr &func_graph) {
MS_LOG(INFO) << "Deparser graph start.";
MS_ASSERT(func_graph != nullptr);
MS_CHECK_TRUE_MSG(func_graph != nullptr, lite::RET_ERROR, "func_graph is nullptr.");
std::set<FuncGraphPtr> all_func_graphs = {};
lite::GetAllFuncGraph(func_graph, &all_func_graphs);
for (auto graph : all_func_graphs) {
@ -190,6 +191,7 @@ STATUS AclPassImpl::RunPrimitiveMapper(const FuncGraphPtr &func_graph) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_CHECK_TRUE_MSG(cnode != nullptr, lite::RET_ERROR, "cnode is nullptr.");
auto prim = GetCNodePrimitive(cnode);
CHECK_NULL_RETURN(prim);
std::string name = AdjustCnodeName(prim);
@ -230,10 +232,7 @@ STATUS AclPassImpl::DeparseGraph(const FuncGraphPtr &func_graph, const FuncGraph
}
STATUS AclPassImpl::ConvertGraphToOm(const FuncGraphPtr &func_graph, Buffer *om_data) {
if (om_data == nullptr) {
MS_LOG(ERROR) << "Om data is nullptr.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(om_data != nullptr, lite::RET_ERROR, "Om data is nullptr.");
SetAclModelOptions(func_graph);
// call interface of cloud
ModelConverter model_converter;
@ -287,13 +286,9 @@ void AclPassImpl::SetAclModelBuildOptions(const std::shared_ptr<AscendDeviceInfo
std::shared_ptr<mindspore::Context> AclPassImpl::CreateModelContext() {
auto model_context = std::make_shared<mindspore::Context>();
if (model_context == nullptr) {
return nullptr;
}
MS_CHECK_TRUE_MSG(model_context != nullptr, nullptr, "model_context is nullptr.");
auto ascend_info = std::make_shared<AscendDeviceInfo>();
if (ascend_info == nullptr) {
return nullptr;
}
MS_CHECK_TRUE_MSG(ascend_info != nullptr, nullptr, "ascend_info is nullptr.");
ascend_info->SetDeviceID(user_options_cfg_.device_id);
SetAclModelInitOptions(ascend_info);
SetAclModelBuildOptions(ascend_info);
@ -304,6 +299,7 @@ std::shared_ptr<mindspore::Context> AclPassImpl::CreateModelContext() {
STATUS AclPassImpl::SetAclModelOptions(const FuncGraphPtr &func_graph) {
MS_LOG(INFO) << "Set acl model options start.";
MS_CHECK_TRUE_MSG(func_graph != nullptr, lite::RET_ERROR, "func_graph is nullptr.");
auto model_context = CreateModelContext();
CHECK_NULL_RETURN(model_context);
options_ = std::make_shared<AclModelOptions>(model_context);
@ -328,21 +324,23 @@ STATUS AclPassImpl::SetAclModelOptions(const FuncGraphPtr &func_graph) {
}
ParameterPtr AclPassImpl::CreateOmParameter(const FuncGraphPtr &func_graph, const Buffer &om_data) {
MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "func_graph is nullptr.");
ParameterPtr om_parameter = func_graph->add_parameter();
MS_CHECK_TRUE_MSG(om_parameter != nullptr, nullptr, "om_parameter is nullptr.");
om_parameter->set_name("ACL_om_data");
auto type_ptr = TypeIdToType(kNumberTypeUInt8);
MS_CHECK_TRUE_MSG(type_ptr != nullptr, nullptr, "type_ptr is nullptr.");
ShapeVector shape_vector = {static_cast<int64_t>(om_data.DataSize())};
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, nullptr, "abstract_tensor is nullptr.");
om_parameter->set_abstract(abstract_tensor);
auto param_value =
std::make_shared<tensor::Tensor>(kNumberTypeUInt8, ShapeVector({static_cast<int64_t>(om_data.DataSize())}));
MS_CHECK_TRUE_MSG(param_value != nullptr, nullptr, "param_value is nullptr.");
auto tensor_data = param_value->data_c();
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "New Tensor failed.";
return nullptr;
}
MS_CHECK_TRUE_MSG(tensor_data != nullptr, nullptr, "New Tensor failed.");
if (param_value->Size() < om_data.DataSize()) {
MS_LOG(ERROR) << "Dst buff size " << param_value->Size() << " should be greater than src buff size "
<< om_data.DataSize();
@ -393,10 +391,7 @@ STATUS AclPassImpl::BuildGraph(const FuncGraphPtr &func_graph) {
return lite::RET_ERROR;
}
om_parameter_ = CreateOmParameter(func_graph, om_data);
if (om_parameter_ == nullptr) {
MS_LOG(ERROR) << "Convert graph to om failed.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(om_parameter_ != nullptr, lite::RET_ERROR, "Convert graph to om failed.");
if (!user_options_cfg_.insert_op_config_file_path.empty()) {
if (CreateGraphAippInput(func_graph, om_data) != lite::RET_OK) {
MS_LOG(ERROR) << "Create aipp input failed.";
@ -416,6 +411,7 @@ STATUS AclPassImpl::TraceOutput(const AnfNodePtr &node) {
CHECK_NULL_RETURN(tmp);
cur_node = tmp->input(kTupleGetItemFirstInputIdx);
}
CHECK_NULL_RETURN(cur_node);
auto cnode = cur_node->cast<CNodePtr>();
CHECK_NULL_RETURN(cnode);
std::string name = lite::acl::GetCNodeTargetFuncName(cnode);
@ -453,6 +449,7 @@ STATUS AclPassImpl::TraceOutput(const AnfNodePtr &node) {
}
STATUS AclPassImpl::GetFuncGraphOutputInfo(const FuncGraphPtr &func_graph) {
MS_CHECK_TRUE_MSG(func_graph != nullptr, lite::RET_ERROR, "func_graph is nullptr.");
AnfNodePtr return_input = func_graph->output();
CHECK_NULL_RETURN(return_input);
if (TraceOutput(return_input) != lite::RET_OK) {
@ -468,10 +465,11 @@ STATUS AclPassImpl::GetFuncGraphOutputInfo(const FuncGraphPtr &func_graph) {
return lite::RET_OK;
}
STATUS AclPassImpl::SetMultiOutputs(const CNodePtr &new_cnode, TypeId data_type) {
STATUS AclPassImpl::SetMultiOutputs(const CNodePtr &new_cnode, std::vector<TypeId> data_type) {
MS_CHECK_TRUE_MSG(new_cnode != nullptr, lite::RET_ERROR, "new_cnode is nullptr.");
AbstractBasePtrList abstract_list;
for (size_t j = 0; j < graph_outputs_.size(); j++) {
auto abstract_tensor = lite::CreateTensorAbstract(graph_output_dims_[j], data_type);
auto abstract_tensor = lite::CreateTensorAbstract(graph_output_dims_[j], data_type[j]);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Abstract tensor is nullptr for output " << j;
return lite::RET_ERROR;
@ -489,17 +487,20 @@ STATUS AclPassImpl::SetCustomOutputs(const FuncGraphPtr &func_graph, const CNode
return lite::RET_ERROR;
}
custom_node->AddAttr(kOutputNames, MakeValue(graph_output_names_));
TypeId type = lite::acl::GetTypeFromNode(graph_outputs_[0]);
TypeId type;
if (graph_outputs_.size() == 1) {
type = lite::acl::GetTypeFromNode(graph_outputs_[0]);
auto abstract_tensor = lite::CreateTensorAbstract(graph_output_dims_[0], type);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Abstract_tensor is nullptr.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, lite::RET_ERROR, "Abstract_tensor is nullptr.");
custom_node->set_abstract(abstract_tensor);
return lite::RET_OK;
}
if (SetMultiOutputs(custom_node, type) != lite::RET_OK) {
std::vector<TypeId> types;
for (size_t i = 0; i < graph_outputs_.size(); i++) {
type = lite::acl::GetTypeFromNode(graph_outputs_[i]);
types.emplace_back(type);
}
if (SetMultiOutputs(custom_node, types) != lite::RET_OK) {
MS_LOG(ERROR) << "Set multi graph output failed.";
return lite::RET_ERROR;
}
@ -521,18 +522,13 @@ void AclPassImpl::SetCustomAttrs(const std::shared_ptr<ops::Custom> &prim) {
}
CNodePtr AclPassImpl::CreateCustomNode(const FuncGraphPtr &func_graph) {
MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "func_graph is nullptr.");
auto prim = std::make_shared<mindspore::ops::Custom>();
if (prim == nullptr) {
MS_LOG(ERROR) << "New custom op failed.";
return nullptr;
}
MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "New custom op failed.");
prim->set_type(kCustomPrimTypeACL);
auto graph_input = func_graph->get_inputs();
CNodePtr custom_node = func_graph->NewCNode(prim, graph_input);
if (custom_node == nullptr) {
MS_LOG(ERROR) << "Custom cnode failed.";
return nullptr;
}
MS_CHECK_TRUE_MSG(custom_node != nullptr, nullptr, "Custom cnode failed.");
custom_node->set_fullname_with_scope(kCustomNodeName);
custom_node->add_input(om_parameter_);
@ -574,7 +570,9 @@ STATUS AclPassImpl::ModifyGraphByCustomNode(const FuncGraphPtr &func_graph, cons
return lite::RET_ERROR;
}
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
MS_CHECK_TRUE_MSG(tuple_get_item_prim != nullptr, lite::RET_ERROR, "item_prim is nullptr.");
auto get_item_value = NewValueNode(MakeValue<int>(j));
MS_CHECK_TRUE_MSG(get_item_value != nullptr, lite::RET_ERROR, "item_value is nullptr.");
AnfNodePtrList inputs{tuple_get_item_prim, custom_node, get_item_value};
CNodePtr get_item_cnode = func_graph->NewCNode(inputs);
if (get_item_cnode == nullptr) {
@ -600,15 +598,9 @@ STATUS AclPassImpl::ModifyGraphByCustomNode(const FuncGraphPtr &func_graph, cons
bool AclPassImpl::Run(const FuncGraphPtr &func_graph) {
MS_LOG(INFO) << "Acl pass run start.";
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Func_graph is nullptr.";
return false;
}
MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "Func_graph is nullptr.");
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "Manager is nullptr.";
return false;
}
MS_CHECK_TRUE_MSG(manager != nullptr, false, "Manager is nullptr.");
if (PreProcGraph(func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "Pre proc graph failed.";
@ -626,10 +618,7 @@ bool AclPassImpl::Run(const FuncGraphPtr &func_graph) {
}
custom_node_ = CreateCustomNode(func_graph);
if (custom_node_ == nullptr) {
MS_LOG(ERROR) << "Create custom node failed.";
return false;
}
MS_CHECK_TRUE_MSG(custom_node_ != nullptr, false, "Create custom node failed.");
// prepare graph for export create
if (ModifyGraphByCustomNode(func_graph, manager, custom_node_) != lite::RET_OK) {
MS_LOG(ERROR) << "Modify func graph by custom failed.";

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -67,7 +67,7 @@ class AclPassImpl {
CNodePtr CreateCustomNode(const FuncGraphPtr &func_graph);
void SetCustomAttrs(const std::shared_ptr<ops::Custom> &prim);
STATUS SetCustomOutputs(const FuncGraphPtr &func_graph, const CNodePtr &custom_node);
STATUS SetMultiOutputs(const CNodePtr &new_cnode, TypeId data_type);
STATUS SetMultiOutputs(const CNodePtr &new_cnode, std::vector<TypeId> data_type);
STATUS GetFuncGraphOutputInfo(const FuncGraphPtr &func_graph);
STATUS TraceOutput(const AnfNodePtr &node);

View File

@ -131,7 +131,7 @@ bool InputAdjust::Run(const FuncGraphPtr &func_graph) {
} else if (opt::CheckPrimitiveType(node, prim::kPrimReduceFusion)) {
MS_LOG(INFO) << "Adjust ReduceFusion";
status = AddAttrToInput(func_graph, cnode, opt::kInputIndexTwo, "axes", kBuildInputFlagTwo);
} else if (opt::CheckPrimitiveType(node, prim::kPrimPadFusion)) {
} else if (opt::CheckPrimitiveType(node, prim::kPrimPad) || opt::CheckPrimitiveType(node, prim::kPrimPadFusion)) {
MS_LOG(INFO) << "Adjust PadFusion";
status = AddAttrToInput(func_graph, cnode, opt::kInputIndexTwo, "paddings", kBuildInputFlagThree);
} else if (opt::CheckPrimitiveType(node, prim::kPrimPowFusion)) {

View File

@ -319,7 +319,7 @@ int RemoveRedundantOpPass::RemoveInvalidPadOp(const AnfNodePtr &anf_node, const
} else {
auto pad_prim = utils::cast<std::shared_ptr<mindspore::ops::PadFusion>>(primitive);
MS_ASSERT(pad_prim != nullptr);
MS_CHECK_TRUE_RET(pad_prim->GetAttr(ops::kPadding) != nullptr, lite::RET_ERROR);
MS_CHECK_TRUE_RET(pad_prim->GetAttr(ops::kPaddings) != nullptr, lite::RET_ERROR);
auto pad_data = pad_prim->get_paddings();
for (size_t i = 0; i < pad_data.size(); i++) {
for (size_t j = 0; j < pad_data[i].size(); j++) {