!56531 增加lite onnx parser 2(new)

Merge pull request !56531 from zhangdanyang/0710_master
This commit is contained in:
i-robot 2023-07-13 06:09:19 +00:00 committed by Gitee
commit a4e0b473a9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
18 changed files with 436 additions and 30 deletions

View File

@ -6,6 +6,7 @@
mindspore/mindspore/core/mindrt/src/thread/actor_threadpool.cc:mindspore::ActorWorker::RunWithSpin
mindspore/mindspore/lite/src/common/ops/primitive_c.cc:mindspore::lite::PrimitiveC::Create
mindspore/mindspore/lite/src/extendrt/convert/runtime_convert.cc:RuntimeConvert
mindspore/mindspore/lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc:mindspore::lite::OnnxInputAdjust::Adjust
mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc:mindspore::dataset::CsvOp::CsvParser::InitCsvParser
mindspore/mindspore/lite/tools/converter/graphdef_transform.cc:mindspore::lite::GraphDefTransform::Transform
mindspore/mindspore/lite/providers/nnie_proposal/src/proposal.cc:mindspore::proposal::Rpn

View File

@ -452,6 +452,7 @@ constexpr const char kNameReduceLogSumExp[] = "ReduceLogSumExp";
constexpr const char kNameReduceLogSum[] = "ReduceLogSum";
constexpr const char kNameSize[] = "Size";
constexpr const char kNameTfIdfVectorizer[] = "TfIdfVectorizer";
constexpr const char kNameMVNV2[] = "MVNV2";
class OpAdapterDesc;

View File

@ -249,4 +249,11 @@ OUTPUT_MAP(SoftmaxGradExt) = {{0, OUTPUT_DESC(y)}};
ATTR_MAP(SoftmaxGradExt) = {{"axis", ATTR_DESC(axes, AnyTraits<int64_t>(), AnyTraits<int64_t>())},
{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>(), AnyTraits<bool>())}};
REG_ADPT_DESC(SoftmaxGradExt, kSoftmaxGradExtOpName, ADPT_DESC(SoftmaxGradExt))
// MVNV2
INPUT_MAP(MVNV2) = {{1, INPUT_DESC(x)}};
ATTR_MAP(MVNV2) = {{"eps", ATTR_DESC(eps, AnyTraits<float>())},
{"axes", ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
OUTPUT_MAP(MVNV2) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(MVNV2, kNameMVNV2, ADPT_DESC(MVNV2))
} // namespace mindspore::transform

View File

@ -115,4 +115,7 @@ DECLARE_OP_USE_OUTPUT(SoftmaxGradExt)
DECLARE_OP_ADAPTER(ConfusionSoftmaxGrad)
DECLARE_OP_USE_OUTPUT(ConfusionSoftmaxGrad)
DECLARE_OP_ADAPTER(MVNV2)
DECLARE_OP_USE_OUTPUT(MVNV2)
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_IMAGE_OPS_DECLARE_H_

View File

@ -116,15 +116,24 @@ TypePtr MultinomialInferType(const PrimitivePtr &prim, const std::vector<Abstrac
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto x_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(x_type);
auto num_samples_type = input_args[1]->BuildType();
MS_EXCEPTION_IF_NULL(num_samples_type);
const std::set valid_types_1 = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kInt32,
kInt64, kUInt8, kUInt16, kUInt32, kUInt64};
const std::set valid_types_2 = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types_1, prim_name);
(void)CheckAndConvertUtils::CheckTypeValid("num_samples", num_samples_type, valid_types_2, prim_name);
auto dtype = GetValue<TypePtr>(prim->GetAttr("dtype"));
const std::set valid_types_3 = {kInt32, kInt64};
auto out_type = CheckAndConvertUtils::CheckTypeValid("dtype", dtype->cast<TypePtr>(), valid_types_3, prim->name());
auto dtype_attr = prim->GetAttr("dtype");
MS_EXCEPTION_IF_NULL(dtype_attr);
if (!dtype_attr->isa<Type>()) {
TypeId type_id = static_cast<TypeId>(GetValue<int64_t>(dtype_attr));
auto out_type = CheckAndConvertUtils::CheckTypeValid("dtype", TypeIdToType(type_id), valid_types_3, prim->name());
return out_type;
}
auto out_type =
CheckAndConvertUtils::CheckTypeValid("dtype", dtype_attr->cast<TypePtr>(), valid_types_3, prim->name());
return out_type;
}
} // namespace

View File

@ -0,0 +1,65 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/adapter/acl/mapper/multinomial_mapper.h"
#include <memory>
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
#include "src/common/log_util.h"
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
#include "ops/op_utils.h"
#include "tools/converter/adapter/acl/common/utils.h"
namespace mindspore {
namespace lite {
STATUS MultinomialMapper::Mapper(const CNodePtr &cnode) {
ValueNodePtr value_node = nullptr;
PrimitivePtr src_prim = nullptr;
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Get primitive from cnode failed.";
return lite::RET_ERROR;
}
ops::Multinomial multinomial_op;
auto dst_prim = multinomial_op.GetPrim();
CHECK_NULL_RETURN(dst_prim);
dst_prim->SetAttrs(src_prim->attrs());
auto dst_type = src_prim->GetAttr(ops::kOutputDType);
if (dst_type != nullptr) {
if (!dst_type->isa<Type>()) {
auto type_id = static_cast<TypeId>(GetValue<int64_t>(dst_type));
dst_prim->AddAttr("dtype", TypeIdToType(type_id));
} else {
dst_prim->AddAttr("dtype", TypeIdToType(acl::GetTypeFromNode(cnode)));
}
}
if (src_prim->HasAttr(ops::kSeed)) {
auto seed_attr = src_prim->GetAttr(ops::kSeed);
dst_prim->AddAttr(ops::kSeed, seed_attr);
dst_prim->AddAttr(ops::kSeed2, seed_attr);
}
if (src_prim->HasAttr(ops::kSeed2)) {
auto seed2_attr = src_prim->GetAttr(ops::kSeed2);
dst_prim->AddAttr(ops::kSeed2, seed2_attr);
}
auto func_graph = cnode->func_graph();
CHECK_NULL_RETURN(func_graph);
value_node->set_value(dst_prim);
return lite::RET_OK;
}
REGISTER_PRIMITIVE_MAPPER(kNameMultinomial, MultinomialMapper)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MULTINOMIAL_MAPPER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MULTINOMIAL_MAPPER_H_
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
#include "ops/multinomial.h"
namespace mindspore {
namespace lite {
using mindspore::ops::kNameMultinomial;
class MultinomialMapper : public PrimitiveMapper {
public:
MultinomialMapper() : PrimitiveMapper(kNameMultinomial) {}
~MultinomialMapper() override = default;
STATUS Mapper(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MULTINOMIAL_MAPPER_H_

View File

@ -0,0 +1,54 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/adapter/acl/mapper/mvn_mapper.h"
#include <memory>
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
#include "src/common/log_util.h"
#include "mindspore/core/ops/op_name.h"
namespace mindspore {
namespace lite {
const auto kNameMVN = "MVN";
STATUS MVNMapper::Mapper(const CNodePtr &cnode) {
ValueNodePtr value_node = nullptr;
PrimitivePtr src_prim = nullptr;
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Get primitive from cnode failed.";
return lite::RET_ERROR;
}
auto dst_prim = std::make_shared<acl::MVNV2>();
CHECK_NULL_RETURN(dst_prim);
dst_prim->SetAttrs(src_prim->attrs());
auto axes_ptr = src_prim->GetAttr(ops::kAxes);
if (axes_ptr != nullptr) {
dst_prim->AddAttr(ops::kAxes, MakeValue(axes_ptr));
}
int64_t node_format = Format::NCHW;
dst_prim->AddAttr(ops::kFormat, MakeValue(node_format));
value_node->set_value(dst_prim);
auto func_graph = cnode->func_graph();
CHECK_NULL_RETURN(func_graph);
return RET_OK;
}
REGISTER_PRIMITIVE_MAPPER(kNameMVN, MVNMapper)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MVN_MAPPER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MVN_MAPPER_H_
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
namespace mindspore {
namespace lite {
class MVNMapper : public PrimitiveMapper {
public:
MVNMapper() : PrimitiveMapper(acl::kNameMVN) {}
~MVNMapper() override = default;
STATUS Mapper(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MVN_MAPPER_H_

View File

@ -73,6 +73,8 @@ ADD_CONVERTER_TBE_OP(Shrink)
ADD_CONVERTER_TBE_OP(ReduceLogSumExp)
ADD_CONVERTER_TBE_OP(ReduceLogSum)
ADD_CONVERTER_TBE_OP(SplitV)
ADD_CONVERTER_TBE_OP(MVN)
ADD_CONVERTER_TBE_OP(MVNV2)
} // namespace acl
} // namespace lite
} // namespace mindspore

View File

@ -53,6 +53,7 @@ ADD_CONVERTER_ONLY_OP(MegatronMakeViewlessTensor);
ADD_CONVERTER_ONLY_OP(MegatronScaledMaskedSoftmax);
ADD_CONVERTER_ONLY_OP(Shrink);
ADD_CONVERTER_ONLY_OP(TfIdfVectorizer);
ADD_CONVERTER_ONLY_OP(MVN);
} // namespace lite
} // namespace mindspore

View File

@ -29,6 +29,7 @@
#include "ops/concat.h"
#include "ops/reshape.h"
#include "ops/cast.h"
#include "ops/multinomial.h"
#include "include/errorcode.h"
#include "nnacl/op_base.h"
#include "tools/common/tensor_util.h"
@ -38,31 +39,7 @@
namespace mindspore::lite {
namespace {
const std::vector<int> kNH2NCPerm = {0, 3, 1, 2};
const int kInputNum2 = 2;
const int kInputNum3 = 3;
const int kInputNum4 = 4;
CNodePtr NewReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &shape) {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(input_node != nullptr);
auto reshape_prim = std::make_shared<ops::Reshape>();
if (reshape_prim == nullptr) {
MS_LOG(ERROR) << "create reshape failed.";
return nullptr;
}
auto prim_c = reshape_prim->GetPrim();
prim_c->set_attr("shape", MakeValue(shape));
ValueNodePtr value_node = NewValueNode(prim_c);
MS_CHECK_TRUE_MSG(value_node != nullptr, nullptr, "Create valuenode return nullptr");
auto new_parameter = opt::BuildIntVecParameterNode(func_graph, shape, input_node->fullname_with_scope() + "_reshape");
MS_CHECK_TRUE_MSG(new_parameter != nullptr, nullptr, "Create parameter return nullptr");
new_parameter->set_name(input_node->fullname_with_scope() + "_reshape");
std::vector<AnfNodePtr> op_inputs = {value_node, input_node, new_parameter};
auto reshape = func_graph->NewCNode(op_inputs);
MS_CHECK_TRUE_MSG(reshape != nullptr, nullptr, "Create cnode return nullptr");
reshape->set_fullname_with_scope(input_node->fullname_with_scope() + "_reshape");
return reshape;
}
constexpr int kInputNum4 = 4;
STATUS AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int input_num,
const std::string &attr_name) {
@ -494,8 +471,8 @@ STATUS AdjustROIAlign(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_LOG(INFO) << "RoiAlign input size is not 3, does not need to adjust.";
return RET_OK;
}
auto rois = cnode->inputs()[kInputNum2];
auto batch_indices = cnode->inputs()[kInputNum3];
auto rois = cnode->inputs()[THIRD_INPUT];
auto batch_indices = cnode->inputs()[FOURTH_INPUT];
auto abstract = batch_indices->abstract();
auto cast_node =
opt::GenCastNode(func_graph, batch_indices, cnode->fullname_with_scope() + "_Cast", kNumberTypeFloat32, abstract);
@ -517,7 +494,7 @@ STATUS AdjustROIAlign(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
}
std::vector<int> shape = {batch_shape_num, 1};
auto new_reshape_node = NewReshapeNode(func_graph, cast_node, shape);
auto new_reshape_node = opt::GenReshapeNode(func_graph, cast_node, shape, cnode->fullname_with_scope() + "_Reshape");
if (new_reshape_node == nullptr) {
MS_LOG(ERROR) << "Create reshape node failed.";
return RET_ERROR;
@ -545,6 +522,32 @@ STATUS AdjustROIAlign(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
opt::UpdateManager(func_graph);
return RET_OK;
}
STATUS AdjustMultinomial(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool *need_update_manager) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
auto multinomial_node = ops::GetOperator<ops::Multinomial>(cnode->input(0));
MS_CHECK_TRUE_RET(multinomial_node != nullptr, RET_ERROR);
auto prim = multinomial_node->GetPrim();
MS_CHECK_TRUE_RET(prim != nullptr, RET_ERROR);
MS_CHECK_TRUE_RET(prim->GetAttr("sample_size") != nullptr, RET_ERROR);
int64_t sample_size = GetValue<int64_t>(prim->GetAttr("sample_size"));
auto num_samples_val = static_cast<int32_t>(sample_size);
auto sample_parameter_ptr =
mindspore::opt::BuildIntValueParameterNode(func_graph, num_samples_val, "num_samples", true);
MS_CHECK_TRUE_RET(sample_parameter_ptr != nullptr, RET_ERROR);
std::vector<AnfNodePtr> new_inputs;
new_inputs.push_back(cnode->inputs()[FIRST_INPUT]);
new_inputs.push_back(cnode->inputs()[SECOND_INPUT]);
new_inputs.push_back(static_cast<AnfNodePtr>(sample_parameter_ptr));
cnode->set_inputs(new_inputs);
*need_update_manager = true;
opt::UpdateManager(func_graph);
return RET_OK;
}
} // namespace
bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph, const converter::ConverterParameters &flag) {
@ -592,6 +595,8 @@ bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph, const converter::Co
status = AdjustUnsqueeze(&need_update_manager, cnode);
} else if (opt::CheckPrimitiveType(node, prim::kPrimROIAlign)) {
status = AdjustROIAlign(func_graph, cnode);
} else if (opt::CheckPrimitiveType(node, prim::kPrimMultinomial)) {
status = AdjustMultinomial(func_graph, cnode, &need_update_manager);
} else {
continue;
}

View File

@ -0,0 +1,48 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_multinomial_parser.h"
#include <memory>
#include "ops/multinomial.h"
#include "mindapi/ir/type.h"
namespace mindspore {
namespace lite {
PrimitiveCPtr OnnxMultinomialParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto prim = std::make_unique<ops::Multinomial>();
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto prim_c = prim->GetPrim();
MS_CHECK_TRUE_RET(prim_c != nullptr, nullptr);
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "seed") {
prim->set_seed(onnx_node_attr.f());
} else if (attribute_name == "sample_size") {
int64_t sample_size = static_cast<int64_t>(onnx_node_attr.i());
(void)prim_c->AddAttr("sample_size", MakeValue<int64_t>(sample_size));
} else if (attribute_name == "dtype") {
auto onnx_dtype = static_cast<onnx::TensorProto_DataType>(onnx_node_attr.i());
auto data_type = OnnxNodeParser::GetDataTypeFromOnnx(onnx_dtype);
(void)prim_c->AddAttr("dtype", MakeValue<int64_t>(static_cast<int64_t>(data_type)));
}
}
return prim->GetPrim();
}
OnnxNodeRegistrar g_onnxMultinomialParser("Multinomial", new OnnxMultinomialParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_MULTINOMIAL_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_MULTINOMIAL_PARSER_H_
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore {
namespace lite {
class OnnxMultinomialParser : public OnnxNodeParser {
public:
OnnxMultinomialParser() : OnnxNodeParser("Multinomial") {}
~OnnxMultinomialParser() override = default;
PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_MULTINOMIAL_PARSER_H_

View File

@ -0,0 +1,45 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_mvn_parser.h"
#include <memory>
#include <vector>
#include "tools/converter/ops/ops_def.h"
#include "ir/value.h"
namespace mindspore {
namespace lite {
PrimitiveCPtr OnnxMVNParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto prim = std::make_unique<lite::MVN>();
if (prim == nullptr) {
MS_LOG(ERROR) << "new MVN prim failed.";
return nullptr;
}
(void)prim->AddAttr(mindspore::ops::kOriginalFormat, MakeValue<int64_t>(mindspore::Format::NCHW));
std::vector<int64_t> axes = {};
for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "axes") {
const int &size = onnx_node_attr.ints_size();
for (int i = 0; i < size; ++i) {
axes.push_back(onnx_node_attr.ints(i));
}
(void)prim->AddAttr("axes", MakeValue(axes));
}
}
return prim;
}
OnnxNodeRegistrar g_onnxMeanVarianceNormalizationParser("MeanVarianceNormalization", new OnnxMVNParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_MVN_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_MVN_PARSER_H_
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore {
namespace lite {
class OnnxMVNParser : public OnnxNodeParser {
public:
OnnxMVNParser() : OnnxNodeParser("MVN") {}
~OnnxMVNParser() override = default;
PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_MVN_PARSER_H_

View File

@ -33,6 +33,7 @@
#include "ops/cast.h"
#include "ops/gather.h"
#include "ops/concat.h"
#include "ops/reshape.h"
#include "ops/tuple_get_item.h"
#include "tools/common/tensor_util.h"
#include "frontend/operator/ops.h"
@ -1058,6 +1059,28 @@ CNodePtr GenCastNode(const FuncGraphPtr &graph, const AnfNodePtr &input_node, co
return cast_cnode;
}
CNodePtr GenReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &shape,
const std::string &cnode_name) {
MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
MS_CHECK_TRUE_RET(input_node != nullptr, nullptr);
auto reshape_prim = std::make_shared<ops::Reshape>();
if (reshape_prim == nullptr) {
MS_LOG(ERROR) << "create reshape failed.";
return nullptr;
}
auto prim_c = reshape_prim->GetPrim();
prim_c->set_attr("shape", MakeValue(shape));
ValueNodePtr value_node = NewValueNode(prim_c);
MS_CHECK_TRUE_MSG(value_node != nullptr, nullptr, "Create value_node return nullptr");
auto new_shape_node = opt::BuildIntVecParameterNode(func_graph, shape, cnode_name + "_shape");
MS_CHECK_TRUE_MSG(new_shape_node != nullptr, nullptr, "Create shape parameter return nullptr");
std::vector<AnfNodePtr> op_inputs = {value_node, input_node, new_shape_node};
auto reshape_cnode = func_graph->NewCNode(op_inputs);
MS_CHECK_TRUE_MSG(reshape_cnode != nullptr, nullptr, "Create cnode return nullptr");
reshape_cnode->set_fullname_with_scope(cnode_name);
return reshape_cnode;
}
CNodePtr GenGatherNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &indices,
const std::string &cnode_name, const std::vector<int> &axis) {
if (func_graph == nullptr || input_node == nullptr) {

View File

@ -133,6 +133,9 @@ CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu
CNodePtr GenCastNode(const FuncGraphPtr &graph, const AnfNodePtr &input_node, const std::string &cnode_name,
const TypeId dst_type, const AbstractBasePtr &abstract);
CNodePtr GenReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &shape,
const std::string &cnode_name);
CNodePtr GenGatherNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &indices,
const std::string &cnode_name, const std::vector<int> &axis = {0});