!56531 增加lite onnx parser 2(new)
Merge pull request !56531 from zhangdanyang/0710_master
This commit is contained in:
commit
a4e0b473a9
|
@ -6,6 +6,7 @@
|
|||
mindspore/mindspore/core/mindrt/src/thread/actor_threadpool.cc:mindspore::ActorWorker::RunWithSpin
|
||||
mindspore/mindspore/lite/src/common/ops/primitive_c.cc:mindspore::lite::PrimitiveC::Create
|
||||
mindspore/mindspore/lite/src/extendrt/convert/runtime_convert.cc:RuntimeConvert
|
||||
mindspore/mindspore/lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc:mindspore::lite::OnnxInputAdjust::Adjust
|
||||
mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc:mindspore::dataset::CsvOp::CsvParser::InitCsvParser
|
||||
mindspore/mindspore/lite/tools/converter/graphdef_transform.cc:mindspore::lite::GraphDefTransform::Transform
|
||||
mindspore/mindspore/lite/providers/nnie_proposal/src/proposal.cc:mindspore::proposal::Rpn
|
||||
|
|
|
@ -452,6 +452,7 @@ constexpr const char kNameReduceLogSumExp[] = "ReduceLogSumExp";
|
|||
constexpr const char kNameReduceLogSum[] = "ReduceLogSum";
|
||||
constexpr const char kNameSize[] = "Size";
|
||||
constexpr const char kNameTfIdfVectorizer[] = "TfIdfVectorizer";
|
||||
constexpr const char kNameMVNV2[] = "MVNV2";
|
||||
|
||||
class OpAdapterDesc;
|
||||
|
||||
|
|
|
@ -249,4 +249,11 @@ OUTPUT_MAP(SoftmaxGradExt) = {{0, OUTPUT_DESC(y)}};
|
|||
ATTR_MAP(SoftmaxGradExt) = {{"axis", ATTR_DESC(axes, AnyTraits<int64_t>(), AnyTraits<int64_t>())},
|
||||
{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>(), AnyTraits<bool>())}};
|
||||
REG_ADPT_DESC(SoftmaxGradExt, kSoftmaxGradExtOpName, ADPT_DESC(SoftmaxGradExt))
|
||||
|
||||
// MVNV2
|
||||
INPUT_MAP(MVNV2) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(MVNV2) = {{"eps", ATTR_DESC(eps, AnyTraits<float>())},
|
||||
{"axes", ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
|
||||
OUTPUT_MAP(MVNV2) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(MVNV2, kNameMVNV2, ADPT_DESC(MVNV2))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -115,4 +115,7 @@ DECLARE_OP_USE_OUTPUT(SoftmaxGradExt)
|
|||
|
||||
DECLARE_OP_ADAPTER(ConfusionSoftmaxGrad)
|
||||
DECLARE_OP_USE_OUTPUT(ConfusionSoftmaxGrad)
|
||||
|
||||
DECLARE_OP_ADAPTER(MVNV2)
|
||||
DECLARE_OP_USE_OUTPUT(MVNV2)
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_IMAGE_OPS_DECLARE_H_
|
||||
|
|
|
@ -116,15 +116,24 @@ TypePtr MultinomialInferType(const PrimitivePtr &prim, const std::vector<Abstrac
|
|||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_type);
|
||||
auto num_samples_type = input_args[1]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(num_samples_type);
|
||||
const std::set valid_types_1 = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kInt32,
|
||||
kInt64, kUInt8, kUInt16, kUInt32, kUInt64};
|
||||
const std::set valid_types_2 = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types_1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("num_samples", num_samples_type, valid_types_2, prim_name);
|
||||
auto dtype = GetValue<TypePtr>(prim->GetAttr("dtype"));
|
||||
const std::set valid_types_3 = {kInt32, kInt64};
|
||||
auto out_type = CheckAndConvertUtils::CheckTypeValid("dtype", dtype->cast<TypePtr>(), valid_types_3, prim->name());
|
||||
auto dtype_attr = prim->GetAttr("dtype");
|
||||
MS_EXCEPTION_IF_NULL(dtype_attr);
|
||||
if (!dtype_attr->isa<Type>()) {
|
||||
TypeId type_id = static_cast<TypeId>(GetValue<int64_t>(dtype_attr));
|
||||
auto out_type = CheckAndConvertUtils::CheckTypeValid("dtype", TypeIdToType(type_id), valid_types_3, prim->name());
|
||||
return out_type;
|
||||
}
|
||||
auto out_type =
|
||||
CheckAndConvertUtils::CheckTypeValid("dtype", dtype_attr->cast<TypePtr>(), valid_types_3, prim->name());
|
||||
return out_type;
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/adapter/acl/mapper/multinomial_mapper.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
||||
#include "src/common/log_util.h"
|
||||
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "tools/converter/adapter/acl/common/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS MultinomialMapper::Mapper(const CNodePtr &cnode) {
|
||||
ValueNodePtr value_node = nullptr;
|
||||
PrimitivePtr src_prim = nullptr;
|
||||
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get primitive from cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ops::Multinomial multinomial_op;
|
||||
auto dst_prim = multinomial_op.GetPrim();
|
||||
CHECK_NULL_RETURN(dst_prim);
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
|
||||
auto dst_type = src_prim->GetAttr(ops::kOutputDType);
|
||||
if (dst_type != nullptr) {
|
||||
if (!dst_type->isa<Type>()) {
|
||||
auto type_id = static_cast<TypeId>(GetValue<int64_t>(dst_type));
|
||||
dst_prim->AddAttr("dtype", TypeIdToType(type_id));
|
||||
} else {
|
||||
dst_prim->AddAttr("dtype", TypeIdToType(acl::GetTypeFromNode(cnode)));
|
||||
}
|
||||
}
|
||||
if (src_prim->HasAttr(ops::kSeed)) {
|
||||
auto seed_attr = src_prim->GetAttr(ops::kSeed);
|
||||
dst_prim->AddAttr(ops::kSeed, seed_attr);
|
||||
dst_prim->AddAttr(ops::kSeed2, seed_attr);
|
||||
}
|
||||
if (src_prim->HasAttr(ops::kSeed2)) {
|
||||
auto seed2_attr = src_prim->GetAttr(ops::kSeed2);
|
||||
dst_prim->AddAttr(ops::kSeed2, seed2_attr);
|
||||
}
|
||||
auto func_graph = cnode->func_graph();
|
||||
CHECK_NULL_RETURN(func_graph);
|
||||
value_node->set_value(dst_prim);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_MAPPER(kNameMultinomial, MultinomialMapper)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MULTINOMIAL_MAPPER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MULTINOMIAL_MAPPER_H_
|
||||
|
||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
|
||||
#include "ops/multinomial.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
using mindspore::ops::kNameMultinomial;
|
||||
|
||||
class MultinomialMapper : public PrimitiveMapper {
|
||||
public:
|
||||
MultinomialMapper() : PrimitiveMapper(kNameMultinomial) {}
|
||||
|
||||
~MultinomialMapper() override = default;
|
||||
|
||||
STATUS Mapper(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MULTINOMIAL_MAPPER_H_
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/adapter/acl/mapper/mvn_mapper.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
||||
#include "src/common/log_util.h"
|
||||
#include "mindspore/core/ops/op_name.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
const auto kNameMVN = "MVN";
|
||||
|
||||
STATUS MVNMapper::Mapper(const CNodePtr &cnode) {
|
||||
ValueNodePtr value_node = nullptr;
|
||||
PrimitivePtr src_prim = nullptr;
|
||||
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get primitive from cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto dst_prim = std::make_shared<acl::MVNV2>();
|
||||
CHECK_NULL_RETURN(dst_prim);
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
auto axes_ptr = src_prim->GetAttr(ops::kAxes);
|
||||
if (axes_ptr != nullptr) {
|
||||
dst_prim->AddAttr(ops::kAxes, MakeValue(axes_ptr));
|
||||
}
|
||||
|
||||
int64_t node_format = Format::NCHW;
|
||||
dst_prim->AddAttr(ops::kFormat, MakeValue(node_format));
|
||||
|
||||
value_node->set_value(dst_prim);
|
||||
auto func_graph = cnode->func_graph();
|
||||
CHECK_NULL_RETURN(func_graph);
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_MAPPER(kNameMVN, MVNMapper)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MVN_MAPPER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MVN_MAPPER_H_
|
||||
|
||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
|
||||
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class MVNMapper : public PrimitiveMapper {
|
||||
public:
|
||||
MVNMapper() : PrimitiveMapper(acl::kNameMVN) {}
|
||||
~MVNMapper() override = default;
|
||||
|
||||
STATUS Mapper(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MVN_MAPPER_H_
|
|
@ -73,6 +73,8 @@ ADD_CONVERTER_TBE_OP(Shrink)
|
|||
ADD_CONVERTER_TBE_OP(ReduceLogSumExp)
|
||||
ADD_CONVERTER_TBE_OP(ReduceLogSum)
|
||||
ADD_CONVERTER_TBE_OP(SplitV)
|
||||
ADD_CONVERTER_TBE_OP(MVN)
|
||||
ADD_CONVERTER_TBE_OP(MVNV2)
|
||||
} // namespace acl
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -53,6 +53,7 @@ ADD_CONVERTER_ONLY_OP(MegatronMakeViewlessTensor);
|
|||
ADD_CONVERTER_ONLY_OP(MegatronScaledMaskedSoftmax);
|
||||
ADD_CONVERTER_ONLY_OP(Shrink);
|
||||
ADD_CONVERTER_ONLY_OP(TfIdfVectorizer);
|
||||
ADD_CONVERTER_ONLY_OP(MVN);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "ops/concat.h"
|
||||
#include "ops/reshape.h"
|
||||
#include "ops/cast.h"
|
||||
#include "ops/multinomial.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
|
@ -38,31 +39,7 @@
|
|||
namespace mindspore::lite {
|
||||
namespace {
|
||||
const std::vector<int> kNH2NCPerm = {0, 3, 1, 2};
|
||||
const int kInputNum2 = 2;
|
||||
const int kInputNum3 = 3;
|
||||
const int kInputNum4 = 4;
|
||||
|
||||
CNodePtr NewReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &shape) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(input_node != nullptr);
|
||||
auto reshape_prim = std::make_shared<ops::Reshape>();
|
||||
if (reshape_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "create reshape failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto prim_c = reshape_prim->GetPrim();
|
||||
prim_c->set_attr("shape", MakeValue(shape));
|
||||
ValueNodePtr value_node = NewValueNode(prim_c);
|
||||
MS_CHECK_TRUE_MSG(value_node != nullptr, nullptr, "Create valuenode return nullptr");
|
||||
auto new_parameter = opt::BuildIntVecParameterNode(func_graph, shape, input_node->fullname_with_scope() + "_reshape");
|
||||
MS_CHECK_TRUE_MSG(new_parameter != nullptr, nullptr, "Create parameter return nullptr");
|
||||
new_parameter->set_name(input_node->fullname_with_scope() + "_reshape");
|
||||
std::vector<AnfNodePtr> op_inputs = {value_node, input_node, new_parameter};
|
||||
auto reshape = func_graph->NewCNode(op_inputs);
|
||||
MS_CHECK_TRUE_MSG(reshape != nullptr, nullptr, "Create cnode return nullptr");
|
||||
reshape->set_fullname_with_scope(input_node->fullname_with_scope() + "_reshape");
|
||||
return reshape;
|
||||
}
|
||||
constexpr int kInputNum4 = 4;
|
||||
|
||||
STATUS AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int input_num,
|
||||
const std::string &attr_name) {
|
||||
|
@ -494,8 +471,8 @@ STATUS AdjustROIAlign(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
|||
MS_LOG(INFO) << "RoiAlign input size is not 3, does not need to adjust.";
|
||||
return RET_OK;
|
||||
}
|
||||
auto rois = cnode->inputs()[kInputNum2];
|
||||
auto batch_indices = cnode->inputs()[kInputNum3];
|
||||
auto rois = cnode->inputs()[THIRD_INPUT];
|
||||
auto batch_indices = cnode->inputs()[FOURTH_INPUT];
|
||||
auto abstract = batch_indices->abstract();
|
||||
auto cast_node =
|
||||
opt::GenCastNode(func_graph, batch_indices, cnode->fullname_with_scope() + "_Cast", kNumberTypeFloat32, abstract);
|
||||
|
@ -517,7 +494,7 @@ STATUS AdjustROIAlign(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
std::vector<int> shape = {batch_shape_num, 1};
|
||||
auto new_reshape_node = NewReshapeNode(func_graph, cast_node, shape);
|
||||
auto new_reshape_node = opt::GenReshapeNode(func_graph, cast_node, shape, cnode->fullname_with_scope() + "_Reshape");
|
||||
if (new_reshape_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Create reshape node failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -545,6 +522,32 @@ STATUS AdjustROIAlign(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
|||
opt::UpdateManager(func_graph);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS AdjustMultinomial(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool *need_update_manager) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
auto multinomial_node = ops::GetOperator<ops::Multinomial>(cnode->input(0));
|
||||
MS_CHECK_TRUE_RET(multinomial_node != nullptr, RET_ERROR);
|
||||
|
||||
auto prim = multinomial_node->GetPrim();
|
||||
MS_CHECK_TRUE_RET(prim != nullptr, RET_ERROR);
|
||||
|
||||
MS_CHECK_TRUE_RET(prim->GetAttr("sample_size") != nullptr, RET_ERROR);
|
||||
int64_t sample_size = GetValue<int64_t>(prim->GetAttr("sample_size"));
|
||||
auto num_samples_val = static_cast<int32_t>(sample_size);
|
||||
|
||||
auto sample_parameter_ptr =
|
||||
mindspore::opt::BuildIntValueParameterNode(func_graph, num_samples_val, "num_samples", true);
|
||||
MS_CHECK_TRUE_RET(sample_parameter_ptr != nullptr, RET_ERROR);
|
||||
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
new_inputs.push_back(cnode->inputs()[FIRST_INPUT]);
|
||||
new_inputs.push_back(cnode->inputs()[SECOND_INPUT]);
|
||||
new_inputs.push_back(static_cast<AnfNodePtr>(sample_parameter_ptr));
|
||||
cnode->set_inputs(new_inputs);
|
||||
*need_update_manager = true;
|
||||
opt::UpdateManager(func_graph);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph, const converter::ConverterParameters &flag) {
|
||||
|
@ -592,6 +595,8 @@ bool OnnxInputAdjust::Adjust(const FuncGraphPtr &func_graph, const converter::Co
|
|||
status = AdjustUnsqueeze(&need_update_manager, cnode);
|
||||
} else if (opt::CheckPrimitiveType(node, prim::kPrimROIAlign)) {
|
||||
status = AdjustROIAlign(func_graph, cnode);
|
||||
} else if (opt::CheckPrimitiveType(node, prim::kPrimMultinomial)) {
|
||||
status = AdjustMultinomial(func_graph, cnode, &need_update_manager);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_multinomial_parser.h"
|
||||
#include <memory>
|
||||
#include "ops/multinomial.h"
|
||||
#include "mindapi/ir/type.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
PrimitiveCPtr OnnxMultinomialParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
|
||||
auto prim = std::make_unique<ops::Multinomial>();
|
||||
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
|
||||
auto prim_c = prim->GetPrim();
|
||||
MS_CHECK_TRUE_RET(prim_c != nullptr, nullptr);
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "seed") {
|
||||
prim->set_seed(onnx_node_attr.f());
|
||||
} else if (attribute_name == "sample_size") {
|
||||
int64_t sample_size = static_cast<int64_t>(onnx_node_attr.i());
|
||||
(void)prim_c->AddAttr("sample_size", MakeValue<int64_t>(sample_size));
|
||||
} else if (attribute_name == "dtype") {
|
||||
auto onnx_dtype = static_cast<onnx::TensorProto_DataType>(onnx_node_attr.i());
|
||||
auto data_type = OnnxNodeParser::GetDataTypeFromOnnx(onnx_dtype);
|
||||
(void)prim_c->AddAttr("dtype", MakeValue<int64_t>(static_cast<int64_t>(data_type)));
|
||||
}
|
||||
}
|
||||
|
||||
return prim->GetPrim();
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxMultinomialParser("Multinomial", new OnnxMultinomialParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_MULTINOMIAL_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_MULTINOMIAL_PARSER_H_
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxMultinomialParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxMultinomialParser() : OnnxNodeParser("Multinomial") {}
|
||||
~OnnxMultinomialParser() override = default;
|
||||
|
||||
PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_MULTINOMIAL_PARSER_H_
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/onnx/onnx_mvn_parser.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "ir/value.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
PrimitiveCPtr OnnxMVNParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
|
||||
auto prim = std::make_unique<lite::MVN>();
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "new MVN prim failed.";
|
||||
return nullptr;
|
||||
}
|
||||
(void)prim->AddAttr(mindspore::ops::kOriginalFormat, MakeValue<int64_t>(mindspore::Format::NCHW));
|
||||
std::vector<int64_t> axes = {};
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
if (onnx_node_attr.name() == "axes") {
|
||||
const int &size = onnx_node_attr.ints_size();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
axes.push_back(onnx_node_attr.ints(i));
|
||||
}
|
||||
(void)prim->AddAttr("axes", MakeValue(axes));
|
||||
}
|
||||
}
|
||||
return prim;
|
||||
}
|
||||
OnnxNodeRegistrar g_onnxMeanVarianceNormalizationParser("MeanVarianceNormalization", new OnnxMVNParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_MVN_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_MVN_PARSER_H_
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxMVNParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxMVNParser() : OnnxNodeParser("MVN") {}
|
||||
~OnnxMVNParser() override = default;
|
||||
|
||||
PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_MVN_PARSER_H_
|
|
@ -33,6 +33,7 @@
|
|||
#include "ops/cast.h"
|
||||
#include "ops/gather.h"
|
||||
#include "ops/concat.h"
|
||||
#include "ops/reshape.h"
|
||||
#include "ops/tuple_get_item.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
@ -1058,6 +1059,28 @@ CNodePtr GenCastNode(const FuncGraphPtr &graph, const AnfNodePtr &input_node, co
|
|||
return cast_cnode;
|
||||
}
|
||||
|
||||
CNodePtr GenReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &shape,
|
||||
const std::string &cnode_name) {
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
|
||||
MS_CHECK_TRUE_RET(input_node != nullptr, nullptr);
|
||||
auto reshape_prim = std::make_shared<ops::Reshape>();
|
||||
if (reshape_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "create reshape failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto prim_c = reshape_prim->GetPrim();
|
||||
prim_c->set_attr("shape", MakeValue(shape));
|
||||
ValueNodePtr value_node = NewValueNode(prim_c);
|
||||
MS_CHECK_TRUE_MSG(value_node != nullptr, nullptr, "Create value_node return nullptr");
|
||||
auto new_shape_node = opt::BuildIntVecParameterNode(func_graph, shape, cnode_name + "_shape");
|
||||
MS_CHECK_TRUE_MSG(new_shape_node != nullptr, nullptr, "Create shape parameter return nullptr");
|
||||
std::vector<AnfNodePtr> op_inputs = {value_node, input_node, new_shape_node};
|
||||
auto reshape_cnode = func_graph->NewCNode(op_inputs);
|
||||
MS_CHECK_TRUE_MSG(reshape_cnode != nullptr, nullptr, "Create cnode return nullptr");
|
||||
reshape_cnode->set_fullname_with_scope(cnode_name);
|
||||
return reshape_cnode;
|
||||
}
|
||||
|
||||
CNodePtr GenGatherNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &indices,
|
||||
const std::string &cnode_name, const std::vector<int> &axis) {
|
||||
if (func_graph == nullptr || input_node == nullptr) {
|
||||
|
|
|
@ -133,6 +133,9 @@ CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu
|
|||
CNodePtr GenCastNode(const FuncGraphPtr &graph, const AnfNodePtr &input_node, const std::string &cnode_name,
|
||||
const TypeId dst_type, const AbstractBasePtr &abstract);
|
||||
|
||||
CNodePtr GenReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &shape,
|
||||
const std::string &cnode_name);
|
||||
|
||||
CNodePtr GenGatherNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector<int> &indices,
|
||||
const std::string &cnode_name, const std::vector<int> &axis = {0});
|
||||
|
||||
|
|
Loading…
Reference in New Issue