forked from mindspore-Ecosystem/mindspore
!8776 [lite] add tflite custom op and adjust identity_remove pass
From: @xu_anyue Reviewed-by: @hangangqiang,@hangangqiang,@zhanghaibo5 Signed-off-by: @hangangqiang
This commit is contained in:
commit
4d00c6bf26
|
@ -64,11 +64,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
// for now - trainning is not supporting fuse operations
|
||||
if (config != nullptr && !config->trainModel) {
|
||||
// remove quantdtype when awaretraining
|
||||
if (config->fmk == lite::converter::FmkType_ONNX) {
|
||||
auto remove_identity_pass = std::make_shared<opt::RemoveIdentityOpPass>();
|
||||
remove_identity_pass->SetFmkType(config->fmk);
|
||||
pm->AddPass(remove_identity_pass);
|
||||
}
|
||||
pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
|
||||
pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ConvScaleFusion>());
|
||||
|
|
|
@ -181,6 +181,33 @@ STATUS TfliteCustomParser::FftImag(const std::vector<uint8_t> &custom_attr, sche
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteCustomParser::Identity(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
|
||||
std::unique_ptr<schema::IdentityT> attr = std::make_unique<schema::IdentityT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive->value.type = schema::PrimitiveType_Identity;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteCustomParser::BatchMatMul(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
|
||||
std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->broadcast = false;
|
||||
attr->transposeA = false;
|
||||
attr->transposeB = false;
|
||||
op->primitive->value.type = schema::PrimitiveType_MatMul;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
|
@ -216,6 +243,10 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
|
|||
status = FftReal(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "FlexImag") {
|
||||
status = FftImag(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "FlexIdentityN" || custom_type == "FlexIdentity") {
|
||||
status = Identity(custom_attr, op, tflite_op);
|
||||
} else if (custom_type == "FlexBatchMatMul") {
|
||||
status = BatchMatMul(custom_attr, op, tflite_op);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "the custom op hasn't been supported now";
|
||||
status = RET_NOT_FIND_OP;
|
||||
|
|
|
@ -60,6 +60,12 @@ class TfliteCustomParser : public TfliteNodeParser {
|
|||
|
||||
STATUS FftImag(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op);
|
||||
|
||||
STATUS Identity(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op);
|
||||
|
||||
STATUS BatchMatMul(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
|
||||
const std::unique_ptr<tflite::OperatorT> &tflite_op);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/tflite/tflite_matmul_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS TfliteMatMulParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "parse TfliteMatMulParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::MatMulT> attr = std::make_unique<schema::MatMulT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsBatchMatMulOptions();
|
||||
attr->transposeA = tflite_attr->adj_x;
|
||||
attr->transposeB = tflite_attr->adj_y;
|
||||
attr->broadcast = false;
|
||||
op->primitive->value.type = schema::PrimitiveType_MatMul;
|
||||
op->primitive->value.value = attr.release();
|
||||
|
||||
for (size_t i = 0; i < tflite_op->inputs.size(); i++) {
|
||||
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
}
|
||||
AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_subgraph->tensors.size(), schema::Format::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
TfliteNodeRegister g_tfliteMatMulParser("MatMul", new TfliteMatMulParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MATMUL_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_MATMUL_PARSER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TfliteMatMulParser : public TfliteNodeParser {
|
||||
public:
|
||||
TfliteMatMulParser() : TfliteNodeParser("MatMul") {}
|
||||
|
||||
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_SLICE_PARSER_H
|
|
@ -109,10 +109,11 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
|
|||
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
|
||||
continue;
|
||||
}
|
||||
if (status == RET_OK) {
|
||||
status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, tflite_subgraph, op.get());
|
||||
if (status != RET_OK) {
|
||||
if (status == RET_NOT_FIND_OP) {
|
||||
if (status == RET_OK || op_type == "Custom") {
|
||||
int status_node = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, tflite_subgraph, op.get());
|
||||
status = (status == RET_OK ? status_node : status);
|
||||
if (status_node != RET_OK) {
|
||||
if (status_node == RET_NOT_FIND_OP) {
|
||||
op_type =
|
||||
(op_type != "Custom" ? op_type : (tflite_model->operator_codes[tflite_op->opcode_index])->custom_code);
|
||||
NoSupportOp::GetInstance()->InsertOp(op_type);
|
||||
|
@ -121,6 +122,9 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
|
|||
}
|
||||
continue;
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
continue;
|
||||
}
|
||||
sub_graph->nodes.emplace_back(op.release());
|
||||
opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get();
|
||||
tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get();
|
||||
|
|
|
@ -14,34 +14,96 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/optimizer/graph/identity_remove_pass.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "mindspore/lite/include/errorcode.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) {
|
||||
if (this->fmk_type != lite::converter::FmkType_ONNX) {
|
||||
MS_LOG(INFO) << "The framework type of model should be onnx.";
|
||||
return RET_OK;
|
||||
int RemoveIdentityOpPass::ReplaceIdentity(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
||||
if (!utils::isa<CNodePtr>(anf_node)) {
|
||||
MS_LOG(DEBUG) << "anf node is node a cnode.";
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto type = opt::GetCNodeType(anf_node);
|
||||
if (type != schema::PrimitiveType_Identity) {
|
||||
MS_LOG(DEBUG) << "anf node is not a identity node.";
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto identity_cnode = anf_node->cast<CNodePtr>();
|
||||
if (identity_cnode->inputs().size() != lite::kDoubleNum) {
|
||||
MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
|
||||
remove_cnode_.insert(anf_node);
|
||||
return lite::RET_NO_CHANGE;
|
||||
} else {
|
||||
bool replace_succ = manager->Replace(anf_node, identity_cnode->input(1));
|
||||
if (!replace_succ) {
|
||||
MS_LOG(ERROR) << "replace identity failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
||||
if (!utils::isa<CNodePtr>(anf_node)) {
|
||||
MS_LOG(DEBUG) << "anf node is node a cnode.";
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto type = opt::GetCNodeType(anf_node);
|
||||
if (type != schema::PrimitiveType_TupleGetItem) {
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
if (cnode->inputs().size() != 3) {
|
||||
MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << cnode->inputs().size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
type = opt::GetCNodeType(cnode->input(1));
|
||||
if (type != schema::PrimitiveType_Identity) {
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
auto get_item_input_cnode = cnode->input(1)->cast<CNodePtr>();
|
||||
auto index_vnode = cnode->input(2);
|
||||
if (!utils::isa<ValueNode>(index_vnode)) {
|
||||
MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
int index = lite::CastToInt(index_vnode->cast<ValueNodePtr>()->value(), false).front();
|
||||
int input_cnode_inputs_size = get_item_input_cnode->inputs().size();
|
||||
if ((index + 1) >= input_cnode_inputs_size) {
|
||||
MS_LOG(ERROR) << "value node index is out of range.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
bool replace_succ = manager->Replace(anf_node, get_item_input_cnode->input(index + 1));
|
||||
if (!replace_succ) {
|
||||
MS_LOG(ERROR) << "replace identity failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
int status = RET_OK;
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto type = opt::GetCNodeType(node);
|
||||
if (type != schema::PrimitiveType_Identity) {
|
||||
continue;
|
||||
if (type == schema::PrimitiveType_Identity) {
|
||||
status = ReplaceIdentity(node, manager);
|
||||
} else if (type == schema::PrimitiveType_TupleGetItem) {
|
||||
status = ReplaceTupleGetItem(node, manager);
|
||||
}
|
||||
auto identity_cnode = node->cast<CNodePtr>();
|
||||
if (identity_cnode->inputs().size() != lite::kDoubleNum) {
|
||||
MS_LOG(ERROR) << "The `node input is a single tensor";
|
||||
return RET_ERROR;
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "remove identity pass is failed.";
|
||||
return false;
|
||||
}
|
||||
manager->Replace(node, identity_cnode->input(1));
|
||||
}
|
||||
for (auto &node : remove_cnode_) {
|
||||
func_graph->DropNode(node);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -17,8 +17,10 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
namespace mindspore::opt {
|
||||
|
@ -26,11 +28,12 @@ class RemoveIdentityOpPass : public Pass {
|
|||
public:
|
||||
RemoveIdentityOpPass() : Pass("remove_identity_pass") {}
|
||||
~RemoveIdentityOpPass() override = default;
|
||||
void SetFmkType(FmkType fmkType) { this->fmk_type = fmkType; }
|
||||
int ReplaceIdentity(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
|
||||
int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
FmkType fmk_type = lite::converter::FmkType_ONNX;
|
||||
std::set<AnfNodePtr> remove_cnode_;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_
|
||||
|
|
Loading…
Reference in New Issue