From b57841785c5097c56c41832f92060b0959cc537c Mon Sep 17 00:00:00 2001 From: yeyunpeng Date: Mon, 11 Jan 2021 10:27:07 +0800 Subject: [PATCH] add tf parser --- mindspore/lite/schema/model.fbs | 2 + mindspore/lite/schema/ops.fbs | 6 + mindspore/lite/src/ops/invert_permutation.cc | 58 ++++++++++ mindspore/lite/src/ops/invert_permutation.h | 43 +++++++ mindspore/lite/src/ops/primitive_c.cc | 6 + mindspore/lite/src/ops/size.cc | 64 +++++++++++ mindspore/lite/src/ops/size.h | 43 +++++++ .../parser/tf/tf_arithmetic_self_parser.cc | 88 ++++++++++++++ .../parser/tf/tf_arithmetic_self_parser.h | 36 ++++++ .../converter/parser/tf/tf_conv_parser.cc | 4 - .../parser/tf/tf_crop_and_resize_parser.cc | 107 ++++++++++++++++++ .../parser/tf/tf_crop_and_resize_parser.h | 36 ++++++ .../parser/tf/tf_gather_nd_parser.cc | 68 +++++++++++ .../converter/parser/tf/tf_gather_nd_parser.h | 37 ++++++ .../parser/tf/tf_invert_permutation_parser.cc | 63 +++++++++++ .../parser/tf/tf_invert_permutation_parser.h | 37 ++++++ .../converter/parser/tf/tf_model_parser.cc | 34 +++++- .../converter/parser/tf/tf_model_parser.h | 4 +- .../tf/tf_non_max_suppression_parser.cc | 83 ++++++++++++++ .../parser/tf/tf_non_max_suppression_parser.h | 37 ++++++ .../converter/parser/tf/tf_pad_parser.cc | 89 +++++++++++++++ .../tools/converter/parser/tf/tf_pad_parser.h | 37 ++++++ .../converter/parser/tf/tf_reverse_parser.cc | 86 ++++++++++++++ .../converter/parser/tf/tf_reverse_parser.h | 37 ++++++ .../converter/parser/tf/tf_size_parser.cc | 63 +++++++++++ .../converter/parser/tf/tf_size_parser.h | 37 ++++++ .../converter/parser/tf/tf_slice_parser.cc | 105 +++++++++++++++++ .../converter/parser/tf/tf_slice_parser.h | 37 ++++++ .../converter/parser/tf/tf_topk_parser.cc | 76 +++++++++++++ .../converter/parser/tf/tf_topk_parser.h | 37 ++++++ 30 files changed, 1450 insertions(+), 10 deletions(-) create mode 100644 mindspore/lite/src/ops/invert_permutation.cc create mode 100644 mindspore/lite/src/ops/invert_permutation.h create mode 100644 mindspore/lite/src/ops/size.cc create mode 100644 mindspore/lite/src/ops/size.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_crop_and_resize_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_crop_and_resize_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_gather_nd_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_gather_nd_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_invert_permutation_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_invert_permutation_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_non_max_suppression_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_non_max_suppression_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_pad_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_pad_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_reverse_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_reverse_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_size_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_size_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_slice_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_slice_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_topk_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_topk_parser.h diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 56fe89c4df9..7d45162d36f 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -267,6 +267,8 @@ union PrimitiveType { GeLU, Gru, NonZero, + InvertPermutation, + Size, } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index cc5fedb3091..72b6a503c15 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -1240,3 +1240,9 @@ table GeLU { table NonZero { } + +table InvertPermutation { +} + +table Size { +} \ No newline at end of file diff --git a/mindspore/lite/src/ops/invert_permutation.cc b/mindspore/lite/src/ops/invert_permutation.cc new file mode 100644 index 00000000000..4da78d86344 --- /dev/null +++ b/mindspore/lite/src/ops/invert_permutation.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/invert_permutation.h" +#include "src/common/common.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { + +#ifdef PRIMITIVE_WRITEABLE +#else +int InvertPermutation::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateSize(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_InvertPermutation, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +PrimitiveC *InvertPermutationCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry InvertPermutationRegistry(schema::PrimitiveType_InvertPermutation, InvertPermutationCreator); +#endif + +int InvertPermutation::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(this->primitive_ != nullptr); + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_format(input->format()); + output->set_data_type(input->data_type()); + if (!infer_flag()) { + return RET_INFER_INVALID; + } + output->set_shape(input->shape()); + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/invert_permutation.h b/mindspore/lite/src/ops/invert_permutation.h new file mode 100644 index 00000000000..a79f4814a44 --- /dev/null +++ b/mindspore/lite/src/ops/invert_permutation.h @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_INVERTPERMUTATION_H_ +#define LITE_MINDSPORE_LITE_C_OPS_INVERTPERMUTATION_H_ + +#include +#include +#include + +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class InvertPermutation : public PrimitiveC { + public: + InvertPermutation() = default; + ~InvertPermutation() = default; +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(InvertPermutation, PrimitiveC); + explicit InvertPermutation(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#else + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_INVERTPERMUTATION_H_ diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 991ca08ed10..badc5e85362 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -164,6 +164,8 @@ #include "src/ops/select.h" #include "src/ops/gelu.h" #include "src/ops/gru.h" +#include "src/ops/size.h" +#include "src/ops/invert_permutation.h" #ifdef SUPPORT_TRAIN #include "src/ops/neg_grad.h" @@ -1004,6 +1006,10 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) Select(primitive); case schema::PrimitiveType_Gru: return new (std::nothrow) Gru(primitive); + case schema::PrimitiveType_Size: + return new (std::nothrow) Size(primitive); + case schema::PrimitiveType_InvertPermutation: + return new (std::nothrow) InvertPermutation(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: return new (std::nothrow) ActivationGrad(primitive); diff --git a/mindspore/lite/src/ops/size.cc b/mindspore/lite/src/ops/size.cc new file mode 100644 index 00000000000..189128fb2ba --- /dev/null +++ b/mindspore/lite/src/ops/size.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/size.h" +#include "src/common/common.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +constexpr int kShapeInputNum = 1; +constexpr int kShapeOutputNum = 1; +#ifdef PRIMITIVE_WRITEABLE +#else +int Size::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateSize(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Size, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +PrimitiveC *SizeCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry SizeRegistry(schema::PrimitiveType_Size, SizeCreator); +#endif + +int Size::InferShape(std::vector inputs_, std::vector outputs_) { + if (inputs_.size() != kShapeInputNum) { + MS_LOG(ERROR) << "inputs to Shape operator should be 1, but " << inputs_.size() << " is given."; + return RET_ERROR; + } + if (outputs_.size() != kShapeOutputNum) { + MS_LOG(ERROR) << "outputs to Shape operator should be 1, but " << outputs_.size() << " is given."; + return RET_ERROR; + } + auto in_tensor = inputs_.front(); + auto out_tensor = outputs_.front(); + out_tensor->set_data_type(kNumberTypeInt32); + out_tensor->set_format(in_tensor->format()); + if (!infer_flag()) { + return RET_INFER_INVALID; + } + std::vector out_shape; + out_shape.push_back(static_cast(in_tensor->shape().size())); + out_tensor->set_shape(out_shape); + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/size.h b/mindspore/lite/src/ops/size.h new file mode 100644 index 00000000000..48a3ac21525 --- /dev/null +++ b/mindspore/lite/src/ops/size.h @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_SIZE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SIZE_H_ + +#include +#include +#include + +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class Size : public PrimitiveC { + public: + Size() = default; + ~Size() = default; +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(Size, PrimitiveC); + explicit Size(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#else + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_SIZE_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.cc new file mode 100644 index 00000000000..76d5812d826 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_arithmetic_self_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { + +template +int CreateOperator(const std::unique_ptr &primitive, schema::PrimitiveType type) { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + primitive->value.type = type; + primitive->value.value = attr.release(); + return RET_OK; +} + +STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF ArithmeticParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + + int status = RET_ERROR; + if (tf_op.op() == "Ceil") { + status = CreateOperator(primitive, schema::PrimitiveType_Ceil); + } else if (tf_op.op() == "Exp") { + status = CreateOperator(primitive, schema::PrimitiveType_Exp); + } else if (tf_op.op() == "Floor") { + status = CreateOperator(primitive, schema::PrimitiveType_Floor); + } else if (tf_op.op() == "Log") { + status = CreateOperator(primitive, schema::PrimitiveType_Log); + } else if (tf_op.op() == "Sqrt") { + status = CreateOperator(primitive, schema::PrimitiveType_Sqrt); + } + if (status != RET_OK) { + return status; + } + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + return status; +} +TFNodeRegistrar g_tfCeilParser("Ceil", new TFArithmeticSelfParser()); +TFNodeRegistrar g_tfExpParser("Exp", new TFArithmeticSelfParser()); +TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser()); +TFNodeRegistrar g_tfLogParser("Log", new TFArithmeticSelfParser()); +TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFArithmeticSelfParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.h new file mode 100644 index 00000000000..16acb588b7b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_SELF_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_SELF_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFArithmeticSelfParser : public TFNodeParser { + public: + TFArithmeticSelfParser() = default; + ~TFArithmeticSelfParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_SELF_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc index 33521fb4166..0e896621824 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc @@ -45,10 +45,6 @@ STATUS TFConvParser::Parse(const tensorflow::NodeDef &tf_op, attr->group = 1; attr->format = TensorFlowUtils::ParseNodeFormat(tf_op); - if (attr->format == schema::Format_NCHW) { - MS_LOG(ERROR) << "TF Conv2D with data_format=NCHW is not supported now"; - return RET_ERROR; - } std::vector dilations(2); auto status = ParseDilations(tf_op, attr->format, &dilations); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_crop_and_resize_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_crop_and_resize_parser.cc new file mode 100644 index 00000000000..744df19bffc --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_crop_and_resize_parser.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_crop_and_resize_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFCropAndResizeParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF ResizeParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + tensorflow::AttrValue attr_value; + attr->format = schema::Format_NHWC; + + attr->coordinateTransformMode = schema::CoordinateTransformMode_CROP_AND_RESIZE; + + // align_corners + if (TensorFlowUtils::FindAttrValue(tf_op, "align_corners", &attr_value)) { + attr->alignCorners = true; + } + + // extrapolation_value + if (!TensorFlowUtils::FindAttrValue(tf_op, "extrapolation_value", &attr_value)) { + MS_LOG(ERROR) << "The align_corners attr should be specified"; + return RET_ERROR; + } + attr->extrapolationValue = attr_value.f(); + + // method + if (!TensorFlowUtils::FindAttrValue(tf_op, "method", &attr_value)) { + MS_LOG(ERROR) << "The align_corners attr should be specified"; + return RET_ERROR; + } + if (attr_value.s() == "bilinear") { + attr->method = schema::ResizeMethod_LINEAR; + } else if (attr_value.s() == "nearest_neighbor") { + attr->method = schema::ResizeMethod_NEAREST; + } else { + MS_LOG(ERROR) << "Do not support method: " << attr_value.s(); + } + + primitive->value.type = schema::PrimitiveType_Resize; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + status = AddOpInput(tf_op, 1, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + status = AddOpInput(tf_op, 2, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + status = AddOpInput(tf_op, 3, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + return status; +} +TFNodeRegistrar g_tfCropAndResizeParser("CropAndResize", new TFCropAndResizeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_crop_and_resize_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_crop_and_resize_parser.h new file mode 100644 index 00000000000..61645df1c5c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_crop_and_resize_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CROP_AND_RESIZE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CROP_AND_RESIZE_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFCropAndResizeParser : public TFNodeParser { + public: + TFCropAndResizeParser() = default; + ~TFCropAndResizeParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CROP_AND_RESIZE_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_gather_nd_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_gather_nd_parser.cc new file mode 100644 index 00000000000..07b796b29bb --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_gather_nd_parser.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_gather_nd_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFGatherNDParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF GatherNDParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + primitive->value.type = schema::PrimitiveType_GatherNd; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + status = AddOpInput(tf_op, 1, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + return status; +} +TFNodeRegistrar g_tfGatherNDParser("GatherNd", new TFGatherNDParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_gather_nd_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_gather_nd_parser.h new file mode 100644 index 00000000000..c082cff4dfb --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_gather_nd_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_GATHER_ND_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_GATHER_ND_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFGatherNDParser : public TFNodeParser { + public: + TFGatherNDParser() = default; + ~TFGatherNDParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_GATHER_ND_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_invert_permutation_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_invert_permutation_parser.cc new file mode 100644 index 00000000000..08c6e27d0d6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_invert_permutation_parser.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_invert_permutation_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFInvertPermutationParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF SizeParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + primitive->value.type = schema::PrimitiveType_InvertPermutation; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + return status; +} +TFNodeRegistrar g_tfInvertPermutationParser("InvertPermutation", new TFInvertPermutationParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_invert_permutation_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_invert_permutation_parser.h new file mode 100644 index 00000000000..9510cae4548 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_invert_permutation_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_INVERT_PERMUTATION_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_INVERT_PERMUTATION_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFInvertPermutationParser : public TFNodeParser { + public: + TFInvertPermutationParser() = default; + ~TFInvertPermutationParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_INVERT_PERMUTATION_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index bb73cd0332e..48d0547c7e8 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -175,8 +175,9 @@ STATUS TFModelParser::ConvertConstVariant(const tensorflow::TensorProto &tensor_ return RET_OK; } -STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, - const ParameterPtr ¶meter, std::vector *shape_vector) { +STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef &node_def, const tensorflow::AttrValue &attr_value, + const TypeId &type, const ParameterPtr ¶meter, + std::vector *shape_vector) { MS_ASSERT(parameter != nullptr); MS_ASSERT(shape_vector != nullptr); const tensorflow::TensorProto &tensor_proto = attr_value.tensor(); @@ -258,6 +259,23 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value } tensor_size = (*tensor_data).size(); param_value->SetTensorData(tensor_data, tensor_size); + } else if (type == kNumberTypeInt64) { + param_value->set_tensor_type(kNumberTypeInt32); + auto *tensor_data = new (std::nothrow) int[shape_size]; + if (tensor_data == nullptr) { + MS_LOG(ERROR) << "new data failed"; + return RET_ERROR; + } + const auto origin_data = reinterpret_cast(tensor_proto.tensor_content().data()); + for (int i = 0; i < shape_size; ++i) { + if (origin_data[i] > static_cast(INT32_MAX) || origin_data[i] < static_cast(INT32_MIN)) { + MS_LOG(WARNING) << "int64 data " << origin_data[i] << "too big to fit into int32"; + tensor_data[i] = origin_data[i] > 0 ? INT32_MAX : INT32_MIN; + } else { + tensor_data[i] = static_cast(origin_data[i]); + } + } + param_value->SetTensorData(tensor_data, shape_size * sizeof(int32_t)); } else { MS_LOG(ERROR) << "Unsupport dataType: " << type; return RET_ERROR; @@ -266,7 +284,15 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value std::vector param_shape(shape_vector->begin(), shape_vector->end()); param_value->set_tensor_shape(param_shape); param_value->set_tensor_type(type); - param_value->set_format(schema::Format::Format_NHWC); + if (TensorFlowUtils::FindAttrValue(node_def, "data_format", const_cast(&attr_value))) { + auto format = mindspore::lite::TensorFlowUtils::ParseNodeFormat(node_def); + if (format == schema::Format_NUM_OF_FORMAT) { + MS_LOG(ERROR) << "Do not support data format: " << attr_value.s(); + } + param_value->set_format(format); + } else { + param_value->set_format(schema::Format::Format_NHWC); + } parameter->set_default_param(param_value); return RET_OK; } @@ -294,7 +320,7 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) { MS_LOG(INFO) << "Found value attr, means it has default value"; - auto status = ConvertConstTensor(attr_value, type, parameter, &shape_vector); + auto status = ConvertConstTensor(node, attr_value, type, parameter, &shape_vector); if (status != RET_OK) { return status; } diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h index eb77fbb6746..0f20d7cbe94 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -45,8 +45,8 @@ class TFModelParser : public ModelParser { private: STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, const ParamValueLitePtr ¶m_value); - STATUS ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, const ParameterPtr ¶meter, - std::vector *shape_vector); + STATUS ConvertConstTensor(const tensorflow::NodeDef &node_def, const tensorflow::AttrValue &attr_value, + const TypeId &type, const ParameterPtr ¶meter, std::vector *shape_vector); STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter, std::unordered_map *anf_node_map); STATUS ConvertGraphInputsAndConsts(const std::map &tf_graph_nodes, diff --git a/mindspore/lite/tools/converter/parser/tf/tf_non_max_suppression_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_non_max_suppression_parser.cc new file mode 100644 index 00000000000..c261e3e32f1 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_non_max_suppression_parser.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_non_max_suppression_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFNonMaxSuppressionParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF NonMaxSuppressionParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + attr->centerPointBox = 0; + primitive->value.type = schema::PrimitiveType_NonMaxSuppression; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + status = AddOpInput(tf_op, 1, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + status = AddOpInput(tf_op, 2, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + status = AddOpInput(tf_op, 3, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + status = AddOpInput(tf_op, 4, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + return status; +} +TFNodeRegistrar g_tfNonMaxSuppressionV3Parser("NonMaxSuppressionV3", new TFNonMaxSuppressionParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_non_max_suppression_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_non_max_suppression_parser.h new file mode 100644 index 00000000000..3471e80e74d --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_non_max_suppression_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_NON_MAX_SUPPRESSION_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_NON_MAX_SUPPRESSION_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFNonMaxSuppressionParser : public TFNodeParser { + public: + TFNonMaxSuppressionParser() = default; + ~TFNonMaxSuppressionParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_NON_MAX_SUPPRESSION_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_pad_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_pad_parser.cc new file mode 100644 index 00000000000..87e14126a58 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_pad_parser.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_pad_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFPadParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF PadParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + if (tf_op.op() == "Pad") { + attr->paddingMode = schema::PaddingMode_CONSTANT; + attr->constantValue = 0.0f; + + } else if (tf_op.op() == "MirrorPad") { + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "mode", &attr_value)) { + MS_LOG(ERROR) << "The axis attr should be specified"; + return RET_ERROR; + } + + if (attr_value.s() == "SYMMETRIC") { + attr->paddingMode = schema::PaddingMode_SYMMETRIC; + } else if (attr_value.s() == "REFLECT") { + attr->paddingMode = schema::PaddingMode_REFLECT; + } else { + MS_LOG(ERROR) << "padding mode:" << attr_value.s() << " don't support"; + return RET_ERROR; + } + } + primitive->value.type = schema::PrimitiveType_Pad; + primitive->value.value = attr.release(); + + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + status = AddOpInput(tf_op, 1, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + return status; +} +TFNodeRegistrar g_tfPadParser("Pad", new TFPadParser()); +TFNodeRegistrar g_tfMirrorPadParser("MirrorPad", new TFPadParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_pad_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_pad_parser.h new file mode 100644 index 00000000000..633b376b231 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_pad_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_PAD_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_PAD_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFPadParser : public TFNodeParser { + public: + TFPadParser() = default; + ~TFPadParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_PAD_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_reverse_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_reverse_parser.cc new file mode 100644 index 00000000000..ba41439a097 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_reverse_parser.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_reverse_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFReverseParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF ReverseParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + tensorflow::AttrValue attr_value; + auto axis = GetConstInputNode(tf_node_map, tf_op.input(1)); + if (axis == nullptr) { + MS_LOG(ERROR) << "Find axis failed"; + return RET_ERROR; + } + if (!TensorFlowUtils::FindAttrValue(*axis, "value", &attr_value)) { + MS_LOG(ERROR) << "The value attr should be specified"; + return RET_ERROR; + } + auto tensor_proto = attr_value.tensor(); + if (tensor_proto.int_val_size() > 0) { + for (int i = 0; i < tensor_proto.int_val_size(); ++i) { + attr->axis.push_back(tensor_proto.int_val(i)); + } + } else { + auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); + auto data = reinterpret_cast(tensor_proto.tensor_content().data()); + for (size_t i = 0; i < data_num; ++i) { + attr->axis.push_back(data[i]); + } + } + + primitive->value.type = schema::PrimitiveType_Reverse; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + return status; +} +TFNodeRegistrar g_tfReverseV2Parser("ReverseV2", new TFReverseParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_reverse_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_reverse_parser.h new file mode 100644 index 00000000000..9b99271b2ee --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_reverse_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REVERSE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REVERSE_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFReverseParser : public TFNodeParser { + public: + TFReverseParser() = default; + ~TFReverseParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REVERSE_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_size_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_size_parser.cc new file mode 100644 index 00000000000..2bcade2fc09 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_size_parser.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_size_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFSizeParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF SizeParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + primitive->value.type = schema::PrimitiveType_Size; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + return status; +} +TFNodeRegistrar g_tfSizeParser("Size", new TFSizeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_size_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_size_parser.h new file mode 100644 index 00000000000..e7c8879a670 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_size_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SIZE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SIZE_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFSizeParser : public TFNodeParser { + public: + TFSizeParser() = default; + ~TFSizeParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SIZE_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_slice_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_slice_parser.cc new file mode 100644 index 00000000000..3f16e4f24f8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_slice_parser.cc @@ -0,0 +1,105 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_slice_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFSliceParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF SliceParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + // begin + tensorflow::AttrValue attr_value; + auto begin_node = GetConstInputNode(tf_node_map, tf_op.input(1)); + if (begin_node == nullptr) { + MS_LOG(ERROR) << "Find StridedSlice input begin failed"; + return RET_ERROR; + } + if (!TensorFlowUtils::FindAttrValue(*begin_node, "value", &attr_value)) { + MS_LOG(ERROR) << "The value attr should be specified"; + return RET_ERROR; + } + auto tensor_proto = attr_value.tensor(); + if (tensor_proto.int_val_size() > 0) { + for (int i = 0; i < tensor_proto.int_val_size(); ++i) { + attr->begin.push_back(tensor_proto.int_val(i)); + } + } else { + auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t); + auto data = reinterpret_cast(tensor_proto.tensor_content().data()); + for (size_t i = 0; i < data_num; ++i) { + attr->begin.push_back(data[i]); + } + } + + // axes + std::vector axes; + axes.clear(); + for (size_t i = 0; i < attr->begin.size(); ++i) { + axes.push_back(i); + } + attr->axes = axes; + + primitive->value.type = schema::PrimitiveType_Slice; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + status = AddOpInput(tf_op, 1, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + status = AddOpInput(tf_op, 2, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + return status; +} +TFNodeRegistrar g_tfSliceParser("Slice", new TFSliceParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_slice_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_slice_parser.h new file mode 100644 index 00000000000..88390825b09 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_slice_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SLICE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SLICE_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFSliceParser : public TFNodeParser { + public: + TFSliceParser() = default; + ~TFSliceParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SLICE_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_topk_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_topk_parser.cc new file mode 100644 index 00000000000..5fd007583c1 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_topk_parser.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_topk_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFTopKParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF TopKParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + // sorted + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "sorted", &attr_value)) { + MS_LOG(ERROR) << "The begin_mask attr should be specified"; + return RET_ERROR; + } + attr->sorted = attr_value.i(); + + primitive->value.type = schema::PrimitiveType_TopK; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 2; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + status = AddOpInput(tf_op, 1, inputs); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add Op input failed."; + return status; + } + return status; +} +TFNodeRegistrar g_tfTopKV2Parser("TopKV2", new TFTopKParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_topk_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_topk_parser.h new file mode 100644 index 00000000000..addb43536d5 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_topk_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TOPK_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TOPK_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFTopKParser : public TFNodeParser { + public: + TFTopKParser() = default; + ~TFTopKParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TOPK_PARSER_H_