diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 8f34e21dcff..3c032268fe0 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -226,6 +226,7 @@ union PrimitiveType { InstanceNorm, Identity, LayerNorm, + While, } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 4ad406898e5..d8bb3233b14 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -1103,3 +1103,8 @@ table LayerNorm { elementwiseAffine : bool; } +table While { + condSubgraphIndex : int; + bodySubgraphIndex : int; +} + diff --git a/mindspore/lite/src/ops/populate/while_populate.cc b/mindspore/lite/src/ops/populate/while_populate.cc new file mode 100644 index 00000000000..efcb64d1776 --- /dev/null +++ b/mindspore/lite/src/ops/populate/while_populate.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/while.h" +#include "src/ops/primitive_c.h" +#include "src/ops/populate/populate_register.h" + +namespace mindspore { +namespace lite { + +typedef struct WhileParemeter { + OpParameter op_parameter_; + int body_subgraph_index; + int cond_subgraph_index; +} WhileParemeter; + +OpParameter *PopulateWhileParemeter(const mindspore::lite::PrimitiveC *primitive) { + WhileParemeter *while_paremeter = reinterpret_cast(malloc(sizeof(WhileParemeter))); + if (while_paremeter == nullptr) { + MS_LOG(ERROR) << "malloc WhileParemeter failed."; + return nullptr; + } + memset(while_paremeter, 0, sizeof(WhileParemeter)); + auto param = reinterpret_cast(const_cast(primitive)); + while_paremeter->op_parameter_.type_ = primitive->Type(); + while_paremeter->body_subgraph_index = param->GetBodySubgraphIndex(); + while_paremeter->cond_subgraph_index = param->GetCondSubgraphIndex(); + return reinterpret_cast(while_paremeter); +} +Registry WhileParemeterRegistry(schema::PrimitiveType_While, PopulateWhileParemeter); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index d3b4889aa4f..773531d7fda 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -144,6 +144,7 @@ #include "src/ops/mfcc.h" #include "src/ops/identity.h" #include "src/ops/instance_norm.h" +#include "src/ops/while.h" #ifdef SUPPORT_TRAIN #include "src/ops/neg_grad.h" @@ -499,6 +500,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Split") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "While") { + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "OneHot") { return NewPrimitiveC(prim, inputs, quantType); @@ -793,6 +796,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new Mfcc(primitive); case schema::PrimitiveType_InstanceNorm: return new InstanceNorm(primitive); + case schema::PrimitiveType_While: + return new While(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: diff --git a/mindspore/lite/src/ops/while.cc b/mindspore/lite/src/ops/while.cc new file mode 100644 index 00000000000..ac5e4915a2c --- /dev/null +++ b/mindspore/lite/src/ops/while.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/while.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE + +void While::SetCondSubgraphIndex(const int cond_subgraph_index) { + this->primitive_->value.AsWhile()->condSubgraphIndex = cond_subgraph_index; +} +void While::SetBodySubgraphIndex(const int body_subgraph_index) { + this->primitive_->value.AsWhile()->bodySubgraphIndex = body_subgraph_index; +} + +int While::GetCondSubgraphIndex() const { return this->primitive_->value.AsWhile()->condSubgraphIndex; } +int While::GetBodySubgraphIndex() const { return this->primitive_->value.AsWhile()->bodySubgraphIndex; } + +int While::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_While; + } + if (this->primitive_->value.type != schema::PrimitiveType_While) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::WhileT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + attr->bodySubgraphIndex = GetValue(prim.GetAttr("body_subgraph_index")); + attr->condSubgraphIndex = GetValue(prim.GetAttr("cond_subgraph_index")); + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} + +#else + +int While::GetCondSubgraphIndex() const { return this->primitive_->value_as_While()->condSubgraphIndex(); } +int While::GetBodySubgraphIndex() const { return this->primitive_->value_as_While()->bodySubgraphIndex(); } + +int While::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_While(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_While return nullptr"; + return RET_ERROR; + } + auto cond_subgraph_index = attr->condSubgraphIndex(); + auto body_subgraph_index = attr->bodySubgraphIndex(); + auto val_offset = schema::CreateWhile(*fbb, body_subgraph_index, cond_subgraph_index); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_While, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +PrimitiveC *WhileCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry WhileRegistry(schema::PrimitiveType_While, WhileCreator); + +#endif + +int While::InferShape(std::vector inputs_, std::vector outputs_) { + if (inputs_.size() != outputs_.size()) { + MS_LOG(ERROR) << "The number of inputs and outputs varies"; + return RET_ERROR; + } + for (size_t i = 0; i < inputs_.size(); i++) { + outputs_[i]->set_data_type(inputs_[i]->data_type()); + outputs_[i]->SetFormat(inputs_[i]->GetFormat()); + outputs_[i]->set_shape(inputs_[i]->shape()); + } + + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/while.h b/mindspore/lite/src/ops/while.h new file mode 100644 index 00000000000..1cfaa3378dc --- /dev/null +++ b/mindspore/lite/src/ops/while.h @@ -0,0 +1,51 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_WHILE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_WHILE_H_ + +#include +#include +#include +#include + +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class While : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(While, PrimitiveC); + While() = default; + explicit While(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; + void SetCondSubgraphIndex(const int cond_subgraph_index); + void SetBodySubgraphIndex(const int body_subgraph_index); + +#else + While() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; + int GetCondSubgraphIndex() const; + int GetBodySubgraphIndex() const; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_WHERE_H_ diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc index 456e6093461..9224d823f82 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc @@ -24,7 +24,8 @@ namespace mindspore { namespace lite { STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -75,10 +76,8 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.type = schema::PrimitiveType_Activation; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h index cc849ed3304..908f44491d2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h @@ -30,7 +30,8 @@ class TfliteActivationParser : public TfliteNodeParser { TfliteActivationParser() : TfliteNodeParser("node_name") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; class TfliteReluParser : public TfliteActivationParser { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc index 18652601947..e766fb4a529 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace lite { STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteAddNParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -41,16 +42,14 @@ STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu return RET_NULL_PTR; } - attr->N = tflite_model->subgraphs[0]->tensors.size() - 1; + attr->N = tflite_subgraph->tensors.size() - 1; op->primitive->value.type = schema::PrimitiveType_AddN; op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h index edaf56e3050..4417fe98626 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_addn_parser.h @@ -30,7 +30,8 @@ class TfliteAddNParser : public TfliteNodeParser { TfliteAddNParser() : TfliteNodeParser("AddN") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc index e0638addb98..be78876a396 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -47,7 +48,7 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni // get axis attr auto axis_idx = tflite_op->inputs[1]; - auto buffer_idx = tflite_model->subgraphs[0]->tensors[axis_idx]->buffer; + auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer; auto &buf_data = tflite_model->buffers[buffer_idx]; if (buf_data == nullptr) { MS_LOG(ERROR) << "the buf data is null"; @@ -63,10 +64,8 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni op->primitive->value.type = schema::PrimitiveType_ArgMax; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h index 013cc5ad2bd..60038edf3e5 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmax_parser.h @@ -30,7 +30,8 @@ class TfliteArgmaxParser : public TfliteNodeParser { TfliteArgmaxParser() : TfliteNodeParser("Argmax") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc index d7017517eb0..4d87a4e7fb4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteArgminParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -47,7 +48,7 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni // get axis attr auto axis_idx = tflite_op->inputs[1]; - auto buffer_idx = tflite_model->subgraphs[0]->tensors[axis_idx]->buffer; + auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer; auto &buf_data = tflite_model->buffers[buffer_idx]; if (buf_data == nullptr) { MS_LOG(ERROR) << "the buf data is null"; @@ -63,10 +64,8 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni op->primitive->value.type = schema::PrimitiveType_ArgMin; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h index ad4ed1c3c89..0422c0232c7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_argmin_parser.h @@ -30,7 +30,8 @@ class TfliteArgminParser : public TfliteNodeParser { TfliteArgminParser() : TfliteNodeParser("Argmin") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc index 62cf917175d..af5637e28a6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc @@ -24,7 +24,8 @@ namespace mindspore { namespace lite { STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -168,17 +169,16 @@ STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, // set input for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -305,16 +305,15 @@ STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.value = attr.release(); } - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -385,11 +384,9 @@ STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, } for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h index c52b0b98f1e..c9d20cddf21 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h @@ -30,7 +30,8 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser { TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; class TfliteAddParser : public TfliteDoubleInputOpParser { @@ -93,7 +94,8 @@ class TfliteSingleInputOpParser : public TfliteNodeParser { TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; class TfliteAbsParser : public TfliteSingleInputOpParser { @@ -161,7 +163,8 @@ class TfliteCompareOpParser : public TfliteNodeParser { TfliteCompareOpParser() : TfliteNodeParser("node_name") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; class TfliteEqualParser : public TfliteCompareOpParser { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc index a1c6797d3d0..4a8cf9e2f0f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.cc @@ -25,7 +25,8 @@ namespace mindspore { namespace lite { STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -51,12 +52,11 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, return RET_NULL_PTR; } - if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, - attr->blockShape)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->blockShape)) { MS_LOG(ERROR) << "get batchToSpace -> blockShape failed"; return RET_ERROR; } - if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->crops)) { + if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->crops)) { MS_LOG(ERROR) << "get batchToSpace -> crops failed"; return RET_ERROR; } @@ -64,10 +64,8 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.type = schema::PrimitiveType_BatchToSpace; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h index 50fd8edb292..a5c8c86201e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_batch_to_space_parser.h @@ -30,7 +30,8 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser { TfliteBatchToSpaceParser() : TfliteNodeParser("BatchToSpace") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc index 7c2b9c8fc9b..51977aae0c0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.cc @@ -24,7 +24,8 @@ namespace mindspore { namespace lite { STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteBroadcastToParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -42,8 +43,7 @@ STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info, return RET_NULL_PTR; } - if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, - attr->dst_shape)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dst_shape)) { MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed"; return RET_ERROR; } @@ -51,10 +51,8 @@ STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.type = schema::PrimitiveType_BroadcastTo; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h index 364709ed358..fe72058c9f0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_broadcast_to_parser.h @@ -30,7 +30,8 @@ class TfliteBroadcastToParser : public TfliteNodeParser { TfliteBroadcastToParser() : TfliteNodeParser("BroadcastTo") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc index da045a9bb76..6c7dabbd360 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteCastParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -40,13 +41,13 @@ STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu return RET_NULL_PTR; } - const auto &in_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; + const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; if (in_tensor == nullptr) { MS_LOG(ERROR) << "tensor is null"; return RET_NULL_PTR; } attr->srcT = GetTfliteDataType(in_tensor->type); - const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]]; + const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]]; if (out_tensor == nullptr) { MS_LOG(ERROR) << "tensor is null"; return RET_NULL_PTR; @@ -56,10 +57,8 @@ STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu op->primitive->value.type = schema::PrimitiveType_Cast; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h index 17cf60ef055..8f9dd069060 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_cast_parser.h @@ -30,7 +30,8 @@ class TfliteCastParser : public TfliteNodeParser { TfliteCastParser() : TfliteNodeParser("Cast") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc index a2741c83656..1ae7041dbd4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteConcatParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -52,11 +53,9 @@ STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h index 4074d41f3a1..647d0a39182 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_concat_parser.h @@ -30,7 +30,8 @@ class TfliteConcatParser : public TfliteNodeParser { TfliteConcatParser() : TfliteNodeParser("Concat") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc index 956b1ee597d..1006d7d48c0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteConvParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -57,7 +58,7 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu // get the conv op weight tensor auto weight_index = tflite_op->inputs[1]; - const auto &weight_tensor = tflite_model->subgraphs[0]->tensors[weight_index]; + const auto &weight_tensor = tflite_subgraph->tensors[weight_index]; if (weight_tensor == nullptr) { MS_LOG(ERROR) << "the weight tensor is null"; return RET_NULL_PTR; @@ -70,7 +71,7 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu // calculate pad params auto data_index = tflite_op->inputs[0]; - const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; + const auto &data_tensor = tflite_subgraph->tensors[data_index]; std::vector params; int status = getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); @@ -87,14 +88,10 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu op->primitive->value.type = schema::PrimitiveType_Conv2D; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_KHWC); - AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h index 4e226c0b982..308250edebc 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.h @@ -30,7 +30,8 @@ class TfliteConvParser : public TfliteNodeParser { TfliteConvParser() : TfliteNodeParser("Conv2D") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc index fcf73d03221..7cbea625653 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc @@ -139,14 +139,15 @@ STATUS TfliteCustomParser::ExtractFeatures(const std::vector &custom_at STATUS TfliteCustomParser::Rfft(const std::vector &custom_attr, schema::CNodeT *op, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph) { std::unique_ptr attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; } std::vector fft_length; - if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, fft_length)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, fft_length)) { MS_LOG(ERROR) << "rfft -> fftLength get failed"; return RET_ERROR; } @@ -181,7 +182,8 @@ STATUS TfliteCustomParser::FftImag(const std::vector &custom_attr, sche } STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteCustomParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -209,7 +211,7 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni } else if (custom_type == "Mfcc") { status = Mfcc(custom_attr, op, tflite_op); } else if (custom_type == "FlexRFFT") { - status = Rfft(custom_attr, op, tflite_op, tflite_model); + status = Rfft(custom_attr, op, tflite_op, tflite_model, tflite_subgraph); } else if (custom_type == "FlexReal") { status = FftReal(custom_attr, op, tflite_op); } else if (custom_type == "FlexImag") { @@ -222,12 +224,10 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni return status; } for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { - AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } for (size_t i = 0; i < tflite_op->outputs.size(); ++i) { - AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } return status; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h index 17ad6a515cb..ac336cead77 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h @@ -30,7 +30,8 @@ class TfliteCustomParser : public TfliteNodeParser { TfliteCustomParser() : TfliteNodeParser("Custom") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; STATUS DetectPostProcess(const std::vector &custom_attr, schema::CNodeT *op, const std::unique_ptr &tflite_op); @@ -51,7 +52,8 @@ class TfliteCustomParser : public TfliteNodeParser { const std::unique_ptr &tflite_op); STATUS Rfft(const std::vector &custom_attr, schema::CNodeT *op, - const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model); + const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph); STATUS FftReal(const std::vector &custom_attr, schema::CNodeT *op, const std::unique_ptr &tflite_op); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc index 670a8534d63..f97859bccb3 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -58,7 +59,7 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni // get the conv op weight tensor auto weight_index = tflite_op->inputs[1]; - const auto &weight_tensor = tflite_model->subgraphs[0]->tensors[weight_index]; + const auto &weight_tensor = tflite_subgraph->tensors[weight_index]; if (weight_tensor == nullptr) { MS_LOG(ERROR) << "the weight tensor is null"; return RET_NULL_PTR; @@ -71,7 +72,7 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni // calculate pad params auto data_index = tflite_op->inputs[2]; - const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; + const auto &data_tensor = tflite_subgraph->tensors[data_index]; std::vector params; int status = getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); @@ -88,12 +89,9 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni op->primitive->value.type = schema::PrimitiveType_DeConv2D; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_KHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h index 58c1a47b5d3..a52e89b7aa8 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.h @@ -30,7 +30,8 @@ class TfliteDeConvParser : public TfliteNodeParser { TfliteDeConvParser() : TfliteNodeParser("DeConv2D") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc index 611c40e288e..acb680f614b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.cc @@ -24,7 +24,8 @@ namespace mindspore { namespace lite { STATUS TfliteDepthToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser"; if (op == nullptr) { @@ -54,10 +55,8 @@ STATUS TfliteDepthToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.type = schema::PrimitiveType_DepthToSpace; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h index 880502cb8be..ae303db6571 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depth_to_space_parser.h @@ -30,7 +30,8 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser { TfliteDepthToSpaceParser() : TfliteNodeParser("DepthToSpace") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc index 043f9fff85a..947c4854349 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc @@ -23,7 +23,9 @@ namespace mindspore { namespace lite { STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, + schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -58,7 +60,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, // get the data tensor auto data_index = tflite_op->inputs[1]; - const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; + const auto &data_tensor = tflite_subgraph->tensors[data_index]; if (data_tensor == nullptr) { MS_LOG(ERROR) << "the data tensor is null"; return RET_NULL_PTR; @@ -68,7 +70,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, // get the weight tensor auto weight_index = tflite_op->inputs[1]; - const auto &weight_tensor = tflite_model->subgraphs[0]->tensors[weight_index]; + const auto &weight_tensor = tflite_subgraph->tensors[weight_index]; if (weight_tensor == nullptr) { MS_LOG(ERROR) << "the weight tensor is null"; return RET_NULL_PTR; @@ -94,14 +96,10 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_KHWC); - AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h index 73b0b25ea4a..22885dc466a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.h @@ -30,7 +30,8 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser { TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc index 472ed3ae117..b93bb749aee 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteDequantizeNParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -34,12 +35,12 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, return RET_NULL_PTR; } - const auto &in_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; + const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; if (in_tensor == nullptr) { MS_LOG(ERROR) << "input tensor is null"; return RET_NULL_PTR; } - const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]]; + const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]]; if (out_tensor == nullptr) { MS_LOG(ERROR) << "output tensor is null"; return RET_NULL_PTR; @@ -68,10 +69,8 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.type = schema::PrimitiveType_Cast; } - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h index 2897b1857b1..61b3d0f25c4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.h @@ -29,7 +29,8 @@ class TfliteDequantizeParser : public TfliteNodeParser { TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc index 890b700f8dd..01a626a1af6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace lite { STATUS TfliteExpandDimsParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteExpandDimsParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -41,17 +42,15 @@ STATUS TfliteExpandDimsParser::Parse(TfliteTensorsInfo *tensors_info, return RET_NULL_PTR; } std::vector dims; - if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, dims)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, dims)) { MS_LOG(ERROR) << "get expand_dims -> dim failed"; return RET_ERROR; } attr->dim = dims[0]; op->primitive->value.type = schema::PrimitiveType_ExpandDims; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h index 4832f1117bf..f4f2e6c551d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_expand_dims_parser.h @@ -30,7 +30,8 @@ class TfliteExpandDimsParser : public TfliteNodeParser { TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc index 3805f7f109f..9a426f00f16 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteFillParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -41,7 +42,7 @@ STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu } if (tflite_op->inputs.size() > 1) { - if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->dims)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dims)) { MS_LOG(ERROR) << "get fill -> dims failed"; return RET_ERROR; } @@ -50,10 +51,8 @@ STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu op->primitive->value.type = schema::PrimitiveType_Fill; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h index 8af709f3a46..9703db3959a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fill_parser.h @@ -30,7 +30,8 @@ class TfliteFillParser : public TfliteNodeParser { TfliteFillParser() : TfliteNodeParser("Fill") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc index 8674e525678..2874c4de7c7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.cc @@ -23,7 +23,9 @@ namespace mindspore { namespace lite { STATUS TfliteFullyConnectedParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, + schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -57,16 +59,12 @@ STATUS TfliteFullyConnectedParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.type = schema::PrimitiveType_FullConnection; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_KHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC); if (hasBias) { - AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h index 3da6407a60a..e9ee93336c0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_fullyconnected_parser.h @@ -30,7 +30,8 @@ class TfliteFullyConnectedParser : public TfliteNodeParser { TfliteFullyConnectedParser() : TfliteNodeParser("FullyConnected") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; class TfliteFakeQuantParser : public TfliteFullyConnectedParser { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc index b1073f8f27c..80389535836 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteGatherNdParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteGatherNdParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -46,11 +47,9 @@ STATUS TfliteGatherNdParser::Parse(TfliteTensorsInfo *tensors_info, const std::u op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h index 9f93547a0b7..6c3bb2a77e1 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_nd_parser.h @@ -30,7 +30,8 @@ class TfliteGatherNdParser : public TfliteNodeParser { TfliteGatherNdParser() : TfliteNodeParser("GatherND") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc index 23f2f07611f..eb3247d641d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteGatherParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteGatherParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -52,11 +53,9 @@ STATUS TfliteGatherParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h index 6ead6b01d2a..30e06be4477 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_gather_parser.h @@ -30,7 +30,8 @@ class TfliteGatherParser : public TfliteNodeParser { TfliteGatherParser() : TfliteNodeParser("Gather") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc index c4dc02eda9c..05957d9e69f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.cc @@ -23,7 +23,9 @@ namespace mindspore { namespace lite { STATUS TfliteHashtableLookupParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, + schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteHashtableLookupParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -44,12 +46,10 @@ STATUS TfliteHashtableLookupParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.type = schema::PrimitiveType_HashtableLookup; op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { - AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } for (size_t i = 0; i < tflite_op->outputs.size(); ++i) { - AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.h index 2ee55c841c1..d23157d7ad7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_hashtable_lookup_parser.h @@ -30,7 +30,8 @@ class TfliteHashtableLookupParser : public TfliteNodeParser { TfliteHashtableLookupParser() : TfliteNodeParser("HashtableLookup") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc index 13b01226891..a177e21dde7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace lite { STATUS TfliteL2NormParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteL2NormParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -49,10 +50,8 @@ STATUS TfliteL2NormParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni op->primitive->value.value = attr.release(); // set input and output - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h index 3ddb116967f..4a929d163b0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.h @@ -30,7 +30,8 @@ class TfliteL2NormParser : public TfliteNodeParser { TfliteL2NormParser() : TfliteNodeParser("L2_NORMALIZATION") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc index da14bb93389..5157a16857d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace lite { STATUS TfliteLogicalParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -67,11 +68,9 @@ STATUS TfliteLogicalParser::Parse(TfliteTensorsInfo *tensors_info, const std::un } for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h index b6a21aeeb49..45b45bebe1e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_logical_parser.h @@ -30,7 +30,8 @@ class TfliteLogicalParser : public TfliteNodeParser { TfliteLogicalParser() : TfliteNodeParser("node_name") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; class TfliteLogicalAndParser : public TfliteLogicalParser { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc index e33b2d5dd4f..09bde50e35e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteLRNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteLRNParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -53,10 +54,8 @@ STATUS TfliteLRNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h index 492d677b63a..575aaa1fca7 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lrn_parser.h @@ -30,7 +30,8 @@ class TfliteLRNParser : public TfliteNodeParser { TfliteLRNParser() : TfliteNodeParser("LocalResponseNorm") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc index 489a2f7f0e4..dd7e3fa0385 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace lite { STATUS TfliteLshProjectionParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteLshProjectionParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -56,11 +57,9 @@ STATUS TfliteLshProjectionParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { - AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.h index 448ceb7fff1..c452e94b8d4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_lsh_projection_parser.h @@ -30,7 +30,8 @@ class TfliteLshProjectionParser : public TfliteNodeParser { TfliteLshProjectionParser() : TfliteNodeParser("LshProjection") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 4c347e04be1..898675707a8 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -116,7 +116,7 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflit continue; } if (status == RET_OK) { - status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, op.get()); + status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, tflite_subgraph, op.get()); if (status != RET_OK) { if (status == RET_NOT_FIND_OP) { op_type = @@ -337,18 +337,10 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph) return RET_OK; } -schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) { - // load graph - auto tflite_model = ReadTfliteModel(model_file.c_str()); - if (tflite_model == nullptr) { - MS_LOG(ERROR) << "read tflite model failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); - return nullptr; - } - - if (tflite_model->subgraphs.size() != 1) { - MS_LOG(ERROR) << "read tflite model subgraphs failed"; +std::unique_ptr TfliteModelParser::ConstructMainGraph( + const std::unique_ptr &tflite_model, const QuantType &quant_type) { + if (tflite_model->subgraphs.size() < 1) { + MS_LOG(ERROR) << "read tflite model main subgraphs failed"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); return nullptr; } @@ -394,7 +386,28 @@ schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, return nullptr; } - return meta_graph.release(); + return meta_graph; +} + +schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) { + // load graph + auto tflite_model = ReadTfliteModel(model_file.c_str()); + if (tflite_model == nullptr) { + MS_LOG(ERROR) << "read tflite model failed"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); + return nullptr; + } + + // construct main_meta_graph + auto main_meta_graph = ConstructMainGraph(tflite_model, quant_type); + if (main_meta_graph == nullptr) { + MS_LOG(ERROR) << "ConstructMainGraph failed"; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); + return nullptr; + } + + return main_meta_graph.release(); } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index 77a9fc89b41..389ad48365c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -64,6 +64,9 @@ class TfliteModelParser : public ModelParser { STATUS ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph); + std::unique_ptr ConstructMainGraph(const std::unique_ptr &tflite_model, + const QuantType &quant_type); + private: TfliteTensorsInfo tensorsInfo; std::vector tensors; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h index a426d70a788..d7c99297f48 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h @@ -39,7 +39,8 @@ class TfliteNodeParser { virtual ~TfliteNodeParser() = default; virtual STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) = 0; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) = 0; void AddOpInput(schema::CNodeT *op, TfliteTensorsInfo *tensors_info, int idx, int total, schema::Format format) { int new_idx = tensors_info->tensorsId.size(); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc index 35984b1375b..d685b167056 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteOneHotParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -46,7 +47,7 @@ STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni return RET_NULL_PTR; } auto axis = tflite_attr->axis; - const auto &tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; + const auto &tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; if (tensor == nullptr) { MS_LOG(ERROR) << "tensor is null"; return RET_NULL_PTR; @@ -57,11 +58,9 @@ STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h index ea3ebe9fb46..518ad70878d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_one_hot_parser.h @@ -30,7 +30,8 @@ class TfliteOneHotParser : public TfliteNodeParser { TfliteOneHotParser() : TfliteNodeParser("OneHot") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc index b8dc009ceb7..38315ff5711 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace lite { STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TflitePadParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -51,8 +52,7 @@ STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique } attr->paddingMode = schema::PaddingMode_CONSTANT; attr->constantValue = 0.0f; - if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, - attr->paddings)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->paddings)) { MS_LOG(ERROR) << "get pad -> paddings failed"; return RET_ERROR; } @@ -81,14 +81,11 @@ STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique op->primitive->value.type = schema::PrimitiveType_Pad; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); if (std::strcmp(node_name, "MirrorPad") == 0) { - AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h index d040aebe6bc..da2b6d26f7d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.h @@ -30,7 +30,8 @@ class TflitePadParser : public TfliteNodeParser { TflitePadParser() : TfliteNodeParser("Pad") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc index 57832f8ced4..a57fb77d91b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace lite { STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -69,7 +70,7 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un // calculate pad params auto data_index = tflite_op->inputs[0]; - const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; + const auto &data_tensor = tflite_subgraph->tensors[data_index]; std::vector params; int status = getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms); @@ -86,10 +87,8 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un op->primitive->value.type = schema::PrimitiveType_Pooling; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h index b6b8b25e511..c066b7ec112 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.h @@ -30,7 +30,8 @@ class TflitePoolingParser : public TfliteNodeParser { TflitePoolingParser() : TfliteNodeParser("node_name") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; class TfliteMeanPoolingParser : public TflitePoolingParser { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.cc index 0c96375c143..b71a5b08981 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace lite { STATUS TflitePReLUParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TflitePReLUParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -44,12 +45,9 @@ STATUS TflitePReLUParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq op->primitive->value.type = schema::PrimitiveType_PReLU; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.h index 35bd7936e64..ef4a9dc5c85 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_prelu_parser.h @@ -30,7 +30,8 @@ class TflitePReLUParser : public TfliteNodeParser { TflitePReLUParser() : TfliteNodeParser("PRELU") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc index 9f76c795c50..901aa6afa14 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc @@ -21,7 +21,8 @@ namespace mindspore { namespace lite { STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteQuantizeNParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -33,12 +34,12 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u return RET_NULL_PTR; } - const auto &in_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; + const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; if (in_tensor == nullptr) { MS_LOG(ERROR) << "input tensor is null"; return RET_NULL_PTR; } - const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]]; + const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]]; if (out_tensor == nullptr) { MS_LOG(ERROR) << "output tensor is null"; return RET_NULL_PTR; @@ -67,10 +68,8 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u op->primitive->value.value = attr.release(); } - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.h index f0d29a76536..1b1cd3cc3c6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.h @@ -29,7 +29,8 @@ class TfliteQuantizeParser : public TfliteNodeParser { TfliteQuantizeParser() : TfliteNodeParser("Quantize") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc index 72ccac5f1e5..90084ba4e92 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteRangeParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -43,12 +44,12 @@ STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq attr->dType = 0; std::vector limit; std::vector delta; - int status = GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, limit); + int status = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, limit); if (status != RET_OK && status != RET_NO_CHANGE) { MS_LOG(ERROR) << "range -> limit get failed"; return RET_ERROR; } else if (status == RET_OK) { - status = GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, delta); + status = GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, delta); if (status != RET_OK && status != RET_NO_CHANGE) { MS_LOG(ERROR) << "stridedSlice -> end get failed"; return RET_ERROR; @@ -63,11 +64,9 @@ STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq int input_num = status == RET_OK ? 1 : 3; for (int i = 0; i < input_num; ++i) { - AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h index 6b44f7d0f8f..eab88494499 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.h @@ -30,7 +30,8 @@ class TfliteRangeParser : public TfliteNodeParser { TfliteRangeParser() : TfliteNodeParser("Range") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc index 3bdf2f156e9..9b83c05c312 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteRankParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteRankParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -43,10 +44,8 @@ STATUS TfliteRankParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu op->primitive->value.type = schema::PrimitiveType_Rank; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h index c732dcf9e77..d105387714e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_rank_parser.h @@ -30,7 +30,8 @@ class TfliteRankParser : public TfliteNodeParser { TfliteRankParser() : TfliteNodeParser("Rank") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc index 70a214c5b9f..bc0f90b75c6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace lite { STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -72,7 +73,7 @@ STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni return RET_NOT_SUPPORT; } - if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->axes)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->axes)) { MS_LOG(ERROR) << "get reduce -> axes failed"; return RET_ERROR; } @@ -80,10 +81,8 @@ STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni op->primitive->value.type = schema::PrimitiveType_Reduce; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h index 0179c0f290f..d108f954743 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.h @@ -30,7 +30,8 @@ class TfliteReduceParser : public TfliteNodeParser { TfliteReduceParser() : TfliteNodeParser("node_name") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; class TfliteReduceMaxParser : public TfliteReduceParser { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc index 98b1813c9ee..7f467b75ef1 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteReshapeParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -47,7 +48,7 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un return RET_ERROR; } auto shape_tensor_index = tflite_op->inputs[1]; - const auto &shape_tensor = tflite_model->subgraphs[0]->tensors[shape_tensor_index]; + const auto &shape_tensor = tflite_subgraph->tensors[shape_tensor_index]; if (shape_tensor == nullptr) { MS_LOG(ERROR) << "shape_tensor is null"; return RET_NULL_PTR; @@ -58,8 +59,7 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un return RET_NULL_PTR; } if (!buf_data->data.empty()) { - if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, - attr->shape)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->shape)) { MS_LOG(ERROR) << "get reshape -> shape failed"; return RET_ERROR; } @@ -76,11 +76,9 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h index 582ba911ceb..3aa7380b8f3 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.h @@ -30,7 +30,8 @@ class TfliteReshapeParser : public TfliteNodeParser { TfliteReshapeParser() : TfliteNodeParser("Reshape") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc index 9c2080284f6..72085bc0213 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace lite { STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -87,7 +88,7 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni attr->preserveAspectRatio = false; auto tfliteResizeTensorIndex = tflite_op->inputs[1]; - const auto &shape_tensor = tflite_model->subgraphs[0]->tensors[tfliteResizeTensorIndex]; + const auto &shape_tensor = tflite_subgraph->tensors[tfliteResizeTensorIndex]; if (shape_tensor == nullptr) { MS_LOG(ERROR) << "shape_tensor is null"; return RET_NULL_PTR; @@ -109,14 +110,11 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni op->primitive->value.type = schema::PrimitiveType_Resize; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); if (buffData == nullptr) { - AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h index 3fae174d6cd..8151984eb26 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.h @@ -30,7 +30,8 @@ class TfliteResizeParser : public TfliteNodeParser { TfliteResizeParser() : TfliteNodeParser("node_name") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; class TfliteResizeBilinearParser : public TfliteResizeParser { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc index 4301b81b7ee..2eac035f99e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteReverseParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -40,7 +41,7 @@ STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::un return RET_NULL_PTR; } - if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->axis)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->axis)) { MS_LOG(ERROR) << "get reverse -> axis failed"; return RET_ERROR; } @@ -48,10 +49,8 @@ STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::un op->primitive->value.type = schema::PrimitiveType_Reverse; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h index 34d59ae5017..278e3f7eceb 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_parser.h @@ -30,7 +30,8 @@ class TfliteReverseParser : public TfliteNodeParser { TfliteReverseParser() : TfliteNodeParser("reverse") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc index cc98fbd5401..23a958b2697 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.cc @@ -24,7 +24,9 @@ namespace mindspore { namespace lite { STATUS TfliteReverseSequenceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, + schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteReverseSequenceParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -53,12 +55,9 @@ STATUS TfliteReverseSequenceParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.type = schema::PrimitiveType_ReverseSequence; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h index afd06219404..d183297ced4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reverse_sequence_parser.h @@ -30,7 +30,8 @@ class TfliteReverseSequenceParser : public TfliteNodeParser { TfliteReverseSequenceParser() : TfliteNodeParser("ReverseSequence") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc index df395a4f7a2..a91b32325d0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.cc @@ -24,7 +24,8 @@ namespace mindspore { namespace lite { STATUS TfliteScatterNdParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteScatterNdParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -52,14 +53,10 @@ STATUS TfliteScatterNdParser::Parse(TfliteTensorsInfo *tensors_info, // in tflite, kIndices = 0, kUpdates = 1, kShape = 2 // in mslite, kScatterShapeIndex = 0, kScatterIndicesIndex = 1, kScatterUpdateIndex = 2; - AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h index cab92dd1f4b..064a3d70c71 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_scatter_nd_parser.h @@ -30,7 +30,8 @@ class TfliteScatterNdParser : public TfliteNodeParser { TfliteScatterNdParser() : TfliteNodeParser("ScatterNd") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc index a0f8c7e00c4..58e1f5b9f7e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteShapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteShapeParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -43,10 +44,8 @@ STATUS TfliteShapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq op->primitive->value.type = schema::PrimitiveType_Shape; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h index 5020b44b5ec..42013f29a84 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_shape_parser.h @@ -30,7 +30,8 @@ class TfliteShapeParser : public TfliteNodeParser { TfliteShapeParser() : TfliteNodeParser("Shape") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc index 21abe5b5be7..22e875a003b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteSkipGramParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSkipGramParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -52,10 +53,8 @@ STATUS TfliteSkipGramParser::Parse(TfliteTensorsInfo *tensors_info, const std::u op->primitive->value.type = schema::PrimitiveType_SkipGram; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.h index 29ece288206..56d80d05b89 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_skip_gram_parser.h @@ -30,7 +30,8 @@ class TfliteSkipGramParser : public TfliteNodeParser { TfliteSkipGramParser() : TfliteNodeParser("SkipGram") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc index 1e48e199991..3dffbd66423 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSliceParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -42,11 +43,11 @@ STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq attr->format = schema::Format::Format_NHWC; - if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->begin)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->begin)) { MS_LOG(ERROR) << "get slice -> begin failed"; return RET_ERROR; } - if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->size)) { + if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->size)) { MS_LOG(ERROR) << "get slice -> size failed"; return RET_ERROR; } @@ -59,10 +60,8 @@ STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq op->primitive->value.type = schema::PrimitiveType_Slice; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h index d363c453c81..48801162b69 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_slice_parser.h @@ -30,7 +30,8 @@ class TfliteSliceParser : public TfliteNodeParser { TfliteSliceParser() : TfliteNodeParser("Slice") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc index 0ef5b4d62b6..39718d1a25e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteSoftmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSoftmaxParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -45,10 +46,8 @@ STATUS TfliteSoftmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::un op->primitive->value.type = schema::PrimitiveType_SoftMax; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h index 30585bc0cdd..f63b77ccd38 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_softmax_parser.h @@ -30,7 +30,8 @@ class TfliteSoftmaxParser : public TfliteNodeParser { TfliteSoftmaxParser() : TfliteNodeParser("Softmax") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc index 58e4816dc36..0b6043ce279 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.cc @@ -24,7 +24,9 @@ namespace mindspore { namespace lite { STATUS TfliteSpaceToBatchNDParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, + schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSpaceToBatchNDParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -42,12 +44,11 @@ STATUS TfliteSpaceToBatchNDParser::Parse(TfliteTensorsInfo *tensors_info, return RET_NULL_PTR; } - if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, - attr->blockShape)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->blockShape)) { MS_LOG(ERROR) << "get spaceToBatchND -> blockShape failed"; return RET_ERROR; } - if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->paddings)) { + if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->paddings)) { MS_LOG(ERROR) << "get spaceToBatchND -> paddings failed"; return RET_ERROR; } @@ -55,10 +56,8 @@ STATUS TfliteSpaceToBatchNDParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.type = schema::PrimitiveType_SpaceToBatchND; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h index e63956ccfb7..cfd2794c07b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_batch_nd_parser.h @@ -30,7 +30,8 @@ class TfliteSpaceToBatchNDParser : public TfliteNodeParser { TfliteSpaceToBatchNDParser() : TfliteNodeParser("SpaceToBatchND") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc index 64750d62bb4..b74378dfb0b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.cc @@ -24,7 +24,8 @@ namespace mindspore { namespace lite { STATUS TfliteSpaceToDepthParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSpaceToDepthParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -53,10 +54,8 @@ STATUS TfliteSpaceToDepthParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.type = schema::PrimitiveType_SpaceToDepth; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h index 4e6e9fd540b..b3b63708174 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_space_to_depth_parser.h @@ -30,7 +30,8 @@ class TfliteSpaceToDepthParser : public TfliteNodeParser { TfliteSpaceToDepthParser() : TfliteNodeParser("SpaceToDepth") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc index ed3f931de23..87fbf30b63e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.cc @@ -24,7 +24,8 @@ namespace mindspore { namespace lite { STATUS TfliteSparseToDenseParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSparseToDenseParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -46,16 +47,11 @@ STATUS TfliteSparseToDenseParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.type = schema::PrimitiveType_SparseToDense; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpInput(op, tensors_info, tflite_op->inputs[3], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[3], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h index 32361d19c9d..6f9a56f1ba6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_sparse_to_dense_parser.h @@ -30,7 +30,8 @@ class TfliteSparseToDenseParser : public TfliteNodeParser { TfliteSparseToDenseParser() : TfliteNodeParser("SparseToDense") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc index 351a2cf75b5..499fdd06830 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSplitParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -47,13 +48,13 @@ STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq } auto num_splits = tflite_attr->num_splits; - const auto &shape_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[1]]; + const auto &shape_tensor = tflite_subgraph->tensors[tflite_op->inputs[1]]; if (shape_tensor == nullptr) { MS_LOG(ERROR) << "shape_tensor is null"; return RET_NULL_PTR; } const auto tensor_shape = shape_tensor->shape; - const auto &axis_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; + const auto &axis_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; if (axis_tensor == nullptr) { MS_LOG(ERROR) << "axis_tensor is null"; return RET_NULL_PTR; @@ -81,11 +82,9 @@ STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq op->primitive->value.type = schema::PrimitiveType_Split; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); for (size_t i = 0; i < tflite_op->outputs.size(); i++) { - AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h index d2c85bbdfd9..a430eed7dd0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.h @@ -30,7 +30,8 @@ class TfliteSplitParser : public TfliteNodeParser { TfliteSplitParser() : TfliteNodeParser("Split") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc index dea7b0e9df8..9cbb69e7656 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteSplitVParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSplitVParser"; if (op == nullptr) { @@ -48,19 +49,18 @@ STATUS TfliteSplitVParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni } attr->numberSplit = tflite_attr->num_splits; - if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, - attr->sizeSplits)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->sizeSplits)) { MS_LOG(ERROR) << "get spliteV -> sizeSplits failed"; return RET_ERROR; } - const auto &tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; + const auto &tensor = tflite_subgraph->tensors[tflite_op->inputs[0]]; if (tensor == nullptr) { MS_LOG(ERROR) << "tensor_shape is null"; return RET_NULL_PTR; } auto tensor_shape = tensor->shape; - const auto &axis_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[2]]; + const auto &axis_tensor = tflite_subgraph->tensors[tflite_op->inputs[2]]; if (axis_tensor == nullptr) { MS_LOG(ERROR) << "axis_tensor is null"; return RET_NULL_PTR; @@ -78,11 +78,9 @@ STATUS TfliteSplitVParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni op->primitive->value.type = schema::PrimitiveType_Split; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); for (size_t i = 0; i < tflite_op->outputs.size(); i++) { - AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h index 85427ceab61..002a8f7e9bb 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_v_parser.h @@ -30,7 +30,8 @@ class TfliteSplitVParser : public TfliteNodeParser { TfliteSplitVParser() : TfliteNodeParser("SplitV") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc index 882f0fd5e46..f127f95d77b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteSqueezeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteSqueezeParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -50,10 +51,8 @@ STATUS TfliteSqueezeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un op->primitive->value.type = schema::PrimitiveType_Squeeze; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h index b486cf1e00e..538a4f42513 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_squeeze_parser.h @@ -30,7 +30,8 @@ class TfliteSqueezeParser : public TfliteNodeParser { TfliteSqueezeParser() : TfliteNodeParser("Squeeze") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc index ad3bd244e8a..f7c5e4f48e0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteStackParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteStackParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -47,18 +48,16 @@ STATUS TfliteStackParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq } attr->axis = tflite_attr->axis; attr->n = tflite_attr->values_count; - attr->isScale.assign(tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.begin(), - tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.end()); + attr->isScale.assign(tflite_subgraph->tensors[tflite_op->inputs[0]]->shape.begin(), + tflite_subgraph->tensors[tflite_op->inputs[0]]->shape.end()); op->primitive->value.type = schema::PrimitiveType_Stack; op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h index b30103b0e76..e67d5f47b5d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_stack_parser.h @@ -30,7 +30,8 @@ class TfliteStackParser : public TfliteNodeParser { TfliteStackParser() : TfliteNodeParser("Stack") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc index 6242d05e2b3..5437bd9c652 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace lite { STATUS TfliteStridedSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteStridedSliceParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -52,38 +53,34 @@ STATUS TfliteStridedSliceParser::Parse(TfliteTensorsInfo *tensors_info, attr->newAxisMask = tflite_attr->new_axis_mask; attr->shrinkAxisMask = tflite_attr->shrink_axis_mask; - int status = - GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->begin); + int status = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->begin); if (status != RET_OK && status != RET_NO_CHANGE) { MS_LOG(ERROR) << "stridedSlice -> begin get failed"; return RET_ERROR; } else if (status == RET_OK) { - status = GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->end); + status = GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->end); if (status != RET_OK && status != RET_NO_CHANGE) { MS_LOG(ERROR) << "stridedSlice -> end get failed"; return RET_ERROR; } else if (status == RET_OK) { - status = - GetTfliteData(tflite_op->inputs[3], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->stride); + status = GetTfliteData(tflite_op->inputs[3], tflite_subgraph->tensors, tflite_model->buffers, attr->stride); if (status != RET_OK && status != RET_NO_CHANGE) { MS_LOG(ERROR) << "stridedSlice -> stride get failed"; return RET_ERROR; } } } - attr->isScale.assign(tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.begin(), - tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.end()); + attr->isScale.assign(tflite_subgraph->tensors[tflite_op->inputs[0]]->shape.begin(), + tflite_subgraph->tensors[tflite_op->inputs[0]]->shape.end()); op->primitive->value.type = schema::PrimitiveType_StridedSlice; op->primitive->value.value = attr.release(); int input_num = status == RET_OK ? 1 : 4; for (int i = 0; i < input_num; ++i) { - AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h index 2fb2e6e3786..36171e20611 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.h @@ -30,7 +30,8 @@ class TfliteStridedSliceParser : public TfliteNodeParser { TfliteStridedSliceParser() : TfliteNodeParser("StridedSlice") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc index 0e39f06b41e..9569d7a838c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace lite { STATUS TfliteTileParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteTileParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -41,8 +42,7 @@ STATUS TfliteTileParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu return RET_NULL_PTR; } - if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, - attr->multiples)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->multiples)) { MS_LOG(ERROR) << "get tile -> multiples failed"; return RET_ERROR; } @@ -54,10 +54,8 @@ STATUS TfliteTileParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu op->primitive->value.type = schema::PrimitiveType_Tile; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h index bb3a801d15d..11d703e2762 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_tile_parser.h @@ -30,7 +30,8 @@ class TfliteTileParser : public TfliteNodeParser { TfliteTileParser() : TfliteNodeParser("Tile") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc index 248745d24c1..a75354df32b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace lite { STATUS TfliteTopKV2Parser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteTopKV2Parser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -43,7 +44,7 @@ STATUS TfliteTopKV2Parser::Parse(TfliteTensorsInfo *tensors_info, const std::uni attr->sorted = true; std::vector k; - if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, k)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, k)) { MS_LOG(ERROR) << "get topKV2 -> k failed"; return RET_ERROR; } @@ -52,11 +53,9 @@ STATUS TfliteTopKV2Parser::Parse(TfliteTensorsInfo *tensors_info, const std::uni op->primitive->value.type = schema::PrimitiveType_TopK; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); for (size_t i = 0; i < tflite_op->outputs.size(); i++) { - AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h index 1ab9a43e185..f9f2c9b83b3 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.h @@ -30,7 +30,8 @@ class TfliteTopKV2Parser : public TfliteNodeParser { TfliteTopKV2Parser() : TfliteNodeParser("TopKV2") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc index 22a5b8975a5..9a915bc9dc0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.cc @@ -22,7 +22,8 @@ namespace mindspore { namespace lite { STATUS TfliteTransposeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteTransposeParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -40,7 +41,7 @@ STATUS TfliteTransposeParser::Parse(TfliteTensorsInfo *tensors_info, return RET_NULL_PTR; } - if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->perm)) { + if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->perm)) { MS_LOG(ERROR) << "get transpose -> perm failed"; return RET_ERROR; } @@ -49,12 +50,9 @@ STATUS TfliteTransposeParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.type = schema::PrimitiveType_Transpose; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_KHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h index 4fe20528d1b..6babc266e1d 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_transpose_parser.h @@ -30,7 +30,8 @@ class TfliteTransposeParser : public TfliteNodeParser { TfliteTransposeParser() : TfliteNodeParser("Transpose") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc index 4c8cf480eef..68683e22c43 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace lite { STATUS TfliteUniqueParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteUniqueParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -51,11 +52,9 @@ STATUS TfliteUniqueParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni op->primitive->value.type = schema::PrimitiveType_Unique; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); for (size_t i = 0; i < tflite_op->outputs.size(); i++) { - AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h index ba86414d22f..4e7c98ac5bc 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unique_parser.h @@ -30,7 +30,8 @@ class TfliteUniqueParser : public TfliteNodeParser { TfliteUniqueParser() : TfliteNodeParser("Unique") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc index 11247e47208..59d1da4f45c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace lite { STATUS TfliteUnstackParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "paser TfliteUnstackParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -52,11 +53,9 @@ STATUS TfliteUnstackParser::Parse(TfliteTensorsInfo *tensors_info, const std::un op->primitive->value.type = schema::PrimitiveType_Unstack; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); for (size_t i = 0; i < tflite_op->outputs.size(); i++) { - AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h index 7d82dcdb942..873121b31bc 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_unstack_parser.h @@ -30,7 +30,8 @@ class TfliteUnstackParser : public TfliteNodeParser { TfliteUnstackParser() : TfliteNodeParser("Unstack") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc index 63347bee391..f8ed0aa10a6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc @@ -123,6 +123,7 @@ std::map tfMsOpTypeMap{ {tflite::BuiltinOperator_HASHTABLE_LOOKUP, "HashtableLookup"}, {tflite::BuiltinOperator_LSH_PROJECTION, "LshProjection"}, {tflite::BuiltinOperator_SKIP_GRAM, "SKipGram"}, + {tflite::BuiltinOperator_WHILE, "While"}, }; std::map tfMsActivationFunctionMap{ diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc index 7afea698760..e74977d8fd9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace lite { STATUS TfliteWhereParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteWhereParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -41,8 +42,7 @@ STATUS TfliteWhereParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq return RET_NULL_PTR; } - if (GetTfliteData(tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, - attr->condition)) { + if (GetTfliteData(tflite_op->inputs[0], tflite_subgraph->tensors, tflite_model->buffers, attr->condition)) { MS_LOG(ERROR) << "get where -> condition failed"; return RET_ERROR; } @@ -51,11 +51,9 @@ STATUS TfliteWhereParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); i++) { - AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); } - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h index 6bdfbbe9f94..dcb97bfedcb 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_where_parser.h @@ -30,7 +30,8 @@ class TfliteWhereParser : public TfliteNodeParser { TfliteWhereParser() : TfliteNodeParser("Where") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_while_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_while_parser.cc new file mode 100644 index 00000000000..9b01d346565 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_while_parser.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * distributed under the License is distributed on an AS + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/tflite/tflite_while_parser.h" +#include +#include + +namespace mindspore { +namespace lite { +STATUS TfliteWhileParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { + MS_LOG(DEBUG) << "parse TfliteWhileParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + const auto &tflite_attr = tflite_op->builtin_options.AsWhileOptions(); + if (tflite_attr == nullptr) { + MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; + return RET_NULL_PTR; + } + + attr->condSubgraphIndex = tflite_attr->cond_subgraph_index; + attr->bodySubgraphIndex = tflite_attr->body_subgraph_index; + + op->primitive->value.type = schema::PrimitiveType_While; + op->primitive->value.value = attr.release(); + + for (size_t i = 0; i < tflite_op->inputs.size(); i++) { + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + } + for (size_t i = 0; i < tflite_op->outputs.size(); i++) { + AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + } + return RET_OK; +} + +TfliteNodeRegister g_tfliteWhileParser("While", new TfliteWhileParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_while_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_while_parser.h new file mode 100644 index 00000000000..0cd2621ce4c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_while_parser.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_WHILE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_WHILE_PARSER_H + +#include +#include +#include +#include "tools/converter/parser/tflite/tflite_node_parser.h" +#include "tools/converter/parser/tflite/tflite_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class TfliteWhileParser : public TfliteNodeParser { + public: + TfliteWhileParser() : TfliteNodeParser("While") {} + + STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_WHERE_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc index f0865edfa70..eac52201ee2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.cc @@ -24,7 +24,8 @@ namespace mindspore { namespace lite { STATUS TfliteZerosLikeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) { MS_LOG(DEBUG) << "parse TfliteZerosLikeParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; @@ -45,10 +46,8 @@ STATUS TfliteZerosLikeParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.type = schema::PrimitiveType_ZerosLike; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); - AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); + AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC); return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h index 045a66ad859..9a412ee20db 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_zeros_like_parser.h @@ -30,7 +30,8 @@ class TfliteZerosLikeParser : public TfliteNodeParser { TfliteZerosLikeParser() : TfliteNodeParser("ZerosLike") {} STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + const std::unique_ptr &tflite_model, + const std::unique_ptr &tflite_subgraph, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore