add tflite LogSoftmax BatchMatmul parser

This commit is contained in:
hangangqiang 2021-05-07 15:16:45 +08:00
parent b7392a2850
commit 8538ae1ebf
8 changed files with 197 additions and 13 deletions

View File

@ -17,6 +17,21 @@
#include "nnacl/infer/slice_infer.h"
#include "nnacl/infer/infer_register.h"
static bool CheckInputsDataType(const TensorC *const *inputs, size_t inputs_size) {
// not support data_type of slice's begin and size is not int32
if (inputs_size >= 2) {
if (inputs[1]->data_type_ != kNumberTypeInt32) {
return false;
}
}
if (inputs_size == 3) {
if (inputs[2]->data_type_ != kNumberTypeInt32) {
return false;
}
}
return true;
}
int SliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
if (inputs_size < 1 || outputs_size != 1) {
@ -26,6 +41,10 @@ int SliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **
TensorC *output = outputs[0];
SetDataTypeFormat(output, input);
if (!CheckInputsDataType(inputs, inputs_size)) {
return NNACL_ERR;
}
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}

View File

@ -37,13 +37,16 @@ OpParameter *PopulateSliceParameter(const void *prim) {
param->op_parameter_.type_ = primitive->value_type();
auto axes = value->axes();
if (axes == nullptr) {
MS_LOG(ERROR) << "axes is nullptr";
free(param);
return nullptr;
}
for (size_t i = 0; i < axes->size(); ++i) {
param->axis_[i] = axes->Get(i);
// if begin is not const input, then axis can not be decided in converter
if (axes != nullptr) {
for (size_t i = 0; i < axes->size(); ++i) {
param->axis_[i] = axes->Get(i);
}
} else {
// use default axes
for (int32_t i = 0; i < DIMENSION_8D; i++) {
param->axis_[i] = i;
}
}
return reinterpret_cast<OpParameter *>(param);
}

View File

@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_batch_matmul_parser.h"
#include <vector>
#include <memory>
#include "ops/mat_mul.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *TfliteBatchMatmulParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto prim = std::make_unique<ops::MatMul>();
MS_ASSERT(tflite_op != nullptr);
const auto &tflite_attr = tflite_op->builtin_options.AsBatchMatMulOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op BatchMatmul attr failed";
return nullptr;
}
prim->set_transpose_a(tflite_attr->adj_x);
prim->set_transpose_b(tflite_attr->adj_y);
return prim.release();
}
TfliteNodeRegister g_tfliteBatchMatmulParser(tflite::BuiltinOperator_BATCH_MATMUL, new TfliteBatchMatmulParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_MATMUL_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_MATMUL_PARSER_H
#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteBatchMatmulParser : public TfliteNodeParser {
public:
TfliteBatchMatmulParser() : TfliteNodeParser("BatchMatmul") {}
~TfliteBatchMatmulParser() override = default;
ops::PrimitiveC *Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_MATMUL_PARSER_H

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tflite/tflite_log_softmax_parser.h"
#include <vector>
#include <memory>
#include "ops/log_softmax.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *TfliteLogSoftmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto prim = std::make_unique<ops::LogSoftmax>();
MS_ASSERT(tflite_op != nullptr);
prim->set_axis(-1);
return prim.release();
}
TfliteNodeRegister g_tfliteLogSoftmaxParser(tflite::BuiltinOperator_LOG_SOFTMAX, new TfliteLogSoftmaxParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LOG_SOFTMAX_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LOG_SOFTMAX_PARSER_H
#include <memory>
#include <vector>
#include <map>
#include "tools/converter/parser/tflite/tflite_node_parser.h"
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
namespace mindspore {
namespace lite {
class TfliteLogSoftmaxParser : public TfliteNodeParser {
public:
TfliteLogSoftmaxParser() : TfliteNodeParser("LogSoftmax") {}
~TfliteLogSoftmaxParser() override = default;
ops::PrimitiveC *Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LOG_SOFTMAX_PARSER_H

View File

@ -33,7 +33,8 @@ ops::PrimitiveC *TfliteSliceParser::Parse(const std::unique_ptr<tflite::Operator
return nullptr;
}
std::vector<int64_t> begin;
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, begin)) {
auto ret = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, begin);
if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(ERROR) << "get slice -> begin failed";
return nullptr;
}

View File

@ -138,13 +138,16 @@ STATUS TfliteInputsAdjustPass::AdjustSlice(const AnfNodePtr &node, const FuncGra
auto begin_param_node = cnode->input(2)->cast<ParameterPtr>();
auto size_param_node = cnode->input(3)->cast<ParameterPtr>();
if (ReplaceInt64ParameterNode(graph, begin_param_node) == RET_OK &&
ReplaceInt64ParameterNode(graph, size_param_node) == RET_OK) {
return RET_OK;
} else {
MS_LOG(ERROR) << "Adjust inputs for Slice failed";
// slice's begin and size could be variable
if (begin_param_node != nullptr && ReplaceInt64ParameterNode(graph, begin_param_node) != RET_OK) {
MS_LOG(ERROR) << "Adjust begin for Slice failed";
return RET_ERROR;
}
if (size_param_node != nullptr && ReplaceInt64ParameterNode(graph, size_param_node) != RET_OK) {
MS_LOG(ERROR) << "Adjust size for Slice failed";
return RET_ERROR;
}
return RET_OK;
}
bool TfliteInputsAdjustPass::Run(const FuncGraphPtr &graph) {