forked from mindspore-Ecosystem/mindspore
!8770 tflite parser supported to anf
From: @cjh9368 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
f9e4af259a
|
@ -47,13 +47,17 @@ using TensorPtr = std::shared_ptr<mindspore::tensor::Tensor>;
|
|||
constexpr int kAnfPopulaterInputNumOne = 1;
|
||||
constexpr int kAnfPopulaterInputNumTwo = 2;
|
||||
constexpr int kAnfPopulaterInputNumThree = 3;
|
||||
static std::map<std::string, schema::ActivationType> kActivationTypeMap{{"ReLU", schema::ActivationType_RELU},
|
||||
{"ReLU6", schema::ActivationType_RELU6},
|
||||
{"Sigmoid", schema::ActivationType_SIGMOID},
|
||||
{"HSwish", schema::ActivationType_HSWISH},
|
||||
{"HSigmoid", schema::ActivationType_HSIGMOID}};
|
||||
static std::map<std::string, schema::ActivationType> kActivationTypeMap{
|
||||
{"ReLU", schema::ActivationType_RELU},
|
||||
{"ReLU6", schema::ActivationType_RELU6},
|
||||
{"Sigmoid", schema::ActivationType_SIGMOID},
|
||||
{"HSwish", schema::ActivationType_HSWISH},
|
||||
{"HSigmoid", schema::ActivationType_HSIGMOID},
|
||||
{"Swish", schema::ActivationType_SWISH},
|
||||
{"LeakyRelu", schema::ActivationType_LEAKY_RELU},
|
||||
{"Tanh", schema::ActivationType_TANH},
|
||||
{"Logistic", schema::ActivationType_SIGMOID}};
|
||||
std::vector<int> CastToInt(const ValuePtr value, bool is_vector);
|
||||
|
||||
class PrimitiveC : public mindspore::Primitive {
|
||||
public:
|
||||
// Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC().
|
||||
|
|
|
@ -104,8 +104,8 @@ int Split::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
|
|||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
if (inputs_.size() != kSplitInputNum) {
|
||||
MS_LOG(ERROR) << "inputs number is not equal to " << kSplitInputNum;
|
||||
if (inputs_.size() < kSplitInputNum) {
|
||||
MS_LOG(ERROR) << "inputs number is less to " << kSplitInputNum;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto output = outputs_.front();
|
||||
|
|
|
@ -194,6 +194,8 @@ if(ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/group_depthwise_op_convert_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc
|
||||
|
|
|
@ -135,6 +135,6 @@ mtk_convert_model.tflite
|
|||
mtk_model_face_dress_fp16.tflite
|
||||
smartreply.tflite
|
||||
mindspore_text_classification_tflite.tflite
|
||||
ml_location.tflite
|
||||
# ml_location.tflite
|
||||
ml_text_correction.tflite
|
||||
ml_pic_shopping.tflite
|
||||
|
|
|
@ -49,6 +49,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/graph/weight_format_transform_pass.cc
|
||||
../optimizer/graph/weight_format_hardcode_pass.cc
|
||||
../optimizer/graph/clip_convert_activation_pass.cc
|
||||
../optimizer/graph/group_depthwise_op_convert_pass.cc
|
||||
../optimizer/graph/tflite_inputs_order_exchange_pass.cc
|
||||
../optimizer/graph/unused_cast_node_remove_pass.cc
|
||||
../optimizer/graph/unused_transpose_node_remove_pass.cc
|
||||
../optimizer/graph/identity_remove_pass.cc
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "tools/optimizer/fusion/conv_bn_fusion.h"
|
||||
#include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h"
|
||||
#include "tools/optimizer/fusion/constant_folding_fusion.h"
|
||||
#include "tools/optimizer/fusion/quant_dtype_cast_fusion.h"
|
||||
#include "tools/optimizer/fusion/layer_norm_fusion.h"
|
||||
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
|
||||
#include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
|
||||
|
@ -34,6 +33,8 @@
|
|||
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
|
||||
#include "tools/optimizer/graph/weight_format_transform_pass.h"
|
||||
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
|
||||
#include "tools/optimizer/graph/group_depthwise_op_convert_pass.h"
|
||||
#include "tools/optimizer/graph/tflite_inputs_order_exchange_pass.h"
|
||||
#include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
|
||||
#include "tools/optimizer/graph/unused_transpose_node_remove_pass.h"
|
||||
#include "tools/optimizer/graph/infershape_pass.h"
|
||||
|
@ -43,8 +44,7 @@
|
|||
#include "tools/converter/quantizer/weight_quantizer.h"
|
||||
|
||||
using std::string;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
AnfTransform::AnfTransform() = default;
|
||||
|
||||
AnfTransform::~AnfTransform() = default;
|
||||
|
@ -65,7 +65,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
cf_pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
|
||||
// for now - trainning is not supporting fuse operations
|
||||
if (config != nullptr && !config->trainModel) {
|
||||
if (!config->trainModel) {
|
||||
// remove quantdtype when awaretraining
|
||||
pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
|
||||
pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
|
||||
|
@ -119,6 +119,10 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
}
|
||||
pm->AddPass(std::make_shared<opt::ConvConvFusion>());
|
||||
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
|
||||
if (config->fmk == lite::converter::FmkType_TFLITE) {
|
||||
convert_pm->AddPass(std::make_shared<opt::GroupDepthwiseOpConvertPass>());
|
||||
convert_pm->AddPass(std::make_shared<opt::TfliteInputsOrderExchangePass>());
|
||||
}
|
||||
optimizer->AddPassManager(cf_pm);
|
||||
optimizer->AddPassManager(convert_pm);
|
||||
optimizer->AddPassManager(pm);
|
||||
|
@ -168,5 +172,4 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
|
||||
return new_graph;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -32,8 +32,9 @@ class ModelParser {
|
|||
|
||||
virtual ~ModelParser() = default;
|
||||
|
||||
virtual FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) {
|
||||
auto *meta_graph = ParseToFb(modelFile, weightFile, quantType);
|
||||
virtual FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) {
|
||||
auto *meta_graph = ParseToFb(model_file, weight_file, quant_type);
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "parse model to fb failed";
|
||||
return nullptr;
|
||||
|
@ -43,8 +44,8 @@ class ModelParser {
|
|||
return func_graph;
|
||||
}
|
||||
|
||||
virtual schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType = QuantType_QUANT_NONE) = 0;
|
||||
virtual schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type = QuantType_QUANT_NONE) = 0;
|
||||
|
||||
public:
|
||||
static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) {
|
||||
|
|
|
@ -31,22 +31,22 @@ CaffeModelParser::~CaffeModelParser() {}
|
|||
|
||||
const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"};
|
||||
|
||||
schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType) {
|
||||
int status = ValidateFileStr(modelFile, ".prototxt");
|
||||
schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) {
|
||||
int status = ValidateFileStr(model_file, ".prototxt");
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (weightFile.empty()) {
|
||||
if (weight_file.empty()) {
|
||||
MS_LOG(ERROR) << "INPUT MISSING: weightFile is necessary";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = ValidateFileStr(weightFile, ".caffemodel");
|
||||
status = ValidateFileStr(weight_file, ".caffemodel");
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: weightFile must be *.caffemodel";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
@ -57,18 +57,18 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co
|
|||
TensorCache tensorCache;
|
||||
|
||||
caffe::NetParameter proto;
|
||||
status = ReadProtoFromText((const char *)modelFile.c_str(), &proto);
|
||||
status = ReadProtoFromText((const char *)model_file.c_str(), &proto);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Read prototxt file failed, model path: " << modelFile;
|
||||
MS_LOG(ERROR) << "Read prototxt file failed, model path: " << model_file;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
metaGraph->name = proto.name();
|
||||
|
||||
caffe::NetParameter weight;
|
||||
status = ReadProtoFromBinaryFile((const char *)weightFile.c_str(), &weight);
|
||||
status = ReadProtoFromBinaryFile((const char *)weight_file.c_str(), &weight);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Read caffemodel file failed, model path: " << weightFile;
|
||||
MS_LOG(ERROR) << "Read caffemodel file failed, model path: " << weight_file;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -81,7 +81,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co
|
|||
}
|
||||
|
||||
NoSupportOp::GetInstance()->SetFmkType("CAFFE");
|
||||
status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quantType);
|
||||
status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quant_type);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "ParseLayer failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
@ -97,7 +97,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
metaGraph->name = GetModelName(modelFile);
|
||||
metaGraph->name = GetModelName(model_file);
|
||||
|
||||
SetAllTensors(tensorCache, metaGraph.get());
|
||||
|
||||
|
|
|
@ -34,8 +34,8 @@ class CaffeModelParser : public ModelParser {
|
|||
|
||||
virtual ~CaffeModelParser();
|
||||
|
||||
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType = QuantType_QUANT_NONE) override;
|
||||
schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type = QuantType_QUANT_NONE) override;
|
||||
|
||||
private:
|
||||
STATUS SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache);
|
||||
|
|
|
@ -623,9 +623,9 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType) {
|
||||
int status = ValidateFileStr(modelFile, ".onnx");
|
||||
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) {
|
||||
int status = ValidateFileStr(model_file, ".onnx");
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.onnx";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
@ -633,9 +633,9 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
|||
}
|
||||
|
||||
onnx::ModelProto onnx_model;
|
||||
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model);
|
||||
status = ReadProtoFromBinaryFile((const char *)model_file.c_str(), &onnx_model);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile;
|
||||
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_file;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -645,13 +645,13 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
|||
|
||||
auto dst_graph = std::make_unique<schema::MetaGraphT>();
|
||||
auto dst_sub_graph = std::make_unique<schema::SubGraphT>();
|
||||
int ret = ParseGraph(dst_graph.get(), dst_sub_graph.get(), onnx_graph, quantType);
|
||||
int ret = ParseGraph(dst_graph.get(), dst_sub_graph.get(), onnx_graph, quant_type);
|
||||
dst_graph->subGraph.push_back(std::move(dst_sub_graph));
|
||||
subGraphNum += 1;
|
||||
if (ret == RET_ERROR) {
|
||||
return nullptr;
|
||||
}
|
||||
dst_graph->name = GetModelName(modelFile);
|
||||
dst_graph->name = GetModelName(model_file);
|
||||
|
||||
std::vector<uint32_t> input_temp_index;
|
||||
for (size_t i = 0; i < dst_graph->subGraph.front()->inputIndices.size(); i++) {
|
||||
|
|
|
@ -45,8 +45,8 @@ class OnnxModelParser : public ModelParser {
|
|||
int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph,
|
||||
const QuantType &quantType);
|
||||
|
||||
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType = QuantType_QUANT_NONE) override;
|
||||
schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type = QuantType_QUANT_NONE) override;
|
||||
|
||||
static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
|
||||
|
||||
|
|
|
@ -1,273 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tflite/model_parser_for_tflite.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "src/param_value_lite.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
||||
FuncGraphPtr ModelParserForTflite::Parse(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType) {
|
||||
// load graph
|
||||
tfliteModel = ReadTfliteModel(modelFile.c_str());
|
||||
if (tfliteModel == nullptr) {
|
||||
MS_LOG(ERROR) << "read tflite model failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (tfliteModel->subgraphs.size() != 1) {
|
||||
MS_LOG(ERROR) << "read tflite model subgraphs failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||
return nullptr;
|
||||
}
|
||||
funcGraphPtr = std::make_shared<FuncGraph>();
|
||||
|
||||
auto status = ConvertGraphInputs();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert graph inputs failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = ConvertOps();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert ops failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = ConvertGraphOutputs();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert graph outputs failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
return funcGraphPtr;
|
||||
}
|
||||
|
||||
STATUS ModelParserForTflite::ConvertOps() {
|
||||
const auto &tfliteSubgraph = tfliteModel->subgraphs.front();
|
||||
const auto &tfliteModelBuffers = tfliteModel->buffers;
|
||||
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
|
||||
STATUS status = RET_OK;
|
||||
int opIdx = 0;
|
||||
for (auto &op : tfliteSubgraph->operators) {
|
||||
auto tfliteOpType = (tfliteModel->operator_codes[op->opcode_index])->builtin_code;
|
||||
auto opType = GetMSOpType(tfliteOpType);
|
||||
|
||||
// parse primitive
|
||||
auto nodeParser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType);
|
||||
if (nodeParser == nullptr) {
|
||||
NoSupportOp::GetInstance()->InsertOp(opType);
|
||||
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
|
||||
continue;
|
||||
}
|
||||
PrimitiveC *primitiveC = nullptr;
|
||||
if (status == RET_OK) {
|
||||
status = nodeParser->Parse(op, tfliteModel, primitiveC);
|
||||
if (status != RET_OK) {
|
||||
if (status == RET_NOT_FIND_OP) {
|
||||
opType = (opType != "Custom" ? opType : (tfliteModel->operator_codes[op->opcode_index])->custom_code);
|
||||
NoSupportOp::GetInstance()->InsertOp(opType);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "node " << opType.c_str() << " parser failed";
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> opInputs = {NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC))};
|
||||
// parse inputs
|
||||
for (auto inputIdx : op->inputs) {
|
||||
const auto &inputTensor = tfliteSubgraph->tensors[inputIdx];
|
||||
if (nodes.find(inputIdx) != nodes.end()) {
|
||||
opInputs.emplace_back(nodes.at(inputIdx));
|
||||
continue;
|
||||
}
|
||||
// const tensor
|
||||
if (tfliteModelBuffers.at(inputTensor->buffer)->data.empty()) {
|
||||
ParameterPtr parameter;
|
||||
ConvertConstTensor(inputTensor.get(), parameter);
|
||||
opInputs.emplace_back(parameter);
|
||||
nodes.insert(std::pair(inputIdx, parameter));
|
||||
continue;
|
||||
}
|
||||
MS_LOG(ERROR) << "tensor" << inputIdx << " is neither a node output nor a weight tensor.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto newCNode = funcGraphPtr->NewCNode(opInputs);
|
||||
newCNode->set_fullname_with_scope(opType + "-" + std::to_string(opIdx++));
|
||||
|
||||
// parse outputs
|
||||
status = ConvertOutputTensor(op.get(), newCNode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert output tensors for " << newCNode->fullname_with_scope() << " failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return status;
|
||||
}
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
STATUS ModelParserForTflite::ConvertGraphInputs() {
|
||||
const auto &tfliteSubgraph = tfliteModel->subgraphs.front();
|
||||
for (auto tfliteGraphInput : tfliteSubgraph->inputs) {
|
||||
if (tfliteGraphInput < 0) {
|
||||
tfliteGraphInput = tfliteGraphInput + tfliteSubgraph->tensors.size();
|
||||
}
|
||||
auto parameter = funcGraphPtr->add_parameter();
|
||||
const auto &tensor = tfliteSubgraph->tensors.at(tfliteGraphInput);
|
||||
std::vector<int64_t> shape_vector;
|
||||
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
|
||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
||||
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
||||
parameter->set_abstract(abstract_tensor);
|
||||
parameter->set_name("graph_input_" + std::to_string(tfliteGraphInput) + "_parameter");
|
||||
nodes.insert(std::pair(tfliteGraphInput, parameter));
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS ModelParserForTflite::ConvertGraphOutputs() {
|
||||
const auto &tfliteSubgraph = tfliteModel->subgraphs.front();
|
||||
if (tfliteSubgraph->outputs.size() > 1) {
|
||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
auto make_tuple_prim_ptr = GetMakeTuplePrim();
|
||||
if (make_tuple_prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
|
||||
make_tuple_inputs.emplace_back(make_tuple_prim);
|
||||
for (auto outputNode : tfliteSubgraph->outputs) {
|
||||
auto cnode = nodes.at(outputNode);
|
||||
if (nullptr == cnode) {
|
||||
MS_LOG(ERROR) << "Can't find input node.";
|
||||
return RET_NOT_FIND_OP;
|
||||
}
|
||||
make_tuple_inputs.emplace_back(cnode);
|
||||
}
|
||||
auto make_tuple_cnode = funcGraphPtr->NewCNode(make_tuple_inputs);
|
||||
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
||||
|
||||
std::vector<AnfNodePtr> op_inputs;
|
||||
auto return_prim_ptr = GetReturnPrim();
|
||||
if (return_prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto value_node = NewValueNode(return_prim_ptr);
|
||||
op_inputs.emplace_back(value_node);
|
||||
op_inputs.emplace_back(make_tuple_cnode);
|
||||
auto cnode = funcGraphPtr->NewCNode(op_inputs);
|
||||
cnode->set_fullname_with_scope("return");
|
||||
funcGraphPtr->set_return(cnode);
|
||||
} else {
|
||||
auto returnPrim = GetReturnPrim();
|
||||
if (returnPrim == nullptr) {
|
||||
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto valueNode = NewValueNode(returnPrim);
|
||||
std::vector<AnfNodePtr> opInputs{valueNode};
|
||||
auto cnode = nodes.at(tfliteSubgraph->outputs.front());
|
||||
if (nullptr == cnode) {
|
||||
MS_LOG(ERROR) << "Can't find input node.";
|
||||
return RET_NOT_FIND_OP;
|
||||
}
|
||||
opInputs.emplace_back(cnode);
|
||||
auto returnCnode = funcGraphPtr->NewCNode(opInputs);
|
||||
returnCnode->set_fullname_with_scope("return");
|
||||
funcGraphPtr->set_return(returnCnode);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelParserForTflite::ConvertConstTensor(const tflite::TensorT *tensor, ParameterPtr parameter) {
|
||||
parameter = funcGraphPtr->add_parameter();
|
||||
const auto &tfliteModelBuffers = tfliteModel->buffers;
|
||||
auto type_id = static_cast<TypeId>(tensor->type);
|
||||
auto type_ptr = TypeIdToType(type_id);
|
||||
std::vector<int64_t> shape_vector;
|
||||
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
|
||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
||||
parameter->set_abstract(abstract_tensor);
|
||||
parameter->set_name("const_" + std::to_string(nodes.size()) + "_parameter");
|
||||
|
||||
ParamValueLitePtr paramValue = std::make_shared<ParamValueLite>();
|
||||
MS_ASSERT(paramValue != nullptr);
|
||||
paramValue->set_tensor_shape(tensor->shape);
|
||||
paramValue->set_tensor_type(GetTfliteDataType(tensor->type));
|
||||
paramValue->set_format(schema::Format::Format_NHWC);
|
||||
const auto &data = tfliteModelBuffers.at(tensor->buffer)->data;
|
||||
if (!data.empty()) {
|
||||
auto size = data.size();
|
||||
char *tensor_data = new (std::nothrow) char[size];
|
||||
if (tensor_data == nullptr) {
|
||||
MS_LOG(ERROR) << "new char[] failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
std::memcpy(tensor_data, data.data(), size);
|
||||
paramValue->set_tensor_addr(tensor_data);
|
||||
paramValue->set_tensor_size(size);
|
||||
parameter->set_default_param(paramValue);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelParserForTflite::ConvertOutputTensor(const tflite::OperatorT *op, CNodePtr dstCNode) {
|
||||
MS_ASSERT(op != nullptr);
|
||||
MS_ASSERT(dstCNode != nullptr);
|
||||
const auto &tfliteSubgraph = tfliteModel->subgraphs.front();
|
||||
if (op->outputs.size() == 1) {
|
||||
const auto &tensor = tfliteSubgraph->tensors.at(op->outputs.front());
|
||||
std::vector<int64_t> shape_vector;
|
||||
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
|
||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
||||
auto typePtr = TypeIdToType(GetTfliteDataType(tensor->type));
|
||||
dstCNode->set_abstract(std::make_shared<abstract::AbstractTensor>(typePtr, shape_vector));
|
||||
nodes.insert(std::pair(op->outputs.front(), dstCNode));
|
||||
} else {
|
||||
AbstractBasePtrList abstractList;
|
||||
for (auto outputIdx : op->outputs) {
|
||||
const auto &tensor = tfliteSubgraph->tensors.at(outputIdx);
|
||||
std::vector<int64_t> shape_vector;
|
||||
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
|
||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
||||
auto typePtr = TypeIdToType(GetTfliteDataType(tensor->type));
|
||||
abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(typePtr, shape_vector));
|
||||
auto tupleGetItemPrimPtr = GetTupleGetItemPrim();
|
||||
if (tupleGetItemPrimPtr == nullptr) {
|
||||
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr);
|
||||
auto getItemValue = NewValueNode(MakeValue<int>(outputIdx));
|
||||
std::vector<AnfNodePtr> inputs{tupleGetItemPrim, dstCNode, getItemValue};
|
||||
CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs);
|
||||
getItemCNode->set_fullname_with_scope(dstCNode->fullname_with_scope() + "_getitem_" + std::to_string(outputIdx));
|
||||
nodes.insert(std::pair(outputIdx, getItemCNode));
|
||||
}
|
||||
dstCNode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList));
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -1,44 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef LITE_MODEL_PARSER_FOR_TFLITE_H
|
||||
#define LITE_MODEL_PARSER_FOR_TFLITE_H
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/tflite/tflite_model_parser.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class ModelParserForTflite : public TfliteModelParser {
|
||||
public:
|
||||
ModelParserForTflite() = default;
|
||||
|
||||
~ModelParserForTflite() override = default;
|
||||
|
||||
FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) override;
|
||||
|
||||
private:
|
||||
std::unordered_map<int, AnfNodePtr> nodes;
|
||||
std::unique_ptr<tflite::ModelT> tfliteModel;
|
||||
FuncGraphPtr funcGraphPtr;
|
||||
STATUS ConvertConstTensor(const tflite::TensorT *tensor, ParameterPtr parameter);
|
||||
STATUS ConvertOutputTensor(const tflite::OperatorT *op, CNodePtr dstCNode);
|
||||
STATUS ConvertOps();
|
||||
STATUS ConvertGraphInputs();
|
||||
STATUS ConvertGraphOutputs();
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // LITE_MODEL_PARSER_FOR_TFLITE_H
|
|
@ -18,9 +18,11 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "src/ops/activation.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "tools/converter/parser/tflite/tflite_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
|
@ -86,12 +88,40 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteReluParser("Relu", new TfliteActivationParser());
|
||||
TfliteNodeRegister g_tfliteRelu6Parser("Relu6", new TfliteActivationParser());
|
||||
TfliteNodeRegister g_tfliteTanhParser("Tanh", new TfliteActivationParser());
|
||||
TfliteNodeRegister g_tfliteSwishParser("Swish", new TfliteActivationParser());
|
||||
TfliteNodeRegister g_tfliteHardSwishParser("HardSwish", new TfliteActivationParser());
|
||||
lite::PrimitiveC *TfliteActivationParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
|
||||
auto ms_op_type = GetMSOpType(tflite_op_type);
|
||||
if (kActivationTypeMap.find(ms_op_type) == kActivationTypeMap.end()) {
|
||||
MS_LOG(ERROR) << ms_op_type << "is a not supported activation type";
|
||||
return nullptr;
|
||||
}
|
||||
attr->type = kActivationTypeMap.find(GetMSOpType(tflite_op_type))->second;
|
||||
if (attr->type == schema::ActivationType_LEAKY_RELU) {
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsLeakyReluOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op: " << GetMSOpType(tflite_op_type) << " attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->alpha = tflite_attr->alpha;
|
||||
}
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
primitive->value.type = schema::PrimitiveType_Activation;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_TfliteReluParser("ReLU", new TfliteActivationParser());
|
||||
TfliteNodeRegister g_TfliteRelu6Parser("ReLU6", new TfliteActivationParser());
|
||||
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteActivationParser());
|
||||
TfliteNodeRegister g_TfliteSwishParser("Swish", new TfliteActivationParser());
|
||||
TfliteNodeRegister g_TfliteHardSwishParser("HSwish", new TfliteActivationParser());
|
||||
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteActivationParser());
|
||||
TfliteNodeRegister g_tfliteLeakyReluParser("LeakyRelu", new TfliteActivationParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteActivationParser());
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -23,8 +23,7 @@
|
|||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
class TfliteActivationParser : public TfliteNodeParser {
|
||||
public:
|
||||
TfliteActivationParser() : TfliteNodeParser("node_name") {}
|
||||
|
@ -32,9 +31,10 @@ class TfliteActivationParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_ACTIVATION_PARSER_H
|
||||
|
|
|
@ -18,9 +18,10 @@
|
|||
#include "tools/converter/parser/tflite/tflite_addn_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "src/ops/addn.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
|
@ -55,7 +56,18 @@ STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
lite::PrimitiveC *TfliteAddNParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto attr = std::make_unique<schema::AddNT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
primitive->value.type = schema::PrimitiveType_AddN;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteAddNParser("AddN", new TfliteAddNParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -32,6 +32,9 @@ class TfliteAddNParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
|
||||
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -76,6 +76,39 @@ STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteArgmaxParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
const auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
std::unique_ptr<schema::ArgMaxT> attr = std::make_unique<schema::ArgMaxT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->outMaxValue = false;
|
||||
attr->topK = 1;
|
||||
attr->keepDims = false;
|
||||
attr->axisType = 1;
|
||||
|
||||
// get axis attr
|
||||
auto axis_idx = tflite_op->inputs[1];
|
||||
auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer;
|
||||
auto &buf_data = tflite_model->buffers[buffer_idx];
|
||||
if (buf_data == nullptr) {
|
||||
MS_LOG(ERROR) << "the buf data is null";
|
||||
return nullptr;
|
||||
}
|
||||
auto data_ptr = buf_data->data.data();
|
||||
if (data_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "the data is null";
|
||||
return nullptr;
|
||||
}
|
||||
attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr)));
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
primitive->value.type = schema::PrimitiveType_ArgMax;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteArgmaxParser("Argmax", new TfliteArgmaxParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,9 @@ class TfliteArgmaxParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -76,6 +76,39 @@ STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteArgminParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
const auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
std::unique_ptr<schema::ArgMinT> attr = std::make_unique<schema::ArgMinT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->outMaxValue = false;
|
||||
attr->topK = 1;
|
||||
attr->keepDims = false;
|
||||
attr->axisType = 1;
|
||||
|
||||
// get axis attr
|
||||
auto axis_idx = tflite_op->inputs[1];
|
||||
auto buffer_idx = tflite_subgraph->tensors[axis_idx]->buffer;
|
||||
auto &buf_data = tflite_model->buffers[buffer_idx];
|
||||
if (buf_data == nullptr) {
|
||||
MS_LOG(ERROR) << "the buf data is null";
|
||||
return nullptr;
|
||||
}
|
||||
auto data_ptr = buf_data->data.data();
|
||||
if (data_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "the data is null";
|
||||
return nullptr;
|
||||
}
|
||||
attr->axis = *(static_cast<int32_t *>(static_cast<void *>(data_ptr)));
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
primitive->value.type = schema::PrimitiveType_ArgMin;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteArgminParser("Argmin", new TfliteArgminParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteArgminParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -179,6 +179,133 @@ STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteDoubleInputOpParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (tflite_op_type == tflite::BuiltinOperator_ADD) {
|
||||
MS_LOG(DEBUG) << "parse TfliteAddParser";
|
||||
auto attr = std::make_unique<schema::AddT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
const auto &tfliteAttr = tflite_op->builtin_options.AsAddOptions();
|
||||
if (nullptr == tfliteAttr) {
|
||||
MS_LOG(ERROR) << "get op: " << tflite_op_type << " attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||
primitive->value.type = schema::PrimitiveType_Add;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_SUB) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSubParser";
|
||||
auto attr = std::make_unique<schema::SubT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
const auto &tfliteAttr = tflite_op->builtin_options.AsSubOptions();
|
||||
if (nullptr == tfliteAttr) {
|
||||
MS_LOG(ERROR) << "get op: " << tflite_op_type << " attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||
primitive->value.type = schema::PrimitiveType_Sub;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_MUL) {
|
||||
MS_LOG(DEBUG) << "parse TfliteMulParser";
|
||||
auto attr = std::make_unique<schema::MulT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
const auto &tfliteAttr = tflite_op->builtin_options.AsMulOptions();
|
||||
if (nullptr == tfliteAttr) {
|
||||
MS_LOG(ERROR) << "get op: " << tflite_op_type << " attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||
primitive->value.type = schema::PrimitiveType_Mul;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_DIV) {
|
||||
MS_LOG(DEBUG) << "parse TfliteDivParser";
|
||||
auto attr = std::make_unique<schema::DivT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
const auto &tfliteAttr = tflite_op->builtin_options.AsDivOptions();
|
||||
if (nullptr == tfliteAttr) {
|
||||
MS_LOG(ERROR) << "get op: " << tflite_op_type << " attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->activationType = GetActivationFunctionType(tfliteAttr->fused_activation_function);
|
||||
primitive->value.type = schema::PrimitiveType_Div;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_FLOOR_DIV) {
|
||||
MS_LOG(DEBUG) << "parse TfliteFloorDivParser";
|
||||
std::unique_ptr<schema::FloorDivT> attr = std::make_unique<schema::FloorDivT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_FloorDiv;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_FLOOR_MOD) {
|
||||
MS_LOG(DEBUG) << "parse TfliteFloorModParser";
|
||||
auto attr = std::make_unique<schema::FloorModT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_FloorMod;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_SQUARED_DIFFERENCE) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser";
|
||||
auto attr = std::make_unique<schema::SquaredDifferenceT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_SquaredDifference;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_POW) {
|
||||
MS_LOG(DEBUG) << "parse TflitePowParser";
|
||||
auto attr = std::make_unique<schema::PowerT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->power = 1.0f;
|
||||
attr->scale = 1.0f;
|
||||
attr->shift = 0.0f;
|
||||
primitive->value.type = schema::PrimitiveType_Power;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_MAXIMUM) {
|
||||
MS_LOG(DEBUG) << "parse TfliteMaximumParser";
|
||||
auto attr = std::make_unique<schema::MaximumT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Maximum;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_MINIMUM) {
|
||||
MS_LOG(DEBUG) << "parse TfliteMinimumParser";
|
||||
auto attr = std::make_unique<schema::MinimumT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Minimum;
|
||||
primitive->value.value = attr.release();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "op hasn't been supported";
|
||||
return nullptr;
|
||||
}
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
|
@ -320,6 +447,124 @@ STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteSingleInputOpParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (tflite_op_type == tflite::BuiltinOperator_ABS) {
|
||||
MS_LOG(DEBUG) << "parse TfliteAbsParser";
|
||||
auto attr = std::make_unique<schema::AbsT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Abs;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_EXP) {
|
||||
MS_LOG(DEBUG) << "parse TfliteExpParser";
|
||||
auto attr = std::make_unique<schema::ExpT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->base = -1; // -1 represent base = e
|
||||
attr->scale = 1;
|
||||
attr->shift = 0;
|
||||
primitive->value.type = schema::PrimitiveType_Exp;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_SQRT) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSqrtParser";
|
||||
auto attr = std::make_unique<schema::SqrtT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Sqrt;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_RSQRT) {
|
||||
MS_LOG(DEBUG) << "parse TfliteRsqrtParser";
|
||||
auto attr = std::make_unique<schema::RsqrtT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Rsqrt;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_SQUARE) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSquareParser";
|
||||
auto attr = std::make_unique<schema::SquareT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Square;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_SIN) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSinParser";
|
||||
auto attr = std::make_unique<schema::SinT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Sin;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_COS) {
|
||||
MS_LOG(DEBUG) << "parse TfliteCosParser";
|
||||
std::unique_ptr<schema::CosT> attr = std::make_unique<schema::CosT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Cos;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_LOG) {
|
||||
MS_LOG(DEBUG) << "parse TfliteLogParser";
|
||||
auto attr = std::make_unique<schema::LogT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Log;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_ROUND) {
|
||||
MS_LOG(DEBUG) << "parse TfliteRoundParser";
|
||||
auto attr = std::make_unique<schema::RoundT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Round;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_CEIL) {
|
||||
MS_LOG(DEBUG) << "parse TfliteCeilParser";
|
||||
auto attr = std::make_unique<schema::CeilT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Ceil;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_FLOOR) {
|
||||
MS_LOG(DEBUG) << "parse TfliteFloorParser";
|
||||
auto attr = std::make_unique<schema::FloorT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Floor;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_NEG) {
|
||||
MS_LOG(DEBUG) << "parse TfliteNegParser";
|
||||
auto attr = std::make_unique<schema::NegT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Neg;
|
||||
primitive->value.value = attr.release();
|
||||
}
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
|
@ -406,29 +651,91 @@ STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteCompareOpParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
|
||||
if (tflite_op_type == tflite::BuiltinOperator_EQUAL) {
|
||||
MS_LOG(DEBUG) << "parse TfliteEqualParser";
|
||||
std::unique_ptr<schema::EqualT> attr = std::make_unique<schema::EqualT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Equal;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_NOT_EQUAL) {
|
||||
MS_LOG(DEBUG) << "parse TfliteNotEqualParser";
|
||||
std::unique_ptr<schema::NotEqualT> attr = std::make_unique<schema::NotEqualT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_NotEqual;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_GREATER) {
|
||||
MS_LOG(DEBUG) << "parse TfliteGreaterParser";
|
||||
std::unique_ptr<schema::GreaterT> attr = std::make_unique<schema::GreaterT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Greater;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_GREATER_EQUAL) {
|
||||
MS_LOG(DEBUG) << "parse TfliteGreaterEqualParser";
|
||||
std::unique_ptr<schema::GreaterEqualT> attr = std::make_unique<schema::GreaterEqualT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_GreaterEqual;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_LESS) {
|
||||
MS_LOG(DEBUG) << "parse TfliteLessParser";
|
||||
std::unique_ptr<schema::LessT> attr = std::make_unique<schema::LessT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Less;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_LESS_EQUAL) {
|
||||
MS_LOG(DEBUG) << "parse TfliteLessEqualParser";
|
||||
std::unique_ptr<schema::LessEqualT> attr = std::make_unique<schema::LessEqualT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_LessEqual;
|
||||
primitive->value.value = attr.release();
|
||||
}
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteAddParser("Add", new TfliteDoubleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteSubParser("Sub", new TfliteDoubleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteMulParser("Mul", new TfliteDoubleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteDivParser("Div", new TfliteDoubleInputOpParser());
|
||||
TfliteNodeRegister g_TfliteMulParser("Mul", new TfliteDoubleInputOpParser());
|
||||
TfliteNodeRegister g_TfliteDivParser("Div", new TfliteDoubleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteFloorDivParser("FloorDiv", new TfliteDoubleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteFloorModParser("FloorMod", new TfliteDoubleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteRealDivParser("RealDiv", new TfliteDoubleInputOpParser());
|
||||
TfliteNodeRegister g_tflitePowParser("Pow", new TfliteDoubleInputOpParser());
|
||||
TfliteNodeRegister g_TflitePowParser("Pow", new TfliteDoubleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteSquaredDifferenceParser("SquaredDifference", new TfliteDoubleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteMaximumParser("Maximum", new TfliteDoubleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteMinimumParser("Minimum", new TfliteDoubleInputOpParser());
|
||||
TfliteNodeRegister g_TfliteMaximumParser("Maximum", new TfliteDoubleInputOpParser());
|
||||
TfliteNodeRegister g_TfliteMinimumParser("Minimum", new TfliteDoubleInputOpParser());
|
||||
|
||||
TfliteNodeRegister g_tfliteAbsParser("Abs", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteExpParser("Exp", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteSqrtParser("Sqrt", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_TfliteAbsParser("Abs", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_TfliteExpParser("Exp", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_TfliteSqrtParser("Sqrt", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteRsqrtParser("Rsqrt", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteSquareParser("Square", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteSinParser("Sin", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteCosParser("Cos", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteLogParser("Log", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_TfliteSquareParser("Square", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_TfliteSinParser("Sin", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_TfliteCosParser("Cos", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_TfliteLogParser("Log", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteCeilParser("Ceil", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteSingleInputOpParser());
|
||||
TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteSingleInputOpParser());
|
||||
|
||||
|
|
|
@ -32,6 +32,9 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
|
||||
class TfliteSingleInputOpParser : public TfliteNodeParser {
|
||||
|
@ -41,6 +44,9 @@ class TfliteSingleInputOpParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
|
||||
class TfliteCompareOpParser : public TfliteNodeParser {
|
||||
|
@ -50,7 +56,11 @@ class TfliteCompareOpParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -74,6 +74,29 @@ STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteBatchToSpaceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
const auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
std::unique_ptr<schema::BatchToSpaceT> attr = std::make_unique<schema::BatchToSpaceT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->blockShape)) {
|
||||
MS_LOG(ERROR) << "get batchToSpace -> blockShape failed";
|
||||
return nullptr;
|
||||
}
|
||||
if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->crops)) {
|
||||
MS_LOG(ERROR) << "get batchToSpace -> crops failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_BatchToSpace;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteBatchToSpaceParser("BatchToSpace", new TfliteBatchToSpaceParser());
|
||||
TfliteNodeRegister g_tfliteBatchToSpaceNDParser("BatchToSpaceND", new TfliteBatchToSpaceParser());
|
||||
|
|
|
@ -32,7 +32,10 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -57,6 +57,28 @@ STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteBroadcastToParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
std::unique_ptr<schema::BroadcastToT> attr = std::make_unique<schema::BroadcastToT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dst_shape)) {
|
||||
MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_BroadcastTo;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteBroadcastToParser("BroadcastTo", new TfliteBroadcastToParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -23,8 +23,7 @@
|
|||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
class TfliteBroadcastToParser : public TfliteNodeParser {
|
||||
public:
|
||||
TfliteBroadcastToParser() : TfliteNodeParser("BroadcastTo") {}
|
||||
|
@ -32,8 +31,10 @@ class TfliteBroadcastToParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BROADCAST_TO_PARSER_H
|
||||
|
|
|
@ -63,6 +63,32 @@ STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteCastParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
|
||||
if (in_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor is null";
|
||||
return nullptr;
|
||||
}
|
||||
attr->srcT = GetTfliteDataType(in_tensor->type);
|
||||
const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]];
|
||||
if (out_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor is null";
|
||||
return nullptr;
|
||||
}
|
||||
attr->dstT = GetTfliteDataType(out_tensor->type);
|
||||
primitive->value.type = schema::PrimitiveType_Cast;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteCastParser("Cast", new TfliteCastParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteCastParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -60,6 +60,26 @@ STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteConcatParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
std::unique_ptr<schema::ConcatT> attr = std::make_unique<schema::ConcatT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &tfliteAttr = tflite_op->builtin_options.AsConcatenationOptions();
|
||||
if (tfliteAttr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op concat attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->axis = tfliteAttr->axis;
|
||||
attr->n = tflite_op->inputs.size();
|
||||
primitive->value.type = schema::PrimitiveType_Concat;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteConcatParser("Concat", new TfliteConcatParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteConcatParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,8 +18,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
|
@ -74,7 +73,7 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
|
|||
// calculate pad params
|
||||
auto data_index = tflite_op->inputs[0];
|
||||
const auto &data_tensor = tflite_subgraph->tensors[data_index];
|
||||
std::vector<int> params;
|
||||
std::vector<int64_t> params;
|
||||
int status =
|
||||
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
|
@ -96,7 +95,63 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
lite::PrimitiveC *TfliteConvParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
const auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
std::unique_ptr<schema::Conv2DT> attr = std::make_unique<schema::Conv2DT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsConv2DOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get conv attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->group = 1;
|
||||
attr->strideW = tflite_attr->stride_w;
|
||||
attr->strideH = tflite_attr->stride_h;
|
||||
attr->dilateH = tflite_attr->dilation_h_factor;
|
||||
attr->dilateW = tflite_attr->dilation_w_factor;
|
||||
attr->padMode = GetPadMode(tflite_attr->padding);
|
||||
attr->format = schema::Format::Format_NHWC;
|
||||
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
|
||||
attr->hasBias = true;
|
||||
|
||||
// get the conv op weight tensor
|
||||
auto weight_index = tflite_op->inputs[1];
|
||||
const auto &weight_tensor = tflite_subgraph->tensors[weight_index];
|
||||
if (weight_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the weight tensor is null";
|
||||
return nullptr;
|
||||
}
|
||||
auto weight_shape = weight_tensor->shape;
|
||||
attr->channelIn = weight_shape[3];
|
||||
attr->channelOut = weight_shape[0];
|
||||
attr->kernelH = weight_shape[1];
|
||||
attr->kernelW = weight_shape[2];
|
||||
|
||||
// calculate pad params
|
||||
auto data_index = tflite_op->inputs[0];
|
||||
const auto &data_tensor = tflite_subgraph->tensors[data_index];
|
||||
std::vector<int64_t> params;
|
||||
int status =
|
||||
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "get padding params failed";
|
||||
return nullptr;
|
||||
} else if (status == RET_OK) {
|
||||
attr->padUp = params.at(0);
|
||||
attr->padDown = params.at(1);
|
||||
attr->padLeft = params.at(2);
|
||||
attr->padRight = params.at(3);
|
||||
}
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteConv2DParser("Conv2D", new TfliteConvParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -23,8 +23,7 @@
|
|||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
class TfliteConvParser : public TfliteNodeParser {
|
||||
public:
|
||||
TfliteConvParser() : TfliteNodeParser("Conv2D") {}
|
||||
|
@ -32,8 +31,9 @@ class TfliteConvParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H
|
||||
|
|
|
@ -15,9 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "tools/converter/parser/tflite/tflite_converter.h"
|
||||
#include "tools/converter/parser/tflite/tflite_model_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
TfliteConverter::TfliteConverter() { modelParser = new TfliteModelParser(); }
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -21,18 +21,15 @@
|
|||
#include <memory>
|
||||
#include <map>
|
||||
#include "tools/converter/converter.h"
|
||||
#include "tools/converter/parser/tflite/tflite_model_parser.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
class TfliteConverter : public Converter {
|
||||
public:
|
||||
TfliteConverter();
|
||||
|
||||
~TfliteConverter() override = default;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_
|
||||
|
|
|
@ -271,6 +271,48 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
}
|
||||
return status;
|
||||
}
|
||||
PrimitiveC *TfliteCustomParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
auto op = new schema::CNodeT;
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
const auto &custom_attr = tflite_op->custom_options;
|
||||
const auto &opcode_index = tflite_op->opcode_index;
|
||||
const auto &custom_type = tflite_model->operator_codes[opcode_index]->custom_code;
|
||||
int status = RET_OK;
|
||||
if (custom_type == "TFLite_Detection_PostProcess") {
|
||||
status = DetectPostProcess(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "Predict") {
|
||||
status = Predict(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "Normalize") {
|
||||
status = Normalize(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "ExtractFeatures") {
|
||||
status = ExtractFeatures(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "AudioSpectrogram") {
|
||||
status = AudioSpectrogram(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "Mfcc") {
|
||||
status = Mfcc(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "FlexRFFT") {
|
||||
status = Rfft(custom_attr, op, tflite_op, tflite_model, tflite_subgraph);
|
||||
} else if (custom_type == "FlexReal") {
|
||||
status = FftReal(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "FlexImag") {
|
||||
status = FftImag(custom_attr, op, tflite_op);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "the custom op hasn't been supported now";
|
||||
status = RET_NOT_FIND_OP;
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
auto primitive = op->primitive.release();
|
||||
delete op;
|
||||
return PrimitiveC::Create(primitive);
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteCustomParser("Custom", new TfliteCustomParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -31,6 +31,8 @@ class TfliteCustomParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
|
||||
static STATUS DetectPostProcess(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op);
|
||||
|
|
|
@ -75,7 +75,7 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
// calculate pad params
|
||||
auto data_index = tflite_op->inputs[2];
|
||||
const auto &data_tensor = tflite_subgraph->tensors[data_index];
|
||||
std::vector<int> params;
|
||||
std::vector<int64_t> params;
|
||||
int status =
|
||||
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
|
@ -96,6 +96,64 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteDeConvParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
std::unique_ptr<schema::DeConv2DT> attr = std::make_unique<schema::DeConv2DT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsTransposeConvOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op deconv attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->group = 1;
|
||||
attr->strideW = tflite_attr->stride_w;
|
||||
attr->strideH = tflite_attr->stride_h;
|
||||
attr->dilateH = 1;
|
||||
attr->dilateW = 1;
|
||||
attr->padMode = GetPadMode(tflite_attr->padding);
|
||||
attr->format = schema::Format::Format_NHWC;
|
||||
attr->activationType = schema::ActivationType_NO_ACTIVATION;
|
||||
attr->hasBias = true;
|
||||
|
||||
// get the conv op weight tensor
|
||||
auto weight_index = tflite_op->inputs[1];
|
||||
const auto &weight_tensor = tflite_subgraph->tensors[weight_index];
|
||||
if (weight_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the weight tensor is null";
|
||||
return nullptr;
|
||||
}
|
||||
auto weight_shape = weight_tensor->shape;
|
||||
attr->channelIn = weight_shape[3];
|
||||
attr->channelOut = weight_shape[0];
|
||||
attr->kernelH = weight_shape[1];
|
||||
attr->kernelW = weight_shape[2];
|
||||
|
||||
// calculate pad params
|
||||
auto data_index = tflite_op->inputs[2];
|
||||
const auto &data_tensor = tflite_subgraph->tensors[data_index];
|
||||
std::vector<int64_t> params;
|
||||
int status =
|
||||
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "get padding params failed";
|
||||
return nullptr;
|
||||
} else if (status == RET_OK) {
|
||||
attr->padUp = params.at(0);
|
||||
attr->padDown = params.at(1);
|
||||
attr->padLeft = params.at(2);
|
||||
attr->padRight = params.at(3);
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_DeConv2D;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteDeConv2DParser("DeConv2D", new TfliteDeConvParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteDeConvParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -60,6 +60,26 @@ STATUS TfliteDepthToSpaceParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteDepthToSpaceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
std::unique_ptr<schema::DepthToSpaceT> attr = std::make_unique<schema::DepthToSpaceT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsDepthToSpaceOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op depthtospace attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->blockSize = tflite_attr->block_size;
|
||||
attr->format = schema::Format::Format_NHWC;
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
primitive->value.type = schema::PrimitiveType_Concat;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteDepthToSpaceParser("DepthToSpace", new TfliteDepthToSpaceParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,8 +18,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
|
@ -82,7 +81,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
attr->kernelW = weight_shape[2];
|
||||
|
||||
// calculate pad params
|
||||
std::vector<int> params;
|
||||
std::vector<int64_t> params;
|
||||
int status =
|
||||
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
|
@ -104,7 +103,71 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
lite::PrimitiveC *TfliteDepthwiseConv2DParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser";
|
||||
std::unique_ptr<schema::DepthwiseConv2DT> attr = std::make_unique<schema::DepthwiseConv2DT>();
|
||||
const auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsDepthwiseConv2DOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op de attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->strideW = tflite_attr->stride_w;
|
||||
attr->strideH = tflite_attr->stride_h;
|
||||
attr->dilateH = tflite_attr->dilation_h_factor;
|
||||
attr->dilateW = tflite_attr->dilation_w_factor;
|
||||
attr->padMode = GetPadMode(tflite_attr->padding);
|
||||
attr->format = schema::Format::Format_NHWC;
|
||||
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
|
||||
attr->hasBias = true;
|
||||
attr->channelMultiplier = tflite_attr->depth_multiplier;
|
||||
|
||||
// get the data tensor
|
||||
auto data_index = tflite_op->inputs[1];
|
||||
const auto &data_tensor = tflite_subgraph->tensors[data_index];
|
||||
if (data_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the data tensor is null";
|
||||
return nullptr;
|
||||
}
|
||||
auto data_shape = data_tensor->shape;
|
||||
attr->channelIn = data_shape[3];
|
||||
|
||||
// get the weight tensor
|
||||
auto weight_index = tflite_op->inputs[1];
|
||||
const auto &weight_tensor = tflite_subgraph->tensors[weight_index];
|
||||
if (weight_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the weight tensor is null";
|
||||
return nullptr;
|
||||
}
|
||||
auto weight_shape = weight_tensor->shape;
|
||||
attr->kernelH = weight_shape[1];
|
||||
attr->kernelW = weight_shape[2];
|
||||
|
||||
// calculate pad params
|
||||
std::vector<int64_t> params;
|
||||
int status =
|
||||
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "get padding params failed";
|
||||
return nullptr;
|
||||
} else if (status == RET_OK) {
|
||||
attr->padUp = params.at(0);
|
||||
attr->padDown = params.at(1);
|
||||
attr->padLeft = params.at(2);
|
||||
attr->padRight = params.at(3);
|
||||
}
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteDepthwiseConv2DParser("DepthwiseConv2D", new TfliteDepthwiseConv2DParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -23,8 +23,7 @@
|
|||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
class TfliteDepthwiseConv2DParser : public TfliteNodeParser {
|
||||
public:
|
||||
TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {}
|
||||
|
@ -32,8 +31,10 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
|
||||
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONV_PARSER_H
|
||||
|
|
|
@ -75,6 +75,45 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteDequantizeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
|
||||
if (in_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "input tensor is null";
|
||||
return nullptr;
|
||||
}
|
||||
const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]];
|
||||
if (out_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "output tensor is null";
|
||||
return nullptr;
|
||||
}
|
||||
if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) &&
|
||||
(GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 ||
|
||||
GetTfliteDataType(in_tensor->type) == kNumberTypeUInt8)) {
|
||||
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->srcT = GetTfliteDataType(in_tensor->type);
|
||||
attr->dstT = GetTfliteDataType(out_tensor->type);
|
||||
primitive->value.value = attr.release();
|
||||
primitive->value.type = schema::PrimitiveType_QuantDTypeCast;
|
||||
} else {
|
||||
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->srcT = GetTfliteDataType(in_tensor->type);
|
||||
attr->dstT = GetTfliteDataType(out_tensor->type);
|
||||
primitive->value.value = attr.release();
|
||||
primitive->value.type = schema::PrimitiveType_Cast;
|
||||
}
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteDequantizeParser("DEQUANTIZE", new TfliteDequantizeParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -31,6 +31,9 @@ class TfliteDequantizeParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -56,6 +56,30 @@ STATUS TfliteExpandDimsParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteExpandDimsParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ExpandDimsT> attr = std::make_unique<schema::ExpandDimsT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<int> dims;
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, dims)) {
|
||||
MS_LOG(ERROR) << "get expand_dims -> dim failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->dim = dims[0];
|
||||
primitive->value.type = schema::PrimitiveType_ExpandDims;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteExpandDimsParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -57,6 +57,32 @@ STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteFillParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::FillT> attr = std::make_unique<schema::FillT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (tflite_op->inputs.size() > 1) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->dims)) {
|
||||
MS_LOG(ERROR) << "get fill -> dims failed";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Fill;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteFillParser("Fill", new TfliteFillParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,9 @@ class TfliteFillParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -69,6 +69,37 @@ STATUS TfliteFullyConnectedParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteFullyConnectedParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::FullConnectionT> attr = std::make_unique<schema::FullConnectionT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsFullyConnectedOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op fully connect attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool hasBias = tflite_op->inputs.size() > 2 && tflite_op->inputs[2] != -1;
|
||||
|
||||
attr->hasBias = hasBias;
|
||||
attr->axis = 1;
|
||||
attr->useAxis = false;
|
||||
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_FullConnection;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteFullyConnectedParser("FullyConnected", new TfliteFullyConnectedParser());
|
||||
TfliteNodeRegister g_tfliteFakeQuantParser("FakeQuant", new TfliteFullyConnectedParser());
|
||||
|
|
|
@ -32,7 +32,10 @@ class TfliteFullyConnectedParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -54,6 +54,26 @@ STATUS TfliteGatherNdParser::Parse(TfliteTensorsInfo *tensors_info, const std::u
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteGatherNdParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::GatherNdT> attr = std::make_unique<schema::GatherNdT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->batchDims = 0;
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_GatherNd;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteGatherNdParser("GatherND", new TfliteGatherNdParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteGatherNdParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -60,6 +60,32 @@ STATUS TfliteGatherParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteGatherParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::GatherT> attr = std::make_unique<schema::GatherT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsGatherOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op gather attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->axis = tflite_attr->axis;
|
||||
attr->batchDims = 0;
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Gather;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteGatherParser("Gather", new TfliteGatherParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteGatherParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -55,6 +55,24 @@ STATUS TfliteHashtableLookupParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteHashtableLookupParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::HashtableLookupT> attr = std::make_unique<schema::HashtableLookupT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_HashtableLookup;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteHashtableLookupParser("HashtableLookup", new TfliteHashtableLookupParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteHashtableLookupParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -55,6 +55,27 @@ STATUS TfliteL2NormParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteL2NormParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
std::unique_ptr<schema::L2NormT> attr = std::make_unique<schema::L2NormT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsL2NormOptions();
|
||||
attr->axis = {-1};
|
||||
attr->epsilon = 1e-6f;
|
||||
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
|
||||
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_L2Norm;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteL2NormParser("L2_NORMALIZATION", new TfliteL2NormParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteL2NormParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -78,6 +78,44 @@ STATUS TfliteLogicalParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteLogicalParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
|
||||
if (tflite_op_type == tflite::BuiltinOperator_LOGICAL_AND) {
|
||||
MS_LOG(DEBUG) << "parse TfliteLogicalAndParser";
|
||||
std::unique_ptr<schema::LogicalAndT> attr = std::make_unique<schema::LogicalAndT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_LogicalAnd;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_LOGICAL_NOT) {
|
||||
MS_LOG(DEBUG) << "parse TfliteLogicalNotParser";
|
||||
std::unique_ptr<schema::LogicalNotT> attr = std::make_unique<schema::LogicalNotT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_LogicalNot;
|
||||
primitive->value.value = attr.release();
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_LOGICAL_OR) {
|
||||
MS_LOG(DEBUG) << "parse TfliteLogicalOrParser";
|
||||
std::unique_ptr<schema::LogicalOrT> attr = std::make_unique<schema::LogicalOrT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_LogicalOr;
|
||||
primitive->value.value = attr.release();
|
||||
}
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteLogicalAndParser("LogicalAnd", new TfliteLogicalParser());
|
||||
TfliteNodeRegister g_tfliteLogicalNotParser("LogicalNot", new TfliteLogicalParser());
|
||||
|
|
|
@ -23,8 +23,7 @@
|
|||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
class TfliteLogicalParser : public TfliteNodeParser {
|
||||
public:
|
||||
TfliteLogicalParser() : TfliteNodeParser("node_name") {}
|
||||
|
@ -32,8 +31,9 @@ class TfliteLogicalParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LOGICAL_PARSER_H
|
||||
|
|
|
@ -60,6 +60,34 @@ STATUS TfliteLRNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteLRNParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::LocalResponseNormalizationT> attr = std::make_unique<schema::LocalResponseNormalizationT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsLocalResponseNormalizationOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op LRN attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->depth_radius = tflite_attr->radius;
|
||||
attr->alpha = tflite_attr->alpha;
|
||||
attr->beta = tflite_attr->beta;
|
||||
attr->bias = tflite_attr->bias;
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_LocalResponseNormalization;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteLRNParser("LocalResponseNorm", new TfliteLRNParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteLRNParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -64,6 +64,35 @@ STATUS TfliteLshProjectionParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteLshProjectionParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::LshProjectionT> attr = std::make_unique<schema::LshProjectionT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsLSHProjectionOptions();
|
||||
switch (tflite_attr->type) {
|
||||
case tflite::LSHProjectionType_SPARSE:
|
||||
attr->type = schema::LshProjectionType_SPARSE;
|
||||
break;
|
||||
case tflite::LSHProjectionType_DENSE:
|
||||
attr->type = schema::LshProjectionType_DENSE;
|
||||
break;
|
||||
default:
|
||||
attr->type = schema::LshProjectionType_UNKNOWN;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_LshProjection;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteLshProjectionParser("LshProjection", new TfliteLshProjectionParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteLshProjectionParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,79 +13,167 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/tflite/tflite_model_parser.h"
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/storage.h"
|
||||
#include "flatbuffers/flatbuffers.h"
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "src/param_value_lite.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
TfliteModelParser::TfliteModelParser() = default;
|
||||
|
||||
TfliteModelParser::~TfliteModelParser() { delete[](this->tfliteModelBuf); }
|
||||
namespace mindspore::lite {
|
||||
|
||||
std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *model_path) {
|
||||
size_t size = 0;
|
||||
tfliteModelBuf = ReadFile(model_path, &size);
|
||||
if (tfliteModelBuf == nullptr) {
|
||||
tflite_model_buf_ = ReadFile(model_path, &size);
|
||||
if (tflite_model_buf_ == nullptr) {
|
||||
MS_LOG(ERROR) << "the file buffer is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
flatbuffers::Verifier verify((const uint8_t *)tfliteModelBuf, size);
|
||||
flatbuffers::Verifier verify((const uint8_t *)tflite_model_buf_, size);
|
||||
if (!tflite::VerifyModelBuffer(verify)) {
|
||||
MS_LOG(ERROR) << "the buffer is invalid and fail to create graph";
|
||||
return nullptr;
|
||||
}
|
||||
return tflite::UnPackModel(tfliteModelBuf);
|
||||
return tflite::UnPackModel(tflite_model_buf_);
|
||||
}
|
||||
|
||||
STATUS TfliteModelParser::CopyConstTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
const tflite::TensorT *tflite_tensor, schema::TensorT *tensor) {
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
MS_ASSERT(tflite_tensor != nullptr);
|
||||
auto buffer_idx = tflite_tensor->buffer;
|
||||
FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) {
|
||||
// load graph
|
||||
tflite_model_ = ReadTfliteModel(model_file.c_str());
|
||||
if (tflite_model_ == nullptr) {
|
||||
MS_LOG(ERROR) << "read tflite model failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &buf = tflite_model_buffer[buffer_idx];
|
||||
if (buf == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor is null";
|
||||
if (tflite_model_->subgraphs.size() != 1) {
|
||||
MS_LOG(ERROR) << "read tflite model subgraphs failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||
return nullptr;
|
||||
}
|
||||
func_graph_ = std::make_shared<FuncGraph>();
|
||||
|
||||
auto status = ConvertGraphInputs();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert graph inputs failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = ConvertOps();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert ops failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = ConvertGraphOutputs();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert graph outputs failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
return func_graph_;
|
||||
}
|
||||
|
||||
STATUS TfliteModelParser::ConvertOps() {
|
||||
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
|
||||
const auto &tflite_model_buffers = tflite_model_->buffers;
|
||||
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
|
||||
STATUS status = RET_OK;
|
||||
int op_idx = 0;
|
||||
for (auto &op : tflite_subgraph->operators) {
|
||||
auto tfliteOpType = (tflite_model_->operator_codes[op->opcode_index])->builtin_code;
|
||||
auto op_type = GetMSOpType(tfliteOpType);
|
||||
auto op_name = op_type + "-" + std::to_string(op_idx);
|
||||
op_idx++;
|
||||
// parse primitive
|
||||
auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
|
||||
if (node_parser == nullptr) {
|
||||
NoSupportOp::GetInstance()->InsertOp(op_type);
|
||||
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (status != RET_OK) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto primitiveC = node_parser->ParseLitePrimitive(op, tflite_model_);
|
||||
if (primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "parse node " << op_type.c_str() << " parser failed";
|
||||
continue;
|
||||
}
|
||||
|
||||
status = ConvertOpQuantParams(op.get(), primitiveC);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "convert " << op_name << " quant param failed.";
|
||||
return status;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<lite::PrimitiveC>(primitiveC))};
|
||||
// parse inputs
|
||||
for (auto input_idx : op->inputs) {
|
||||
if (input_idx < 0) {
|
||||
input_idx += tflite_subgraph->tensors.size();
|
||||
}
|
||||
const auto &input_tensor = tflite_subgraph->tensors[input_idx];
|
||||
if (nodes_.find(input_idx) != nodes_.end()) {
|
||||
op_inputs.emplace_back(nodes_.at(input_idx));
|
||||
continue;
|
||||
}
|
||||
// const tensor
|
||||
if (!tflite_model_buffers.at(input_tensor->buffer)->data.empty()) {
|
||||
auto parameter = func_graph_->add_parameter();
|
||||
status = ConvertConstTensor(input_tensor.get(), parameter.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed.";
|
||||
return status;
|
||||
}
|
||||
op_inputs.emplace_back(parameter);
|
||||
nodes_.insert(std::pair(input_idx, parameter));
|
||||
continue;
|
||||
}
|
||||
MS_LOG(WARNING) << "tensor " << input_idx << " is neither a node output nor a weight tensor.";
|
||||
}
|
||||
auto new_cnode = func_graph_->NewCNode(op_inputs);
|
||||
new_cnode->set_fullname_with_scope(op_name);
|
||||
|
||||
// parse outputs
|
||||
status = ConvertOutputTensor(op.get(), new_cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert output tensors for " << new_cnode->fullname_with_scope() << " failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
STATUS TfliteModelParser::SetTensorQuantParam(const tflite::TensorT *tflite_tensor,
|
||||
std::vector<QuantParamT> *quant_params) {
|
||||
if (tflite_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tflite_tensor is null, set tensor quant params failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
quant_params->clear();
|
||||
|
||||
if (!buf->data.empty()) {
|
||||
auto data_size = buf->data.size();
|
||||
tensor->data.resize(data_size);
|
||||
if (memcpy_s(tensor->data.data(), data_size, buf->data.data(), data_size) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy tensor data failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "src tensor data is empty";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
if (tflite_tensor->quantization == nullptr ||
|
||||
(tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() &&
|
||||
tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty())) {
|
||||
std::vector<schema::QuantParamT> notinited_quant_params(1);
|
||||
*quant_params = notinited_quant_params;
|
||||
return RET_OK;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::TensorT> &tflite_tensor,
|
||||
schema::TensorT *tensor) {
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
tensor->quantParams.clear();
|
||||
|
||||
if (tflite_tensor->quantization == nullptr) {
|
||||
MS_LOG(ERROR) << "tflite_tensor->quantization is null";
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < tflite_tensor->quantization->scale.size(); i++) {
|
||||
std::unique_ptr<schema::QuantParamT> quant_param = std::make_unique<QuantParamT>();
|
||||
if (quant_param == nullptr) {
|
||||
MS_LOG(ERROR) << "quant_param is null";
|
||||
return;
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (!tflite_tensor->quantization->scale.empty()) {
|
||||
|
@ -104,364 +192,219 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::Tensor
|
|||
quant_param->max = tflite_tensor->quantization->max[i];
|
||||
}
|
||||
quant_param->inited = true;
|
||||
tensor->quantParams.emplace_back(std::move(quant_param));
|
||||
}
|
||||
}
|
||||
|
||||
STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
||||
const QuantType &quant_type, schema::MetaGraphT *sub_graph) {
|
||||
MS_ASSERT(tflite_model != nullptr);
|
||||
MS_ASSERT(tflite_subgraph != nullptr);
|
||||
MS_ASSERT(sub_graph != nullptr);
|
||||
|
||||
int idx = 0;
|
||||
int status = RET_OK;
|
||||
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
|
||||
for (const auto &tflite_op : tflite_subgraph->operators) {
|
||||
const auto opcode_index = tflite_op->opcode_index;
|
||||
const auto &operator_code = tflite_model->operator_codes[opcode_index];
|
||||
if (operator_code == nullptr) {
|
||||
MS_LOG(ERROR) << "operator_code is null";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto op_type = GetMSOpType(operator_code->builtin_code);
|
||||
|
||||
auto op = std::make_unique<schema::CNodeT>();
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->name = op_type + "-" + std::to_string(idx++);
|
||||
op->quantType = quant_type;
|
||||
MS_LOG(INFO) << "parse op: " << op->name.c_str();
|
||||
|
||||
auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
|
||||
if (node_parser == nullptr) {
|
||||
NoSupportOp::GetInstance()->InsertOp(op_type);
|
||||
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
|
||||
continue;
|
||||
}
|
||||
if (status == RET_OK || op_type == "Custom") {
|
||||
int status_node = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, tflite_subgraph, op.get());
|
||||
status = (status == RET_OK ? status_node : status);
|
||||
if (status_node != RET_OK) {
|
||||
if (status_node == RET_NOT_FIND_OP) {
|
||||
op_type = (op_type != "Custom" ? op_type : operator_code->custom_code);
|
||||
NoSupportOp::GetInstance()->InsertOp(op_type);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed";
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
continue;
|
||||
}
|
||||
sub_graph->nodes.emplace_back(op.release());
|
||||
opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get();
|
||||
tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get();
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::MetaGraphT *sub_graph) {
|
||||
MS_ASSERT(tflite_subgraph != nullptr);
|
||||
MS_ASSERT(sub_graph != nullptr);
|
||||
std::set<int> output_index;
|
||||
for (const auto &tflite_op : tflite_subgraph->operators) {
|
||||
for (int idx : tflite_op->outputs) {
|
||||
if (idx < 0) {
|
||||
idx += tflite_subgraph->tensors.size();
|
||||
}
|
||||
output_index.insert(idx);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < tensorsInfo.tensorsId.size(); i++) {
|
||||
auto idx = tensorsInfo.tensorsId[i];
|
||||
if (idx < 0) {
|
||||
idx += tflite_subgraph->tensors.size();
|
||||
}
|
||||
const auto &tflite_tensor = tflite_subgraph->tensors[idx];
|
||||
if (tflite_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tflite_tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::TensorT> tensor = std::make_unique<schema::TensorT>();
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
tensor->format = tensorsInfo.tensorsFormat[i];
|
||||
tensor->dataType = GetTfliteDataType(tflite_tensor->type);
|
||||
tensor->dims = tflite_tensor->shape;
|
||||
|
||||
// if graph input tensor
|
||||
bool isInput = false;
|
||||
auto tflite_inputs = tflite_subgraph->inputs;
|
||||
for (int tflite_input : tflite_inputs) {
|
||||
if (idx == tflite_input) {
|
||||
isInput = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// add data for const tensor
|
||||
auto &tensor_buffer = tflite_model_buffer.at(tflite_tensor->buffer);
|
||||
if (tensor_buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_buffer is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto isConst = (!tensor_buffer->data.empty());
|
||||
if (isConst) {
|
||||
int status = CopyConstTensorData(tflite_model_buffer, tflite_tensor.get(), tensor.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "obtain const tensor failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
// set tensor attr
|
||||
if (isInput || isConst) {
|
||||
tensor->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
} else {
|
||||
if (output_index.find(idx) == output_index.end() && tflite_tensor->shape[0] == 0) {
|
||||
tensor->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
} else {
|
||||
tensor->nodeType = schema::NodeType_Parameter;
|
||||
}
|
||||
}
|
||||
|
||||
// quant param
|
||||
if (tflite_tensor->quantization != nullptr &&
|
||||
!(tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() &&
|
||||
tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty())) {
|
||||
SetTensorQuantParam(tflite_tensor, tensor.get());
|
||||
}
|
||||
|
||||
tensors.push_back(tensor.release());
|
||||
}
|
||||
|
||||
for (auto iter : tensors) {
|
||||
std::unique_ptr<schema::TensorT> temp(iter);
|
||||
sub_graph->allTensors.emplace_back(move(temp));
|
||||
quant_params->emplace_back(*std::move(quant_param));
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
||||
schema::MetaGraphT *sub_graph) {
|
||||
MS_ASSERT(sub_graph != nullptr);
|
||||
MS_ASSERT(tflite_subgraph != nullptr);
|
||||
// graph input
|
||||
std::vector<int> graph_inputs;
|
||||
for (size_t i = 0; i < tflite_subgraph->inputs.size(); i++) {
|
||||
const int idx = tflite_subgraph->inputs[i];
|
||||
int id = idx < 0 ? idx + tflite_subgraph->tensors.size() : idx;
|
||||
auto iter = tensorsInfo.tensorsIdMap.find(id);
|
||||
if (iter != tensorsInfo.tensorsIdMap.end()) {
|
||||
graph_inputs.push_back(iter->second);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "get graph input failed";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, lite::PrimitiveC *primitive_c) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "tflite op is null, get quant params failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
sub_graph->inputIndex.assign(graph_inputs.begin(), graph_inputs.end());
|
||||
|
||||
// graph output
|
||||
std::vector<int> graph_outputs;
|
||||
for (size_t i = 0; i < tflite_subgraph->outputs.size(); i++) {
|
||||
const int idx = tflite_subgraph->outputs[i];
|
||||
int id = idx < 0 ? idx + tflite_subgraph->tensors.size() : idx;
|
||||
auto iter = tensorsInfo.tensorsIdMap.find(id);
|
||||
if (iter != tensorsInfo.tensorsIdMap.end()) {
|
||||
graph_outputs.push_back(iter->second);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "get graph output failed";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is null, get quant params failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
|
||||
for (auto input_idx : op->inputs) {
|
||||
if (input_idx < 0) {
|
||||
input_idx += tflite_subgraph->tensors.size();
|
||||
}
|
||||
const auto &input_tensor = tflite_subgraph->tensors[input_idx];
|
||||
std::vector<schema::QuantParamT> quant_params;
|
||||
auto status = SetTensorQuantParam(input_tensor.get(), &quant_params);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "set input tensor quant param failed.";
|
||||
return status;
|
||||
}
|
||||
primitive_c->AddInputQuantParam(quant_params);
|
||||
}
|
||||
for (auto output_idx : op->outputs) {
|
||||
if (output_idx < 0) {
|
||||
output_idx += tflite_subgraph->tensors.size();
|
||||
}
|
||||
const auto &output_tensor = tflite_subgraph->tensors.at(output_idx);
|
||||
std::vector<schema::QuantParamT> quant_params;
|
||||
auto status = SetTensorQuantParam(output_tensor.get(), &quant_params);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "set output tensor quant param failed.";
|
||||
return status;
|
||||
}
|
||||
primitive_c->AddOutputQuantParam(quant_params);
|
||||
}
|
||||
sub_graph->outputIndex.assign(graph_outputs.begin(), graph_outputs.end());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph) {
|
||||
MS_ASSERT(sub_graph != nullptr);
|
||||
for (auto &op : sub_graph->nodes) {
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
STATUS TfliteModelParser::ConvertGraphInputs() {
|
||||
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
|
||||
for (auto tflite_graph_input : tflite_subgraph->inputs) {
|
||||
if (tflite_graph_input < 0) {
|
||||
tflite_graph_input = tflite_graph_input + tflite_subgraph->tensors.size();
|
||||
}
|
||||
auto parameter = func_graph_->add_parameter();
|
||||
const auto &tensor = tflite_subgraph->tensors.at(tflite_graph_input);
|
||||
std::vector<int64_t> shape_vector;
|
||||
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
|
||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
||||
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
||||
parameter->set_abstract(abstract_tensor);
|
||||
parameter->set_name("graph_input_" + std::to_string(tflite_graph_input) + "_parameter");
|
||||
nodes_.insert(std::pair(tflite_graph_input, parameter));
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS TfliteModelParser::ConvertGraphOutputs() {
|
||||
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
|
||||
if (tflite_subgraph->outputs.size() > 1) {
|
||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
auto make_tuple_prim_ptr = GetMakeTuplePrim();
|
||||
if (make_tuple_prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
|
||||
auto attr = op->primitive->value.AsDepthwiseConv2D();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "attr is null";
|
||||
auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
|
||||
make_tuple_inputs.emplace_back(make_tuple_prim);
|
||||
for (auto outputNode : tflite_subgraph->outputs) {
|
||||
auto cnode = nodes_.at(outputNode);
|
||||
if (nullptr == cnode) {
|
||||
MS_LOG(ERROR) << "Can't find input node.";
|
||||
return RET_NOT_FIND_OP;
|
||||
}
|
||||
make_tuple_inputs.emplace_back(cnode);
|
||||
}
|
||||
auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs);
|
||||
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
||||
|
||||
std::vector<AnfNodePtr> op_inputs;
|
||||
auto return_prim_ptr = GetReturnPrim();
|
||||
if (return_prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto value_node = NewValueNode(return_prim_ptr);
|
||||
op_inputs.emplace_back(value_node);
|
||||
op_inputs.emplace_back(make_tuple_cnode);
|
||||
auto cnode = func_graph_->NewCNode(op_inputs);
|
||||
cnode->set_fullname_with_scope("return");
|
||||
func_graph_->set_return(cnode);
|
||||
} else {
|
||||
auto returnPrim = GetReturnPrim();
|
||||
if (returnPrim == nullptr) {
|
||||
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto valueNode = NewValueNode(returnPrim);
|
||||
std::vector<AnfNodePtr> op_inputs{valueNode};
|
||||
auto cnode = nodes_.at(tflite_subgraph->outputs.front());
|
||||
if (nullptr == cnode) {
|
||||
MS_LOG(ERROR) << "Can't find input node.";
|
||||
return RET_NOT_FIND_OP;
|
||||
}
|
||||
op_inputs.emplace_back(cnode);
|
||||
auto returnCnode = func_graph_->NewCNode(op_inputs);
|
||||
returnCnode->set_fullname_with_scope("return");
|
||||
func_graph_->set_return(returnCnode);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteModelParser::ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter) {
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor is null, get const tensor failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter is null, get const tensor failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
const auto &tfliteModelBuffers = tflite_model_->buffers;
|
||||
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
|
||||
std::vector<int64_t> shape_vector;
|
||||
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
|
||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
||||
parameter->set_abstract(abstract_tensor);
|
||||
parameter->set_name("const_" + std::to_string(nodes_.size()) + "_parameter");
|
||||
|
||||
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
||||
MS_ASSERT(param_value != nullptr);
|
||||
param_value->set_tensor_shape(tensor->shape);
|
||||
param_value->set_tensor_type(GetTfliteDataType(tensor->type));
|
||||
param_value->set_format(schema::Format::Format_NHWC);
|
||||
const auto &data = tfliteModelBuffers.at(tensor->buffer)->data;
|
||||
if (!data.empty()) {
|
||||
auto size = data.size();
|
||||
char *tensor_data = new (std::nothrow) char[size];
|
||||
if (tensor_data == nullptr) {
|
||||
MS_LOG(ERROR) << "new char[] failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
std::memcpy(tensor_data, data.data(), size);
|
||||
param_value->set_tensor_addr(tensor_data);
|
||||
param_value->set_tensor_size(size);
|
||||
parameter->set_default_param(param_value);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const CNodePtr &dst_cnode) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null, get output tensor failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (dst_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter is null, get output tensor failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
const auto &tflite_subgraph = tflite_model_->subgraphs.front();
|
||||
if (op->outputs.size() == 1) {
|
||||
int output_idx =
|
||||
op->outputs.front() < 0 ? tflite_subgraph->tensors.size() + op->outputs.front() : op->outputs.front();
|
||||
const auto &tensor = tflite_subgraph->tensors.at(output_idx);
|
||||
std::vector<int64_t> shape_vector;
|
||||
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
|
||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
||||
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
|
||||
dst_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
|
||||
nodes_.insert(std::pair(op->outputs.front(), dst_cnode));
|
||||
} else {
|
||||
AbstractBasePtrList abstract_list;
|
||||
int op_idx = 0;
|
||||
for (auto output_idx : op->outputs) {
|
||||
if (output_idx < 0) {
|
||||
output_idx = output_idx + tflite_subgraph->tensors.size();
|
||||
}
|
||||
const auto &tensor = tflite_subgraph->tensors.at(output_idx);
|
||||
std::vector<int64_t> shape_vector;
|
||||
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
|
||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
||||
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
|
||||
abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
|
||||
auto tuple_get_item_prim_ptr = GetTupleGetItemPrim();
|
||||
if (tuple_get_item_prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (attr->channelMultiplier > 1) {
|
||||
// get channel attr
|
||||
if (op->inputIndex.empty()) {
|
||||
MS_LOG(ERROR) << "the input of DepthwiseConv2D is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
const auto data_id = op->inputIndex[0];
|
||||
if (sub_graph->allTensors.size() <= data_id) {
|
||||
MS_LOG(ERROR) << "the number of allTensors is less than " << data_id;
|
||||
return RET_ERROR;
|
||||
}
|
||||
const auto &data_tensor = sub_graph->allTensors.at(data_id);
|
||||
if (data_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "the data tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto data_shape = data_tensor->dims;
|
||||
if (data_shape.empty()) {
|
||||
MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain only when running";
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
std::unique_ptr<schema::Conv2DT> conv_attr = std::make_unique<schema::Conv2DT>();
|
||||
if (conv_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "conv_attr is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (data_shape[3] == 1) {
|
||||
conv_attr->channelIn = data_shape[3];
|
||||
conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier;
|
||||
|
||||
// update attr
|
||||
conv_attr->group = 1;
|
||||
conv_attr->format = attr->format;
|
||||
conv_attr->kernelH = attr->kernelH;
|
||||
conv_attr->kernelW = attr->kernelW;
|
||||
conv_attr->strideH = attr->strideH;
|
||||
conv_attr->strideW = attr->strideW;
|
||||
conv_attr->padMode = attr->padMode;
|
||||
conv_attr->padUp = attr->padUp;
|
||||
conv_attr->padDown = attr->padDown;
|
||||
conv_attr->padLeft = attr->padLeft;
|
||||
conv_attr->padRight = attr->padRight;
|
||||
conv_attr->dilateH = attr->dilateH;
|
||||
conv_attr->dilateW = attr->dilateW;
|
||||
conv_attr->hasBias = attr->hasBias;
|
||||
conv_attr->activationType = attr->activationType;
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
op->primitive->value.value = conv_attr.release();
|
||||
|
||||
// update weight
|
||||
auto weight_id = op->inputIndex[1];
|
||||
auto &weight_tensor = sub_graph->allTensors.at(weight_id);
|
||||
if (weight_tensor->dataType == TypeId::kNumberTypeUInt8) {
|
||||
auto status = TransFilterFormat<uint8_t>(weight_tensor.get(), kKHWC2CHWK);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Trans depthwiseConv Filter schema::Format failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (weight_tensor->dataType == kNumberTypeInt8) {
|
||||
auto status = TransFilterFormat<int8_t>(weight_tensor.get(), kKHWC2CHWK);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Trans filter format failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (weight_tensor->dataType == kNumberTypeFloat32 || weight_tensor->dataType == kNumberTypeFloat) {
|
||||
auto status = TransFilterFormat<float>(weight_tensor.get(), kKHWC2CHWK);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Trans filter format failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "The dataType of weight tensor is unsupported.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight_tensor->format = schema::Format::Format_CHWK;
|
||||
}
|
||||
}
|
||||
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
|
||||
auto get_item_value = NewValueNode(MakeValue<int>(op_idx));
|
||||
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, dst_cnode, get_item_value};
|
||||
CNodePtr get_item_cnode = func_graph_->NewCNode(inputs);
|
||||
get_item_cnode->set_fullname_with_scope(dst_cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx));
|
||||
nodes_.insert(std::pair(output_idx, get_item_cnode));
|
||||
op_idx++;
|
||||
}
|
||||
dst_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::MetaGraphT> TfliteModelParser::ConstructMainGraph(
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, const QuantType &quant_type) {
|
||||
MS_ASSERT(tflite_model != nullptr);
|
||||
if (tflite_model->subgraphs.empty()) {
|
||||
MS_LOG(ERROR) << "read tflite model main subgraphs failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||
return nullptr;
|
||||
}
|
||||
const auto &tflite_subgraph = tflite_model->subgraphs[0];
|
||||
|
||||
auto meta_graph = std::make_unique<schema::MetaGraphT>();
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "new meta graph failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return nullptr;
|
||||
}
|
||||
meta_graph->name = "MS_model converted by TF-Lite";
|
||||
quantType = quant_type;
|
||||
// convert op
|
||||
int status = ConvertOp(tflite_model, tflite_subgraph, quant_type, meta_graph.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "parse op failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// convert tensor
|
||||
status = ConvertTensor(tflite_subgraph, tflite_model->buffers, meta_graph.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "convert tensor failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// set graph input/output
|
||||
status = GetGraphInfo(tflite_subgraph, meta_graph.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "convert tensors failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// update for depthwiseConv
|
||||
status = ConvertGroupDepthwiseOp(meta_graph.get());
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "convert group depthwise conv failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return meta_graph;
|
||||
MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) {
|
||||
if (model_file.empty()) {
|
||||
MS_LOG(ERROR) << "model_file is empty";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// load graph
|
||||
auto tflite_model = ReadTfliteModel(model_file.c_str());
|
||||
if (tflite_model == nullptr) {
|
||||
MS_LOG(ERROR) << "read tflite model failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// construct main_meta_graph
|
||||
auto main_meta_graph = ConstructMainGraph(tflite_model, quant_type);
|
||||
if (main_meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "ConstructMainGraph failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return main_meta_graph.release();
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,67 +13,42 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef LITE_TFLITE_MODEL_PARSER_H
|
||||
#define LITE_TFLITE_MODEL_PARSER_H
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <unistd.h>
|
||||
#include <google/protobuf/io/coded_stream.h>
|
||||
#include <google/protobuf/io/zero_copy_stream_impl.h>
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include "securec/include/securec.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class TfliteModelParser : public ModelParser {
|
||||
public:
|
||||
TfliteModelParser();
|
||||
TfliteModelParser() = default;
|
||||
|
||||
~TfliteModelParser() override;
|
||||
~TfliteModelParser() override = default;
|
||||
|
||||
schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quantTyp) override;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<tflite::ModelT> ReadTfliteModel(const char *model_path);
|
||||
|
||||
static STATUS CopyConstTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
const tflite::TensorT *tflite_tensor, schema::TensorT *tensor);
|
||||
|
||||
static void SetTensorQuantParam(const std::unique_ptr<tflite::TensorT> &tflite_tensor, schema::TensorT *tensor);
|
||||
|
||||
STATUS ConvertOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, const QuantType &quant_type,
|
||||
schema::MetaGraphT *sub_graph);
|
||||
|
||||
STATUS ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::MetaGraphT *sub_graph);
|
||||
|
||||
STATUS GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::MetaGraphT *sub_graph);
|
||||
|
||||
static STATUS ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph);
|
||||
|
||||
QuantType quantType = QuantType_QUANT_NONE;
|
||||
char *tfliteModelBuf = nullptr;
|
||||
std::unique_ptr<schema::MetaGraphT> ConstructMainGraph(const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const QuantType &quant_type);
|
||||
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) override;
|
||||
MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) override;
|
||||
|
||||
private:
|
||||
TfliteTensorsInfo tensorsInfo;
|
||||
std::vector<schema::TensorT *> tensors;
|
||||
|
||||
std::map<std::string, schema::CNodeT *> opMap;
|
||||
std::map<const tflite::OperatorT *, schema::CNodeT *> tfliteOpMap;
|
||||
std::unordered_map<int, AnfNodePtr> nodes_;
|
||||
std::unique_ptr<tflite::ModelT> tflite_model_;
|
||||
FuncGraphPtr func_graph_;
|
||||
char *tflite_model_buf_ = nullptr;
|
||||
std::unique_ptr<tflite::ModelT> ReadTfliteModel(const char *model_path);
|
||||
STATUS ConvertConstTensor(const tflite::TensorT *tensor, Parameter *parameter);
|
||||
STATUS ConvertOutputTensor(const tflite::OperatorT *op, const CNodePtr &dst_cnode);
|
||||
STATUS ConvertOpQuantParams(const tflite::OperatorT *op, lite::PrimitiveC *primitive_c);
|
||||
STATUS ConvertOps();
|
||||
STATUS ConvertGraphInputs();
|
||||
STATUS ConvertGraphOutputs();
|
||||
STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector<QuantParamT> *quant_params);
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MODEL_PARSER_H
|
||||
#endif // LITE_TFLITE_MODEL_PARSER_H
|
||||
|
|
|
@ -41,9 +41,9 @@ class TfliteNodeParser {
|
|||
virtual STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) = 0;
|
||||
virtual STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model, PrimitiveC *primitiveC) {
|
||||
return RET_OK;
|
||||
virtual lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static void AddOpInput(schema::CNodeT *op, TfliteTensorsInfo *tensors_info, int idx, int total,
|
||||
|
|
|
@ -64,6 +64,38 @@ STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteOneHotParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::OneHotT> attr = std::make_unique<schema::OneHotT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsOneHotOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op onehot attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto axis = tflite_attr->axis;
|
||||
const auto &tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor is null";
|
||||
return nullptr;
|
||||
}
|
||||
attr->axis = axis;
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_OneHot;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteOneHotParser("OneHot", new TfliteOneHotParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteOneHotParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -90,6 +90,59 @@ STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TflitePadParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::PadT> attr = std::make_unique<schema::PadT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
|
||||
if (tflite_op_type == tflite::BuiltinOperator_PAD) {
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsPadOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op pad attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->paddingMode = schema::PaddingMode_CONSTANT;
|
||||
attr->constantValue = 0.0f;
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->paddings)) {
|
||||
MS_LOG(ERROR) << "get pad -> paddings failed";
|
||||
return nullptr;
|
||||
}
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_MIRROR_PAD) {
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsMirrorPadOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op pad attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
switch (tflite_attr->mode) {
|
||||
case tflite::MirrorPadMode_REFLECT:
|
||||
attr->paddingMode = schema::PaddingMode_REFLECT;
|
||||
break;
|
||||
case tflite::MirrorPadMode_SYMMETRIC:
|
||||
attr->paddingMode = schema::PaddingMode_SYMMETRIC;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support";
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "this pad:" << tflite_op_type << " hasn't been supported";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Pad;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tflitePadParser("Pad", new TflitePadParser());
|
||||
TfliteNodeRegister g_tfliteMirorPadParser("MirrorPad", new TflitePadParser());
|
||||
|
|
|
@ -32,6 +32,8 @@ class TflitePadParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,8 +19,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
|
@ -43,17 +42,13 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::vector<std::string> node_name_str;
|
||||
Split(op->name, &node_name_str, "-");
|
||||
const char *node_name = node_name_str.data()->c_str();
|
||||
if (std::strcmp(node_name, "MeanPooling") == 0) {
|
||||
MS_LOG(DEBUG) << "parser TfliteMeanPoolingParser";
|
||||
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
|
||||
if (tflite_op_type == tflite::BuiltinOperator_AVERAGE_POOL_2D) {
|
||||
attr->poolingMode = schema::PoolMode_MEAN_POOLING;
|
||||
} else if (std::strcmp(node_name, "MaxPooling") == 0) {
|
||||
MS_LOG(DEBUG) << "parse TfliteMaxPoolingParser";
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_MAX_POOL_2D) {
|
||||
attr->poolingMode = schema::PoolMode_MAX_POOLING;
|
||||
} else {
|
||||
MS_LOG(ERROR) << node_name << " hasn't been supported";
|
||||
MS_LOG(ERROR) << "pooling mode " << tflite_op_type << " hasn't been supported";
|
||||
return RET_NOT_FIND_OP;
|
||||
}
|
||||
|
||||
|
@ -75,7 +70,7 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
// calculate pad params
|
||||
auto data_index = tflite_op->inputs[0];
|
||||
const auto &data_tensor = tflite_subgraph->tensors[data_index];
|
||||
std::vector<int> params;
|
||||
std::vector<int64_t> params;
|
||||
int status =
|
||||
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
|
@ -95,8 +90,58 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
lite::PrimitiveC *TflitePoolingParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
const auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
std::unique_ptr<schema::PoolingT> attr = std::make_unique<schema::PoolingT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
|
||||
if (tflite_op_type == tflite::BuiltinOperator_AVERAGE_POOL_2D) {
|
||||
attr->poolingMode = schema::PoolMode_MEAN_POOLING;
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_MAX_POOL_2D) {
|
||||
attr->poolingMode = schema::PoolMode_MAX_POOLING;
|
||||
}
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsPool2DOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op pooling attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->windowW = tflite_attr->filter_width;
|
||||
attr->windowH = tflite_attr->filter_height;
|
||||
attr->strideW = tflite_attr->stride_w;
|
||||
attr->strideH = tflite_attr->stride_h;
|
||||
attr->padMode = GetPadMode(tflite_attr->padding);
|
||||
attr->format = schema::Format::Format_NHWC;
|
||||
|
||||
attr->global = false;
|
||||
attr->roundMode = schema::RoundMode_FLOOR;
|
||||
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
|
||||
|
||||
// calculate pad params
|
||||
auto data_index = tflite_op->inputs[0];
|
||||
const auto &data_tensor = tflite_subgraph->tensors[data_index];
|
||||
std::vector<int64_t> params;
|
||||
int status =
|
||||
getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "get padding params failed";
|
||||
return nullptr;
|
||||
} else if (status == RET_OK) {
|
||||
attr->padUp = params.at(0);
|
||||
attr->padDown = params.at(1);
|
||||
attr->padLeft = params.at(2);
|
||||
attr->padRight = params.at(3);
|
||||
}
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
primitive->value.type = schema::PrimitiveType_Pooling;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteMeanPoolingParser("MeanPooling", new TflitePoolingParser());
|
||||
TfliteNodeRegister g_tfliteMaxPoolingParser("MaxPooling", new TflitePoolingParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -23,8 +23,7 @@
|
|||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
class TflitePoolingParser : public TfliteNodeParser {
|
||||
public:
|
||||
TflitePoolingParser() : TfliteNodeParser("node_name") {}
|
||||
|
@ -32,8 +31,9 @@ class TflitePoolingParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_POOLING_PARSER_H
|
||||
|
|
|
@ -52,6 +52,24 @@ STATUS TflitePReLUParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TflitePReLUParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::PReLUT> attr = std::make_unique<schema::PReLUT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->channelShared = true;
|
||||
primitive->value.type = schema::PrimitiveType_PReLU;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tflitePReLUParser("PRELU", new TflitePReLUParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TflitePReLUParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -74,6 +74,50 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteQuantizeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &in_tensor = tflite_subgraph->tensors[tflite_op->inputs[0]];
|
||||
if (in_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "input tensor is null";
|
||||
return nullptr;
|
||||
}
|
||||
const auto &out_tensor = tflite_subgraph->tensors[tflite_op->outputs[0]];
|
||||
if (out_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "output tensor is null";
|
||||
return nullptr;
|
||||
}
|
||||
if (GetTfliteDataType(in_tensor->type) != GetTfliteDataType(out_tensor->type) &&
|
||||
(GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 ||
|
||||
GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8)) {
|
||||
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->srcT = GetTfliteDataType(in_tensor->type);
|
||||
attr->dstT = GetTfliteDataType(out_tensor->type);
|
||||
primitive->value.type = schema::PrimitiveType_QuantDTypeCast;
|
||||
primitive->value.value = attr.release();
|
||||
} else {
|
||||
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->srcT = GetTfliteDataType(in_tensor->type);
|
||||
attr->dstT = GetTfliteDataType(out_tensor->type);
|
||||
primitive->value.type = schema::PrimitiveType_Cast;
|
||||
primitive->value.value = attr.release();
|
||||
}
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteQuantizeParser("QUANTIZE", new TfliteQuantizeParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -31,6 +31,8 @@ class TfliteQuantizeParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -71,6 +71,43 @@ STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteRangeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::RangeT> attr = std::make_unique<schema::RangeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->dType = 0;
|
||||
std::vector<int> limit;
|
||||
std::vector<int> delta;
|
||||
int status = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, limit);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "range -> limit get failed";
|
||||
return nullptr;
|
||||
} else if (status == RET_OK) {
|
||||
status = GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, delta);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "stridedSlice -> end get failed";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
if (status == RET_OK) {
|
||||
attr->limit = limit.front();
|
||||
attr->delta = delta.front();
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_Range;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteRangeParser("Range", new TfliteRangeParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteRangeParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -50,6 +50,24 @@ STATUS TfliteRankParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteRankParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::RankT> attr = std::make_unique<schema::RankT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Rank;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteRankParser("Rank", new TfliteRankParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteRankParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -85,11 +85,65 @@ STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteReduceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteSumParser("Sum", new TfliteReduceParser());
|
||||
TfliteNodeRegister g_tfliteMeanParser("Mean", new TfliteReduceParser());
|
||||
TfliteNodeRegister g_tfliteReduceMaxParser("ReduceMax", new TfliteReduceParser());
|
||||
TfliteNodeRegister g_tfliteReduceMinParser("ReduceMin", new TfliteReduceParser());
|
||||
TfliteNodeRegister g_tfliteReduceProdParser("ReduceProd", new TfliteReduceParser());
|
||||
std::unique_ptr<schema::ReduceT> attr = std::make_unique<schema::ReduceT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsReducerOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op reduce attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->keepDims = tflite_attr->keep_dims;
|
||||
|
||||
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
|
||||
if (tflite_op_type == tflite::BuiltinOperator_REDUCE_MAX) {
|
||||
MS_LOG(DEBUG) << "parse TfliteReduceMaxParser";
|
||||
attr->mode = schema::ReduceMode_ReduceMax;
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_REDUCE_MIN) {
|
||||
MS_LOG(DEBUG) << "parse TfliteReduceMinParser";
|
||||
attr->mode = schema::ReduceMode_ReduceMin;
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_REDUCE_PROD) {
|
||||
MS_LOG(DEBUG) << "parse TfliteReduceProdParser";
|
||||
attr->mode = schema::ReduceMode_ReduceProd;
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_SUM) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSumParser";
|
||||
attr->mode = schema::ReduceMode_ReduceSum;
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_MEAN) {
|
||||
MS_LOG(DEBUG) << "parse TfliteMeanParser";
|
||||
attr->mode = schema::ReduceMode_ReduceMean;
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_REDUCE_ANY) {
|
||||
// attr->mode;
|
||||
MS_LOG(ERROR) << "ms-lite haven't supported REDUCE_ANY now";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->axes)) {
|
||||
MS_LOG(ERROR) << "get reduce -> axes failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Reduce;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_TfliteSumParser("Sum", new TfliteReduceParser());
|
||||
TfliteNodeRegister g_TfliteMeanParser("Mean", new TfliteReduceParser());
|
||||
TfliteNodeRegister g_TfliteReduceMaxParser("ReduceMax", new TfliteReduceParser());
|
||||
TfliteNodeRegister g_TfliteReduceMinParser("ReduceMin", new TfliteReduceParser());
|
||||
TfliteNodeRegister g_TfliteReduceProdParser("ReduceProd", new TfliteReduceParser());
|
||||
TfliteNodeRegister g_TfliteReduceAnyParser("ReduceAny", new TfliteReduceParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,8 +23,7 @@
|
|||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
class TfliteReduceParser : public TfliteNodeParser {
|
||||
public:
|
||||
TfliteReduceParser() : TfliteNodeParser("node_name") {}
|
||||
|
@ -32,8 +31,9 @@ class TfliteReduceParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_REDUCE_PARSER_H
|
||||
|
|
|
@ -18,8 +18,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
|
@ -43,8 +42,8 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
const auto &tfliteAttr = tflite_op->builtin_options.AsReshapeOptions();
|
||||
if (tfliteAttr == nullptr) {
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsReshapeOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
if (tflite_op->inputs.size() < 2) {
|
||||
MS_LOG(ERROR) << "expected two input tensors, but got: " << tflite_op->inputs.size();
|
||||
return RET_ERROR;
|
||||
|
@ -68,9 +67,9 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
}
|
||||
} else {
|
||||
attr->format = schema::Format::Format_NHWC;
|
||||
attr->shape.resize(tfliteAttr->new_shape.size());
|
||||
for (size_t i = 0; i < tfliteAttr->new_shape.size(); ++i) {
|
||||
attr->shape[i] = tfliteAttr->new_shape[i];
|
||||
attr->shape.resize(tflite_attr->new_shape.size());
|
||||
for (size_t i = 0; i < tflite_attr->new_shape.size(); ++i) {
|
||||
attr->shape[i] = tflite_attr->new_shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -83,7 +82,50 @@ STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
lite::PrimitiveC *TfliteReshapeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
const auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
std::unique_ptr<schema::ReshapeT> attr = std::make_unique<schema::ReshapeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsReshapeOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
if (tflite_op->inputs.size() < 2) {
|
||||
MS_LOG(ERROR) << "expected two input tensors, but got: " << tflite_op->inputs.size();
|
||||
return nullptr;
|
||||
}
|
||||
auto shape_tensor_index = tflite_op->inputs[1];
|
||||
const auto &shape_tensor = tflite_subgraph->tensors[shape_tensor_index];
|
||||
if (shape_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "shape_tensor is null";
|
||||
return nullptr;
|
||||
}
|
||||
auto &buf_data = tflite_model->buffers[shape_tensor->buffer];
|
||||
if (buf_data == nullptr) {
|
||||
MS_LOG(ERROR) << "buf_data is null";
|
||||
return nullptr;
|
||||
}
|
||||
if (!buf_data->data.empty()) {
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->shape)) {
|
||||
MS_LOG(ERROR) << "get reshape -> shape failed";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
attr->format = schema::Format::Format_NHWC;
|
||||
attr->shape.resize(tflite_attr->new_shape.size());
|
||||
for (size_t i = 0; i < tflite_attr->new_shape.size(); ++i) {
|
||||
attr->shape[i] = tflite_attr->new_shape[i];
|
||||
}
|
||||
}
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
primitive->value.type = schema::PrimitiveType_Reshape;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteReshapeParser("Reshape", new TfliteReshapeParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -23,8 +23,7 @@
|
|||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
class TfliteReshapeParser : public TfliteNodeParser {
|
||||
public:
|
||||
TfliteReshapeParser() : TfliteNodeParser("Reshape") {}
|
||||
|
@ -32,8 +31,10 @@ class TfliteReshapeParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
|
||||
lite::PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESHAPE_PARSER_H
|
||||
|
|
|
@ -120,6 +120,89 @@ STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteResizeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ResizeT> attr = std::make_unique<schema::ResizeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->coordinateTransformMode = schema::CoordinateTransformMode_COMMON;
|
||||
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
|
||||
if (tflite_op_type == tflite::BuiltinOperator_RESIZE_BILINEAR) {
|
||||
MS_LOG(DEBUG) << "parse TfliteResizeBilinearParser";
|
||||
const auto &tfliteAttr = tflite_op->builtin_options.AsResizeBilinearOptions();
|
||||
if (tfliteAttr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op ResizeBilinear attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
if (tfliteAttr->align_corners) {
|
||||
attr->alignCorners = tfliteAttr->align_corners;
|
||||
attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS;
|
||||
}
|
||||
if (tfliteAttr->half_pixel_centers) {
|
||||
attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON
|
||||
? schema::CoordinateTransformMode_TF_HALF_PIXEL
|
||||
: schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL);
|
||||
}
|
||||
attr->method = schema::ResizeMethod_LINEAR;
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR) {
|
||||
MS_LOG(DEBUG) << "parse TfliteResizeNearestNeighborParser";
|
||||
const auto &tfliteAttr = tflite_op->builtin_options.AsResizeNearestNeighborOptions();
|
||||
if (tfliteAttr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op ResizeNearestNeighbor attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
if (tfliteAttr->align_corners) {
|
||||
attr->alignCorners = tfliteAttr->align_corners;
|
||||
attr->coordinateTransformMode = schema::CoordinateTransformMode_ALIGN_CORNERS;
|
||||
}
|
||||
if (tfliteAttr->half_pixel_centers) {
|
||||
attr->coordinateTransformMode = (attr->coordinateTransformMode == schema::CoordinateTransformMode_COMMON
|
||||
? schema::CoordinateTransformMode_TF_HALF_PIXEL
|
||||
: schema::CoordinateTransformMode_ALIGN_CORNERS_WITH_HALF_PIEXL);
|
||||
}
|
||||
attr->method = schema::ResizeMethod_NEAREST;
|
||||
attr->nearestMode = schema::NearestMode_NORMAL;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "wrong resize type";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->format = schema::Format::Format_NHWC;
|
||||
attr->preserveAspectRatio = false;
|
||||
|
||||
auto tfliteResizeTensorIndex = tflite_op->inputs[1];
|
||||
const auto &shape_tensor = tflite_subgraph->tensors[tfliteResizeTensorIndex];
|
||||
if (shape_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "shape_tensor is null";
|
||||
return nullptr;
|
||||
}
|
||||
auto resizeTensorBufferIndex = shape_tensor->buffer;
|
||||
const auto &buff = tflite_model->buffers.at(resizeTensorBufferIndex);
|
||||
if (buff == nullptr) {
|
||||
MS_LOG(ERROR) << "buff_data is null";
|
||||
return nullptr;
|
||||
}
|
||||
auto buffData = reinterpret_cast<int32_t *>(buff->data.data());
|
||||
if (buffData != nullptr) {
|
||||
auto height = buffData[0];
|
||||
auto width = buffData[1];
|
||||
attr->newWidth = width;
|
||||
attr->newHeight = height;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Resize;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteResizeBilinearParser("ResizeBilinear", new TfliteResizeParser());
|
||||
TfliteNodeRegister g_tfliteResizeNearestNeighborParser("NearestNeighbor", new TfliteResizeParser());
|
||||
|
|
|
@ -23,8 +23,7 @@
|
|||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace mindspore::lite {
|
||||
class TfliteResizeParser : public TfliteNodeParser {
|
||||
public:
|
||||
TfliteResizeParser() : TfliteNodeParser("node_name") {}
|
||||
|
@ -32,8 +31,9 @@ class TfliteResizeParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_RESIZE_PARSER_H
|
||||
|
|
|
@ -55,6 +55,30 @@ STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::un
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteReverseParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ReverseT> attr = std::make_unique<schema::ReverseT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->axis)) {
|
||||
MS_LOG(ERROR) << "get reverse -> axis failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Reverse;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteReverseParser("reverse", new TfliteReverseParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteReverseParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -62,6 +62,32 @@ STATUS TfliteReverseSequenceParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteReverseSequenceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ReverseSequenceT> attr = std::make_unique<schema::ReverseSequenceT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsReverseSequenceOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op reverse attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->seqAxis = tflite_attr->seq_dim;
|
||||
attr->batchAxis = tflite_attr->batch_dim;
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_ReverseSequence;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteReverseSequenceParser("ReverseSequence", new TfliteReverseSequenceParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteReverseSequenceParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -58,6 +58,29 @@ STATUS TfliteScatterNdParser::Parse(TfliteTensorsInfo *tensors_info,
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteScatterNdParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ScatterNDT> attr = std::make_unique<schema::ScatterNDT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsScatterNdOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op ScatterNd attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_ScatterND;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteScatterNdParser("ScatterNd", new TfliteScatterNdParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteScatterNdParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -50,6 +50,24 @@ STATUS TfliteShapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteShapeParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ShapeT> attr = std::make_unique<schema::ShapeT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_Shape;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteShapeParser("Shape", new TfliteShapeParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteShapeParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -59,6 +59,33 @@ STATUS TfliteSkipGramParser::Parse(TfliteTensorsInfo *tensors_info, const std::u
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteSkipGramParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::SkipGramT> attr = std::make_unique<schema::SkipGramT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsSkipGramOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
attr->includeAllGrams = tflite_attr->include_all_ngrams;
|
||||
attr->maxSkipSize = tflite_attr->max_skip_size;
|
||||
attr->ngramSize = tflite_attr->ngram_size;
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_SkipGram;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteSkiGramParser("SKipGram", new TfliteSkipGramParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteSkipGramParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -66,6 +66,41 @@ STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq
|
|||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *TfliteSliceParser::ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto &tflite_subgraph = tflite_model->subgraphs.front();
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::SliceT> attr = std::make_unique<schema::SliceT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
attr->format = schema::Format::Format_NHWC;
|
||||
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, attr->begin)) {
|
||||
MS_LOG(ERROR) << "get slice -> begin failed";
|
||||
return nullptr;
|
||||
}
|
||||
if (GetTfliteData(tflite_op->inputs[2], tflite_subgraph->tensors, tflite_model->buffers, attr->size)) {
|
||||
MS_LOG(ERROR) << "get slice -> size failed";
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<int> axes;
|
||||
axes.clear();
|
||||
for (size_t i = 0; i < attr->begin.size(); ++i) {
|
||||
axes.push_back(i);
|
||||
}
|
||||
attr->axes = axes;
|
||||
primitive->value.type = schema::PrimitiveType_Slice;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteSliceParser("Slice", new TfliteSliceParser());
|
||||
} // namespace lite
|
||||
|
|
|
@ -32,6 +32,8 @@ class TfliteSliceParser : public TfliteNodeParser {
|
|||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
PrimitiveC *ParseLitePrimitive(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue