!6230 MSLITE add tflite hashtable_lookup,mirror_pad,skipGram parser;adjust const fold adapt to datatype
Merge pull request !6230 from 徐安越/master
This commit is contained in:
commit
4b84c6ee98
|
@ -32,6 +32,7 @@ constexpr int RET_PARAM_INVALID = -3; /**< Invalid parameter.*/
|
|||
constexpr int RET_NO_CHANGE = -4; /**< No change. */
|
||||
constexpr int RET_SUCCESS_EXIT = -5; /**< No error but exit. */
|
||||
constexpr int RET_MEMORY_FAILED = -6; /**< Fail to create memory. */
|
||||
constexpr int RET_NOT_SUPPORT = -7; /**< Fail to support. */
|
||||
|
||||
/* Executor error code, range: [-101,-200] */
|
||||
constexpr int RET_OUT_OF_TENSOR_RANGE = -101; /**< Failed to check range. */
|
||||
|
@ -53,6 +54,10 @@ constexpr int RET_FORMAT_ERR = -401; /**< Failed to checking tensor format. */
|
|||
/* InferShape error code, range: [-501,-600] */
|
||||
constexpr int RET_INFER_ERR = -501; /**< Failed to infer shape. */
|
||||
constexpr int RET_INFER_INVALID = -502; /**< Invalid infer shape before runtime. */
|
||||
|
||||
/* User input param error code, range: [-601, 700]*/
|
||||
constexpr int RET_INPUT_PARAM_INVALID = -601; /**< Invalid input param by user. */
|
||||
constexpr int RET_INPUT_PARAM_LACK = -602; /**< LACK input param by user. */
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -203,6 +203,8 @@ union PrimitiveType {
|
|||
LogGrad,
|
||||
BatchToSpaceND,
|
||||
LshProjection,
|
||||
HashtableLookup,
|
||||
SkipGram,
|
||||
}
|
||||
|
||||
enum QuantType: int {
|
||||
|
|
|
@ -948,3 +948,12 @@ table BlackBox {
|
|||
table LshProjection {
|
||||
type : LshProjectionType;
|
||||
}
|
||||
|
||||
table HashtableLookup {
|
||||
}
|
||||
|
||||
table SkipGram {
|
||||
includeAllGrams : bool;
|
||||
maxSkipSize : int;
|
||||
ngramSize : int;
|
||||
}
|
||||
|
|
|
@ -109,15 +109,15 @@ int RunConverter(int argc, const char **argv) {
|
|||
std::unique_ptr<converter::Flags> flags(new (std::nothrow) converter::Flags);
|
||||
if (flags == nullptr) {
|
||||
MS_LOG(ERROR) << "new flags error ";
|
||||
std::cout << "NEW FLAGS ERROR:" << RET_MEMORY_FAILED << std::endl;
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
auto status = flags->Init(argc, argv);
|
||||
if (status == RET_SUCCESS_EXIT) {
|
||||
return status;
|
||||
}
|
||||
if (status != 0) {
|
||||
MS_LOG(ERROR) << "converter::Flags Init failed: " << status;
|
||||
std::cout << "CONVERTER::FLAGS INIT FAILED" << std::endl;
|
||||
if (status != RET_OK) {
|
||||
if (status != RET_SUCCESS_EXIT) {
|
||||
MS_LOG(ERROR) << "converter::Flags Init failed: " << status;
|
||||
std::cout << "CONVERTER::FLAGS INIT FAILED:" << status << std::endl;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
// Load graph
|
||||
|
@ -148,13 +148,14 @@ int RunConverter(int argc, const char **argv) {
|
|||
} break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "Unsupported fmkType: " << flags->fmk;
|
||||
return 1;
|
||||
std::cout << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
status = ReturnCode::GetSingleReturnCode()->GetReturnCode();
|
||||
if (fb_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert model return nullptr";
|
||||
std::cout << "CONVERT RESULT: FAILED!" << std::endl;
|
||||
std::cout << "CONVERT RESULT FAILED:" << status << std::endl;
|
||||
return status;
|
||||
}
|
||||
|
||||
|
@ -164,14 +165,14 @@ int RunConverter(int argc, const char **argv) {
|
|||
status = storage.Save(*fb_graph, flags->outputFile);
|
||||
if (status != 0) {
|
||||
MS_LOG(ERROR) << "Save graph failed";
|
||||
std::cout << "SAVE GRAPH FAILED!" << std::endl;
|
||||
return RET_ERROR;
|
||||
std::cout << "SAVE GRAPH FAILED:" << status << std::endl;
|
||||
return status;
|
||||
}
|
||||
|
||||
delete fb_graph;
|
||||
MS_LOG(INFO) << "CONVERT RESULT: SUCCESS!";
|
||||
std::cout << "CONVERT RESULT: SUCCESS!" << std::endl;
|
||||
return RET_OK;
|
||||
std::cout << "CONVERT RESULT SUCCESS:" << status << std::endl;
|
||||
return status;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -55,7 +55,7 @@ int Flags::Init(int argc, const char **argv) {
|
|||
if (err.IsSome()) {
|
||||
std::cerr << err.Get();
|
||||
std::cerr << this->Usage() << std::endl;
|
||||
return 1;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (this->help) {
|
||||
|
@ -64,21 +64,21 @@ int Flags::Init(int argc, const char **argv) {
|
|||
}
|
||||
if (this->modelFile.empty()) {
|
||||
std::cerr << "INPUT MISSING: model file path is necessary";
|
||||
return 1;
|
||||
return RET_INPUT_PARAM_LACK;
|
||||
}
|
||||
if (this->outputFile.empty()) {
|
||||
std::cerr << "INPUT MISSING: output file path is necessary";
|
||||
return 1;
|
||||
return RET_INPUT_PARAM_LACK;
|
||||
}
|
||||
|
||||
if (this->outputFile.rfind('/') == this->outputFile.length() - 1) {
|
||||
std::cerr << "INPUT ILLEGAL: outputFile must be a valid file path";
|
||||
return 1;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (this->fmkIn.empty()) {
|
||||
std::cerr << "INPUT MISSING: fmk is necessary";
|
||||
return 1;
|
||||
return RET_INPUT_PARAM_LACK;
|
||||
}
|
||||
if (this->inputInferenceTypeIn == "FLOAT") {
|
||||
this->inputInferenceType = TypeId::kNumberTypeFloat;
|
||||
|
@ -87,7 +87,7 @@ int Flags::Init(int argc, const char **argv) {
|
|||
} else {
|
||||
std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s, supported inputInferenceType: FLOAT | INT8",
|
||||
this->inputInferenceTypeIn.c_str();
|
||||
return 1;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (this->inferenceTypeIn == "FLOAT") {
|
||||
|
@ -97,7 +97,7 @@ int Flags::Init(int argc, const char **argv) {
|
|||
} else {
|
||||
std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8",
|
||||
this->inferenceTypeIn.c_str();
|
||||
return 1;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (this->fmkIn == "CAFFE") {
|
||||
|
@ -110,12 +110,12 @@ int Flags::Init(int argc, const char **argv) {
|
|||
this->fmk = FmkType_ONNX;
|
||||
} else {
|
||||
std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS|ONNX";
|
||||
return 1;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (this->fmk != FmkType_CAFFE && !weightFile.empty()) {
|
||||
std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag";
|
||||
return 1;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (this->quantTypeIn == "AwareTraining") {
|
||||
this->quantType = QuantType_AwareTraining;
|
||||
|
@ -127,7 +127,7 @@ int Flags::Init(int argc, const char **argv) {
|
|||
this->quantType = QuantType_QUANT_NONE;
|
||||
} else {
|
||||
std::cerr << "INPUT ILLEGAL: quantType must be AwareTraining|WeightQuant|PostTraining";
|
||||
return 1;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
|
||||
|
@ -137,9 +137,9 @@ int Flags::Init(int argc, const char **argv) {
|
|||
this->trainModel = false;
|
||||
} else {
|
||||
std::cerr << "INPUT ILLEGAL: trainModel must be true|false ";
|
||||
return 1;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
return 0;
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace converter
|
||||
} // namespace lite
|
||||
|
|
|
@ -176,6 +176,9 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (attr->group != 1) {
|
||||
MS_LOG(ERROR) << "group conv hasn't supported";
|
||||
return RET_NOT_SUPPORT;
|
||||
} else {
|
||||
op->primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
|
|
@ -78,5 +78,6 @@ STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
|
|||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxLrnxParser("Lrn", new OnnxLrnParser());
|
||||
OnnxNodeRegistrar g_onnxLRNxParser("LRN", new OnnxLrnParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,19 +42,19 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
|
|||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsExpandDimsOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
std::vector<int> dims;
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, dims)) {
|
||||
MS_LOG(ERROR) << "get expand_dims -> dim failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
attr->dim = -1;
|
||||
|
||||
MS_LOG(ERROR) << "The attr dim is folded by TFLite.";
|
||||
return RET_ERROR;
|
||||
attr->dim = dims[0];
|
||||
op->primitive->value.type = schema::PrimitiveType_ExpandDims;
|
||||
op->primitive->value.value = attr.release();
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(),
|
||||
tflite_tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(),
|
||||
tflite_tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/tflite/tflite_hashtable_lookup_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteHashtableLookupParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op, std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse TfliteHashtableLookupParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::HashtableLookupT> attr = std::make_unique<schema::HashtableLookupT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_HashtableLookup;
|
||||
op->primitive->value.value = attr.release();
|
||||
for (size_t i = 0; i < tflite_op->inputs.size(); ++i) {
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(),
|
||||
tflite_tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
for (size_t i = 0; i < tflite_op->outputs.size(); ++i) {
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[i], tensors_id->size(),
|
||||
tflite_tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteHashtableLookupParser("HashtableLookup", new TfliteHashtableLookupParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_HASHTABLE_LOOKUP_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_HASHTABLE_LOOKUP_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TfliteHashtableLookupParser : public TfliteNodeParser {
|
||||
public:
|
||||
TfliteHashtableLookupParser() : TfliteNodeParser("HashtableLookup") {}
|
||||
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_HASHTABLE_LOOKUP_PARSER_H
|
|
@ -42,18 +42,43 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_o
|
|||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsPadOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
attr->paddingMode = schema::PaddingMode_CONSTANT;
|
||||
attr->constantValue = 0.0f;
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->paddings)) {
|
||||
MS_LOG(ERROR) << "get pad -> paddings failed";
|
||||
return RET_ERROR;
|
||||
std::vector<std::string> node_name_str;
|
||||
Split(op->name, &node_name_str, "-");
|
||||
const char *node_name = node_name_str.data()->c_str();
|
||||
if (std::strcmp(node_name, "Pad") == 0) {
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsPadOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->paddingMode = schema::PaddingMode_CONSTANT;
|
||||
attr->constantValue = 0.0f;
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->paddings)) {
|
||||
MS_LOG(ERROR) << "get pad -> paddings failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (std::strcmp(node_name, "MirrorPad") == 0) {
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsMirrorPadOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
switch (tflite_attr->mode) {
|
||||
case tflite::MirrorPadMode_REFLECT:
|
||||
attr->paddingMode = schema::PaddingMode_REFLECT;
|
||||
break;
|
||||
case tflite::MirrorPadMode_SYMMETRIC:
|
||||
attr->paddingMode = schema::PaddingMode_SYMMETRIC;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support";
|
||||
return RET_INVALID_OP_ATTR;
|
||||
}
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(),
|
||||
tflite_tensors.size(), schema::Format::Format_NHWC);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "this pad:" << node_name << " hasn't been supported";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Pad;
|
||||
|
@ -67,5 +92,6 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_o
|
|||
}
|
||||
|
||||
TfliteNodeRegister g_tflitePadParser("Pad", new TflitePadParser());
|
||||
TfliteNodeRegister g_tfliteMirorPadParser("MirrorPad", new TflitePadParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/tflite/tflite_skip_gram_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteSkipGramParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
schema::CNodeT *op, std::vector<int32_t> *tensors_id,
|
||||
std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) {
|
||||
MS_LOG(DEBUG) << "parse TfliteSkipGramParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::SkipGramT> attr = std::make_unique<schema::SkipGramT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsSkipGramOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op: " << op->name << " attr failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->includeAllGrams = tflite_attr->include_all_ngrams;
|
||||
attr->maxSkipSize = tflite_attr->max_skip_size;
|
||||
attr->ngramSize = tflite_attr->ngram_size;
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_SkipGram;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(),
|
||||
tflite_tensors.size(), schema::Format::Format_NHWC);
|
||||
AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(),
|
||||
tflite_tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_TfliteSkiGramParser("SKipGram", new TfliteSkipGramParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SKIP_GRAM_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SKIP_GRAM_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TfliteSkipGramParser : public TfliteNodeParser {
|
||||
public:
|
||||
TfliteSkipGramParser() : TfliteNodeParser("SkipGram") {}
|
||||
|
||||
STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op,
|
||||
std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format,
|
||||
std::map<int, int> *tensors_id_map) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SKIP_GRAM_PARSER_H
|
|
@ -57,7 +57,7 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{
|
|||
{tflite::BuiltinOperator_POW, "Pow"},
|
||||
{tflite::BuiltinOperator_ARG_MIN, "Argmin"},
|
||||
{tflite::BuiltinOperator_CEIL, "Ceil"},
|
||||
// {tflite::BuiltinOperator_EXPAND_DIMS, "ExpandDims"},
|
||||
{tflite::BuiltinOperator_EXPAND_DIMS, "ExpandDims"},
|
||||
{tflite::BuiltinOperator_FILL, "Fill"},
|
||||
{tflite::BuiltinOperator_DIV, "Div"},
|
||||
{tflite::BuiltinOperator_FLOOR, "flOOR"},
|
||||
|
@ -117,6 +117,7 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{
|
|||
{tflite::BuiltinOperator_UNIQUE, "Unique"},
|
||||
{tflite::BuiltinOperator_UNPACK, "Unstack"},
|
||||
{tflite::BuiltinOperator_CUSTOM, "Custom"},
|
||||
{tflite::BuiltinOperator_MIRROR_PAD, "MirrorPad"},
|
||||
};
|
||||
|
||||
std::map<tflite::ActivationFunctionType, schema::ActivationType> tfMsActivationFunctionMap{
|
||||
|
|
|
@ -33,7 +33,7 @@ class ReturnCode {
|
|||
statusCode = status;
|
||||
}
|
||||
}
|
||||
STATUS GetReturnCode() {
|
||||
STATUS GetReturnCode() const {
|
||||
return statusCode;
|
||||
}
|
||||
private:
|
||||
|
|
|
@ -85,8 +85,8 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
|
|||
param_value->set_tensor_type(type_id);
|
||||
param_value->set_format(tensor->GetFormat());
|
||||
if (tensor->MutableData() != nullptr) {
|
||||
auto size = tensor->ElementsNum();
|
||||
auto tensor_data = new (std::nothrow) float[size];
|
||||
auto size = tensor->Size();
|
||||
auto tensor_data = new (std::nothrow) uint8_t[size];
|
||||
if (tensor_data == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_data is nullptr";
|
||||
return nullptr;
|
||||
|
@ -98,7 +98,7 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
|
|||
return nullptr;
|
||||
}
|
||||
param_value->set_tensor_addr(tensor_data);
|
||||
param_value->set_tensor_size(size * sizeof(float) / sizeof(uint8_t));
|
||||
param_value->set_tensor_size(size);
|
||||
}
|
||||
parameter->set_default_param(param_value);
|
||||
return parameter;
|
||||
|
|
Loading…
Reference in New Issue