!8770 tflite parser supported to anf

From: @cjh9368
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-24 09:32:12 +08:00 committed by Gitee
commit f9e4af259a
141 changed files with 3053 additions and 925 deletions

View File

@ -47,13 +47,17 @@ using TensorPtr = std::shared_ptr<mindspore::tensor::Tensor>;
constexpr int kAnfPopulaterInputNumOne = 1;
constexpr int kAnfPopulaterInputNumTwo = 2;
constexpr int kAnfPopulaterInputNumThree = 3;
static std::map<std::string, schema::ActivationType> kActivationTypeMap{{"ReLU", schema::ActivationType_RELU},
{"ReLU6", schema::ActivationType_RELU6},
{"Sigmoid", schema::ActivationType_SIGMOID},
{"HSwish", schema::ActivationType_HSWISH},
{"HSigmoid", schema::ActivationType_HSIGMOID}};
static std::map<std::string, schema::ActivationType> kActivationTypeMap{
{"ReLU", schema::ActivationType_RELU},
{"ReLU6", schema::ActivationType_RELU6},
{"Sigmoid", schema::ActivationType_SIGMOID},
{"HSwish", schema::ActivationType_HSWISH},
{"HSigmoid", schema::ActivationType_HSIGMOID},
{"Swish", schema::ActivationType_SWISH},
{"LeakyRelu", schema::ActivationType_LEAKY_RELU},
{"Tanh", schema::ActivationType_TANH},
{"Logistic", schema::ActivationType_SIGMOID}};
std::vector<int> CastToInt(const ValuePtr value, bool is_vector);
class PrimitiveC : public mindspore::Primitive {
public:
// Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC().

View File

@ -104,8 +104,8 @@ int Split::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
if (inputs_.size() != kSplitInputNum) {
MS_LOG(ERROR) << "inputs number is not equal to " << kSplitInputNum;
if (inputs_.size() < kSplitInputNum) {
MS_LOG(ERROR) << "inputs number is less to " << kSplitInputNum;
return RET_ERROR;
}
auto output = outputs_.front();

View File

@ -194,6 +194,8 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
${LITE_DIR}/tools/optimizer/graph/group_depthwise_op_convert_pass.cc
${LITE_DIR}/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc

View File

@ -135,6 +135,6 @@ mtk_convert_model.tflite
mtk_model_face_dress_fp16.tflite
smartreply.tflite
mindspore_text_classification_tflite.tflite
ml_location.tflite
# ml_location.tflite
ml_text_correction.tflite
ml_pic_shopping.tflite

View File

@ -49,6 +49,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc
../optimizer/graph/clip_convert_activation_pass.cc
../optimizer/graph/group_depthwise_op_convert_pass.cc
../optimizer/graph/tflite_inputs_order_exchange_pass.cc
../optimizer/graph/unused_cast_node_remove_pass.cc
../optimizer/graph/unused_transpose_node_remove_pass.cc
../optimizer/graph/identity_remove_pass.cc

View File

@ -25,7 +25,6 @@
#include "tools/optimizer/fusion/conv_bn_fusion.h"
#include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h"
#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include "tools/optimizer/fusion/quant_dtype_cast_fusion.h"
#include "tools/optimizer/fusion/layer_norm_fusion.h"
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
#include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
@ -34,6 +33,8 @@
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/optimizer/graph/weight_format_transform_pass.h"
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
#include "tools/optimizer/graph/group_depthwise_op_convert_pass.h"
#include "tools/optimizer/graph/tflite_inputs_order_exchange_pass.h"
#include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
#include "tools/optimizer/graph/unused_transpose_node_remove_pass.h"
#include "tools/optimizer/graph/infershape_pass.h"
@ -43,8 +44,7 @@
#include "tools/converter/quantizer/weight_quantizer.h"
using std::string;
namespace mindspore {
namespace lite {
namespace mindspore::lite {
AnfTransform::AnfTransform() = default;
AnfTransform::~AnfTransform() = default;
@ -65,7 +65,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
cf_pm->AddPass(std::make_shared<opt::ConstFoldPass>());
// for now - trainning is not supporting fuse operations
if (config != nullptr && !config->trainModel) {
if (!config->trainModel) {
// remove quantdtype when awaretraining
pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
@ -119,6 +119,10 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
}
pm->AddPass(std::make_shared<opt::ConvConvFusion>());
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
if (config->fmk == lite::converter::FmkType_TFLITE) {
convert_pm->AddPass(std::make_shared<opt::GroupDepthwiseOpConvertPass>());
convert_pm->AddPass(std::make_shared<opt::TfliteInputsOrderExchangePass>());
}
optimizer->AddPassManager(cf_pm);
optimizer->AddPassManager(convert_pm);
optimizer->AddPassManager(pm);
@ -168,5 +172,4 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
return new_graph;
}
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite

View File

@ -32,8 +32,9 @@ class ModelParser {
virtual ~ModelParser() = default;
virtual FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) {
auto *meta_graph = ParseToFb(modelFile, weightFile, quantType);
virtual FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
auto *meta_graph = ParseToFb(model_file, weight_file, quant_type);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "parse model to fb failed";
return nullptr;
@ -43,8 +44,8 @@ class ModelParser {
return func_graph;
}
virtual schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) = 0;
virtual schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type = QuantType_QUANT_NONE) = 0;
public:
static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) {

View File

@ -31,22 +31,22 @@ CaffeModelParser::~CaffeModelParser() {}
const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"};
schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
int status = ValidateFileStr(modelFile, ".prototxt");
schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
int status = ValidateFileStr(model_file, ".prototxt");
if (status != RET_OK) {
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
if (weightFile.empty()) {
if (weight_file.empty()) {
MS_LOG(ERROR) << "INPUT MISSING: weightFile is necessary";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
return nullptr;
}
status = ValidateFileStr(weightFile, ".caffemodel");
status = ValidateFileStr(weight_file, ".caffemodel");
if (status != RET_OK) {
MS_LOG(ERROR) << "INPUT ILLEGAL: weightFile must be *.caffemodel";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@ -57,18 +57,18 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co
TensorCache tensorCache;
caffe::NetParameter proto;
status = ReadProtoFromText((const char *)modelFile.c_str(), &proto);
status = ReadProtoFromText((const char *)model_file.c_str(), &proto);
if (status != RET_OK) {
MS_LOG(ERROR) << "Read prototxt file failed, model path: " << modelFile;
MS_LOG(ERROR) << "Read prototxt file failed, model path: " << model_file;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
metaGraph->name = proto.name();
caffe::NetParameter weight;
status = ReadProtoFromBinaryFile((const char *)weightFile.c_str(), &weight);
status = ReadProtoFromBinaryFile((const char *)weight_file.c_str(), &weight);
if (status != RET_OK) {
MS_LOG(ERROR) << "Read caffemodel file failed, model path: " << weightFile;
MS_LOG(ERROR) << "Read caffemodel file failed, model path: " << weight_file;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
@ -81,7 +81,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co
}
NoSupportOp::GetInstance()->SetFmkType("CAFFE");
status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quantType);
status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quant_type);
if (status != RET_OK) {
MS_LOG(ERROR) << "ParseLayer failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@ -97,7 +97,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
metaGraph->name = GetModelName(modelFile);
metaGraph->name = GetModelName(model_file);
SetAllTensors(tensorCache, metaGraph.get());

View File

@ -34,8 +34,8 @@ class CaffeModelParser : public ModelParser {
virtual ~CaffeModelParser();
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;
schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type = QuantType_QUANT_NONE) override;
private:
STATUS SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache);

View File

@ -623,9 +623,9 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT
return RET_OK;
}
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
int status = ValidateFileStr(modelFile, ".onnx");
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
int status = ValidateFileStr(model_file, ".onnx");
if (status != RET_OK) {
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.onnx";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@ -633,9 +633,9 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
}
onnx::ModelProto onnx_model;
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model);
status = ReadProtoFromBinaryFile((const char *)model_file.c_str(), &onnx_model);
if (status != RET_OK) {
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile;
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_file;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
@ -645,13 +645,13 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
auto dst_graph = std::make_unique<schema::MetaGraphT>();
auto dst_sub_graph = std::make_unique<schema::SubGraphT>();
int ret = ParseGraph(dst_graph.get(), dst_sub_graph.get(), onnx_graph, quantType);
int ret = ParseGraph(dst_graph.get(), dst_sub_graph.get(), onnx_graph, quant_type);
dst_graph->subGraph.push_back(std::move(dst_sub_graph));
subGraphNum += 1;
if (ret == RET_ERROR) {
return nullptr;
}
dst_graph->name = GetModelName(modelFile);
dst_graph->name = GetModelName(model_file);
std::vector<uint32_t> input_temp_index;
for (size_t i = 0; i < dst_graph->subGraph.front()->inputIndices.size(); i++) {

View File

@ -45,8 +45,8 @@ class OnnxModelParser : public ModelParser {
int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph,
const QuantType &quantType);
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;
schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type = QuantType_QUANT_NONE) override;
static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);

View File

@ -1,273 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/model_parser_for_tflite.h"
#include <string>
#include <vector>
#include <memory>
#include <algorithm>
#include "src/param_value_lite.h"
namespace mindspore::lite {
FuncGraphPtr ModelParserForTflite::Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
// load graph
tfliteModel = ReadTfliteModel(modelFile.c_str());
if (tfliteModel == nullptr) {
MS_LOG(ERROR) << "read tflite model failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
return nullptr;
}
if (tfliteModel->subgraphs.size() != 1) {
MS_LOG(ERROR) << "read tflite model subgraphs failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
return nullptr;
}
funcGraphPtr = std::make_shared<FuncGraph>();
auto status = ConvertGraphInputs();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph inputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
status = ConvertOps();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert ops failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
status = ConvertGraphOutputs();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph outputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
return funcGraphPtr;
}
STATUS ModelParserForTflite::ConvertOps() {
const auto &tfliteSubgraph = tfliteModel->subgraphs.front();
const auto &tfliteModelBuffers = tfliteModel->buffers;
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
STATUS status = RET_OK;
int opIdx = 0;
for (auto &op : tfliteSubgraph->operators) {
auto tfliteOpType = (tfliteModel->operator_codes[op->opcode_index])->builtin_code;
auto opType = GetMSOpType(tfliteOpType);
// parse primitive
auto nodeParser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType);
if (nodeParser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(opType);
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
continue;
}
PrimitiveC *primitiveC = nullptr;
if (status == RET_OK) {
status = nodeParser->Parse(op, tfliteModel, primitiveC);
if (status != RET_OK) {
if (status == RET_NOT_FIND_OP) {
opType = (opType != "Custom" ? opType : (tfliteModel->operator_codes[op->opcode_index])->custom_code);
NoSupportOp::GetInstance()->InsertOp(opType);
} else {
MS_LOG(ERROR) << "node " << opType.c_str() << " parser failed";
}
continue;
}
std::vector<AnfNodePtr> opInputs = {NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC))};
// parse inputs
for (auto inputIdx : op->inputs) {
const auto &inputTensor = tfliteSubgraph->tensors[inputIdx];
if (nodes.find(inputIdx) != nodes.end()) {
opInputs.emplace_back(nodes.at(inputIdx));
continue;
}
// const tensor
if (tfliteModelBuffers.at(inputTensor->buffer)->data.empty()) {
ParameterPtr parameter;
ConvertConstTensor(inputTensor.get(), parameter);
opInputs.emplace_back(parameter);
nodes.insert(std::pair(inputIdx, parameter));
continue;
}
MS_LOG(ERROR) << "tensor" << inputIdx << " is neither a node output nor a weight tensor.";
return RET_ERROR;
}
auto newCNode = funcGraphPtr->NewCNode(opInputs);
newCNode->set_fullname_with_scope(opType + "-" + std::to_string(opIdx++));
// parse outputs
status = ConvertOutputTensor(op.get(), newCNode);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert output tensors for " << newCNode->fullname_with_scope() << " failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return status;
}
}
}
return status;
}
STATUS ModelParserForTflite::ConvertGraphInputs() {
const auto &tfliteSubgraph = tfliteModel->subgraphs.front();
for (auto tfliteGraphInput : tfliteSubgraph->inputs) {
if (tfliteGraphInput < 0) {
tfliteGraphInput = tfliteGraphInput + tfliteSubgraph->tensors.size();
}
auto parameter = funcGraphPtr->add_parameter();
const auto &tensor = tfliteSubgraph->tensors.at(tfliteGraphInput);
std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
parameter->set_abstract(abstract_tensor);
parameter->set_name("graph_input_" + std::to_string(tfliteGraphInput) + "_parameter");
nodes.insert(std::pair(tfliteGraphInput, parameter));
}
return RET_OK;
}
STATUS ModelParserForTflite::ConvertGraphOutputs() {
const auto &tfliteSubgraph = tfliteModel->subgraphs.front();
if (tfliteSubgraph->outputs.size() > 1) {
std::vector<AnfNodePtr> make_tuple_inputs;
auto make_tuple_prim_ptr = GetMakeTuplePrim();
if (make_tuple_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr";
return RET_NULL_PTR;
}
auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
make_tuple_inputs.emplace_back(make_tuple_prim);
for (auto outputNode : tfliteSubgraph->outputs) {
auto cnode = nodes.at(outputNode);
if (nullptr == cnode) {
MS_LOG(ERROR) << "Can't find input node.";
return RET_NOT_FIND_OP;
}
make_tuple_inputs.emplace_back(cnode);
}
auto make_tuple_cnode = funcGraphPtr->NewCNode(make_tuple_inputs);
make_tuple_cnode->set_fullname_with_scope("return tuple");
std::vector<AnfNodePtr> op_inputs;
auto return_prim_ptr = GetReturnPrim();
if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
return RET_NULL_PTR;
}
auto value_node = NewValueNode(return_prim_ptr);
op_inputs.emplace_back(value_node);
op_inputs.emplace_back(make_tuple_cnode);
auto cnode = funcGraphPtr->NewCNode(op_inputs);
cnode->set_fullname_with_scope("return");
funcGraphPtr->set_return(cnode);
} else {
auto returnPrim = GetReturnPrim();
if (returnPrim == nullptr) {
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
return RET_NULL_PTR;
}
auto valueNode = NewValueNode(returnPrim);
std::vector<AnfNodePtr> opInputs{valueNode};
auto cnode = nodes.at(tfliteSubgraph->outputs.front());
if (nullptr == cnode) {
MS_LOG(ERROR) << "Can't find input node.";
return RET_NOT_FIND_OP;
}
opInputs.emplace_back(cnode);
auto returnCnode = funcGraphPtr->NewCNode(opInputs);
returnCnode->set_fullname_with_scope("return");
funcGraphPtr->set_return(returnCnode);
}
return RET_OK;
}
STATUS ModelParserForTflite::ConvertConstTensor(const tflite::TensorT *tensor, ParameterPtr parameter) {
parameter = funcGraphPtr->add_parameter();
const auto &tfliteModelBuffers = tfliteModel->buffers;
auto type_id = static_cast<TypeId>(tensor->type);
auto type_ptr = TypeIdToType(type_id);
std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
parameter->set_abstract(abstract_tensor);
parameter->set_name("const_" + std::to_string(nodes.size()) + "_parameter");
ParamValueLitePtr paramValue = std::make_shared<ParamValueLite>();
MS_ASSERT(paramValue != nullptr);
paramValue->set_tensor_shape(tensor->shape);
paramValue->set_tensor_type(GetTfliteDataType(tensor->type));
paramValue->set_format(schema::Format::Format_NHWC);
const auto &data = tfliteModelBuffers.at(tensor->buffer)->data;
if (!data.empty()) {
auto size = data.size();
char *tensor_data = new (std::nothrow) char[size];
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "new char[] failed";
return RET_MEMORY_FAILED;
}
std::memcpy(tensor_data, data.data(), size);
paramValue->set_tensor_addr(tensor_data);
paramValue->set_tensor_size(size);
parameter->set_default_param(paramValue);
}
return RET_OK;
}
STATUS ModelParserForTflite::ConvertOutputTensor(const tflite::OperatorT *op, CNodePtr dstCNode) {
MS_ASSERT(op != nullptr);
MS_ASSERT(dstCNode != nullptr);
const auto &tfliteSubgraph = tfliteModel->subgraphs.front();
if (op->outputs.size() == 1) {
const auto &tensor = tfliteSubgraph->tensors.at(op->outputs.front());
std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto typePtr = TypeIdToType(GetTfliteDataType(tensor->type));
dstCNode->set_abstract(std::make_shared<abstract::AbstractTensor>(typePtr, shape_vector));
nodes.insert(std::pair(op->outputs.front(), dstCNode));
} else {
AbstractBasePtrList abstractList;
for (auto outputIdx : op->outputs) {
const auto &tensor = tfliteSubgraph->tensors.at(outputIdx);
std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto typePtr = TypeIdToType(GetTfliteDataType(tensor->type));
abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(typePtr, shape_vector));
auto tupleGetItemPrimPtr = GetTupleGetItemPrim();
if (tupleGetItemPrimPtr == nullptr) {
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
return RET_NULL_PTR;
}
auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr);
auto getItemValue = NewValueNode(MakeValue<int>(outputIdx));
std::vector<AnfNodePtr> inputs{tupleGetItemPrim, dstCNode, getItemValue};
CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs);
getItemCNode->set_fullname_with_scope(dstCNode->fullname_with_scope() + "_getitem_" + std::to_string(outputIdx));
nodes.insert(std::pair(outputIdx, getItemCNode));
}
dstCNode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList));
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -1,44 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MODEL_PARSER_FOR_TFLITE_H
#define LITE_MODEL_PARSER_FOR_TFLITE_H
#include <string>
#include <unordered_map>
#include <memory>
#include "tools/converter/parser/tflite/tflite_model_parser.h"
namespace mindspore::lite {
class ModelParserForTflite : public TfliteModelParser {
public:
ModelParserForTflite() = default;
~ModelParserForTflite() override = default;
FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) override;
private:
std::unordered_map<int, AnfNodePtr> nodes;
std::unique_ptr<tflite::ModelT> tfliteModel;
FuncGraphPtr funcGraphPtr;
STATUS ConvertConstTensor(const tflite::TensorT *tensor, ParameterPtr parameter);
STATUS ConvertOutputTensor(const tflite::OperatorT *op, CNodePtr dstCNode);
STATUS ConvertOps();
STATUS ConvertGraphInputs();
STATUS ConvertGraphOutputs();
};
} // namespace mindspore::lite
#endif // LITE_MODEL_PARSER_FOR_TFLITE_H

View File

@ -18,9 +18,11 @@
#include <memory>
#include <vector>
#include <string>
#include "src/ops/activation.h"
#include "src/ops/primitive_c.h"
#include "tools/converter/parser/tflite/tflite_util.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
@ -86,12 +88,40 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
return RET_OK;
}
TfliteNodeRegister g_tfliteReluParser("Relu", new TfliteActivationParser());
TfliteNodeRegister g_tfliteRelu6Parser("Relu6", new TfliteActivationParser());
TfliteNodeRegister g_tfliteTanhParser("Tanh", new TfliteActivationParser());
TfliteNodeRegister g_tfliteSwishParser("Swish", new TfliteActivationParser());
TfliteNodeRegister g_tfliteHardSwishParser("HardSwish", new TfliteActivationParser());
lite::PrimitiveC *TfliteActivationParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
auto ms_op_type = GetMSOpType(tflite_op_type);
if (kActivationTypeMap.find(ms_op_type) == kActivationTypeMap.end()) {
MS_LOG(ERROR) << ms_op_type << "is a not supported activation type";
return nullptr;
}
attr->type = kActivationTypeMap.find(GetMSOpType(tflite_op_type))->second;
if (attr->type == schema::ActivationType_LEAKY_RELU) {
const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << GetMSOpType(tflite_op_type) << " attr failed";
return nullptr;
}
attr->alpha = tflite_attr->alpha;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_Activation;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_TfliteReluParser("ReLU", new TfliteActivationParser());
TfliteNodeRegister g_TfliteRelu6Parser("ReLU6", new TfliteActivationParser());
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteActivationParser());
TfliteNodeRegister g_TfliteSwishParser("Swish", new TfliteActivationParser());
TfliteNodeRegister g_TfliteHardSwishParser("HSwish", new TfliteActivationParser());
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteActivationParser());
TfliteNodeRegister g_tfliteLeakyReluParser("LeakyRelu", new TfliteActivationParser());
} // namespace lite
} // namespace mindspore
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteActivationParser());
} // namespace mindspore::lite

View File

@ -23,8 +23,7 @@
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TfliteActivationParser : public TfliteNodeParser {
public:
TfliteActivationParser() : TfliteNodeParser("node_name") {}
@ -32,9 +31,10 @@ class TfliteActivationParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H

View File

@ -18,9 +18,10 @@
#include "tools/converter/parser/tflite/tflite_addn_parser.h"
#include <vector>
#include <memory>
#include <map>
#include "src/ops/addn.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
@ -55,7 +56,18 @@ STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
lite::PrimitiveC *TfliteAddNParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto attr = std::make_unique<schema::AddNT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_AddN;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteAddNParser("AddN", new TfliteAddNParser());
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite

View File

@ -32,6 +32,9 @@ class TfliteAddNParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -76,6 +76,39 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteArgmaxParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
const auto &tflite_subgraph = tflite_model->subgraphs.front();
std::unique_ptr<schema::ArgMaxT> attr = std::make_unique<schema::ArgMaxT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
attr->outMaxValue = false;
attr->topK = 1;
attr->keepDims = false;
attr->axisType = 1;
// get axis attr
auto axis_idx = tflite_op->inputs[1];
auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer;
auto &buf_data = tflite_model->buffers[buffer_idx];
if (buf_data == nullptr) {
MS_LOG(ERROR) << "the buf data is null";
return nullptr;
}
auto data_ptr = buf_data->data.data();
if (data_ptr == nullptr) {
MS_LOG(ERROR) << "the data is null";
return nullptr;
}
attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr)));
auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_ArgMax;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteArgmaxParser("Argmax", new TfliteArgmaxParser());
} // namespace lite

View File

@ -32,6 +32,9 @@ class TfliteArgmaxParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -76,6 +76,39 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteArgminParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
const auto &tflite_subgraph = tflite_model->subgraphs.front();
std::unique_ptr<schema::ArgMinT> attr = std::make_unique<schema::ArgMinT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
attr->outMaxValue = false;
attr->topK = 1;
attr->keepDims = false;
attr->axisType = 1;
// get axis attr
auto axis_idx = tflite_op->inputs[1];
auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer;
auto &buf_data = tflite_model->buffers[buffer_idx];
if (buf_data == nullptr) {
MS_LOG(ERROR) << "the buf data is null";
return nullptr;
}
auto data_ptr = buf_data->data.data();
if (data_ptr == nullptr) {
MS_LOG(ERROR) << "the data is null";
return nullptr;
}
attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr)));
auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_ArgMin;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteArgminParser("Argmin", new TfliteArgminParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteArgminParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -179,6 +179,133 @@ STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info,
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteDoubleInputOpParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (tflite_op_type == tflite::BuiltinOperator_ADD) {
MS_LOG(DEBUG) << "parse TfliteAddParser";
auto attr = std::make_unique<schema::AddT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tfliteAttr = tflite_op->builtin_options.AsAddOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << tflite_op_type << " attr failed";
return nullptr;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
primitive->value.type = schema::PrimitiveType_Add;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_SUB) {
MS_LOG(DEBUG) << "parse TfliteSubParser";
auto attr = std::make_unique<schema::SubT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tfliteAttr = tflite_op->builtin_options.AsSubOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << tflite_op_type << " attr failed";
return nullptr;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
primitive->value.type = schema::PrimitiveType_Sub;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_MUL) {
MS_LOG(DEBUG) << "parse TfliteMulParser";
auto attr = std::make_unique<schema::MulT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tfliteAttr = tflite_op->builtin_options.AsMulOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << tflite_op_type << " attr failed";
return nullptr;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
primitive->value.type = schema::PrimitiveType_Mul;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_DIV) {
MS_LOG(DEBUG) << "parse TfliteDivParser";
auto attr = std::make_unique<schema::DivT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tfliteAttr = tflite_op->builtin_options.AsDivOptions();
if (nullptr == tfliteAttr) {
MS_LOG(ERROR) << "get op: " << tflite_op_type << " attr failed";
return nullptr;
}
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
primitive->value.type = schema::PrimitiveType_Div;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_FLOOR_DIV) {
MS_LOG(DEBUG) << "parse TfliteFloorDivParser";
std::unique_ptr<schema::FloorDivT> attr = std::make_unique<schema::FloorDivT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_FloorDiv;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_FLOOR_MOD) {
MS_LOG(DEBUG) << "parse TfliteFloorModParser";
auto attr = std::make_unique<schema::FloorModT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_FloorMod;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_SQUARED_DIFFERENCE) {
MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser";
auto attr = std::make_unique<schema::SquaredDifferenceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_SquaredDifference;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_POW) {
MS_LOG(DEBUG) << "parse TflitePowParser";
auto attr = std::make_unique<schema::PowerT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
attr->power = 1.0f;
attr->scale = 1.0f;
attr->shift = 0.0f;
primitive->value.type = schema::PrimitiveType_Power;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_MAXIMUM) {
MS_LOG(DEBUG) << "parse TfliteMaximumParser";
auto attr = std::make_unique<schema::MaximumT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Maximum;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_MINIMUM) {
MS_LOG(DEBUG) << "parse TfliteMinimumParser";
auto attr = std::make_unique<schema::MinimumT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Minimum;
primitive->value.value = attr.release();
} else {
MS_LOG(ERROR) << "op hasn't been supported";
return nullptr;
}
return PrimitiveC::Create(primitive.release());
}
STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
@ -320,6 +447,124 @@ STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info,
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteSingleInputOpParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (tflite_op_type == tflite::BuiltinOperator_ABS) {
MS_LOG(DEBUG) << "parse TfliteAbsParser";
auto attr = std::make_unique<schema::AbsT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Abs;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_EXP) {
MS_LOG(DEBUG) << "parse TfliteExpParser";
auto attr = std::make_unique<schema::ExpT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
attr->base = -1; // -1 represent base = e
attr->scale = 1;
attr->shift = 0;
primitive->value.type = schema::PrimitiveType_Exp;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_SQRT) {
MS_LOG(DEBUG) << "parse TfliteSqrtParser";
auto attr = std::make_unique<schema::SqrtT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Sqrt;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_RSQRT) {
MS_LOG(DEBUG) << "parse TfliteRsqrtParser";
auto attr = std::make_unique<schema::RsqrtT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Rsqrt;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_SQUARE) {
MS_LOG(DEBUG) << "parse TfliteSquareParser";
auto attr = std::make_unique<schema::SquareT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Square;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_SIN) {
MS_LOG(DEBUG) << "parse TfliteSinParser";
auto attr = std::make_unique<schema::SinT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Sin;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_COS) {
MS_LOG(DEBUG) << "parse TfliteCosParser";
std::unique_ptr<schema::CosT> attr = std::make_unique<schema::CosT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Cos;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_LOG) {
MS_LOG(DEBUG) << "parse TfliteLogParser";
auto attr = std::make_unique<schema::LogT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Log;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_ROUND) {
MS_LOG(DEBUG) << "parse TfliteRoundParser";
auto attr = std::make_unique<schema::RoundT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Round;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_CEIL) {
MS_LOG(DEBUG) << "parse TfliteCeilParser";
auto attr = std::make_unique<schema::CeilT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Ceil;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_FLOOR) {
MS_LOG(DEBUG) << "parse TfliteFloorParser";
auto attr = std::make_unique<schema::FloorT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Floor;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_NEG) {
MS_LOG(DEBUG) << "parse TfliteNegParser";
auto attr = std::make_unique<schema::NegT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Neg;
primitive->value.value = attr.release();
}
return PrimitiveC::Create(primitive.release());
}
STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
@ -406,29 +651,91 @@ STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info,
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteCompareOpParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
auto primitive = std::make_unique<schema::PrimitiveT>();
if (tflite_op_type == tflite::BuiltinOperator_EQUAL) {
MS_LOG(DEBUG) << "parse TfliteEqualParser";
std::unique_ptr<schema::EqualT> attr = std::make_unique<schema::EqualT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Equal;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_NOT_EQUAL) {
MS_LOG(DEBUG) << "parse TfliteNotEqualParser";
std::unique_ptr<schema::NotEqualT> attr = std::make_unique<schema::NotEqualT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_NotEqual;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_GREATER) {
MS_LOG(DEBUG) << "parse TfliteGreaterParser";
std::unique_ptr<schema::GreaterT> attr = std::make_unique<schema::GreaterT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Greater;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_GREATER_EQUAL) {
MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser";
std::unique_ptr<schema::GreaterEqualT> attr = std::make_unique<schema::GreaterEqualT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_GreaterEqual;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_LESS) {
MS_LOG(DEBUG) << "parse TfliteLessParser";
std::unique_ptr<schema::LessT> attr = std::make_unique<schema::LessT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Less;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_LESS_EQUAL) {
MS_LOG(DEBUG) << "parse TfliteLessEqualParser";
std::unique_ptr<schema::LessEqualT> attr = std::make_unique<schema::LessEqualT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_LessEqual;
primitive->value.value = attr.release();
}
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteAddParser("Add", new TfliteDoubleInputOpParser());
TfliteNodeRegister g_tfliteSubParser("Sub", new TfliteDoubleInputOpParser());
TfliteNodeRegister g_tfliteMulParser("Mul", new TfliteDoubleInputOpParser());
TfliteNodeRegister g_tfliteDivParser("Div", new TfliteDoubleInputOpParser());
TfliteNodeRegister g_TfliteMulParser("Mul", new TfliteDoubleInputOpParser());
TfliteNodeRegister g_TfliteDivParser("Div", new TfliteDoubleInputOpParser());
TfliteNodeRegister g_tfliteFloorDivParser("FloorDiv", new TfliteDoubleInputOpParser());
TfliteNodeRegister g_tfliteFloorModParser("FloorMod", new TfliteDoubleInputOpParser());
TfliteNodeRegister g_tfliteRealDivParser("RealDiv", new TfliteDoubleInputOpParser());
TfliteNodeRegister g_tflitePowParser("Pow", new TfliteDoubleInputOpParser());
TfliteNodeRegister g_TflitePowParser("Pow", new TfliteDoubleInputOpParser());
TfliteNodeRegister g_tfliteSquaredDifferenceParser("SquaredDifference", new TfliteDoubleInputOpParser());
TfliteNodeRegister g_tfliteMaximumParser("Maximum", new TfliteDoubleInputOpParser());
TfliteNodeRegister g_tfliteMinimumParser("Minimum", new TfliteDoubleInputOpParser());
TfliteNodeRegister g_TfliteMaximumParser("Maximum", new TfliteDoubleInputOpParser());
TfliteNodeRegister g_TfliteMinimumParser("Minimum", new TfliteDoubleInputOpParser());
TfliteNodeRegister g_tfliteAbsParser("Abs", new TfliteSingleInputOpParser());
TfliteNodeRegister g_tfliteExpParser("Exp", new TfliteSingleInputOpParser());
TfliteNodeRegister g_tfliteSqrtParser("Sqrt", new TfliteSingleInputOpParser());
TfliteNodeRegister g_TfliteAbsParser("Abs", new TfliteSingleInputOpParser());
TfliteNodeRegister g_TfliteExpParser("Exp", new TfliteSingleInputOpParser());
TfliteNodeRegister g_TfliteSqrtParser("Sqrt", new TfliteSingleInputOpParser());
TfliteNodeRegister g_tfliteRsqrtParser("Rsqrt", new TfliteSingleInputOpParser());
TfliteNodeRegister g_tfliteSquareParser("Square", new TfliteSingleInputOpParser());
TfliteNodeRegister g_tfliteSinParser("Sin", new TfliteSingleInputOpParser());
TfliteNodeRegister g_tfliteCosParser("Cos", new TfliteSingleInputOpParser());
TfliteNodeRegister g_tfliteLogParser("Log", new TfliteSingleInputOpParser());
TfliteNodeRegister g_TfliteSquareParser("Square", new TfliteSingleInputOpParser());
TfliteNodeRegister g_TfliteSinParser("Sin", new TfliteSingleInputOpParser());
TfliteNodeRegister g_TfliteCosParser("Cos", new TfliteSingleInputOpParser());
TfliteNodeRegister g_TfliteLogParser("Log", new TfliteSingleInputOpParser());
TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteSingleInputOpParser());
TfliteNodeRegister g_tfliteCeilParser("Ceil", new TfliteSingleInputOpParser());
TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteSingleInputOpParser());
TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteSingleInputOpParser());
TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteSingleInputOpParser());

View File

@ -32,6 +32,9 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
class TfliteSingleInputOpParser : public TfliteNodeParser {
@ -41,6 +44,9 @@ class TfliteSingleInputOpParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
class TfliteCompareOpParser : public TfliteNodeParser {
@ -50,7 +56,11 @@ class TfliteCompareOpParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -74,6 +74,29 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info,
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteBatchToSpaceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
const auto &tflite_subgraph = tflite_model->subgraphs.front();
auto primitive = std::make_unique<schema::PrimitiveT>();
std::unique_ptr<schema::BatchToSpaceT> attr = std::make_unique<schema::BatchToSpaceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->blockShape)) {
MS_LOG(ERROR) << "get batchToSpace -> blockShape failed";
return nullptr;
}
if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->crops)) {
MS_LOG(ERROR) << "get batchToSpace -> crops failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_BatchToSpace;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteBatchToSpaceParser("BatchToSpace", new TfliteBatchToSpaceParser());
TfliteNodeRegister g_tfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceParser());

View File

@ -32,7 +32,10 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -57,6 +57,28 @@ STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info,
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteBroadcastToParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return nullptr;
}
std::unique_ptr<schema::BroadcastToT> attr = std::make_unique<schema::BroadcastToT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dst_shape)) {
MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_BroadcastTo;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteBroadcastToParser("BroadcastTo", new TfliteBroadcastToParser());
} // namespace lite

View File

@ -23,8 +23,7 @@
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TfliteBroadcastToParser : public TfliteNodeParser {
public:
TfliteBroadcastToParser() : TfliteNodeParser("BroadcastTo") {}
@ -32,8 +31,10 @@ class TfliteBroadcastToParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H

View File

@ -63,6 +63,32 @@ STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteCastParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
auto primitive = std::make_unique<schema::PrimitiveT>();
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
if (in_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null";
return nullptr;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]];
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null";
return nullptr;
}
attr->dstT = GetTfliteDataType(out_tensor->type);
primitive->value.type = schema::PrimitiveType_Cast;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteCastParser("Cast", new TfliteCastParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteCastParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -60,6 +60,26 @@ STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteConcatParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
std::unique_ptr<schema::ConcatT> attr = std::make_unique<schema::ConcatT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tfliteAttr = tflite_op->builtin_options.AsConcatenationOptions();
if (tfliteAttr == nullptr) {
MS_LOG(ERROR) << "get op concat attr failed";
return nullptr;
}
attr->axis = tfliteAttr->axis;
attr->n = tflite_op->inputs.size();
primitive->value.type = schema::PrimitiveType_Concat;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteConcatParser("Concat", new TfliteConcatParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteConcatParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -18,8 +18,7 @@
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
namespace mindspore::lite {
STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
@ -74,7 +73,7 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
// calculate pad params
auto data_index = tflite_op->inputs[0];
const auto &data_tensor = tflite_subgraph->tensors[data_index];
std::vector<int> params;
std::vector<int64_t> params;
int status =
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, &params);
if (status != RET_OK && status != RET_NO_CHANGE) {
@ -96,7 +95,63 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
lite::PrimitiveC *TfliteConvParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
const auto &tflite_subgraph = tflite_model->subgraphs.front();
std::unique_ptr<schema::Conv2DT> attr = std::make_unique<schema::Conv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tflite_attr = tflite_op->builtin_options.AsConv2DOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get conv attr failed";
return nullptr;
}
attr->group = 1;
attr->strideW = tflite_attr->stride_w;
attr->strideH = tflite_attr->stride_h;
attr->dilateH = tflite_attr->dilation_h_factor;
attr->dilateW = tflite_attr->dilation_w_factor;
attr->padMode = GetPadMode(tflite_attr->padding);
attr->format = schema::Format::Format_NHWC;
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
attr->hasBias = true;
// get the conv op weight tensor
auto weight_index = tflite_op->inputs[1];
const auto &weight_tensor = tflite_subgraph->tensors[weight_index];
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "the weight tensor is null";
return nullptr;
}
auto weight_shape = weight_tensor->shape;
attr->channelIn = weight_shape[3];
attr->channelOut = weight_shape[0];
attr->kernelH = weight_shape[1];
attr->kernelW = weight_shape[2];
// calculate pad params
auto data_index = tflite_op->inputs[0];
const auto &data_tensor = tflite_subgraph->tensors[data_index];
std::vector<int64_t> params;
int status =
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, &params);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "get padding params failed";
return nullptr;
} else if (status == RET_OK) {
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);
attr->padRight = params.at(3);
}
auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_Conv2D;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteConv2DParser("Conv2D", new TfliteConvParser());
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite

View File

@ -23,8 +23,7 @@
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TfliteConvParser : public TfliteNodeParser {
public:
TfliteConvParser() : TfliteNodeParser("Conv2D") {}
@ -32,8 +31,9 @@ class TfliteConvParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H

View File

@ -15,9 +15,8 @@
*/
#include "tools/converter/parser/tflite/tflite_converter.h"
#include "tools/converter/parser/tflite/tflite_model_parser.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
TfliteConverter::TfliteConverter() { modelParser = new TfliteModelParser(); }
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite

View File

@ -21,18 +21,15 @@
#include <memory>
#include <map>
#include "tools/converter/converter.h"
#include "tools/converter/parser/tflite/tflite_model_parser.h"
#include "tools/converter/graphdef_transform.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TfliteConverter : public Converter {
public:
TfliteConverter();
~TfliteConverter() override = default;
};
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_

View File

@ -271,6 +271,48 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
}
return status;
}
PrimitiveC *TfliteCustomParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
auto op = new schema::CNodeT;
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return nullptr;
}
const auto &custom_attr = tflite_op->custom_options;
const auto &opcode_index = tflite_op->opcode_index;
const auto &custom_type = tflite_model->operator_codes[opcode_index]->custom_code;
int status = RET_OK;
if (custom_type == "TFLite_Detection_PostProcess") {
status = DetectPostProcess(custom_attr, op, tflite_op);
} else if (custom_type == "Predict") {
status = Predict(custom_attr, op, tflite_op);
} else if (custom_type == "Normalize") {
status = Normalize(custom_attr, op, tflite_op);
} else if (custom_type == "ExtractFeatures") {
status = ExtractFeatures(custom_attr, op, tflite_op);
} else if (custom_type == "AudioSpectrogram") {
status = AudioSpectrogram(custom_attr, op, tflite_op);
} else if (custom_type == "Mfcc") {
status = Mfcc(custom_attr, op, tflite_op);
} else if (custom_type == "FlexRFFT") {
status = Rfft(custom_attr, op, tflite_op, tflite_model, tflite_subgraph);
} else if (custom_type == "FlexReal") {
status = FftReal(custom_attr, op, tflite_op);
} else if (custom_type == "FlexImag") {
status = FftImag(custom_attr, op, tflite_op);
} else {
MS_LOG(ERROR) << "the custom op hasn't been supported now";
status = RET_NOT_FIND_OP;
}
if (status != RET_OK) {
return nullptr;
}
auto primitive = op->primitive.release();
delete op;
return PrimitiveC::Create(primitive);
}
TfliteNodeRegister g_tfliteCustomParser("Custom", new TfliteCustomParser());
} // namespace lite

View File

@ -31,6 +31,8 @@ class TfliteCustomParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
static STATUS DetectPostProcess(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op);

View File

@ -75,7 +75,7 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
// calculate pad params
auto data_index = tflite_op->inputs[2];
const auto &data_tensor = tflite_subgraph->tensors[data_index];
std::vector<int> params;
std::vector<int64_t> params;
int status =
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, &params);
if (status != RET_OK && status != RET_NO_CHANGE) {
@ -96,6 +96,64 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteDeConvParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto &tflite_subgraph = tflite_model->subgraphs.front();
std::unique_ptr<schema::DeConv2DT> attr = std::make_unique<schema::DeConv2DT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op deconv attr failed";
return nullptr;
}
attr->group = 1;
attr->strideW = tflite_attr->stride_w;
attr->strideH = tflite_attr->stride_h;
attr->dilateH = 1;
attr->dilateW = 1;
attr->padMode = GetPadMode(tflite_attr->padding);
attr->format = schema::Format::Format_NHWC;
attr->activationType = schema::ActivationType_NO_ACTIVATION;
attr->hasBias = true;
// get the conv op weight tensor
auto weight_index = tflite_op->inputs[1];
const auto &weight_tensor = tflite_subgraph->tensors[weight_index];
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "the weight tensor is null";
return nullptr;
}
auto weight_shape = weight_tensor->shape;
attr->channelIn = weight_shape[3];
attr->channelOut = weight_shape[0];
attr->kernelH = weight_shape[1];
attr->kernelW = weight_shape[2];
// calculate pad params
auto data_index = tflite_op->inputs[2];
const auto &data_tensor = tflite_subgraph->tensors[data_index];
std::vector<int64_t> params;
int status =
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, &params);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "get padding params failed";
return nullptr;
} else if (status == RET_OK) {
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);
attr->padRight = params.at(3);
}
primitive->value.type = schema::PrimitiveType_DeConv2D;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteDeConv2DParser("DeConv2D", new TfliteDeConvParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteDeConvParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -60,6 +60,26 @@ STATUS TfliteDepthToSpaceParser::Parse(TfliteTensorsInfo *tensors_info,
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteDepthToSpaceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
std::unique_ptr<schema::DepthToSpaceT> attr = std::make_unique<schema::DepthToSpaceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tflite_attr = tflite_op->builtin_options.AsDepthToSpaceOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op depthtospace attr failed";
return nullptr;
}
attr->blockSize = tflite_attr->block_size;
attr->format = schema::Format::Format_NHWC;
auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_Concat;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteDepthToSpaceParser("DepthToSpace", new TfliteDepthToSpaceParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -18,8 +18,7 @@
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
namespace mindspore::lite {
STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
@ -82,7 +81,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info,
attr->kernelW = weight_shape[2];
// calculate pad params
std::vector<int> params;
std::vector<int64_t> params;
int status =
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, &params);
if (status != RET_OK && status != RET_NO_CHANGE) {
@ -104,7 +103,71 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info,
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
lite::PrimitiveC *TfliteDepthwiseConv2DParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser";
std::unique_ptr<schema::DepthwiseConv2DT> attr = std::make_unique<schema::DepthwiseConv2DT>();
const auto &tflite_subgraph = tflite_model->subgraphs.front();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op de attr failed";
return nullptr;
}
attr->strideW = tflite_attr->stride_w;
attr->strideH = tflite_attr->stride_h;
attr->dilateH = tflite_attr->dilation_h_factor;
attr->dilateW = tflite_attr->dilation_w_factor;
attr->padMode = GetPadMode(tflite_attr->padding);
attr->format = schema::Format::Format_NHWC;
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
attr->hasBias = true;
attr->channelMultiplier = tflite_attr->depth_multiplier;
// get the data tensor
auto data_index = tflite_op->inputs[1];
const auto &data_tensor = tflite_subgraph->tensors[data_index];
if (data_tensor == nullptr) {
MS_LOG(ERROR) << "the data tensor is null";
return nullptr;
}
auto data_shape = data_tensor->shape;
attr->channelIn = data_shape[3];
// get the weight tensor
auto weight_index = tflite_op->inputs[1];
const auto &weight_tensor = tflite_subgraph->tensors[weight_index];
if (weight_tensor == nullptr) {
MS_LOG(ERROR) << "the weight tensor is null";
return nullptr;
}
auto weight_shape = weight_tensor->shape;
attr->kernelH = weight_shape[1];
attr->kernelW = weight_shape[2];
// calculate pad params
std::vector<int64_t> params;
int status =
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, &params);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "get padding params failed";
return nullptr;
} else if (status == RET_OK) {
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);
attr->padRight = params.at(3);
}
auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteDepthwiseConv2DParser("DepthwiseConv2D", new TfliteDepthwiseConv2DParser());
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite

View File

@ -23,8 +23,7 @@
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TfliteDepthwiseConv2DParser : public TfliteNodeParser {
public:
TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {}
@ -32,8 +31,10 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H

View File

@ -75,6 +75,45 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info,
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteDequantizeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
auto primitive = std::make_unique<schema::PrimitiveT>();
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
if (in_tensor == nullptr) {
MS_LOG(ERROR) << "input tensor is null";
return nullptr;
}
const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]];
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "output tensor is null";
return nullptr;
}
if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) &&
(GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 ||
GetTfliteDataType(in_tensor->type) == kNumberTypeUInt8)) {
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
attr->dstT = GetTfliteDataType(out_tensor->type);
primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_QuantDTypeCast;
} else {
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
attr->dstT = GetTfliteDataType(out_tensor->type);
primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_Cast;
}
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser());
} // namespace lite

View File

@ -31,6 +31,9 @@ class TfliteDequantizeParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -56,6 +56,30 @@ STATUS TfliteExpandDimsParser::Parse(TfliteTensorsInfo *tensors_info,
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteExpandDimsParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return nullptr;
}
std::unique_ptr<schema::ExpandDimsT> attr = std::make_unique<schema::ExpandDimsT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
std::vector<int> dims;
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, dims)) {
MS_LOG(ERROR) << "get expand_dims -> dim failed";
return nullptr;
}
attr->dim = dims[0];
primitive->value.type = schema::PrimitiveType_ExpandDims;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser());
} // namespace lite
} // namespace mindspore

View File

@ -32,6 +32,8 @@ class TfliteExpandDimsParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -57,6 +57,32 @@ STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteFillParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return nullptr;
}
std::unique_ptr<schema::FillT> attr = std::make_unique<schema::FillT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
if (tflite_op->inputs.size() > 1) {
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dims)) {
MS_LOG(ERROR) << "get fill -> dims failed";
return nullptr;
}
}
primitive->value.type = schema::PrimitiveType_Fill;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteFillParser("Fill", new TfliteFillParser());
} // namespace lite

View File

@ -32,6 +32,9 @@ class TfliteFillParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -69,6 +69,37 @@ STATUS TfliteFullyConnectedParser::Parse(TfliteTensorsInfo *tensors_info,
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteFullyConnectedParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
std::unique_ptr<schema::FullConnectionT> attr = std::make_unique<schema::FullConnectionT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tflite_attr = tflite_op->builtin_options.AsFullyConnectedOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op fully connect attr failed";
return nullptr;
}
bool hasBias = tflite_op->inputs.size() > 2 && tflite_op->inputs[2] != -1;
attr->hasBias = hasBias;
attr->axis = 1;
attr->useAxis = false;
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
primitive->value.type = schema::PrimitiveType_FullConnection;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteFullyConnectedParser("FullyConnected", new TfliteFullyConnectedParser());
TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFullyConnectedParser());

View File

@ -32,7 +32,10 @@ class TfliteFullyConnectedParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -54,6 +54,26 @@ STATUS TfliteGatherNdParser::Parse(TfliteTensorsInfo *tensors_info, const std::u
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteGatherNdParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return nullptr;
}
std::unique_ptr<schema::GatherNdT> attr = std::make_unique<schema::GatherNdT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
attr->batchDims = 0;
primitive->value.type = schema::PrimitiveType_GatherNd;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteGatherNdParser("GatherND", new TfliteGatherNdParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteGatherNdParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -60,6 +60,32 @@ STATUS TfliteGatherParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteGatherParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return nullptr;
}
std::unique_ptr<schema::GatherT> attr = std::make_unique<schema::GatherT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tflite_attr = tflite_op->builtin_options.AsGatherOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op gather attr failed";
return nullptr;
}
attr->axis = tflite_attr->axis;
attr->batchDims = 0;
primitive->value.type = schema::PrimitiveType_Gather;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteGatherParser("Gather", new TfliteGatherParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteGatherParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -55,6 +55,24 @@ STATUS TfliteHashtableLookupParser::Parse(TfliteTensorsInfo *tensors_info,
}
return RET_OK;
}
PrimitiveC *TfliteHashtableLookupParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
std::unique_ptr<schema::HashtableLookupT> attr = std::make_unique<schema::HashtableLookupT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_HashtableLookup;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteHashtableLookupParser("HashtableLookup", new TfliteHashtableLookupParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteHashtableLookupParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -55,6 +55,27 @@ STATUS TfliteL2NormParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteL2NormParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
std::unique_ptr<schema::L2NormT> attr = std::make_unique<schema::L2NormT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tflite_attr = tflite_op->builtin_options.AsL2NormOptions();
attr->axis = {-1};
attr->epsilon = 1e-6f;
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_L2Norm;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteL2NormParser("L2_NORMALIZATION", new TfliteL2NormParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteL2NormParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -78,6 +78,44 @@ STATUS TfliteLogicalParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteLogicalParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
if (tflite_op_type == tflite::BuiltinOperator_LOGICAL_AND) {
MS_LOG(DEBUG) << "parse TfliteLogicalAndParser";
std::unique_ptr<schema::LogicalAndT> attr = std::make_unique<schema::LogicalAndT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_LogicalAnd;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_LOGICAL_NOT) {
MS_LOG(DEBUG) << "parse TfliteLogicalNotParser";
std::unique_ptr<schema::LogicalNotT> attr = std::make_unique<schema::LogicalNotT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_LogicalNot;
primitive->value.value = attr.release();
} else if (tflite_op_type == tflite::BuiltinOperator_LOGICAL_OR) {
MS_LOG(DEBUG) << "parse TfliteLogicalOrParser";
std::unique_ptr<schema::LogicalOrT> attr = std::make_unique<schema::LogicalOrT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_LogicalOr;
primitive->value.value = attr.release();
}
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteLogicalAndParser("LogicalAnd", new TfliteLogicalParser());
TfliteNodeRegister g_tfliteLogicalNotParser("LogicalNot", new TfliteLogicalParser());

View File

@ -23,8 +23,7 @@
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TfliteLogicalParser : public TfliteNodeParser {
public:
TfliteLogicalParser() : TfliteNodeParser("node_name") {}
@ -32,8 +31,9 @@ class TfliteLogicalParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LOGICAL_PARSER_H

View File

@ -60,6 +60,34 @@ STATUS TfliteLRNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteLRNParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
std::unique_ptr<schema::LocalResponseNormalizationT> attr = std::make_unique<schema::LocalResponseNormalizationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tflite_attr = tflite_op->builtin_options.AsLocalResponseNormalizationOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op LRN attr failed";
return nullptr;
}
attr->depth_radius = tflite_attr->radius;
attr->alpha = tflite_attr->alpha;
attr->beta = tflite_attr->beta;
attr->bias = tflite_attr->bias;
primitive->value.type = schema::PrimitiveType_LocalResponseNormalization;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteLRNParser("LocalResponseNorm", new TfliteLRNParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteLRNParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -64,6 +64,35 @@ STATUS TfliteLshProjectionParser::Parse(TfliteTensorsInfo *tensors_info,
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteLshProjectionParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
std::unique_ptr<schema::LshProjectionT> attr = std::make_unique<schema::LshProjectionT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tflite_attr = tflite_op->builtin_options.AsLSHProjectionOptions();
switch (tflite_attr->type) {
case tflite::LSHProjectionType_SPARSE:
attr->type = schema::LshProjectionType_SPARSE;
break;
case tflite::LSHProjectionType_DENSE:
attr->type = schema::LshProjectionType_DENSE;
break;
default:
attr->type = schema::LshProjectionType_UNKNOWN;
}
primitive->value.type = schema::PrimitiveType_LshProjection;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteLshProjectionParser("LshProjection", new TfliteLshProjectionParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteLshProjectionParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,79 +13,167 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_model_parser.h"
#include <utility>
#include <memory>
#include <string>
#include <vector>
#include <set>
#include "tools/common/graph_util.h"
#include "tools/common/storage.h"
#include "flatbuffers/flatbuffers.h"
#include <memory>
#include <algorithm>
#include <utility>
#include "src/param_value_lite.h"
#include "src/common/file_utils.h"
#include "tools/common/node_util.h"
namespace mindspore {
namespace lite {
TfliteModelParser::TfliteModelParser() = default;
TfliteModelParser::~TfliteModelParser() { delete[](this->tfliteModelBuf); }
namespace mindspore::lite {
std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *model_path) {
size_t size = 0;
tfliteModelBuf = ReadFile(model_path, &size);
if (tfliteModelBuf == nullptr) {
tflite_model_buf_ = ReadFile(model_path, &size);
if (tflite_model_buf_ == nullptr) {
MS_LOG(ERROR) << "the file buffer is nullptr";
return nullptr;
}
flatbuffers::Verifier verify((const uint8_t *)tfliteModelBuf, size);
flatbuffers::Verifier verify((const uint8_t *)tflite_model_buf_, size);
if (!tflite::VerifyModelBuffer(verify)) {
MS_LOG(ERROR) << "the buffer is invalid and fail to create graph";
return nullptr;
}
return tflite::UnPackModel(tfliteModelBuf);
return tflite::UnPackModel(tflite_model_buf_);
}
STATUS TfliteModelParser::CopyConstTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const tflite::TensorT *tflite_tensor, schema::TensorT *tensor) {
MS_ASSERT(tensor != nullptr);
MS_ASSERT(tflite_tensor != nullptr);
auto buffer_idx = tflite_tensor->buffer;
FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
// load graph
tflite_model_ = ReadTfliteModel(model_file.c_str());
if (tflite_model_ == nullptr) {
MS_LOG(ERROR) << "read tflite model failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
return nullptr;
}
const auto &buf = tflite_model_buffer[buffer_idx];
if (buf == nullptr) {
MS_LOG(ERROR) << "tensor is null";
if (tflite_model_->subgraphs.size() != 1) {
MS_LOG(ERROR) << "read tflite model subgraphs failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
return nullptr;
}
func_graph_ = std::make_shared<FuncGraph>();
auto status = ConvertGraphInputs();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph inputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
status = ConvertOps();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert ops failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
status = ConvertGraphOutputs();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph outputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
return func_graph_;
}
STATUS TfliteModelParser::ConvertOps() {
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
const auto &tflite_model_buffers = tflite_model_->buffers;
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
STATUS status = RET_OK;
int op_idx = 0;
for (auto &op : tflite_subgraph->operators) {
auto tfliteOpType = (tflite_model_->operator_codes[op->opcode_index])->builtin_code;
auto op_type = GetMSOpType(tfliteOpType);
auto op_name = op_type + "-" + std::to_string(op_idx);
op_idx++;
// parse primitive
auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
if (node_parser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(op_type);
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
continue;
}
if (status != RET_OK) {
continue;
}
auto primitiveC = node_parser->ParseLitePrimitive(op, tflite_model_);
if (primitiveC == nullptr) {
MS_LOG(ERROR) << "parse node " << op_type.c_str() << " parser failed";
continue;
}
status = ConvertOpQuantParams(op.get(), primitiveC);
if (status != RET_OK) {
MS_LOG(ERROR) << "convert " << op_name << " quant param failed.";
return status;
}
std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<lite::PrimitiveC>(primitiveC))};
// parse inputs
for (auto input_idx : op->inputs) {
if (input_idx < 0) {
input_idx += tflite_subgraph->tensors.size();
}
const auto &input_tensor = tflite_subgraph->tensors[input_idx];
if (nodes_.find(input_idx) != nodes_.end()) {
op_inputs.emplace_back(nodes_.at(input_idx));
continue;
}
// const tensor
if (!tflite_model_buffers.at(input_tensor->buffer)->data.empty()) {
auto parameter = func_graph_->add_parameter();
status = ConvertConstTensor(input_tensor.get(), parameter.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed.";
return status;
}
op_inputs.emplace_back(parameter);
nodes_.insert(std::pair(input_idx, parameter));
continue;
}
MS_LOG(WARNING) << "tensor " << input_idx << " is neither a node output nor a weight tensor.";
}
auto new_cnode = func_graph_->NewCNode(op_inputs);
new_cnode->set_fullname_with_scope(op_name);
// parse outputs
status = ConvertOutputTensor(op.get(), new_cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert output tensors for " << new_cnode->fullname_with_scope() << " failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return status;
}
}
return status;
}
STATUS TfliteModelParser::SetTensorQuantParam(const tflite::TensorT *tflite_tensor,
std::vector<QuantParamT> *quant_params) {
if (tflite_tensor == nullptr) {
MS_LOG(ERROR) << "tflite_tensor is null, set tensor quant params failed.";
return RET_NULL_PTR;
}
quant_params->clear();
if (!buf->data.empty()) {
auto data_size = buf->data.size();
tensor->data.resize(data_size);
if (memcpy_s(tensor->data.data(), data_size, buf->data.data(), data_size) != EOK) {
MS_LOG(ERROR) << "memcpy tensor data failed";
return RET_MEMORY_FAILED;
}
} else {
MS_LOG(ERROR) << "src tensor data is empty";
return RET_INPUT_TENSOR_ERROR;
if (tflite_tensor->quantization == nullptr ||
(tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() &&
tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty())) {
std::vector<schema::QuantParamT> notinited_quant_params(1);
*quant_params = notinited_quant_params;
return RET_OK;
}
return RET_OK;
}
void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::TensorT> &tflite_tensor,
schema::TensorT *tensor) {
MS_ASSERT(tensor != nullptr);
tensor->quantParams.clear();
if (tflite_tensor->quantization == nullptr) {
MS_LOG(ERROR) << "tflite_tensor->quantization is null";
return;
}
for (size_t i = 0; i < tflite_tensor->quantization->scale.size(); i++) {
std::unique_ptr<schema::QuantParamT> quant_param = std::make_unique<QuantParamT>();
if (quant_param == nullptr) {
MS_LOG(ERROR) << "quant_param is null";
return;
return RET_NULL_PTR;
}
if (!tflite_tensor->quantization->scale.empty()) {
@ -104,364 +192,219 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::Tensor
quant_param->max = tflite_tensor->quantization->max[i];
}
quant_param->inited = true;
tensor->quantParams.emplace_back(std::move(quant_param));
}
}
STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const QuantType &quant_type, schema::MetaGraphT *sub_graph) {
MS_ASSERT(tflite_model != nullptr);
MS_ASSERT(tflite_subgraph != nullptr);
MS_ASSERT(sub_graph != nullptr);
int idx = 0;
int status = RET_OK;
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
for (const auto &tflite_op : tflite_subgraph->operators) {
const auto opcode_index = tflite_op->opcode_index;
const auto &operator_code = tflite_model->operator_codes[opcode_index];
if (operator_code == nullptr) {
MS_LOG(ERROR) << "operator_code is null";
return RET_ERROR;
}
auto op_type = GetMSOpType(operator_code->builtin_code);
auto op = std::make_unique<schema::CNodeT>();
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->name = op_type + "-" + std::to_string(idx++);
op->quantType = quant_type;
MS_LOG(INFO) << "parse op: " << op->name.c_str();
auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
if (node_parser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(op_type);
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
continue;
}
if (status == RET_OK || op_type == "Custom") {
int status_node = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, tflite_subgraph, op.get());
status = (status == RET_OK ? status_node : status);
if (status_node != RET_OK) {
if (status_node == RET_NOT_FIND_OP) {
op_type = (op_type != "Custom" ? op_type : operator_code->custom_code);
NoSupportOp::GetInstance()->InsertOp(op_type);
} else {
MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed";
}
continue;
}
if (status != RET_OK) {
continue;
}
sub_graph->nodes.emplace_back(op.release());
opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get();
tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get();
}
}
return status;
}
STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::MetaGraphT *sub_graph) {
MS_ASSERT(tflite_subgraph != nullptr);
MS_ASSERT(sub_graph != nullptr);
std::set<int> output_index;
for (const auto &tflite_op : tflite_subgraph->operators) {
for (int idx : tflite_op->outputs) {
if (idx < 0) {
idx += tflite_subgraph->tensors.size();
}
output_index.insert(idx);
}
}
for (size_t i = 0; i < tensorsInfo.tensorsId.size(); i++) {
auto idx = tensorsInfo.tensorsId[i];
if (idx < 0) {
idx += tflite_subgraph->tensors.size();
}
const auto &tflite_tensor = tflite_subgraph->tensors[idx];
if (tflite_tensor == nullptr) {
MS_LOG(ERROR) << "tflite_tensor is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::TensorT> tensor = std::make_unique<schema::TensorT>();
if (tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null";
return RET_NULL_PTR;
}
tensor->format = tensorsInfo.tensorsFormat[i];
tensor->dataType = GetTfliteDataType(tflite_tensor->type);
tensor->dims = tflite_tensor->shape;
// if graph input tensor
bool isInput = false;
auto tflite_inputs = tflite_subgraph->inputs;
for (int tflite_input : tflite_inputs) {
if (idx == tflite_input) {
isInput = true;
break;
}
}
// add data for const tensor
auto &tensor_buffer = tflite_model_buffer.at(tflite_tensor->buffer);
if (tensor_buffer == nullptr) {
MS_LOG(ERROR) << "tensor_buffer is null";
return RET_NULL_PTR;
}
auto isConst = (!tensor_buffer->data.empty());
if (isConst) {
int status = CopyConstTensorData(tflite_model_buffer, tflite_tensor.get(), tensor.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "obtain const tensor failed";
return status;
}
}
// set tensor attr
if (isInput || isConst) {
tensor->nodeType = schema::NodeType::NodeType_ValueNode;
} else {
if (output_index.find(idx) == output_index.end() && tflite_tensor->shape[0] == 0) {
tensor->nodeType = schema::NodeType::NodeType_ValueNode;
} else {
tensor->nodeType = schema::NodeType_Parameter;
}
}
// quant param
if (tflite_tensor->quantization != nullptr &&
!(tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() &&
tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty())) {
SetTensorQuantParam(tflite_tensor, tensor.get());
}
tensors.push_back(tensor.release());
}
for (auto iter : tensors) {
std::unique_ptr<schema::TensorT> temp(iter);
sub_graph->allTensors.emplace_back(move(temp));
quant_params->emplace_back(*std::move(quant_param));
}
return RET_OK;
}
STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
schema::MetaGraphT *sub_graph) {
MS_ASSERT(sub_graph != nullptr);
MS_ASSERT(tflite_subgraph != nullptr);
// graph input
std::vector<int> graph_inputs;
for (size_t i = 0; i < tflite_subgraph->inputs.size(); i++) {
const int idx = tflite_subgraph->inputs[i];
int id = idx < 0 ? idx + tflite_subgraph->tensors.size() : idx;
auto iter = tensorsInfo.tensorsIdMap.find(id);
if (iter != tensorsInfo.tensorsIdMap.end()) {
graph_inputs.push_back(iter->second);
} else {
MS_LOG(ERROR) << "get graph input failed";
return RET_INPUT_TENSOR_ERROR;
}
STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, lite::PrimitiveC *primitive_c) {
if (op == nullptr) {
MS_LOG(ERROR) << "tflite op is null, get quant params failed.";
return RET_NULL_PTR;
}
sub_graph->inputIndex.assign(graph_inputs.begin(), graph_inputs.end());
// graph output
std::vector<int> graph_outputs;
for (size_t i = 0; i < tflite_subgraph->outputs.size(); i++) {
const int idx = tflite_subgraph->outputs[i];
int id = idx < 0 ? idx + tflite_subgraph->tensors.size() : idx;
auto iter = tensorsInfo.tensorsIdMap.find(id);
if (iter != tensorsInfo.tensorsIdMap.end()) {
graph_outputs.push_back(iter->second);
} else {
MS_LOG(ERROR) << "get graph output failed";
return RET_INPUT_TENSOR_ERROR;
}
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is null, get quant params failed.";
return RET_NULL_PTR;
}
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
for (auto input_idx : op->inputs) {
if (input_idx < 0) {
input_idx += tflite_subgraph->tensors.size();
}
const auto &input_tensor = tflite_subgraph->tensors[input_idx];
std::vector<schema::QuantParamT> quant_params;
auto status = SetTensorQuantParam(input_tensor.get(), &quant_params);
if (status != RET_OK) {
MS_LOG(ERROR) << "set input tensor quant param failed.";
return status;
}
primitive_c->AddInputQuantParam(quant_params);
}
for (auto output_idx : op->outputs) {
if (output_idx < 0) {
output_idx += tflite_subgraph->tensors.size();
}
const auto &output_tensor = tflite_subgraph->tensors.at(output_idx);
std::vector<schema::QuantParamT> quant_params;
auto status = SetTensorQuantParam(output_tensor.get(), &quant_params);
if (status != RET_OK) {
MS_LOG(ERROR) << "set output tensor quant param failed.";
return status;
}
primitive_c->AddOutputQuantParam(quant_params);
}
sub_graph->outputIndex.assign(graph_outputs.begin(), graph_outputs.end());
return RET_OK;
}
STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph) {
MS_ASSERT(sub_graph != nullptr);
for (auto &op : sub_graph->nodes) {
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
STATUS TfliteModelParser::ConvertGraphInputs() {
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
for (auto tflite_graph_input : tflite_subgraph->inputs) {
if (tflite_graph_input < 0) {
tflite_graph_input = tflite_graph_input + tflite_subgraph->tensors.size();
}
auto parameter = func_graph_->add_parameter();
const auto &tensor = tflite_subgraph->tensors.at(tflite_graph_input);
std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
parameter->set_abstract(abstract_tensor);
parameter->set_name("graph_input_" + std::to_string(tflite_graph_input) + "_parameter");
nodes_.insert(std::pair(tflite_graph_input, parameter));
}
return RET_OK;
}
STATUS TfliteModelParser::ConvertGraphOutputs() {
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
if (tflite_subgraph->outputs.size() > 1) {
std::vector<AnfNodePtr> make_tuple_inputs;
auto make_tuple_prim_ptr = GetMakeTuplePrim();
if (make_tuple_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr";
return RET_NULL_PTR;
}
if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
auto attr = op->primitive->value.AsDepthwiseConv2D();
if (attr == nullptr) {
MS_LOG(ERROR) << "attr is null";
auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
make_tuple_inputs.emplace_back(make_tuple_prim);
for (auto outputNode : tflite_subgraph->outputs) {
auto cnode = nodes_.at(outputNode);
if (nullptr == cnode) {
MS_LOG(ERROR) << "Can't find input node.";
return RET_NOT_FIND_OP;
}
make_tuple_inputs.emplace_back(cnode);
}
auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs);
make_tuple_cnode->set_fullname_with_scope("return tuple");
std::vector<AnfNodePtr> op_inputs;
auto return_prim_ptr = GetReturnPrim();
if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
return RET_NULL_PTR;
}
auto value_node = NewValueNode(return_prim_ptr);
op_inputs.emplace_back(value_node);
op_inputs.emplace_back(make_tuple_cnode);
auto cnode = func_graph_->NewCNode(op_inputs);
cnode->set_fullname_with_scope("return");
func_graph_->set_return(cnode);
} else {
auto returnPrim = GetReturnPrim();
if (returnPrim == nullptr) {
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
return RET_NULL_PTR;
}
auto valueNode = NewValueNode(returnPrim);
std::vector<AnfNodePtr> op_inputs{valueNode};
auto cnode = nodes_.at(tflite_subgraph->outputs.front());
if (nullptr == cnode) {
MS_LOG(ERROR) << "Can't find input node.";
return RET_NOT_FIND_OP;
}
op_inputs.emplace_back(cnode);
auto returnCnode = func_graph_->NewCNode(op_inputs);
returnCnode->set_fullname_with_scope("return");
func_graph_->set_return(returnCnode);
}
return RET_OK;
}
STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter) {
if (tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null, get const tensor failed.";
return RET_NULL_PTR;
}
if (parameter == nullptr) {
MS_LOG(ERROR) << "parameter is null, get const tensor failed.";
return RET_NULL_PTR;
}
const auto &tfliteModelBuffers = tflite_model_->buffers;
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
parameter->set_abstract(abstract_tensor);
parameter->set_name("const_" + std::to_string(nodes_.size()) + "_parameter");
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
MS_ASSERT(param_value != nullptr);
param_value->set_tensor_shape(tensor->shape);
param_value->set_tensor_type(GetTfliteDataType(tensor->type));
param_value->set_format(schema::Format::Format_NHWC);
const auto &data = tfliteModelBuffers.at(tensor->buffer)->data;
if (!data.empty()) {
auto size = data.size();
char *tensor_data = new (std::nothrow) char[size];
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "new char[] failed";
return RET_MEMORY_FAILED;
}
std::memcpy(tensor_data, data.data(), size);
param_value->set_tensor_addr(tensor_data);
param_value->set_tensor_size(size);
parameter->set_default_param(param_value);
}
return RET_OK;
}
STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const CNodePtr &dst_cnode) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null, get output tensor failed.";
return RET_NULL_PTR;
}
if (dst_cnode == nullptr) {
MS_LOG(ERROR) << "parameter is null, get output tensor failed.";
return RET_NULL_PTR;
}
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
if (op->outputs.size() == 1) {
int output_idx =
op->outputs.front() < 0 ? tflite_subgraph->tensors.size() + op->outputs.front() : op->outputs.front();
const auto &tensor = tflite_subgraph->tensors.at(output_idx);
std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
dst_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
nodes_.insert(std::pair(op->outputs.front(), dst_cnode));
} else {
AbstractBasePtrList abstract_list;
int op_idx = 0;
for (auto output_idx : op->outputs) {
if (output_idx < 0) {
output_idx = output_idx + tflite_subgraph->tensors.size();
}
const auto &tensor = tflite_subgraph->tensors.at(output_idx);
std::vector<int64_t> shape_vector;
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
[](const int32_t &value) { return static_cast<int64_t>(value); });
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
auto tuple_get_item_prim_ptr = GetTupleGetItemPrim();
if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
return RET_NULL_PTR;
}
if (attr->channelMultiplier > 1) {
// get channel attr
if (op->inputIndex.empty()) {
MS_LOG(ERROR) << "the input of DepthwiseConv2D is null";
return RET_NULL_PTR;
}
const auto data_id = op->inputIndex[0];
if (sub_graph->allTensors.size() <= data_id) {
MS_LOG(ERROR) << "the number of allTensors is less than " << data_id;
return RET_ERROR;
}
const auto &data_tensor = sub_graph->allTensors.at(data_id);
if (data_tensor == nullptr) {
MS_LOG(ERROR) << "the data tensor is null";
return RET_NULL_PTR;
}
auto data_shape = data_tensor->dims;
if (data_shape.empty()) {
MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain only when running";
return RET_NO_CHANGE;
}
std::unique_ptr<schema::Conv2DT> conv_attr = std::make_unique<schema::Conv2DT>();
if (conv_attr == nullptr) {
MS_LOG(ERROR) << "conv_attr is null";
return RET_NULL_PTR;
}
if (data_shape[3] == 1) {
conv_attr->channelIn = data_shape[3];
conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier;
// update attr
conv_attr->group = 1;
conv_attr->format = attr->format;
conv_attr->kernelH = attr->kernelH;
conv_attr->kernelW = attr->kernelW;
conv_attr->strideH = attr->strideH;
conv_attr->strideW = attr->strideW;
conv_attr->padMode = attr->padMode;
conv_attr->padUp = attr->padUp;
conv_attr->padDown = attr->padDown;
conv_attr->padLeft = attr->padLeft;
conv_attr->padRight = attr->padRight;
conv_attr->dilateH = attr->dilateH;
conv_attr->dilateW = attr->dilateW;
conv_attr->hasBias = attr->hasBias;
conv_attr->activationType = attr->activationType;
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = conv_attr.release();
// update weight
auto weight_id = op->inputIndex[1];
auto &weight_tensor = sub_graph->allTensors.at(weight_id);
if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) {
auto status = TransFilterFormat<uint8_t>(weight_tensor.get(), kKHWC2CHWK);
if (status != RET_OK) {
MS_LOG(ERROR) << "Trans depthwiseConv Filter schema::Format failed.";
return RET_ERROR;
}
} else if (weight_tensor->dataType == kNumberTypeInt8) {
auto status = TransFilterFormat<int8_t>(weight_tensor.get(), kKHWC2CHWK);
if (status != RET_OK) {
MS_LOG(ERROR) << "Trans filter format failed.";
return RET_ERROR;
}
} else if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) {
auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK);
if (status != RET_OK) {
MS_LOG(ERROR) << "Trans filter format failed.";
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "The dataType of weight tensor is unsupported.";
return RET_ERROR;
}
weight_tensor->format = schema::Format::Format_CHWK;
}
}
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
auto get_item_value = NewValueNode(MakeValue<int>(op_idx));
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, dst_cnode, get_item_value};
CNodePtr get_item_cnode = func_graph_->NewCNode(inputs);
get_item_cnode->set_fullname_with_scope(dst_cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx));
nodes_.insert(std::pair(output_idx, get_item_cnode));
op_idx++;
}
dst_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
}
return RET_OK;
}
std::unique_ptr<schema::MetaGraphT> TfliteModelParser::ConstructMainGraph(
const std::unique_ptr<tflite::ModelT> &tflite_model, const QuantType &quant_type) {
MS_ASSERT(tflite_model != nullptr);
if (tflite_model->subgraphs.empty()) {
MS_LOG(ERROR) << "read tflite model main subgraphs failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
return nullptr;
}
const auto &tflite_subgraph = tflite_model->subgraphs[0];
auto meta_graph = std::make_unique<schema::MetaGraphT>();
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "new meta graph failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return nullptr;
}
meta_graph->name = "MS_model converted by TF-Lite";
quantType = quant_type;
// convert op
int status = ConvertOp(tflite_model, tflite_subgraph, quant_type, meta_graph.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "parse op failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
// convert tensor
status = ConvertTensor(tflite_subgraph, tflite_model->buffers, meta_graph.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "convert tensor failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
// set graph input/output
status = GetGraphInfo(tflite_subgraph, meta_graph.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "convert tensors failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
// update for depthwiseConv
status = ConvertGroupDepthwiseOp(meta_graph.get());
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "convert group depthwise conv failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
return meta_graph;
MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
return nullptr;
}
schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
if (model_file.empty()) {
MS_LOG(ERROR) << "model_file is empty";
return nullptr;
}
// load graph
auto tflite_model = ReadTfliteModel(model_file.c_str());
if (tflite_model == nullptr) {
MS_LOG(ERROR) << "read tflite model failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
return nullptr;
}
// construct main_meta_graph
auto main_meta_graph = ConstructMainGraph(tflite_model, quant_type);
if (main_meta_graph == nullptr) {
MS_LOG(ERROR) << "ConstructMainGraph failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
return nullptr;
}
return main_meta_graph.release();
}
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,67 +13,42 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_TFLITE_MODEL_PARSER_H
#define LITE_TFLITE_MODEL_PARSER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H
#include <fcntl.h>
#include <unistd.h>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include <string>
#include <vector>
#include <memory>
#include <map>
#include <unordered_map>
#include "securec/include/securec.h"
#include <memory>
#include <vector>
#include "tools/converter/model_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
#include "tools/common/tensor_util.h"
#include "schema/inner/model_generated.h"
namespace mindspore::lite {
class TfliteModelParser : public ModelParser {
public:
TfliteModelParser();
TfliteModelParser() = default;
~TfliteModelParser() override;
~TfliteModelParser() override = default;
schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quantTyp) override;
protected:
std::unique_ptr<tflite::ModelT> ReadTfliteModel(const char *model_path);
static STATUS CopyConstTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const tflite::TensorT *tflite_tensor, schema::TensorT *tensor);
static void SetTensorQuantParam(const std::unique_ptr<tflite::TensorT> &tflite_tensor, schema::TensorT *tensor);
STATUS ConvertOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, const QuantType &quant_type,
schema::MetaGraphT *sub_graph);
STATUS ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::MetaGraphT *sub_graph);
STATUS GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::MetaGraphT *sub_graph);
static STATUS ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph);
QuantType quantType = QuantType_QUANT_NONE;
char *tfliteModelBuf = nullptr;
std::unique_ptr<schema::MetaGraphT> ConstructMainGraph(const std::unique_ptr<tflite::ModelT> &tflite_model,
const QuantType &quant_type);
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) override;
MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) override;
private:
TfliteTensorsInfo tensorsInfo;
std::vector<schema::TensorT *> tensors;
std::map<std::string, schema::CNodeT *> opMap;
std::map<const tflite::OperatorT *, schema::CNodeT *> tfliteOpMap;
std::unordered_map<int, AnfNodePtr> nodes_;
std::unique_ptr<tflite::ModelT> tflite_model_;
FuncGraphPtr func_graph_;
char *tflite_model_buf_ = nullptr;
std::unique_ptr<tflite::ModelT> ReadTfliteModel(const char *model_path);
STATUS ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter);
STATUS ConvertOutputTensor(const tflite::OperatorT *op, const CNodePtr &dst_cnode);
STATUS ConvertOpQuantParams(const tflite::OperatorT *op, lite::PrimitiveC *primitive_c);
STATUS ConvertOps();
STATUS ConvertGraphInputs();
STATUS ConvertGraphOutputs();
STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params);
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H
#endif // LITE_TFLITE_MODEL_PARSER_H

View File

@ -41,9 +41,9 @@ class TfliteNodeParser {
virtual STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) = 0;
virtual STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, PrimitiveC *primitiveC) {
return RET_OK;
virtual lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
return nullptr;
}
static void AddOpInput(schema::CNodeT *op, TfliteTensorsInfo *tensors_info, int idx, int total,

View File

@ -64,6 +64,38 @@ STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteOneHotParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return nullptr;
}
std::unique_ptr<schema::OneHotT> attr = std::make_unique<schema::OneHotT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tflite_attr = tflite_op->builtin_options.AsOneHotOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op onehot attr failed";
return nullptr;
}
auto axis = tflite_attr->axis;
const auto &tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
if (tensor == nullptr) {
MS_LOG(ERROR) << "tensor is null";
return nullptr;
}
attr->axis = axis;
primitive->value.type = schema::PrimitiveType_OneHot;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteOneHotParser("OneHot", new TfliteOneHotParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteOneHotParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -90,6 +90,59 @@ STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TflitePadParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
std::unique_ptr<schema::PadT> attr = std::make_unique<schema::PadT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
if (tflite_op_type == tflite::BuiltinOperator_PAD) {
const auto &tflite_attr = tflite_op->builtin_options.AsPadOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op pad attr failed";
return nullptr;
}
attr->paddingMode = schema::PaddingMode_CONSTANT;
attr->constantValue = 0.0f;
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->paddings)) {
MS_LOG(ERROR) << "get pad -> paddings failed";
return nullptr;
}
} else if (tflite_op_type == tflite::BuiltinOperator_MIRROR_PAD) {
const auto &tflite_attr = tflite_op->builtin_options.AsMirrorPadOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op pad attr failed";
return nullptr;
}
switch (tflite_attr->mode) {
case tflite::MirrorPadMode_REFLECT:
attr->paddingMode = schema::PaddingMode_REFLECT;
break;
case tflite::MirrorPadMode_SYMMETRIC:
attr->paddingMode = schema::PaddingMode_SYMMETRIC;
break;
default:
MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support";
return nullptr;
}
} else {
MS_LOG(ERROR) << "this pad:" << tflite_op_type << " hasn't been supported";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Pad;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tflitePadParser("Pad", new TflitePadParser());
TfliteNodeRegister g_tfliteMirorPadParser("MirrorPad", new TflitePadParser());

View File

@ -32,6 +32,8 @@ class TflitePadParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,8 +19,7 @@
#include <memory>
#include <string>
namespace mindspore {
namespace lite {
namespace mindspore::lite {
STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
@ -43,17 +42,13 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
return RET_NULL_PTR;
}
std::vector<std::string> node_name_str;
Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "MeanPooling") == 0) {
MS_LOG(DEBUG) << "parser TfliteMeanPoolingParser";
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
if (tflite_op_type == tflite::BuiltinOperator_AVERAGE_POOL_2D) {
attr->poolingMode = schema::PoolMode_MEAN_POOLING;
} else if (std::strcmp(node_name, "MaxPooling") == 0) {
MS_LOG(DEBUG) << "parse TfliteMaxPoolingParser";
} else if (tflite_op_type == tflite::BuiltinOperator_MAX_POOL_2D) {
attr->poolingMode = schema::PoolMode_MAX_POOLING;
} else {
MS_LOG(ERROR) << node_name << " hasn't been supported";
MS_LOG(ERROR) << "pooling mode " << tflite_op_type << " hasn't been supported";
return RET_NOT_FIND_OP;
}
@ -75,7 +70,7 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
// calculate pad params
auto data_index = tflite_op->inputs[0];
const auto &data_tensor = tflite_subgraph->tensors[data_index];
std::vector<int> params;
std::vector<int64_t> params;
int status =
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, &params);
if (status != RET_OK && status != RET_NO_CHANGE) {
@ -95,8 +90,58 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
lite::PrimitiveC *TflitePoolingParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
const auto &tflite_subgraph = tflite_model->subgraphs.front();
std::unique_ptr<schema::PoolingT> attr = std::make_unique<schema::PoolingT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
if (tflite_op_type == tflite::BuiltinOperator_AVERAGE_POOL_2D) {
attr->poolingMode = schema::PoolMode_MEAN_POOLING;
} else if (tflite_op_type == tflite::BuiltinOperator_MAX_POOL_2D) {
attr->poolingMode = schema::PoolMode_MAX_POOLING;
}
const auto &tflite_attr = tflite_op->builtin_options.AsPool2DOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op pooling attr failed";
return nullptr;
}
attr->windowW = tflite_attr->filter_width;
attr->windowH = tflite_attr->filter_height;
attr->strideW = tflite_attr->stride_w;
attr->strideH = tflite_attr->stride_h;
attr->padMode = GetPadMode(tflite_attr->padding);
attr->format = schema::Format::Format_NHWC;
attr->global = false;
attr->roundMode = schema::RoundMode_FLOOR;
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
// calculate pad params
auto data_index = tflite_op->inputs[0];
const auto &data_tensor = tflite_subgraph->tensors[data_index];
std::vector<int64_t> params;
int status =
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, &params);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "get padding params failed";
return nullptr;
} else if (status == RET_OK) {
attr->padUp = params.at(0);
attr->padDown = params.at(1);
attr->padLeft = params.at(2);
attr->padRight = params.at(3);
}
auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_Pooling;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteMeanPoolingParser("MeanPooling", new TflitePoolingParser());
TfliteNodeRegister g_tfliteMaxPoolingParser("MaxPooling", new TflitePoolingParser());
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite

View File

@ -23,8 +23,7 @@
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TflitePoolingParser : public TfliteNodeParser {
public:
TflitePoolingParser() : TfliteNodeParser("node_name") {}
@ -32,8 +31,9 @@ class TflitePoolingParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_POOLING_PARSER_H

View File

@ -52,6 +52,24 @@ STATUS TflitePReLUParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TflitePReLUParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
std::unique_ptr<schema::PReLUT> attr = std::make_unique<schema::PReLUT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
attr->channelShared = true;
primitive->value.type = schema::PrimitiveType_PReLU;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tflitePReLUParser("PRELU", new TflitePReLUParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TflitePReLUParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -74,6 +74,50 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteQuantizeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
if (in_tensor == nullptr) {
MS_LOG(ERROR) << "input tensor is null";
return nullptr;
}
const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]];
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "output tensor is null";
return nullptr;
}
if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) &&
(GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 ||
GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8)) {
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
attr->dstT = GetTfliteDataType(out_tensor->type);
primitive->value.type = schema::PrimitiveType_QuantDTypeCast;
primitive->value.value = attr.release();
} else {
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
attr->dstT = GetTfliteDataType(out_tensor->type);
primitive->value.type = schema::PrimitiveType_Cast;
primitive->value.value = attr.release();
}
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteQuantizeParser("QUANTIZE", new TfliteQuantizeParser());
} // namespace lite

View File

@ -31,6 +31,8 @@ class TfliteQuantizeParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -71,6 +71,43 @@ STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteRangeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
std::unique_ptr<schema::RangeT> attr = std::make_unique<schema::RangeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
attr->dType = 0;
std::vector<int> limit;
std::vector<int> delta;
int status = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, limit);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "range -> limit get failed";
return nullptr;
} else if (status == RET_OK) {
status = GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, delta);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "stridedSlice -> end get failed";
return nullptr;
}
}
if (status == RET_OK) {
attr->limit = limit.front();
attr->delta = delta.front();
}
primitive->value.type = schema::PrimitiveType_Range;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteRangeParser("Range", new TfliteRangeParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteRangeParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -50,6 +50,24 @@ STATUS TfliteRankParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteRankParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
std::unique_ptr<schema::RankT> attr = std::make_unique<schema::RankT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Rank;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteRankParser("Rank", new TfliteRankParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteRankParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -85,11 +85,65 @@ STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteReduceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
TfliteNodeRegister g_tfliteSumParser("Sum", new TfliteReduceParser());
TfliteNodeRegister g_tfliteMeanParser("Mean", new TfliteReduceParser());
TfliteNodeRegister g_tfliteReduceMaxParser("ReduceMax", new TfliteReduceParser());
TfliteNodeRegister g_tfliteReduceMinParser("ReduceMin", new TfliteReduceParser());
TfliteNodeRegister g_tfliteReduceProdParser("ReduceProd", new TfliteReduceParser());
std::unique_ptr<schema::ReduceT> attr = std::make_unique<schema::ReduceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tflite_attr = tflite_op->builtin_options.AsReducerOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op reduce attr failed";
return nullptr;
}
attr->keepDims = tflite_attr->keep_dims;
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
if (tflite_op_type == tflite::BuiltinOperator_REDUCE_MAX) {
MS_LOG(DEBUG) << "parse TfliteReduceMaxParser";
attr->mode = schema::ReduceMode_ReduceMax;
} else if (tflite_op_type == tflite::BuiltinOperator_REDUCE_MIN) {
MS_LOG(DEBUG) << "parse TfliteReduceMinParser";
attr->mode = schema::ReduceMode_ReduceMin;
} else if (tflite_op_type == tflite::BuiltinOperator_REDUCE_PROD) {
MS_LOG(DEBUG) << "parse TfliteReduceProdParser";
attr->mode = schema::ReduceMode_ReduceProd;
} else if (tflite_op_type == tflite::BuiltinOperator_SUM) {
MS_LOG(DEBUG) << "parse TfliteSumParser";
attr->mode = schema::ReduceMode_ReduceSum;
} else if (tflite_op_type == tflite::BuiltinOperator_MEAN) {
MS_LOG(DEBUG) << "parse TfliteMeanParser";
attr->mode = schema::ReduceMode_ReduceMean;
} else if (tflite_op_type == tflite::BuiltinOperator_REDUCE_ANY) {
// attr->mode;
MS_LOG(ERROR) << "ms-lite haven't supported REDUCE_ANY now";
return nullptr;
}
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->axes)) {
MS_LOG(ERROR) << "get reduce -> axes failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Reduce;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_TfliteSumParser("Sum", new TfliteReduceParser());
TfliteNodeRegister g_TfliteMeanParser("Mean", new TfliteReduceParser());
TfliteNodeRegister g_TfliteReduceMaxParser("ReduceMax", new TfliteReduceParser());
TfliteNodeRegister g_TfliteReduceMinParser("ReduceMin", new TfliteReduceParser());
TfliteNodeRegister g_TfliteReduceProdParser("ReduceProd", new TfliteReduceParser());
TfliteNodeRegister g_TfliteReduceAnyParser("ReduceAny", new TfliteReduceParser());
} // namespace lite
} // namespace mindspore

View File

@ -23,8 +23,7 @@
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TfliteReduceParser : public TfliteNodeParser {
public:
TfliteReduceParser() : TfliteNodeParser("node_name") {}
@ -32,8 +31,9 @@ class TfliteReduceParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REDUCE_PARSER_H

View File

@ -18,8 +18,7 @@
#include <vector>
#include <memory>
namespace mindspore {
namespace lite {
namespace mindspore::lite {
STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
@ -43,8 +42,8 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
return RET_NULL_PTR;
}
const auto &tfliteAttr = tflite_op->builtin_options.AsReshapeOptions();
if (tfliteAttr == nullptr) {
const auto &tflite_attr = tflite_op->builtin_options.AsReshapeOptions();
if (tflite_attr == nullptr) {
if (tflite_op->inputs.size() < 2) {
MS_LOG(ERROR) << "expected two input tensors, but got: " << tflite_op->inputs.size();
return RET_ERROR;
@ -68,9 +67,9 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
}
} else {
attr->format = schema::Format::Format_NHWC;
attr->shape.resize(tfliteAttr->new_shape.size());
for (size_t i = 0; i < tfliteAttr->new_shape.size(); ++i) {
attr->shape[i] = tfliteAttr->new_shape[i];
attr->shape.resize(tflite_attr->new_shape.size());
for (size_t i = 0; i < tflite_attr->new_shape.size(); ++i) {
attr->shape[i] = tflite_attr->new_shape[i];
}
}
@ -83,7 +82,50 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
lite::PrimitiveC *TfliteReshapeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
const auto &tflite_subgraph = tflite_model->subgraphs.front();
std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tflite_attr = tflite_op->builtin_options.AsReshapeOptions();
if (tflite_attr == nullptr) {
if (tflite_op->inputs.size() < 2) {
MS_LOG(ERROR) << "expected two input tensors, but got: " << tflite_op->inputs.size();
return nullptr;
}
auto shape_tensor_index = tflite_op->inputs[1];
const auto &shape_tensor = tflite_subgraph->tensors[shape_tensor_index];
if (shape_tensor == nullptr) {
MS_LOG(ERROR) << "shape_tensor is null";
return nullptr;
}
auto &buf_data = tflite_model->buffers[shape_tensor->buffer];
if (buf_data == nullptr) {
MS_LOG(ERROR) << "buf_data is null";
return nullptr;
}
if (!buf_data->data.empty()) {
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->shape)) {
MS_LOG(ERROR) << "get reshape -> shape failed";
return nullptr;
}
}
} else {
attr->format = schema::Format::Format_NHWC;
attr->shape.resize(tflite_attr->new_shape.size());
for (size_t i = 0; i < tflite_attr->new_shape.size(); ++i) {
attr->shape[i] = tflite_attr->new_shape[i];
}
}
auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_Reshape;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteReshapeParser("Reshape", new TfliteReshapeParser());
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite

View File

@ -23,8 +23,7 @@
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TfliteReshapeParser : public TfliteNodeParser {
public:
TfliteReshapeParser() : TfliteNodeParser("Reshape") {}
@ -32,8 +31,10 @@ class TfliteReshapeParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESHAPE_PARSER_H

View File

@ -120,6 +120,89 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteResizeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
std::unique_ptr<schema::ResizeT> attr = std::make_unique<schema::ResizeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
attr->coordinateTransformMode = schema::CoordinateTransformMode_COMMON;
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
if (tflite_op_type == tflite::BuiltinOperator_RESIZE_BILINEAR) {
MS_LOG(DEBUG) << "parse TfliteResizeBilinearParser";
const auto &tfliteAttr = tflite_op->builtin_options.AsResizeBilinearOptions();
if (tfliteAttr == nullptr) {
MS_LOG(ERROR) << "get op ResizeBilinear attr failed";
return nullptr;
}
if (tfliteAttr->align_corners) {
attr->alignCorners = tfliteAttr->align_corners;
attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS;
}
if (tfliteAttr->half_pixel_centers) {
attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON
? schema::CoordinateTransformMode_TF_HALF_PIXEL
: schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL);
}
attr->method = schema::ResizeMethod_LINEAR;
} else if (tflite_op_type == tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR) {
MS_LOG(DEBUG) << "parse TfliteResizeNearestNeighborParser";
const auto &tfliteAttr = tflite_op->builtin_options.AsResizeNearestNeighborOptions();
if (tfliteAttr == nullptr) {
MS_LOG(ERROR) << "get op ResizeNearestNeighbor attr failed";
return nullptr;
}
if (tfliteAttr->align_corners) {
attr->alignCorners = tfliteAttr->align_corners;
attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS;
}
if (tfliteAttr->half_pixel_centers) {
attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON
? schema::CoordinateTransformMode_TF_HALF_PIXEL
: schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL);
}
attr->method = schema::ResizeMethod_NEAREST;
attr->nearestMode = schema::NearestMode_NORMAL;
} else {
MS_LOG(ERROR) << "wrong resize type";
return nullptr;
}
attr->format = schema::Format::Format_NHWC;
attr->preserveAspectRatio = false;
auto tfliteResizeTensorIndex = tflite_op->inputs[1];
const auto &shape_tensor = tflite_subgraph->tensors[tfliteResizeTensorIndex];
if (shape_tensor == nullptr) {
MS_LOG(ERROR) << "shape_tensor is null";
return nullptr;
}
auto resizeTensorBufferIndex = shape_tensor->buffer;
const auto &buff = tflite_model->buffers.at(resizeTensorBufferIndex);
if (buff == nullptr) {
MS_LOG(ERROR) << "buff_data is null";
return nullptr;
}
auto buffData = reinterpret_cast<int32_t *>(buff->data.data());
if (buffData != nullptr) {
auto height = buffData[0];
auto width = buffData[1];
attr->newWidth = width;
attr->newHeight = height;
}
primitive->value.type = schema::PrimitiveType_Resize;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteResizeBilinearParser("ResizeBilinear", new TfliteResizeParser());
TfliteNodeRegister g_tfliteResizeNearestNeighborParser("NearestNeighbor", new TfliteResizeParser());

View File

@ -23,8 +23,7 @@
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
class TfliteResizeParser : public TfliteNodeParser {
public:
TfliteResizeParser() : TfliteNodeParser("node_name") {}
@ -32,8 +31,9 @@ class TfliteResizeParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESIZE_PARSER_H

View File

@ -55,6 +55,30 @@ STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteReverseParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
std::unique_ptr<schema::ReverseT> attr = std::make_unique<schema::ReverseT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->axis)) {
MS_LOG(ERROR) << "get reverse -> axis failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Reverse;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteReverseParser("reverse", new TfliteReverseParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteReverseParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -62,6 +62,32 @@ STATUS TfliteReverseSequenceParser::Parse(TfliteTensorsInfo *tensors_info,
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteReverseSequenceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
std::unique_ptr<schema::ReverseSequenceT> attr = std::make_unique<schema::ReverseSequenceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tflite_attr = tflite_op->builtin_options.AsReverseSequenceOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op reverse attr failed";
return nullptr;
}
attr->seqAxis = tflite_attr->seq_dim;
attr->batchAxis = tflite_attr->batch_dim;
primitive->value.type = schema::PrimitiveType_ReverseSequence;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteReverseSequenceParser("ReverseSequence", new TfliteReverseSequenceParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteReverseSequenceParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -58,6 +58,29 @@ STATUS TfliteScatterNdParser::Parse(TfliteTensorsInfo *tensors_info,
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteScatterNdParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
std::unique_ptr<schema::ScatterNDT> attr = std::make_unique<schema::ScatterNDT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tflite_attr = tflite_op->builtin_options.AsScatterNdOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op ScatterNd attr failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_ScatterND;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteScatterNdParser("ScatterNd", new TfliteScatterNdParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteScatterNdParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -50,6 +50,24 @@ STATUS TfliteShapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteShapeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
std::unique_ptr<schema::ShapeT> attr = std::make_unique<schema::ShapeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_Shape;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteShapeParser("Shape", new TfliteShapeParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteShapeParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -59,6 +59,33 @@ STATUS TfliteSkipGramParser::Parse(TfliteTensorsInfo *tensors_info, const std::u
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteSkipGramParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
std::unique_ptr<schema::SkipGramT> attr = std::make_unique<schema::SkipGramT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
const auto &tflite_attr = tflite_op->builtin_options.AsSkipGramOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op attr failed";
return nullptr;
}
attr->includeAllGrams = tflite_attr->include_all_ngrams;
attr->maxSkipSize = tflite_attr->max_skip_size;
attr->ngramSize = tflite_attr->ngram_size;
primitive->value.type = schema::PrimitiveType_SkipGram;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteSkiGramParser("SKipGram", new TfliteSkipGramParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteSkipGramParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

View File

@ -66,6 +66,41 @@ STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
PrimitiveC *TfliteSliceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto &tflite_subgraph = tflite_model->subgraphs.front();
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null";
return nullptr;
}
std::unique_ptr<schema::SliceT> attr = std::make_unique<schema::SliceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
attr->format = schema::Format::Format_NHWC;
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->begin)) {
MS_LOG(ERROR) << "get slice -> begin failed";
return nullptr;
}
if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->size)) {
MS_LOG(ERROR) << "get slice -> size failed";
return nullptr;
}
std::vector<int> axes;
axes.clear();
for (size_t i = 0; i < attr->begin.size(); ++i) {
axes.push_back(i);
}
attr->axes = axes;
primitive->value.type = schema::PrimitiveType_Slice;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
TfliteNodeRegister g_tfliteSliceParser("Slice", new TfliteSliceParser());
} // namespace lite

View File

@ -32,6 +32,8 @@ class TfliteSliceParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More