forked from mindspore-Ecosystem/mindspore
!7757 [MS][LITE][DEVELOP] add while op parser
Merge pull request !7757 from mengyuanli/while
This commit is contained in:
commit
d378044f37
|
@ -226,6 +226,7 @@ union PrimitiveType {
|
|||
InstanceNorm,
|
||||
Identity,
|
||||
LayerNorm,
|
||||
While,
|
||||
}
|
||||
|
||||
enum QuantType: int {
|
||||
|
|
|
@ -1103,3 +1103,8 @@ table LayerNorm {
|
|||
elementwiseAffine : bool;
|
||||
}
|
||||
|
||||
table While {
|
||||
condSubgraphIndex : int;
|
||||
bodySubgraphIndex : int;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/while.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
typedef struct WhileParemeter {
|
||||
OpParameter op_parameter_;
|
||||
int body_subgraph_index;
|
||||
int cond_subgraph_index;
|
||||
} WhileParemeter;
|
||||
|
||||
OpParameter *PopulateWhileParemeter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
WhileParemeter *while_paremeter = reinterpret_cast<WhileParemeter *>(malloc(sizeof(WhileParemeter)));
|
||||
if (while_paremeter == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc WhileParemeter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(while_paremeter, 0, sizeof(WhileParemeter));
|
||||
auto param = reinterpret_cast<mindspore::lite::While *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
while_paremeter->op_parameter_.type_ = primitive->Type();
|
||||
while_paremeter->body_subgraph_index = param->GetBodySubgraphIndex();
|
||||
while_paremeter->cond_subgraph_index = param->GetCondSubgraphIndex();
|
||||
return reinterpret_cast<OpParameter *>(while_paremeter);
|
||||
}
|
||||
Registry WhileParemeterRegistry(schema::PrimitiveType_While, PopulateWhileParemeter);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -144,6 +144,7 @@
|
|||
#include "src/ops/mfcc.h"
|
||||
#include "src/ops/identity.h"
|
||||
#include "src/ops/instance_norm.h"
|
||||
#include "src/ops/while.h"
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
#include "src/ops/neg_grad.h"
|
||||
|
@ -499,6 +500,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
|
|||
return NewPrimitiveC<Maximum>(prim, inputs, quantType);
|
||||
} else if (op_type == "Split") {
|
||||
return NewPrimitiveC<Split>(prim, inputs, quantType);
|
||||
} else if (op_type == "While") {
|
||||
return NewPrimitiveC<While>(prim, inputs, quantType);
|
||||
} else if (op_type == "OneHot") {
|
||||
return NewPrimitiveC<OneHot>(prim, inputs, quantType);
|
||||
|
||||
|
@ -793,6 +796,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
|
|||
return new Mfcc(primitive);
|
||||
case schema::PrimitiveType_InstanceNorm:
|
||||
return new InstanceNorm(primitive);
|
||||
case schema::PrimitiveType_While:
|
||||
return new While(primitive);
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
case schema::PrimitiveType_ActivationGrad:
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/while.h"
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
||||
void While::SetCondSubgraphIndex(const int cond_subgraph_index) {
|
||||
this->primitive_->value.AsWhile()->condSubgraphIndex = cond_subgraph_index;
|
||||
}
|
||||
void While::SetBodySubgraphIndex(const int body_subgraph_index) {
|
||||
this->primitive_->value.AsWhile()->bodySubgraphIndex = body_subgraph_index;
|
||||
}
|
||||
|
||||
int While::GetCondSubgraphIndex() const { return this->primitive_->value.AsWhile()->condSubgraphIndex; }
|
||||
int While::GetBodySubgraphIndex() const { return this->primitive_->value.AsWhile()->bodySubgraphIndex; }
|
||||
|
||||
int While::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_While;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_While) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::WhileT();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->bodySubgraphIndex = GetValue<bool>(prim.GetAttr("body_subgraph_index"));
|
||||
attr->condSubgraphIndex = GetValue<bool>(prim.GetAttr("cond_subgraph_index"));
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive value is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
int While::GetCondSubgraphIndex() const { return this->primitive_->value_as_While()->condSubgraphIndex(); }
|
||||
int While::GetBodySubgraphIndex() const { return this->primitive_->value_as_While()->bodySubgraphIndex(); }
|
||||
|
||||
int While::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_While();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_While return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto cond_subgraph_index = attr->condSubgraphIndex();
|
||||
auto body_subgraph_index = attr->bodySubgraphIndex();
|
||||
auto val_offset = schema::CreateWhile(*fbb, body_subgraph_index, cond_subgraph_index);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_While, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *WhileCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<While>(primitive); }
|
||||
Registry WhileRegistry(schema::PrimitiveType_While, WhileCreator);
|
||||
|
||||
#endif
|
||||
|
||||
int While::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
if (inputs_.size() != outputs_.size()) {
|
||||
MS_LOG(ERROR) << "The number of inputs and outputs varies";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < inputs_.size(); i++) {
|
||||
outputs_[i]->set_data_type(inputs_[i]->data_type());
|
||||
outputs_[i]->SetFormat(inputs_[i]->GetFormat());
|
||||
outputs_[i]->set_shape(inputs_[i]->shape());
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_WHILE_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_WHILE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class While : public PrimitiveC {
|
||||
public:
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(While, PrimitiveC);
|
||||
While() = default;
|
||||
explicit While(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void SetCondSubgraphIndex(const int cond_subgraph_index);
|
||||
void SetBodySubgraphIndex(const int body_subgraph_index);
|
||||
|
||||
#else
|
||||
While() = default;
|
||||
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
int GetCondSubgraphIndex() const;
|
||||
int GetBodySubgraphIndex() const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_WHERE_H_
|
|
@ -24,7 +24,8 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -75,10 +76,8 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
op->primitive->value.type = schema::PrimitiveType_Activation;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteActivationParser : public TfliteNodeParser {
|
|||
TfliteActivationParser() : TfliteNodeParser("node_name") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class TfliteReluParser : public TfliteActivationParser {
|
||||
|
|
|
@ -23,7 +23,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteAddNParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -41,16 +42,14 @@ STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
attr->N = tflite_model->subgraphs[0]->tensors.size() - 1;
|
||||
attr->N = tflite_subgraph->tensors.size() - 1;
|
||||
op->primitive->value.type = schema::PrimitiveType_AddN;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteAddNParser : public TfliteNodeParser {
|
|||
TfliteAddNParser() : TfliteNodeParser("AddN") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteArgmaxParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -47,7 +48,7 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
|
||||
// get axis attr
|
||||
auto axis_idx = tflite_op->inputs[1];
|
||||
auto buffer_idx = tflite_model->subgraphs[0]->tensors[axis_idx]->buffer;
|
||||
auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer;
|
||||
auto &buf_data = tflite_model->buffers[buffer_idx];
|
||||
if (buf_data == nullptr) {
|
||||
MS_LOG(ERROR) << "the buf data is null";
|
||||
|
@ -63,10 +64,8 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
op->primitive->value.type = schema::PrimitiveType_ArgMax;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteArgmaxParser : public TfliteNodeParser {
|
|||
TfliteArgmaxParser() : TfliteNodeParser("Argmax") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteArgminParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -47,7 +48,7 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
|
||||
// get axis attr
|
||||
auto axis_idx = tflite_op->inputs[1];
|
||||
auto buffer_idx = tflite_model->subgraphs[0]->tensors[axis_idx]->buffer;
|
||||
auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer;
|
||||
auto &buf_data = tflite_model->buffers[buffer_idx];
|
||||
if (buf_data == nullptr) {
|
||||
MS_LOG(ERROR) << "the buf data is null";
|
||||
|
@ -63,10 +64,8 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
op->primitive->value.type = schema::PrimitiveType_ArgMin;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteArgminParser : public TfliteNodeParser {
|
|||
TfliteArgminParser() : TfliteNodeParser("Argmin") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,8 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -168,17 +169,16 @@ STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
|
||||
// set input
|
||||
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -305,16 +305,15 @@ STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -385,11 +384,9 @@ STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
}
|
||||
|
||||
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser {
|
|||
TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class TfliteAddParser : public TfliteDoubleInputOpParser {
|
||||
|
@ -93,7 +94,8 @@ class TfliteSingleInputOpParser : public TfliteNodeParser {
|
|||
TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class TfliteAbsParser : public TfliteSingleInputOpParser {
|
||||
|
@ -161,7 +163,8 @@ class TfliteCompareOpParser : public TfliteNodeParser {
|
|||
TfliteCompareOpParser() : TfliteNodeParser("node_name") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class TfliteEqualParser : public TfliteCompareOpParser {
|
||||
|
|
|
@ -25,7 +25,8 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -51,12 +52,11 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers,
|
||||
attr->blockShape)) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->blockShape)) {
|
||||
MS_LOG(ERROR) << "get batchToSpace -> blockShape failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->crops)) {
|
||||
if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->crops)) {
|
||||
MS_LOG(ERROR) << "get batchToSpace -> crops failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -64,10 +64,8 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
op->primitive->value.type = schema::PrimitiveType_BatchToSpace;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser {
|
|||
TfliteBatchToSpaceParser() : TfliteNodeParser("BatchToSpace") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser {
|
||||
|
|
|
@ -24,7 +24,8 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteBroadcastToParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -42,8 +43,7 @@ STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers,
|
||||
attr->dst_shape)) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dst_shape)) {
|
||||
MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -51,10 +51,8 @@ STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteBroadcastToParser : public TfliteNodeParser {
|
|||
TfliteBroadcastToParser() : TfliteNodeParser("BroadcastTo") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteCastParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -40,13 +41,13 @@ STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
const auto &in_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]];
|
||||
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
|
||||
if (in_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->srcT = GetTfliteDataType(in_tensor->type);
|
||||
const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]];
|
||||
const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]];
|
||||
if (out_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -56,10 +57,8 @@ STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
|
|||
op->primitive->value.type = schema::PrimitiveType_Cast;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteCastParser : public TfliteNodeParser {
|
|||
TfliteCastParser() : TfliteNodeParser("Cast") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteConcatParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -52,11 +53,9 @@ STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
op->primitive->value.value = attr.release();
|
||||
|
||||
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteConcatParser : public TfliteNodeParser {
|
|||
TfliteConcatParser() : TfliteNodeParser("Concat") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteConvParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -57,7 +58,7 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
|
|||
|
||||
// get the conv op weight tensor
|
||||
auto weight_index = tflite_op->inputs[1];
|
||||
const auto &weight_tensor = tflite_model->subgraphs[0]->tensors[weight_index];
|
||||
const auto &weight_tensor = tflite_subgraph->tensors[weight_index];
|
||||
if (weight_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the weight tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -70,7 +71,7 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
|
|||
|
||||
// calculate pad params
|
||||
auto data_index = tflite_op->inputs[0];
|
||||
const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index];
|
||||
const auto &data_tensor = tflite_subgraph->tensors[data_index];
|
||||
std::vector<int> params;
|
||||
int status =
|
||||
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms);
|
||||
|
@ -87,14 +88,10 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
|
|||
op->primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_KHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteConvParser : public TfliteNodeParser {
|
|||
TfliteConvParser() : TfliteNodeParser("Conv2D") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -139,14 +139,15 @@ STATUS TfliteCustomParser::ExtractFeatures(const std::vector<uint8_t> &custom_at
|
|||
|
||||
STATUS TfliteCustomParser::Rfft(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph) {
|
||||
std::unique_ptr<schema::RfftT> attr = std::make_unique<schema::RfftT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<int> fft_length;
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, fft_length)) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, fft_length)) {
|
||||
MS_LOG(ERROR) << "rfft -> fftLength get failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -181,7 +182,8 @@ STATUS TfliteCustomParser::FftImag(const std::vector<uint8_t> &custom_attr, sche
|
|||
}
|
||||
|
||||
STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteCustomParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -209,7 +211,7 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
} else if (custom_type == "Mfcc") {
|
||||
status = Mfcc(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "FlexRFFT") {
|
||||
status = Rfft(custom_attr, op, tflite_op, tflite_model);
|
||||
status = Rfft(custom_attr, op, tflite_op, tflite_model, tflite_subgraph);
|
||||
} else if (custom_type == "FlexReal") {
|
||||
status = FftReal(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "FlexImag") {
|
||||
|
@ -222,12 +224,10 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
return status;
|
||||
}
|
||||
for (size_t i = 0; i < tflite_op->inputs.size(); ++i) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
for (size_t i = 0; i < tflite_op->outputs.size(); ++i) {
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteCustomParser : public TfliteNodeParser {
|
|||
TfliteCustomParser() : TfliteNodeParser("Custom") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
|
||||
STATUS DetectPostProcess(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op);
|
||||
|
@ -51,7 +52,8 @@ class TfliteCustomParser : public TfliteNodeParser {
|
|||
const std::unique_ptr<tflite::OperatorT> &tflite_op);
|
||||
|
||||
STATUS Rfft(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op, const std::unique_ptr<tflite::ModelT> &tflite_model);
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op, const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph);
|
||||
|
||||
STATUS FftReal(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op);
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -58,7 +59,7 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
|
||||
// get the conv op weight tensor
|
||||
auto weight_index = tflite_op->inputs[1];
|
||||
const auto &weight_tensor = tflite_model->subgraphs[0]->tensors[weight_index];
|
||||
const auto &weight_tensor = tflite_subgraph->tensors[weight_index];
|
||||
if (weight_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the weight tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -71,7 +72,7 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
|
||||
// calculate pad params
|
||||
auto data_index = tflite_op->inputs[2];
|
||||
const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index];
|
||||
const auto &data_tensor = tflite_subgraph->tensors[data_index];
|
||||
std::vector<int> params;
|
||||
int status =
|
||||
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms);
|
||||
|
@ -88,12 +89,9 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
op->primitive->value.type = schema::PrimitiveType_DeConv2D;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_KHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteDeConvParser : public TfliteNodeParser {
|
|||
TfliteDeConvParser() : TfliteNodeParser("DeConv2D") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,8 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TfliteDepthToSpaceParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser";
|
||||
|
||||
if (op == nullptr) {
|
||||
|
@ -54,10 +55,8 @@ STATUS TfliteDepthToSpaceParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
op->primitive->value.type = schema::PrimitiveType_DepthToSpace;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser {
|
|||
TfliteDepthToSpaceParser() : TfliteNodeParser("DepthToSpace") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,9 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -58,7 +60,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
|
||||
// get the data tensor
|
||||
auto data_index = tflite_op->inputs[1];
|
||||
const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index];
|
||||
const auto &data_tensor = tflite_subgraph->tensors[data_index];
|
||||
if (data_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the data tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -68,7 +70,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
|
||||
// get the weight tensor
|
||||
auto weight_index = tflite_op->inputs[1];
|
||||
const auto &weight_tensor = tflite_model->subgraphs[0]->tensors[weight_index];
|
||||
const auto &weight_tensor = tflite_subgraph->tensors[weight_index];
|
||||
if (weight_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the weight tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -94,14 +96,10 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_KHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser {
|
|||
TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,8 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteDequantizeNParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -34,12 +35,12 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
const auto &in_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]];
|
||||
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
|
||||
if (in_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "input tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]];
|
||||
const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]];
|
||||
if (out_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "output tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -68,10 +69,8 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
op->primitive->value.type = schema::PrimitiveType_Cast;
|
||||
}
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,8 @@ class TfliteDequantizeParser : public TfliteNodeParser {
|
|||
TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,8 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TfliteExpandDimsParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteExpandDimsParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -41,17 +42,15 @@ STATUS TfliteExpandDimsParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<int> dims;
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, dims)) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, dims)) {
|
||||
MS_LOG(ERROR) << "get expand_dims -> dim failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->dim = dims[0];
|
||||
op->primitive->value.type = schema::PrimitiveType_ExpandDims;
|
||||
op->primitive->value.value = attr.release();
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser());
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteExpandDimsParser : public TfliteNodeParser {
|
|||
TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteFillParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -41,7 +42,7 @@ STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
|
|||
}
|
||||
|
||||
if (tflite_op->inputs.size() > 1) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->dims)) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dims)) {
|
||||
MS_LOG(ERROR) << "get fill -> dims failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -50,10 +51,8 @@ STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
|
|||
op->primitive->value.type = schema::PrimitiveType_Fill;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteFillParser : public TfliteNodeParser {
|
|||
TfliteFillParser() : TfliteNodeParser("Fill") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,9 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TfliteFullyConnectedParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -57,16 +59,12 @@ STATUS TfliteFullyConnectedParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
op->primitive->value.type = schema::PrimitiveType_FullConnection;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_KHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_KHWC);
|
||||
if (hasBias) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteFullyConnectedParser : public TfliteNodeParser {
|
|||
TfliteFullyConnectedParser() : TfliteNodeParser("FullyConnected") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class TfliteFakeQuantParser : public TfliteFullyConnectedParser {
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteGatherNdParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteGatherNdParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -46,11 +47,9 @@ STATUS TfliteGatherNdParser::Parse(TfliteTensorsInfo *tensors_info, const std::u
|
|||
op->primitive->value.value = attr.release();
|
||||
|
||||
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteGatherNdParser : public TfliteNodeParser {
|
|||
TfliteGatherNdParser() : TfliteNodeParser("GatherND") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteGatherParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteGatherParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -52,11 +53,9 @@ STATUS TfliteGatherParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
op->primitive->value.value = attr.release();
|
||||
|
||||
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteGatherParser : public TfliteNodeParser {
|
|||
TfliteGatherParser() : TfliteNodeParser("Gather") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,9 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TfliteHashtableLookupParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteHashtableLookupParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -44,12 +46,10 @@ STATUS TfliteHashtableLookupParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
op->primitive->value.type = schema::PrimitiveType_HashtableLookup;
|
||||
op->primitive->value.value = attr.release();
|
||||
for (size_t i = 0; i < tflite_op->inputs.size(); ++i) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
for (size_t i = 0; i < tflite_op->outputs.size(); ++i) {
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteHashtableLookupParser : public TfliteNodeParser {
|
|||
TfliteHashtableLookupParser() : TfliteNodeParser("HashtableLookup") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteL2NormParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteL2NormParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -49,10 +50,8 @@ STATUS TfliteL2NormParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
op->primitive->value.value = attr.release();
|
||||
|
||||
// set input and output
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteL2NormParser : public TfliteNodeParser {
|
|||
TfliteL2NormParser() : TfliteNodeParser("L2_NORMALIZATION") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteLogicalParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -67,11 +68,9 @@ STATUS TfliteLogicalParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
}
|
||||
|
||||
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteLogicalParser : public TfliteNodeParser {
|
|||
TfliteLogicalParser() : TfliteNodeParser("node_name") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class TfliteLogicalAndParser : public TfliteLogicalParser {
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteLRNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteLRNParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -53,10 +54,8 @@ STATUS TfliteLRNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique
|
|||
op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteLRNParser : public TfliteNodeParser {
|
|||
TfliteLRNParser() : TfliteNodeParser("LocalResponseNorm") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,8 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TfliteLshProjectionParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteLshProjectionParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -56,11 +57,9 @@ STATUS TfliteLshProjectionParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
op->primitive->value.value = attr.release();
|
||||
|
||||
for (size_t i = 0; i < tflite_op->inputs.size(); ++i) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteLshProjectionParser : public TfliteNodeParser {
|
|||
TfliteLshProjectionParser() : TfliteNodeParser("LshProjection") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -116,7 +116,7 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
|
|||
continue;
|
||||
}
|
||||
if (status == RET_OK) {
|
||||
status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, op.get());
|
||||
status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, tflite_subgraph, op.get());
|
||||
if (status != RET_OK) {
|
||||
if (status == RET_NOT_FIND_OP) {
|
||||
op_type =
|
||||
|
@ -337,18 +337,10 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph)
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) {
|
||||
// load graph
|
||||
auto tflite_model = ReadTfliteModel(model_file.c_str());
|
||||
if (tflite_model == nullptr) {
|
||||
MS_LOG(ERROR) << "read tflite model failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (tflite_model->subgraphs.size() != 1) {
|
||||
MS_LOG(ERROR) << "read tflite model subgraphs failed";
|
||||
std::unique_ptr<schema::MetaGraphT> TfliteModelParser::ConstructMainGraph(
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, const QuantType &quant_type) {
|
||||
if (tflite_model->subgraphs.size() < 1) {
|
||||
MS_LOG(ERROR) << "read tflite model main subgraphs failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -394,7 +386,28 @@ schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
return meta_graph.release();
|
||||
return meta_graph;
|
||||
}
|
||||
|
||||
schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) {
|
||||
// load graph
|
||||
auto tflite_model = ReadTfliteModel(model_file.c_str());
|
||||
if (tflite_model == nullptr) {
|
||||
MS_LOG(ERROR) << "read tflite model failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// construct main_meta_graph
|
||||
auto main_meta_graph = ConstructMainGraph(tflite_model, quant_type);
|
||||
if (main_meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "ConstructMainGraph failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return main_meta_graph.release();
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -64,6 +64,9 @@ class TfliteModelParser : public ModelParser {
|
|||
|
||||
STATUS ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph);
|
||||
|
||||
std::unique_ptr<schema::MetaGraphT> ConstructMainGraph(const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const QuantType &quant_type);
|
||||
|
||||
private:
|
||||
TfliteTensorsInfo tensorsInfo;
|
||||
std::vector<schema::TensorT *> tensors;
|
||||
|
|
|
@ -39,7 +39,8 @@ class TfliteNodeParser {
|
|||
virtual ~TfliteNodeParser() = default;
|
||||
|
||||
virtual STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) = 0;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) = 0;
|
||||
|
||||
void AddOpInput(schema::CNodeT *op, TfliteTensorsInfo *tensors_info, int idx, int total, schema::Format format) {
|
||||
int new_idx = tensors_info->tensorsId.size();
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteOneHotParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -46,7 +47,7 @@ STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
auto axis = tflite_attr->axis;
|
||||
const auto &tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]];
|
||||
const auto &tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -57,11 +58,9 @@ STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
op->primitive->value.value = attr.release();
|
||||
|
||||
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteOneHotParser : public TfliteNodeParser {
|
|||
TfliteOneHotParser() : TfliteNodeParser("OneHot") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TflitePadParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -51,8 +52,7 @@ STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique
|
|||
}
|
||||
attr->paddingMode = schema::PaddingMode_CONSTANT;
|
||||
attr->constantValue = 0.0f;
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers,
|
||||
attr->paddings)) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->paddings)) {
|
||||
MS_LOG(ERROR) << "get pad -> paddings failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -81,14 +81,11 @@ STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique
|
|||
op->primitive->value.type = schema::PrimitiveType_Pad;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
if (std::strcmp(node_name, "MirrorPad") == 0) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TflitePadParser : public TfliteNodeParser {
|
|||
TflitePadParser() : TfliteNodeParser("Pad") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -69,7 +70,7 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
|
||||
// calculate pad params
|
||||
auto data_index = tflite_op->inputs[0];
|
||||
const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index];
|
||||
const auto &data_tensor = tflite_subgraph->tensors[data_index];
|
||||
std::vector<int> params;
|
||||
int status =
|
||||
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms);
|
||||
|
@ -86,10 +87,8 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
op->primitive->value.type = schema::PrimitiveType_Pooling;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TflitePoolingParser : public TfliteNodeParser {
|
|||
TflitePoolingParser() : TfliteNodeParser("node_name") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class TfliteMeanPoolingParser : public TflitePoolingParser {
|
||||
|
|
|
@ -23,7 +23,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TflitePReLUParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TflitePReLUParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -44,12 +45,9 @@ STATUS TflitePReLUParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
|
|||
op->primitive->value.type = schema::PrimitiveType_PReLU;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TflitePReLUParser : public TfliteNodeParser {
|
|||
TflitePReLUParser() : TfliteNodeParser("PRELU") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,7 +21,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteQuantizeNParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -33,12 +34,12 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
const auto &in_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]];
|
||||
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
|
||||
if (in_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "input tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]];
|
||||
const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]];
|
||||
if (out_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "output tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -67,10 +68,8 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u
|
|||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,8 @@ class TfliteQuantizeParser : public TfliteNodeParser {
|
|||
TfliteQuantizeParser() : TfliteNodeParser("Quantize") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteRangeParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -43,12 +44,12 @@ STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
|
|||
attr->dType = 0;
|
||||
std::vector<int> limit;
|
||||
std::vector<int> delta;
|
||||
int status = GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, limit);
|
||||
int status = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, limit);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "range -> limit get failed";
|
||||
return RET_ERROR;
|
||||
} else if (status == RET_OK) {
|
||||
status = GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, delta);
|
||||
status = GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, delta);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "stridedSlice -> end get failed";
|
||||
return RET_ERROR;
|
||||
|
@ -63,11 +64,9 @@ STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
|
|||
|
||||
int input_num = status == RET_OK ? 1 : 3;
|
||||
for (int i = 0; i < input_num; ++i) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteRangeParser : public TfliteNodeParser {
|
|||
TfliteRangeParser() : TfliteNodeParser("Range") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteRankParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteRankParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -43,10 +44,8 @@ STATUS TfliteRankParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
|
|||
op->primitive->value.type = schema::PrimitiveType_Rank;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteRankParser : public TfliteNodeParser {
|
|||
TfliteRankParser() : TfliteNodeParser("Rank") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -72,7 +73,7 @@ STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->axes)) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->axes)) {
|
||||
MS_LOG(ERROR) << "get reduce -> axes failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -80,10 +81,8 @@ STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
op->primitive->value.type = schema::PrimitiveType_Reduce;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteReduceParser : public TfliteNodeParser {
|
|||
TfliteReduceParser() : TfliteNodeParser("node_name") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class TfliteReduceMaxParser : public TfliteReduceParser {
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteReshapeParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -47,7 +48,7 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto shape_tensor_index = tflite_op->inputs[1];
|
||||
const auto &shape_tensor = tflite_model->subgraphs[0]->tensors[shape_tensor_index];
|
||||
const auto &shape_tensor = tflite_subgraph->tensors[shape_tensor_index];
|
||||
if (shape_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "shape_tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -58,8 +59,7 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
if (!buf_data->data.empty()) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers,
|
||||
attr->shape)) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->shape)) {
|
||||
MS_LOG(ERROR) << "get reshape -> shape failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -76,11 +76,9 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
op->primitive->value.value = attr.release();
|
||||
|
||||
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteReshapeParser : public TfliteNodeParser {
|
|||
TfliteReshapeParser() : TfliteNodeParser("Reshape") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -87,7 +88,7 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
attr->preserveAspectRatio = false;
|
||||
|
||||
auto tfliteResizeTensorIndex = tflite_op->inputs[1];
|
||||
const auto &shape_tensor = tflite_model->subgraphs[0]->tensors[tfliteResizeTensorIndex];
|
||||
const auto &shape_tensor = tflite_subgraph->tensors[tfliteResizeTensorIndex];
|
||||
if (shape_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "shape_tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -109,14 +110,11 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
op->primitive->value.type = schema::PrimitiveType_Resize;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
if (buffData == nullptr) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteResizeParser : public TfliteNodeParser {
|
|||
TfliteResizeParser() : TfliteNodeParser("node_name") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class TfliteResizeBilinearParser : public TfliteResizeParser {
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteReverseParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -40,7 +41,7 @@ STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->axis)) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->axis)) {
|
||||
MS_LOG(ERROR) << "get reverse -> axis failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -48,10 +49,8 @@ STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
op->primitive->value.type = schema::PrimitiveType_Reverse;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteReverseParser : public TfliteNodeParser {
|
|||
TfliteReverseParser() : TfliteNodeParser("reverse") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,9 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TfliteReverseSequenceParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteReverseSequenceParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -53,12 +55,9 @@ STATUS TfliteReverseSequenceParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
op->primitive->value.type = schema::PrimitiveType_ReverseSequence;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteReverseSequenceParser : public TfliteNodeParser {
|
|||
TfliteReverseSequenceParser() : TfliteNodeParser("ReverseSequence") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,8 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TfliteScatterNdParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteScatterNdParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -52,14 +53,10 @@ STATUS TfliteScatterNdParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
|
||||
// in tflite, kIndices = 0, kUpdates = 1, kShape = 2
|
||||
// in mslite, kScatterShapeIndex = 0, kScatterIndicesIndex = 1, kScatterUpdateIndex = 2;
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteScatterNdParser : public TfliteNodeParser {
|
|||
TfliteScatterNdParser() : TfliteNodeParser("ScatterNd") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteShapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteShapeParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -43,10 +44,8 @@ STATUS TfliteShapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
|
|||
op->primitive->value.type = schema::PrimitiveType_Shape;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteShapeParser : public TfliteNodeParser {
|
|||
TfliteShapeParser() : TfliteNodeParser("Shape") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteSkipGramParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSkipGramParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -52,10 +53,8 @@ STATUS TfliteSkipGramParser::Parse(TfliteTensorsInfo *tensors_info, const std::u
|
|||
op->primitive->value.type = schema::PrimitiveType_SkipGram;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteSkipGramParser : public TfliteNodeParser {
|
|||
TfliteSkipGramParser() : TfliteNodeParser("SkipGram") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSliceParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -42,11 +43,11 @@ STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
|
|||
|
||||
attr->format = schema::Format::Format_NHWC;
|
||||
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->begin)) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->begin)) {
|
||||
MS_LOG(ERROR) << "get slice -> begin failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->size)) {
|
||||
if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->size)) {
|
||||
MS_LOG(ERROR) << "get slice -> size failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -59,10 +60,8 @@ STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
|
|||
op->primitive->value.type = schema::PrimitiveType_Slice;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteSliceParser : public TfliteNodeParser {
|
|||
TfliteSliceParser() : TfliteNodeParser("Slice") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteSoftmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSoftmaxParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -45,10 +46,8 @@ STATUS TfliteSoftmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
op->primitive->value.type = schema::PrimitiveType_SoftMax;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteSoftmaxParser : public TfliteNodeParser {
|
|||
TfliteSoftmaxParser() : TfliteNodeParser("Softmax") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,9 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TfliteSpaceToBatchNDParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
||||
schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSpaceToBatchNDParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -42,12 +44,11 @@ STATUS TfliteSpaceToBatchNDParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers,
|
||||
attr->blockShape)) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->blockShape)) {
|
||||
MS_LOG(ERROR) << "get spaceToBatchND -> blockShape failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->paddings)) {
|
||||
if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->paddings)) {
|
||||
MS_LOG(ERROR) << "get spaceToBatchND -> paddings failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -55,10 +56,8 @@ STATUS TfliteSpaceToBatchNDParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
op->primitive->value.type = schema::PrimitiveType_SpaceToBatchND;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteSpaceToBatchNDParser : public TfliteNodeParser {
|
|||
TfliteSpaceToBatchNDParser() : TfliteNodeParser("SpaceToBatchND") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,8 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TfliteSpaceToDepthParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSpaceToDepthParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -53,10 +54,8 @@ STATUS TfliteSpaceToDepthParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
op->primitive->value.type = schema::PrimitiveType_SpaceToDepth;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteSpaceToDepthParser : public TfliteNodeParser {
|
|||
TfliteSpaceToDepthParser() : TfliteNodeParser("SpaceToDepth") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,8 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
STATUS TfliteSparseToDenseParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSparseToDenseParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -46,16 +47,11 @@ STATUS TfliteSparseToDenseParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
op->primitive->value.type = schema::PrimitiveType_SparseToDense;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[3], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[3], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class TfliteSparseToDenseParser : public TfliteNodeParser {
|
|||
TfliteSparseToDenseParser() : TfliteNodeParser("SparseToDense") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSplitParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
|
@ -47,13 +48,13 @@ STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
|
|||
}
|
||||
auto num_splits = tflite_attr->num_splits;
|
||||
|
||||
const auto &shape_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[1]];
|
||||
const auto &shape_tensor = tflite_subgraph->tensors[tflite_op->inputs[1]];
|
||||
if (shape_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "shape_tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
const auto tensor_shape = shape_tensor->shape;
|
||||
const auto &axis_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]];
|
||||
const auto &axis_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
|
||||
if (axis_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "axis_tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -81,11 +82,9 @@ STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
|
|||
op->primitive->value.type = schema::PrimitiveType_Split;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
for (size_t i = 0; i < tflite_op->outputs.size(); i++) {
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(),
|
||||
schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue