forked from mindspore-Ecosystem/mindspore
add tflite LogSoftmax BatchMatmul parser
This commit is contained in:
parent
b7392a2850
commit
8538ae1ebf
|
@ -17,6 +17,21 @@
|
|||
#include "nnacl/infer/slice_infer.h"
|
||||
#include "nnacl/infer/infer_register.h"
|
||||
|
||||
static bool CheckInputsDataType(const TensorC *const *inputs, size_t inputs_size) {
|
||||
// not support data_type of slice's begin and size is not int32
|
||||
if (inputs_size >= 2) {
|
||||
if (inputs[1]->data_type_ != kNumberTypeInt32) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (inputs_size == 3) {
|
||||
if (inputs[2]->data_type_ != kNumberTypeInt32) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int SliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
if (inputs_size < 1 || outputs_size != 1) {
|
||||
|
@ -26,6 +41,10 @@ int SliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **
|
|||
TensorC *output = outputs[0];
|
||||
SetDataTypeFormat(output, input);
|
||||
|
||||
if (!CheckInputsDataType(inputs, inputs_size)) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
|
||||
if (!InferFlag(inputs, inputs_size)) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
|
|
|
@ -37,13 +37,16 @@ OpParameter *PopulateSliceParameter(const void *prim) {
|
|||
|
||||
param->op_parameter_.type_ = primitive->value_type();
|
||||
auto axes = value->axes();
|
||||
if (axes == nullptr) {
|
||||
MS_LOG(ERROR) << "axes is nullptr";
|
||||
free(param);
|
||||
return nullptr;
|
||||
}
|
||||
for (size_t i = 0; i < axes->size(); ++i) {
|
||||
param->axis_[i] = axes->Get(i);
|
||||
// if begin is not const input, then axis can not be decided in converter
|
||||
if (axes != nullptr) {
|
||||
for (size_t i = 0; i < axes->size(); ++i) {
|
||||
param->axis_[i] = axes->Get(i);
|
||||
}
|
||||
} else {
|
||||
// use default axes
|
||||
for (int32_t i = 0; i < DIMENSION_8D; i++) {
|
||||
param->axis_[i] = i;
|
||||
}
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/tflite/tflite_batch_matmul_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/mat_mul.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *TfliteBatchMatmulParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto prim = std::make_unique<ops::MatMul>();
|
||||
|
||||
MS_ASSERT(tflite_op != nullptr);
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsBatchMatMulOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "get op BatchMatmul attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
prim->set_transpose_a(tflite_attr->adj_x);
|
||||
prim->set_transpose_b(tflite_attr->adj_y);
|
||||
|
||||
return prim.release();
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteBatchMatmulParser(tflite::BuiltinOperator_BATCH_MATMUL, new TfliteBatchMatmulParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_MATMUL_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_MATMUL_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TfliteBatchMatmulParser : public TfliteNodeParser {
|
||||
public:
|
||||
TfliteBatchMatmulParser() : TfliteNodeParser("BatchMatmul") {}
|
||||
|
||||
~TfliteBatchMatmulParser() override = default;
|
||||
|
||||
ops::PrimitiveC *Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_BATCH_MATMUL_PARSER_H
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/tflite/tflite_log_softmax_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/log_softmax.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *TfliteLogSoftmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto prim = std::make_unique<ops::LogSoftmax>();
|
||||
|
||||
MS_ASSERT(tflite_op != nullptr);
|
||||
prim->set_axis(-1);
|
||||
|
||||
return prim.release();
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteLogSoftmaxParser(tflite::BuiltinOperator_LOG_SOFTMAX, new TfliteLogSoftmaxParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LOG_SOFTMAX_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LOG_SOFTMAX_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TfliteLogSoftmaxParser : public TfliteNodeParser {
|
||||
public:
|
||||
TfliteLogSoftmaxParser() : TfliteNodeParser("LogSoftmax") {}
|
||||
|
||||
~TfliteLogSoftmaxParser() override = default;
|
||||
|
||||
ops::PrimitiveC *Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LOG_SOFTMAX_PARSER_H
|
|
@ -33,7 +33,8 @@ ops::PrimitiveC *TfliteSliceParser::Parse(const std::unique_ptr<tflite::Operator
|
|||
return nullptr;
|
||||
}
|
||||
std::vector<int64_t> begin;
|
||||
if (GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, begin)) {
|
||||
auto ret = GetTfliteData(tflite_op->inputs[1], tflite_subgraph->tensors, tflite_model->buffers, begin);
|
||||
if (ret != RET_OK && ret != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "get slice -> begin failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -138,13 +138,16 @@ STATUS TfliteInputsAdjustPass::AdjustSlice(const AnfNodePtr &node, const FuncGra
|
|||
|
||||
auto begin_param_node = cnode->input(2)->cast<ParameterPtr>();
|
||||
auto size_param_node = cnode->input(3)->cast<ParameterPtr>();
|
||||
if (ReplaceInt64ParameterNode(graph, begin_param_node) == RET_OK &&
|
||||
ReplaceInt64ParameterNode(graph, size_param_node) == RET_OK) {
|
||||
return RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Adjust inputs for Slice failed";
|
||||
// slice's begin and size could be variable
|
||||
if (begin_param_node != nullptr && ReplaceInt64ParameterNode(graph, begin_param_node) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Adjust begin for Slice failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (size_param_node != nullptr && ReplaceInt64ParameterNode(graph, size_param_node) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Adjust size for Slice failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool TfliteInputsAdjustPass::Run(const FuncGraphPtr &graph) {
|
||||
|
|
Loading…
Reference in New Issue