310 codex and online infer bugfix
This commit is contained in:
parent
25c3d507cd
commit
8e20d1e4c3
|
@ -576,9 +576,11 @@ int LiteSession::IsolateOutputTensor() {
|
|||
new_tensor->set_init_ref_count(src_tensor->init_ref_count());
|
||||
|
||||
/* src tensor set for graph calculate */
|
||||
#ifdef ENABLE_FP16
|
||||
if (src_tensor->data_type() == kNumberTypeFloat16) {
|
||||
src_tensor->set_data_type(kNumberTypeFloat32);
|
||||
}
|
||||
#endif
|
||||
src_tensor->set_ref_count(1);
|
||||
|
||||
isolate_graph_output_map_.insert(std::make_pair(new_tensor, src_tensor));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -32,6 +32,7 @@ STATUS AclOptionsParser::ParseAclOptions(const mindspore::Context *ctx, AclModel
|
|||
CHECK_NULL_RETURN(acl_options);
|
||||
|
||||
auto context = const_cast<mindspore::Context *>(ctx);
|
||||
CHECK_NULL_RETURN(context);
|
||||
auto device_infos = context->MutableDeviceInfo();
|
||||
if (device_infos.size() < 1) {
|
||||
MS_LOG(WARNING) << "Context is not set device info, please check.";
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -24,6 +24,7 @@
|
|||
#include "utils/utils.h"
|
||||
#include "src/common/log_util.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -35,24 +36,23 @@ constexpr size_t kInvalidSize = SIZE_MAX;
|
|||
} // namespace
|
||||
|
||||
static size_t GetTupleGetItemOutIndex(const mindspore::CNodePtr &tuple_get_item) {
|
||||
MS_ASSERT(tuple_get_item != nullptr);
|
||||
if (tuple_get_item->size() != mindspore::kTupleGetItemInputSize) {
|
||||
MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!";
|
||||
return kInvalidSize;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(tuple_get_item != nullptr, kInvalidSize, "tuple_get_item is nullptr.");
|
||||
MS_CHECK_TRUE_MSG(tuple_get_item->size() == mindspore::kTupleGetItemInputSize, kInvalidSize,
|
||||
"The node tuple_get_item must have 3 inputs!");
|
||||
auto output_index_value_node = tuple_get_item->input(mindspore::kInputNodeOutputIndexInTupleGetItem);
|
||||
MS_ASSERT(output_index_value_node != nullptr);
|
||||
MS_CHECK_TRUE_MSG(output_index_value_node != nullptr, kInvalidSize, "output_index_value_node is nullptr.");
|
||||
auto value_node = output_index_value_node->cast<mindspore::ValueNodePtr>();
|
||||
MS_ASSERT(value_node != nullptr);
|
||||
return IntToSize(opt::CastToInt(value_node->value()).front());
|
||||
MS_CHECK_TRUE_MSG(value_node != nullptr, kInvalidSize, "value_node is nullptr.");
|
||||
auto values = opt::CastToInt(value_node->value());
|
||||
MS_CHECK_TRUE_MSG(values.size() > 0, kInvalidSize, "value_node has no value.");
|
||||
return IntToSize(values.front());
|
||||
}
|
||||
|
||||
static bool CheckPrimitiveType(const mindspore::AnfNodePtr &node, const mindspore::PrimitivePtr &primitive_type) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
|
||||
if (node->isa<mindspore::CNode>()) {
|
||||
auto cnode = node->cast<mindspore::CNodePtr>();
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cnode is nullptr.");
|
||||
return IsPrimitive(cnode->input(0), primitive_type);
|
||||
} else if (node->isa<mindspore::ValueNode>()) {
|
||||
return IsPrimitive(node, primitive_type);
|
||||
|
@ -64,9 +64,9 @@ STATUS GetShapeVectorFromCNode(const mindspore::CNodePtr &cnode, std::vector<int
|
|||
mindspore::AbstractBasePtr cnode_abstract;
|
||||
if (CheckPrimitiveType(cnode, mindspore::prim::kPrimTupleGetItem)) {
|
||||
auto tuple_inputs = cnode->inputs();
|
||||
MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize);
|
||||
MS_CHECK_TRUE_MSG(tuple_inputs.size() == kTupleGetItemInputSize, lite::RET_ERROR, "The node must has 3 inputs.");
|
||||
auto get_item_input_cnode = tuple_inputs.at(kSecondIndex);
|
||||
MS_ASSERT(get_item_input_cnode != nullptr);
|
||||
MS_CHECK_TRUE_MSG(get_item_input_cnode != nullptr, lite::RET_ERROR, "input cnode is nullptr.");
|
||||
auto idx = GetTupleGetItemOutIndex(cnode);
|
||||
if (!mindspore::utils::isa<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
|
||||
MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
|
||||
|
@ -106,8 +106,10 @@ STATUS GetShapeVectorFromCNode(const mindspore::CNodePtr &cnode, std::vector<int
|
|||
|
||||
TypeId GetTypeFromNode(const AnfNodePtr &node) {
|
||||
TypeId type = kNumberTypeFloat32;
|
||||
MS_CHECK_TRUE_MSG(node != nullptr, type, "node is nullptr.");
|
||||
if (utils::isa<CNodePtr>(node)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr, type, "cnode is nullptr.");
|
||||
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
|
||||
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
|
||||
|
@ -115,6 +117,7 @@ TypeId GetTypeFromNode(const AnfNodePtr &node) {
|
|||
return type;
|
||||
}
|
||||
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
|
||||
MS_CHECK_TRUE_MSG(type_ptr != nullptr, type, "type_ptr is nullptr.");
|
||||
type = type_ptr->type_id();
|
||||
}
|
||||
MS_LOG(INFO) << "node type id is " << type;
|
||||
|
@ -124,31 +127,27 @@ TypeId GetTypeFromNode(const AnfNodePtr &node) {
|
|||
|
||||
std::vector<int> GetIntParameterData(const ParameterPtr ¶m_ptr) {
|
||||
std::vector<int> result;
|
||||
if (param_ptr == nullptr) {
|
||||
MS_LOG(DEBUG) << "Param is nullptr.";
|
||||
return result;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(param_ptr != nullptr, result, "Param is nullptr.");
|
||||
|
||||
if (!param_ptr->has_default()) {
|
||||
MS_LOG(DEBUG) << "Param has not default.";
|
||||
return result;
|
||||
}
|
||||
auto default_param = param_ptr->default_param();
|
||||
MS_CHECK_TRUE_MSG(default_param != nullptr, result, "default_param is nullptr.");
|
||||
if (!utils::isa<tensor::TensorPtr>(default_param)) {
|
||||
MS_LOG(DEBUG) << "Tensor info is not tensor::TensorPtr.";
|
||||
return result;
|
||||
}
|
||||
auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
|
||||
if (default_param_ptr == nullptr) {
|
||||
MS_LOG(DEBUG) << "Default param ptr is nullptr.";
|
||||
return result;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(default_param_ptr != nullptr, result, "default_param_ptr is nullptr.");
|
||||
if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) {
|
||||
MS_LOG(DEBUG) << "Default param is not int.";
|
||||
return result;
|
||||
}
|
||||
|
||||
auto ptr = reinterpret_cast<int *>(default_param_ptr->data_c());
|
||||
MS_CHECK_TRUE_MSG(ptr != nullptr, result, "ptr is nullptr.");
|
||||
int shape_size =
|
||||
std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<int>());
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
|
@ -158,6 +157,7 @@ std::vector<int> GetIntParameterData(const ParameterPtr ¶m_ptr) {
|
|||
}
|
||||
|
||||
bool IsCaseNode(const CNodePtr node) {
|
||||
MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
|
||||
if (node->input(0) == nullptr) {
|
||||
MS_LOG(WARNING) << "The input of node is nullptr.";
|
||||
return false;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -25,6 +25,7 @@
|
|||
#include "ops/relu6.h"
|
||||
#include "ops/sigmoid.h"
|
||||
#include "ops/tanh.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -44,10 +45,7 @@ STATUS ActivationMapper::Mapper(const CNodePtr &cnode) {
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
auto activate_prim = dynamic_cast<ops::Activation *>(src_prim.get());
|
||||
if (activate_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Dynamic cast activation failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(activate_prim != nullptr, lite::RET_ERROR, "Dynamic cast activation failed.");
|
||||
PrimitivePtr dst_prim = nullptr;
|
||||
ActivationType type = activate_prim->get_activation_type();
|
||||
if (activation_type_map.find(type) != activation_type_map.end()) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,6 +17,7 @@
|
|||
#include "tools/converter/adapter/acl/mapper/cast_mapper.h"
|
||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
||||
#include "tools/converter/adapter/acl/common/utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -25,33 +26,26 @@ constexpr size_t kNameCastInputNum = 3;
|
|||
} // namespace
|
||||
|
||||
STATUS CastMapper::Mapper(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "Cnode is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr, lite::RET_ERROR, "Cnode is nullptr.");
|
||||
if (cnode->size() != kNameCastInputNum) {
|
||||
MS_LOG(ERROR) << "Input size of cast must be " << kNameCastInputNum << ", real size: " << cnode->size();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
// convert last parameter to const value node
|
||||
auto to_input = cnode->input(kNameCastInputNum - 1);
|
||||
MS_CHECK_TRUE_MSG(to_input != nullptr, lite::RET_ERROR, "to_input is nullptr.");
|
||||
if (!utils::isa<ParameterPtr>(to_input)) {
|
||||
MS_LOG(ERROR) << "The to node is not parameter.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ParameterPtr to_param = to_input->cast<ParameterPtr>();
|
||||
MS_CHECK_TRUE_MSG(to_param != nullptr, lite::RET_ERROR, "to_param is nullptr.");
|
||||
auto data = acl::GetIntParameterData(to_param);
|
||||
int dst_type = data.empty() ? kNumberTypeInt32 : data.front();
|
||||
TypePtr type_ptr = TypeIdToType(TypeId(dst_type));
|
||||
if (type_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "New type ptr failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(type_ptr != nullptr, lite::RET_ERROR, "New type ptr failed.");
|
||||
ValueNodePtr value_node = NewValueNode(type_ptr);
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(ERROR) << "New value node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(value_node != nullptr, lite::RET_ERROR, "New value node failed.");
|
||||
cnode->set_input(kNameCastInputNum - 1, value_node);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,6 +21,7 @@
|
|||
#include "tools/converter/adapter/acl/common/utils.h"
|
||||
#include "include/registry/converter_context.h"
|
||||
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -44,7 +45,7 @@ STATUS Conv2dTransposeMapper::Mapper(const CNodePtr &cnode) {
|
|||
} else {
|
||||
dst_prim = std::make_shared<acl::Conv2DTransposeD>();
|
||||
}
|
||||
MS_ASSERT(dst_prim != nullptr);
|
||||
MS_CHECK_TRUE_MSG(dst_prim != nullptr, RET_ERROR, "dst_prim is nullptr.");
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
if (fmk_type != converter::kFmkTypeCaffe) {
|
||||
if (AdjustGeAttr(cnode, dst_prim) != RET_OK) {
|
||||
|
@ -62,6 +63,7 @@ STATUS Conv2dTransposeMapper::Mapper(const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
STATUS Conv2dTransposeMapper::AdjustGeAttr(const CNodePtr &cnode, const PrimitivePtr &dst_prim) {
|
||||
MS_CHECK_TRUE_MSG(dst_prim != nullptr, RET_ERROR, "dst_prim is nullptr.");
|
||||
std::vector<int64_t> shape = {0, 0, 0, 0};
|
||||
dst_prim->AddAttr("input_size", MakeValue(shape));
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,6 +18,7 @@
|
|||
#include <vector>
|
||||
#include <map>
|
||||
#include "ops/op_utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -32,6 +33,7 @@ constexpr auto kNamePaddingMode = "padding";
|
|||
|
||||
STATUS ConvBaseMapper::AdjustAttrPad(const PrimitivePtr &prim) {
|
||||
// attr pad val
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "prim is nullptr.");
|
||||
auto pad_ptr = prim->GetAttr(ops::kPadList);
|
||||
if (pad_ptr == nullptr) {
|
||||
std::vector<int64_t> pad_list = {0, 0, 0, 0};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,6 +17,7 @@
|
|||
#include "tools/converter/adapter/acl/mapper/gather_fusion_mapper.h"
|
||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
||||
#include "tools/converter/adapter/acl/common/utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -25,28 +26,24 @@ constexpr size_t kNameGatherInputNum = 4;
|
|||
}
|
||||
|
||||
STATUS GatherMapper::Mapper(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "Cnode is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr, lite::RET_ERROR, "Cnode is nullptr.");
|
||||
if (cnode->size() != kNameGatherInputNum) {
|
||||
MS_LOG(ERROR) << "Input size of gather must be " << kNameGatherInputNum << ", real size: " << cnode->size();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
// convert last parameter to const value node
|
||||
auto axis_input = cnode->input(kNameGatherInputNum - 1);
|
||||
MS_CHECK_TRUE_MSG(axis_input != nullptr, lite::RET_ERROR, "axis_input is nullptr.");
|
||||
if (!utils::isa<ParameterPtr>(axis_input)) {
|
||||
MS_LOG(ERROR) << "The axis node is not parameter.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ParameterPtr axis_param = axis_input->cast<ParameterPtr>();
|
||||
MS_CHECK_TRUE_MSG(axis_param != nullptr, lite::RET_ERROR, "axis_param is nullptr.");
|
||||
auto data = acl::GetIntParameterData(axis_param);
|
||||
int64_t axis = data.empty() ? 0 : static_cast<int64_t>(data.front());
|
||||
ValueNodePtr value_node = NewValueNode<int64_t>(axis);
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(ERROR) << "New value node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(value_node != nullptr, lite::RET_ERROR, "New value node failed.");
|
||||
cnode->set_input(kNameGatherInputNum - 1, value_node);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,6 +19,7 @@
|
|||
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
||||
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
|
||||
#include "include/registry/converter_context.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -40,10 +41,7 @@ STATUS MaxPoolFusionMapper::Mapper(const CNodePtr &cnode) {
|
|||
} else {
|
||||
dst_prim = std::make_shared<ops::MaxPool>();
|
||||
}
|
||||
if (dst_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Get primitive by fmk type failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "Get primitive by fmk type failed.");
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
if (AdjustPoolAttr(fmk_type, kNameMaxPoolFusion, dst_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Adjust pool attr failed.";
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -61,6 +61,7 @@ STATUS PrimitiveMapper::GetValueNodeAndPrimFromCnode(const CNodePtr &cnode, Valu
|
|||
}
|
||||
|
||||
STATUS PrimitiveMapper::AttrAdjust(const PrimitivePtr &prim, const std::string &name) const {
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "prim is nullptr.");
|
||||
auto value_ptr = prim->GetAttr(name);
|
||||
if (value_ptr == nullptr) {
|
||||
MS_LOG(WARNING) << prim->name() << " has no attr " << name;
|
||||
|
@ -143,6 +144,7 @@ STATUS PrimitiveMapper::AdjustPoolAttr(int fmk_type, const std::string &src_prim
|
|||
AdjustOnnxPoolAttr(dst_prim);
|
||||
}
|
||||
// adjust common attr
|
||||
MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
|
||||
auto status = AttrAdjust(dst_prim, ops::kKernelSize);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Adjust kernel size failed.";
|
||||
|
@ -163,10 +165,7 @@ STATUS PrimitiveMapper::MoveAttrMap(const CNodePtr &cnode, const PrimitivePtr &d
|
|||
MS_LOG(ERROR) << "Get primitive from cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (dst_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Primitive is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
value_node->set_value(dst_prim);
|
||||
return lite::RET_OK;
|
||||
|
@ -174,6 +173,7 @@ STATUS PrimitiveMapper::MoveAttrMap(const CNodePtr &cnode, const PrimitivePtr &d
|
|||
|
||||
STATUS PrimitiveMapper::AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const PrimitivePtr &dst_prim, const std::string &attr_name, size_t flag) const {
|
||||
MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
|
||||
auto attr_val = dst_prim->GetAttr(attr_name);
|
||||
if (attr_val == nullptr) {
|
||||
MS_LOG(INFO) << "There is no attr: " << attr_name;
|
||||
|
@ -232,6 +232,7 @@ STATUS PrimitiveMapper::AddAttrForDynInputPrimitive(const CNodePtr &cnode, const
|
|||
|
||||
STATUS PrimitiveMapper::AdjustAttrFormat(const PrimitivePtr &prim, const std::string &name) const {
|
||||
int64_t format = Format::NCHW;
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "prim is nullptr.");
|
||||
if (prim->GetAttr(ops::kFormat) != nullptr) {
|
||||
format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -45,18 +45,14 @@ const std::set<std::string> kCNodeWithDynamicInput = {kNamewiEltwise, ops::kName
|
|||
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode) {
|
||||
CNodePtr get_item_cnode = nullptr;
|
||||
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
||||
if (tuple_get_item_prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "New TupleGetItem failed";
|
||||
return nullptr;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(tuple_get_item_prim_ptr != nullptr, nullptr, "New TupleGetItem failed.");
|
||||
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
|
||||
MS_CHECK_TRUE_MSG(tuple_get_item_prim != nullptr, nullptr, "tuple_prim is nullptr.");
|
||||
auto get_item_value = NewValueNode(MakeValue<int64_t>(0));
|
||||
MS_CHECK_TRUE_MSG(get_item_value != nullptr, nullptr, "item_value is nullptr.");
|
||||
AnfNodePtrList inputs{tuple_get_item_prim, input_cnode, get_item_value};
|
||||
get_item_cnode = func_graph->NewCNode(inputs);
|
||||
if (get_item_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "New get item cnode failed.";
|
||||
return nullptr;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(get_item_cnode != nullptr, nullptr, "New get item cnode failed.");
|
||||
|
||||
std::vector<int64_t> shape;
|
||||
if (acl::GetShapeVectorFromCNode(input_cnode, &shape) != lite::RET_OK) {
|
||||
|
@ -65,10 +61,7 @@ CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
}
|
||||
TypeId type = acl::GetTypeFromNode(input_cnode);
|
||||
auto get_item_abstract = CreateTensorAbstract(shape, type);
|
||||
if (get_item_abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstract failed.";
|
||||
return nullptr;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(get_item_abstract != nullptr, nullptr, "Create tensor abstract failed.");
|
||||
get_item_cnode->set_abstract(get_item_abstract);
|
||||
get_item_cnode->set_fullname_with_scope(input_cnode->fullname_with_scope() + "_getitem");
|
||||
return get_item_cnode;
|
||||
|
@ -83,10 +76,12 @@ static STATUS AdapteNodeWithMultiOutputs(const FuncGraphPtr &func_graph, const C
|
|||
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
auto input = cnode->input(i);
|
||||
MS_CHECK_TRUE_MSG(input != nullptr, lite::RET_ERROR, "input is nullptr.");
|
||||
if (!utils::isa<CNode>(input)) {
|
||||
continue;
|
||||
}
|
||||
auto input_cnode = input->cast<CNodePtr>();
|
||||
MS_CHECK_TRUE_MSG(input_cnode != nullptr, lite::RET_ERROR, "input_cnode is nullptr.");
|
||||
std::string input_func_name = GetCNodeFuncName(input_cnode);
|
||||
if (kCNodeWithMultiOutputs.find(input_func_name) != kCNodeWithMultiOutputs.end()) {
|
||||
MS_LOG(DEBUG) << "Input " << input_func_name << " of cnode " << cnode_func_name << " has multioutputs";
|
||||
|
@ -111,20 +106,14 @@ static STATUS AdapteNodeWithDynamicInput(const FuncGraphPtr &func_graph, const C
|
|||
}
|
||||
MS_LOG(DEBUG) << "Adapter cnode with dynamic input: " << cnode_func_name;
|
||||
auto make_tuple_val_node = NewValueNode(prim::kPrimMakeTuple);
|
||||
if (make_tuple_val_node == nullptr) {
|
||||
MS_LOG(ERROR) << "New make tuple val node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(make_tuple_val_node != nullptr, lite::RET_ERROR, "New make tuple val node failed.");
|
||||
AnfNodePtrList new_inputs = {make_tuple_val_node};
|
||||
auto cnode_inputs = cnode->inputs();
|
||||
if (cnode_inputs.size() >= kCnodeInputMinNum) {
|
||||
new_inputs.insert(new_inputs.end(), cnode_inputs.begin() + 1, cnode_inputs.end());
|
||||
}
|
||||
auto make_tuple_cnode = func_graph->NewCNode(new_inputs);
|
||||
if (make_tuple_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "New make tuple cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(make_tuple_cnode != nullptr, lite::RET_ERROR, "New make tuple cnode failed.");
|
||||
|
||||
const std::vector<AnfNodePtr> replace_node = {cnode_inputs[0], make_tuple_cnode};
|
||||
cnode->set_inputs(replace_node);
|
||||
|
@ -134,10 +123,7 @@ static STATUS AdapteNodeWithDynamicInput(const FuncGraphPtr &func_graph, const C
|
|||
STATUS AdapteSpatialNode(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
for (const auto &cnode : cnodes) {
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "Cnode is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr, lite::RET_ERROR, "Cnode is nullptr.");
|
||||
if (AdapteNodeWithMultiOutputs(func_graph, cnode, manager) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Adapter node with multioutput failed.";
|
||||
return lite::RET_ERROR;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,6 +20,7 @@
|
|||
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
|
||||
#include "include/registry/converter_context.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -35,7 +36,7 @@ STATUS StridedSliceMapper::Mapper(const CNodePtr &cnode) {
|
|||
int fmk_type = attr_val != nullptr ? GetValue<int>(attr_val) : converter::kFmkTypeTf;
|
||||
if (fmk_type == converter::kFmkTypeOnnx) {
|
||||
auto dst_prim = std::make_shared<acl::StridedSliceV2>();
|
||||
MS_ASSERT(dst_prim != nullptr);
|
||||
MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
value_node->set_value(dst_prim);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,6 +19,7 @@
|
|||
#include <vector>
|
||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
||||
#include "tools/converter/adapter/acl/common/utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -27,30 +28,26 @@ constexpr size_t kCommonInputNum = 3;
|
|||
}
|
||||
|
||||
STATUS TransposeMapper::Mapper(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "Cnode is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr, lite::RET_ERROR, "Cnode is nullptr.");
|
||||
if (cnode->size() != kCommonInputNum) {
|
||||
MS_LOG(ERROR) << "Input size of transpose must be " << kCommonInputNum << ", real size: " << cnode->size();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
// convert last parameter to const value node
|
||||
auto perm_input = cnode->input(kCommonInputNum - 1);
|
||||
MS_CHECK_TRUE_MSG(perm_input != nullptr, lite::RET_ERROR, "perm_input is nullptr.");
|
||||
if (!utils::isa<ParameterPtr>(perm_input)) {
|
||||
MS_LOG(ERROR) << "The perm node is not parameter.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ParameterPtr perm_param = perm_input->cast<ParameterPtr>();
|
||||
MS_CHECK_TRUE_MSG(perm_param != nullptr, lite::RET_ERROR, "perm_param is nullptr.");
|
||||
auto data = acl::GetIntParameterData(perm_param);
|
||||
std::vector<int64_t> perm;
|
||||
std::transform(data.begin(), data.end(), std::back_inserter(perm),
|
||||
[](int32_t n) -> int64_t { return static_cast<int64_t>(n); });
|
||||
ValueNodePtr value_node = NewValueNode<std::vector<int64_t>>(perm);
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(ERROR) << "New value node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(value_node != nullptr, lite::RET_ERROR, "New value node failed.");
|
||||
cnode->set_input(kCommonInputNum - 1, value_node);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,6 +21,7 @@
|
|||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "src/common/log_util.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -60,6 +61,8 @@ STATUS UpsampleMapper::Mapper(const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
STATUS UpsampleMapper::AttrAdjust(const PrimitivePtr &src_prim, const ValueNodePtr &val_node) {
|
||||
MS_CHECK_TRUE_MSG(src_prim != nullptr, RET_ERROR, "src_prim is nullptr.");
|
||||
MS_CHECK_TRUE_MSG(val_node != nullptr, RET_ERROR, "val_node is nullptr.");
|
||||
auto attr_val = src_prim->GetAttr("scale");
|
||||
CHECK_NULL_RETURN(attr_val);
|
||||
std::vector<float> scale = opt::CastToFloat(attr_val);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -166,6 +166,7 @@ STATUS AclPassImpl::PostProcGraph(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
|
||||
std::string AclPassImpl::AdjustCnodeName(const PrimitivePtr &prim) {
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, "", "prim is nullptr.");
|
||||
std::string name = prim->name();
|
||||
if (kAdjustCnodeName.find(name) != kAdjustCnodeName.end()) {
|
||||
auto val_ptr = prim->GetAttr(ops::kOriginalOpName);
|
||||
|
@ -180,7 +181,7 @@ std::string AclPassImpl::AdjustCnodeName(const PrimitivePtr &prim) {
|
|||
|
||||
STATUS AclPassImpl::RunPrimitiveMapper(const FuncGraphPtr &func_graph) {
|
||||
MS_LOG(INFO) << "Deparser graph start.";
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_CHECK_TRUE_MSG(func_graph != nullptr, lite::RET_ERROR, "func_graph is nullptr.");
|
||||
std::set<FuncGraphPtr> all_func_graphs = {};
|
||||
lite::GetAllFuncGraph(func_graph, &all_func_graphs);
|
||||
for (auto graph : all_func_graphs) {
|
||||
|
@ -190,6 +191,7 @@ STATUS AclPassImpl::RunPrimitiveMapper(const FuncGraphPtr &func_graph) {
|
|||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr, lite::RET_ERROR, "cnode is nullptr.");
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
CHECK_NULL_RETURN(prim);
|
||||
std::string name = AdjustCnodeName(prim);
|
||||
|
@ -230,10 +232,7 @@ STATUS AclPassImpl::DeparseGraph(const FuncGraphPtr &func_graph, const FuncGraph
|
|||
}
|
||||
|
||||
STATUS AclPassImpl::ConvertGraphToOm(const FuncGraphPtr &func_graph, Buffer *om_data) {
|
||||
if (om_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Om data is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(om_data != nullptr, lite::RET_ERROR, "Om data is nullptr.");
|
||||
SetAclModelOptions(func_graph);
|
||||
// call interface of cloud
|
||||
ModelConverter model_converter;
|
||||
|
@ -287,13 +286,9 @@ void AclPassImpl::SetAclModelBuildOptions(const std::shared_ptr<AscendDeviceInfo
|
|||
|
||||
std::shared_ptr<mindspore::Context> AclPassImpl::CreateModelContext() {
|
||||
auto model_context = std::make_shared<mindspore::Context>();
|
||||
if (model_context == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(model_context != nullptr, nullptr, "model_context is nullptr.");
|
||||
auto ascend_info = std::make_shared<AscendDeviceInfo>();
|
||||
if (ascend_info == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(ascend_info != nullptr, nullptr, "ascend_info is nullptr.");
|
||||
ascend_info->SetDeviceID(user_options_cfg_.device_id);
|
||||
SetAclModelInitOptions(ascend_info);
|
||||
SetAclModelBuildOptions(ascend_info);
|
||||
|
@ -304,6 +299,7 @@ std::shared_ptr<mindspore::Context> AclPassImpl::CreateModelContext() {
|
|||
|
||||
STATUS AclPassImpl::SetAclModelOptions(const FuncGraphPtr &func_graph) {
|
||||
MS_LOG(INFO) << "Set acl model options start.";
|
||||
MS_CHECK_TRUE_MSG(func_graph != nullptr, lite::RET_ERROR, "func_graph is nullptr.");
|
||||
auto model_context = CreateModelContext();
|
||||
CHECK_NULL_RETURN(model_context);
|
||||
options_ = std::make_shared<AclModelOptions>(model_context);
|
||||
|
@ -328,21 +324,23 @@ STATUS AclPassImpl::SetAclModelOptions(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
|
||||
ParameterPtr AclPassImpl::CreateOmParameter(const FuncGraphPtr &func_graph, const Buffer &om_data) {
|
||||
MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "func_graph is nullptr.");
|
||||
ParameterPtr om_parameter = func_graph->add_parameter();
|
||||
MS_CHECK_TRUE_MSG(om_parameter != nullptr, nullptr, "om_parameter is nullptr.");
|
||||
om_parameter->set_name("ACL_om_data");
|
||||
|
||||
auto type_ptr = TypeIdToType(kNumberTypeUInt8);
|
||||
MS_CHECK_TRUE_MSG(type_ptr != nullptr, nullptr, "type_ptr is nullptr.");
|
||||
ShapeVector shape_vector = {static_cast<int64_t>(om_data.DataSize())};
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
||||
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, nullptr, "abstract_tensor is nullptr.");
|
||||
om_parameter->set_abstract(abstract_tensor);
|
||||
|
||||
auto param_value =
|
||||
std::make_shared<tensor::Tensor>(kNumberTypeUInt8, ShapeVector({static_cast<int64_t>(om_data.DataSize())}));
|
||||
MS_CHECK_TRUE_MSG(param_value != nullptr, nullptr, "param_value is nullptr.");
|
||||
auto tensor_data = param_value->data_c();
|
||||
if (tensor_data == nullptr) {
|
||||
MS_LOG(ERROR) << "New Tensor failed.";
|
||||
return nullptr;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(tensor_data != nullptr, nullptr, "New Tensor failed.");
|
||||
if (param_value->Size() < om_data.DataSize()) {
|
||||
MS_LOG(ERROR) << "Dst buff size " << param_value->Size() << " should be greater than src buff size "
|
||||
<< om_data.DataSize();
|
||||
|
@ -393,10 +391,7 @@ STATUS AclPassImpl::BuildGraph(const FuncGraphPtr &func_graph) {
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
om_parameter_ = CreateOmParameter(func_graph, om_data);
|
||||
if (om_parameter_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert graph to om failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(om_parameter_ != nullptr, lite::RET_ERROR, "Convert graph to om failed.");
|
||||
if (!user_options_cfg_.insert_op_config_file_path.empty()) {
|
||||
if (CreateGraphAippInput(func_graph, om_data) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Create aipp input failed.";
|
||||
|
@ -416,6 +411,7 @@ STATUS AclPassImpl::TraceOutput(const AnfNodePtr &node) {
|
|||
CHECK_NULL_RETURN(tmp);
|
||||
cur_node = tmp->input(kTupleGetItemFirstInputIdx);
|
||||
}
|
||||
CHECK_NULL_RETURN(cur_node);
|
||||
auto cnode = cur_node->cast<CNodePtr>();
|
||||
CHECK_NULL_RETURN(cnode);
|
||||
std::string name = lite::acl::GetCNodeTargetFuncName(cnode);
|
||||
|
@ -453,6 +449,7 @@ STATUS AclPassImpl::TraceOutput(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
STATUS AclPassImpl::GetFuncGraphOutputInfo(const FuncGraphPtr &func_graph) {
|
||||
MS_CHECK_TRUE_MSG(func_graph != nullptr, lite::RET_ERROR, "func_graph is nullptr.");
|
||||
AnfNodePtr return_input = func_graph->output();
|
||||
CHECK_NULL_RETURN(return_input);
|
||||
if (TraceOutput(return_input) != lite::RET_OK) {
|
||||
|
@ -468,10 +465,11 @@ STATUS AclPassImpl::GetFuncGraphOutputInfo(const FuncGraphPtr &func_graph) {
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS AclPassImpl::SetMultiOutputs(const CNodePtr &new_cnode, TypeId data_type) {
|
||||
STATUS AclPassImpl::SetMultiOutputs(const CNodePtr &new_cnode, std::vector<TypeId> data_type) {
|
||||
MS_CHECK_TRUE_MSG(new_cnode != nullptr, lite::RET_ERROR, "new_cnode is nullptr.");
|
||||
AbstractBasePtrList abstract_list;
|
||||
for (size_t j = 0; j < graph_outputs_.size(); j++) {
|
||||
auto abstract_tensor = lite::CreateTensorAbstract(graph_output_dims_[j], data_type);
|
||||
auto abstract_tensor = lite::CreateTensorAbstract(graph_output_dims_[j], data_type[j]);
|
||||
if (abstract_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract tensor is nullptr for output " << j;
|
||||
return lite::RET_ERROR;
|
||||
|
@ -489,17 +487,20 @@ STATUS AclPassImpl::SetCustomOutputs(const FuncGraphPtr &func_graph, const CNode
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
custom_node->AddAttr(kOutputNames, MakeValue(graph_output_names_));
|
||||
TypeId type = lite::acl::GetTypeFromNode(graph_outputs_[0]);
|
||||
TypeId type;
|
||||
if (graph_outputs_.size() == 1) {
|
||||
type = lite::acl::GetTypeFromNode(graph_outputs_[0]);
|
||||
auto abstract_tensor = lite::CreateTensorAbstract(graph_output_dims_[0], type);
|
||||
if (abstract_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract_tensor is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, lite::RET_ERROR, "Abstract_tensor is nullptr.");
|
||||
custom_node->set_abstract(abstract_tensor);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (SetMultiOutputs(custom_node, type) != lite::RET_OK) {
|
||||
std::vector<TypeId> types;
|
||||
for (size_t i = 0; i < graph_outputs_.size(); i++) {
|
||||
type = lite::acl::GetTypeFromNode(graph_outputs_[i]);
|
||||
types.emplace_back(type);
|
||||
}
|
||||
if (SetMultiOutputs(custom_node, types) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Set multi graph output failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
@ -521,18 +522,13 @@ void AclPassImpl::SetCustomAttrs(const std::shared_ptr<ops::Custom> &prim) {
|
|||
}
|
||||
|
||||
CNodePtr AclPassImpl::CreateCustomNode(const FuncGraphPtr &func_graph) {
|
||||
MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "func_graph is nullptr.");
|
||||
auto prim = std::make_shared<mindspore::ops::Custom>();
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "New custom op failed.";
|
||||
return nullptr;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "New custom op failed.");
|
||||
prim->set_type(kCustomPrimTypeACL);
|
||||
auto graph_input = func_graph->get_inputs();
|
||||
CNodePtr custom_node = func_graph->NewCNode(prim, graph_input);
|
||||
if (custom_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Custom cnode failed.";
|
||||
return nullptr;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(custom_node != nullptr, nullptr, "Custom cnode failed.");
|
||||
custom_node->set_fullname_with_scope(kCustomNodeName);
|
||||
custom_node->add_input(om_parameter_);
|
||||
|
||||
|
@ -574,7 +570,9 @@ STATUS AclPassImpl::ModifyGraphByCustomNode(const FuncGraphPtr &func_graph, cons
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
|
||||
MS_CHECK_TRUE_MSG(tuple_get_item_prim != nullptr, lite::RET_ERROR, "item_prim is nullptr.");
|
||||
auto get_item_value = NewValueNode(MakeValue<int>(j));
|
||||
MS_CHECK_TRUE_MSG(get_item_value != nullptr, lite::RET_ERROR, "item_value is nullptr.");
|
||||
AnfNodePtrList inputs{tuple_get_item_prim, custom_node, get_item_value};
|
||||
CNodePtr get_item_cnode = func_graph->NewCNode(inputs);
|
||||
if (get_item_cnode == nullptr) {
|
||||
|
@ -600,15 +598,9 @@ STATUS AclPassImpl::ModifyGraphByCustomNode(const FuncGraphPtr &func_graph, cons
|
|||
|
||||
bool AclPassImpl::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_LOG(INFO) << "Acl pass run start.";
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Func_graph is nullptr.";
|
||||
return false;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "Func_graph is nullptr.");
|
||||
auto manager = Manage(func_graph, true);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "Manager is nullptr.";
|
||||
return false;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(manager != nullptr, false, "Manager is nullptr.");
|
||||
|
||||
if (PreProcGraph(func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Pre proc graph failed.";
|
||||
|
@ -626,10 +618,7 @@ bool AclPassImpl::Run(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
|
||||
custom_node_ = CreateCustomNode(func_graph);
|
||||
if (custom_node_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Create custom node failed.";
|
||||
return false;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(custom_node_ != nullptr, false, "Create custom node failed.");
|
||||
// prepare graph for export create
|
||||
if (ModifyGraphByCustomNode(func_graph, manager, custom_node_) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Modify func graph by custom failed.";
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -67,7 +67,7 @@ class AclPassImpl {
|
|||
CNodePtr CreateCustomNode(const FuncGraphPtr &func_graph);
|
||||
void SetCustomAttrs(const std::shared_ptr<ops::Custom> &prim);
|
||||
STATUS SetCustomOutputs(const FuncGraphPtr &func_graph, const CNodePtr &custom_node);
|
||||
STATUS SetMultiOutputs(const CNodePtr &new_cnode, TypeId data_type);
|
||||
STATUS SetMultiOutputs(const CNodePtr &new_cnode, std::vector<TypeId> data_type);
|
||||
STATUS GetFuncGraphOutputInfo(const FuncGraphPtr &func_graph);
|
||||
STATUS TraceOutput(const AnfNodePtr &node);
|
||||
|
||||
|
|
|
@ -131,7 +131,7 @@ bool InputAdjust::Run(const FuncGraphPtr &func_graph) {
|
|||
} else if (opt::CheckPrimitiveType(node, prim::kPrimReduceFusion)) {
|
||||
MS_LOG(INFO) << "Adjust ReduceFusion";
|
||||
status = AddAttrToInput(func_graph, cnode, opt::kInputIndexTwo, "axes", kBuildInputFlagTwo);
|
||||
} else if (opt::CheckPrimitiveType(node, prim::kPrimPadFusion)) {
|
||||
} else if (opt::CheckPrimitiveType(node, prim::kPrimPad) || opt::CheckPrimitiveType(node, prim::kPrimPadFusion)) {
|
||||
MS_LOG(INFO) << "Adjust PadFusion";
|
||||
status = AddAttrToInput(func_graph, cnode, opt::kInputIndexTwo, "paddings", kBuildInputFlagThree);
|
||||
} else if (opt::CheckPrimitiveType(node, prim::kPrimPowFusion)) {
|
||||
|
|
|
@ -319,7 +319,7 @@ int RemoveRedundantOpPass::RemoveInvalidPadOp(const AnfNodePtr &anf_node, const
|
|||
} else {
|
||||
auto pad_prim = utils::cast<std::shared_ptr<mindspore::ops::PadFusion>>(primitive);
|
||||
MS_ASSERT(pad_prim != nullptr);
|
||||
MS_CHECK_TRUE_RET(pad_prim->GetAttr(ops::kPadding) != nullptr, lite::RET_ERROR);
|
||||
MS_CHECK_TRUE_RET(pad_prim->GetAttr(ops::kPaddings) != nullptr, lite::RET_ERROR);
|
||||
auto pad_data = pad_prim->get_paddings();
|
||||
for (size_t i = 0; i < pad_data.size(); i++) {
|
||||
for (size_t j = 0; j < pad_data[i].size(); j++) {
|
||||
|
|
Loading…
Reference in New Issue