!6230 MSLITE add tflite hashtable_lookup,mirror_pad,skipGram parser;adjust const fold adapt to datatype

Merge pull request !6230 from 徐安越/master
This commit is contained in:
mindspore-ci-bot 2020-09-15 15:42:54 +08:00 committed by Gitee
commit 4b84c6ee98
16 changed files with 312 additions and 52 deletions

View File

@ -32,6 +32,7 @@ constexpr int RET_PARAM_INVALID = -3; /**< Invalid parameter.*/
constexpr int RET_NO_CHANGE = -4; /**< No change. */
constexpr int RET_SUCCESS_EXIT = -5; /**< No error but exit. */
constexpr int RET_MEMORY_FAILED = -6; /**< Fail to create memory. */
constexpr int RET_NOT_SUPPORT = -7; /**< Fail to support. */
/* Executor error code, range: [-101,-200] */
constexpr int RET_OUT_OF_TENSOR_RANGE = -101; /**< Failed to check range. */
@ -53,6 +54,10 @@ constexpr int RET_FORMAT_ERR = -401; /**< Failed to checking tensor format. */
/* InferShape error code, range: [-501,-600] */
constexpr int RET_INFER_ERR = -501; /**< Failed to infer shape. */
constexpr int RET_INFER_INVALID = -502; /**< Invalid infer shape before runtime. */
/* User input param error code, range: [-601, 700]*/
constexpr int RET_INPUT_PARAM_INVALID = -601; /**< Invalid input param by user. */
constexpr int RET_INPUT_PARAM_LACK = -602; /**< LACK input param by user. */
} // namespace lite
} // namespace mindspore

View File

@ -203,6 +203,8 @@ union PrimitiveType {
LogGrad,
BatchToSpaceND,
LshProjection,
HashtableLookup,
SkipGram,
}
enum QuantType: int {

View File

@ -948,3 +948,12 @@ table BlackBox {
table LshProjection {
type : LshProjectionType;
}
table HashtableLookup {
}
table SkipGram {
includeAllGrams : bool;
maxSkipSize : int;
ngramSize : int;
}

View File

@ -109,15 +109,15 @@ int RunConverter(int argc, const char **argv) {
std::unique_ptr<converter::Flags> flags(new (std::nothrow) converter::Flags);
if (flags == nullptr) {
MS_LOG(ERROR) << "new flags error ";
std::cout << "NEW FLAGS ERROR:" << RET_MEMORY_FAILED << std::endl;
return RET_MEMORY_FAILED;
}
auto status = flags->Init(argc, argv);
if (status == RET_SUCCESS_EXIT) {
return status;
}
if (status != 0) {
MS_LOG(ERROR) << "converter::Flags Init failed: " << status;
std::cout << "CONVERTER::FLAGS INIT FAILED" << std::endl;
if (status != RET_OK) {
if (status != RET_SUCCESS_EXIT) {
MS_LOG(ERROR) << "converter::Flags Init failed: " << status;
std::cout << "CONVERTER::FLAGS INIT FAILED:" << status << std::endl;
}
return status;
}
// Load graph
@ -148,13 +148,14 @@ int RunConverter(int argc, const char **argv) {
} break;
default: {
MS_LOG(ERROR) << "Unsupported fmkType: " << flags->fmk;
return 1;
std::cout << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << std::endl;
return RET_INPUT_PARAM_INVALID;
}
}
status = ReturnCode::GetSingleReturnCode()->GetReturnCode();
if (fb_graph == nullptr) {
MS_LOG(ERROR) << "Convert model return nullptr";
std::cout << "CONVERT RESULT: FAILED!" << std::endl;
std::cout << "CONVERT RESULT FAILED:" << status << std::endl;
return status;
}
@ -164,14 +165,14 @@ int RunConverter(int argc, const char **argv) {
status = storage.Save(*fb_graph, flags->outputFile);
if (status != 0) {
MS_LOG(ERROR) << "Save graph failed";
std::cout << "SAVE GRAPH FAILED!" << std::endl;
return RET_ERROR;
std::cout << "SAVE GRAPH FAILED:" << status << std::endl;
return status;
}
delete fb_graph;
MS_LOG(INFO) << "CONVERT RESULT: SUCCESS!";
std::cout << "CONVERT RESULT: SUCCESS!" << std::endl;
return RET_OK;
std::cout << "CONVERT RESULT SUCCESS:" << status << std::endl;
return status;
}
} // namespace lite
} // namespace mindspore

View File

@ -55,7 +55,7 @@ int Flags::Init(int argc, const char **argv) {
if (err.IsSome()) {
std::cerr << err.Get();
std::cerr << this->Usage() << std::endl;
return 1;
return RET_INPUT_PARAM_INVALID;
}
if (this->help) {
@ -64,21 +64,21 @@ int Flags::Init(int argc, const char **argv) {
}
if (this->modelFile.empty()) {
std::cerr << "INPUT MISSING: model file path is necessary";
return 1;
return RET_INPUT_PARAM_LACK;
}
if (this->outputFile.empty()) {
std::cerr << "INPUT MISSING: output file path is necessary";
return 1;
return RET_INPUT_PARAM_LACK;
}
if (this->outputFile.rfind('/') == this->outputFile.length() - 1) {
std::cerr << "INPUT ILLEGAL: outputFile must be a valid file path";
return 1;
return RET_INPUT_PARAM_INVALID;
}
if (this->fmkIn.empty()) {
std::cerr << "INPUT MISSING: fmk is necessary";
return 1;
return RET_INPUT_PARAM_LACK;
}
if (this->inputInferenceTypeIn == "FLOAT") {
this->inputInferenceType = TypeId::kNumberTypeFloat;
@ -87,7 +87,7 @@ int Flags::Init(int argc, const char **argv) {
} else {
std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s, supported inputInferenceType: FLOAT | INT8",
this->inputInferenceTypeIn.c_str();
return 1;
return RET_INPUT_PARAM_INVALID;
}
if (this->inferenceTypeIn == "FLOAT") {
@ -97,7 +97,7 @@ int Flags::Init(int argc, const char **argv) {
} else {
std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8",
this->inferenceTypeIn.c_str();
return 1;
return RET_INPUT_PARAM_INVALID;
}
if (this->fmkIn == "CAFFE") {
@ -110,12 +110,12 @@ int Flags::Init(int argc, const char **argv) {
this->fmk = FmkType_ONNX;
} else {
std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS|ONNX";
return 1;
return RET_INPUT_PARAM_INVALID;
}
if (this->fmk != FmkType_CAFFE && !weightFile.empty()) {
std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag";
return 1;
return RET_INPUT_PARAM_INVALID;
}
if (this->quantTypeIn == "AwareTraining") {
this->quantType = QuantType_AwareTraining;
@ -127,7 +127,7 @@ int Flags::Init(int argc, const char **argv) {
this->quantType = QuantType_QUANT_NONE;
} else {
std::cerr << "INPUT ILLEGAL: quantType must be AwareTraining|WeightQuant|PostTraining";
return 1;
return RET_INPUT_PARAM_INVALID;
}
@ -137,9 +137,9 @@ int Flags::Init(int argc, const char **argv) {
this->trainModel = false;
} else {
std::cerr << "INPUT ILLEGAL: trainModel must be true|false ";
return 1;
return RET_INPUT_PARAM_INVALID;
}
return 0;
return RET_OK;
}
} // namespace converter
} // namespace lite

View File

@ -176,6 +176,9 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
return RET_ERROR;
}
} else if (attr->group != 1) {
MS_LOG(ERROR) << "group conv hasn't supported";
return RET_NOT_SUPPORT;
} else {
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.release();

View File

@ -78,5 +78,6 @@ STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
}
OnnxNodeRegistrar g_onnxLrnxParser("Lrn", new OnnxLrnParser());
OnnxNodeRegistrar g_onnxLRNxParser("LRN", new OnnxLrnParser());
} // namespace lite
} // namespace mindspore

View File

@ -42,19 +42,19 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
const auto &tflite_attr = tflite_op->builtin_options.AsExpandDimsOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
std::vector<int> dims;
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, dims)) {
MS_LOG(ERROR) << "get expand_dims -> dim failed";
return RET_ERROR;
}
attr->dim = -1;
MS_LOG(ERROR) << "The attr dim is folded by TFLite.";
return RET_ERROR;
attr->dim = dims[0];
op->primitive->value.type = schema::PrimitiveType_ExpandDims;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
}
TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,63 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_hashtable_lookup_parser.h"
#include <vector>
#include <memory>
#include <map>
namespace mindspore {
namespace lite {
STATUS TfliteHashtableLookupParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op, std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteHashtableLookupParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::HashtableLookupT> attr = std::make_unique<schema::HashtableLookupT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_HashtableLookup;
op->primitive->value.value = attr.release();
for (size_t i = 0; i < tflite_op->inputs.size(); ++i) {
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
}
for (size_t i = 0; i < tflite_op->outputs.size(); ++i) {
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[i], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
}
return RET_OK;
}
TfliteNodeRegister g_tfliteHashtableLookupParser("HashtableLookup", new TfliteHashtableLookupParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_HASHTABLE_LOOKUP_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_HASHTABLE_LOOKUP_PARSER_H
#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteHashtableLookupParser : public TfliteNodeParser {
public:
TfliteHashtableLookupParser() : TfliteNodeParser("HashtableLookup") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op,
std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_HASHTABLE_LOOKUP_PARSER_H

View File

@ -42,18 +42,43 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_o
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
const auto &tflite_attr = tflite_op->builtin_options.AsPadOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->paddingMode = schema::PaddingMode_CONSTANT;
attr->constantValue = 0.0f;
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->paddings)) {
MS_LOG(ERROR) << "get pad -> paddings failed";
return RET_ERROR;
std::vector<std::string> node_name_str;
Split(op->name, &node_name_str, "-");
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Pad") == 0) {
const auto &tflite_attr = tflite_op->builtin_options.AsPadOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->paddingMode = schema::PaddingMode_CONSTANT;
attr->constantValue = 0.0f;
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->paddings)) {
MS_LOG(ERROR) << "get pad -> paddings failed";
return RET_ERROR;
}
} else if (std::strcmp(node_name, "MirrorPad") == 0) {
const auto &tflite_attr = tflite_op->builtin_options.AsMirrorPadOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
switch (tflite_attr->mode) {
case tflite::MirrorPadMode_REFLECT:
attr->paddingMode = schema::PaddingMode_REFLECT;
break;
case tflite::MirrorPadMode_SYMMETRIC:
attr->paddingMode = schema::PaddingMode_SYMMETRIC;
break;
default:
MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support";
return RET_INVALID_OP_ATTR;
}
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
} else {
MS_LOG(ERROR) << "this pad:" << node_name << " hasn't been supported";
return RET_NOT_SUPPORT;
}
op->primitive->value.type = schema::PrimitiveType_Pad;
@ -67,5 +92,6 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_o
}
TfliteNodeRegister g_tflitePadParser("Pad", new TflitePadParser());
TfliteNodeRegister g_tfliteMirorPadParser("MirrorPad", new TflitePadParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,67 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_skip_gram_parser.h"
#include <vector>
#include <memory>
#include <map>
namespace mindspore {
namespace lite {
STATUS TfliteSkipGramParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::CNodeT *op, std::vector<int32_t> *tensors_id,
std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) {
MS_LOG(DEBUG) << "parse TfliteSkipGramParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::SkipGramT> attr = std::make_unique<schema::SkipGramT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
const auto &tflite_attr = tflite_op->builtin_options.AsSkipGramOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name << " attr failed";
return RET_NULL_PTR;
}
attr->includeAllGrams = tflite_attr->include_all_ngrams;
attr->maxSkipSize = tflite_attr->max_skip_size;
attr->ngramSize = tflite_attr->ngram_size;
op->primitive->value.type = schema::PrimitiveType_SkipGram;
op->primitive->value.value = attr.release();
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(),
tflite_tensors.size(), schema::Format::Format_NHWC);
return RET_OK;
}
TfliteNodeRegister g_TfliteSkiGramParser("SKipGram", new TfliteSkipGramParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SKIP_GRAM_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SKIP_GRAM_PARSER_H
#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteSkipGramParser : public TfliteNodeParser {
public:
TfliteSkipGramParser() : TfliteNodeParser("SkipGram") {}
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op,
std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format,
std::map<int, int> *tensors_id_map) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SKIP_GRAM_PARSER_H

View File

@ -57,7 +57,7 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{
{tflite::BuiltinOperator_POW, "Pow"},
{tflite::BuiltinOperator_ARG_MIN, "Argmin"},
{tflite::BuiltinOperator_CEIL, "Ceil"},
// {tflite::BuiltinOperator_EXPAND_DIMS, "ExpandDims"},
{tflite::BuiltinOperator_EXPAND_DIMS, "ExpandDims"},
{tflite::BuiltinOperator_FILL, "Fill"},
{tflite::BuiltinOperator_DIV, "Div"},
{tflite::BuiltinOperator_FLOOR, "flOOR"},
@ -117,6 +117,7 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{
{tflite::BuiltinOperator_UNIQUE, "Unique"},
{tflite::BuiltinOperator_UNPACK, "Unstack"},
{tflite::BuiltinOperator_CUSTOM, "Custom"},
{tflite::BuiltinOperator_MIRROR_PAD, "MirrorPad"},
};
std::map<tflite::ActivationFunctionType, schema::ActivationType> tfMsActivationFunctionMap{

View File

@ -33,7 +33,7 @@ class ReturnCode {
statusCode = status;
}
}
STATUS GetReturnCode() {
STATUS GetReturnCode() const {
return statusCode;
}
private:

View File

@ -85,8 +85,8 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
param_value->set_tensor_type(type_id);
param_value->set_format(tensor->GetFormat());
if (tensor->MutableData() != nullptr) {
auto size = tensor->ElementsNum();
auto tensor_data = new (std::nothrow) float[size];
auto size = tensor->Size();
auto tensor_data = new (std::nothrow) uint8_t[size];
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "tensor_data is nullptr";
return nullptr;
@ -98,7 +98,7 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
return nullptr;
}
param_value->set_tensor_addr(tensor_data);
param_value->set_tensor_size(size * sizeof(float) / sizeof(uint8_t));
param_value->set_tensor_size(size);
}
parameter->set_default_param(param_value);
return parameter;