!13309 add functionalize_cond & tf_bidirection_gru_cf_fusion

From: @wangzhe128
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-18 09:02:27 +08:00 committed by Gitee
commit bc7db669cb
38 changed files with 1279 additions and 215 deletions

View File

@ -240,13 +240,15 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
${LITE_DIR}/tools/optimizer/graph/group_depthwise_op_convert_pass.cc
${LITE_DIR}/tools/optimizer/graph/tflite_inputs_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/update_conv2d_param_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_node_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/redundant_op_remove_pass.cc
@ -258,6 +260,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/if_pass.cc
${LITE_DIR}/tools/optimizer/graph/functionalize_control_op_pass.cc
${LITE_DIR}/tools/optimizer/graph/functionalize_while.cc
${LITE_DIR}/tools/optimizer/graph/functionalize_cond.cc
${LITE_DIR}/tools/optimizer/graph/inputs_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/primitive_adjust_pass.cc
)

View File

@ -61,3 +61,4 @@ ml_noya_tts_melgan.pb 1;16,16,80
ml_video_edit_oneclick_adaptis.pb 3
# Q_hand_0812.pb is not suitable for float16. Out of float16 range.
Q_hand_0812.pb
tacotron_encoder_stf.pb 5;1:1,62:1,62:1,62:1,62

View File

@ -50,13 +50,15 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/conv_conv_fusion.cc
../optimizer/fusion/tflite_lstm_cell_fusion.cc
../optimizer/fusion/tf_lstm_cell_fusion.cc
../optimizer/fusion/bidirection_tf_gru_cell_fusion.cc
../optimizer/fusion/tf_bidirection_gru_fusion.cc
../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc
../optimizer/graph/clip_convert_activation_pass.cc
../optimizer/graph/group_depthwise_op_convert_pass.cc
../optimizer/graph/tflite_inputs_adjust_pass.cc
../optimizer/graph/update_conv2d_param_pass.cc
../optimizer/graph/unused_node_remove_pass.cc
../optimizer/graph/unused_cast_node_remove_pass.cc
../optimizer/graph/unused_transpose_node_remove_pass.cc
../optimizer/graph/redundant_op_remove_pass.cc
@ -68,6 +70,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/if_pass.cc
../optimizer/graph/functionalize_control_op_pass.cc
../optimizer/graph/functionalize_while.cc
../optimizer/graph/functionalize_cond.cc
../optimizer/graph/inputs_adjust_pass.cc
../optimizer/graph/primitive_adjust_pass.cc
)

View File

@ -18,6 +18,8 @@
#include <memory>
#include <string>
#include "src/common/log_adapter.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "mindspore/core/ir/primitive.h"
#include "tools/optimizer/fusion/conv_biasadd_fusion.h"
#include "tools/optimizer/fusion/conv_activation_fusion.h"
#include "tools/optimizer/fusion/conv_tuple_activation_fusion.h"
@ -31,7 +33,8 @@
#include "tools/optimizer/fusion/conv_conv_fusion.h"
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
#include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
#include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h"
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
#include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h"
#include "tools/optimizer/graph/primitive_adjust_pass.h"
#include "tools/optimizer/graph/mindir_adjust_pass.h"
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
@ -42,6 +45,7 @@
#include "tools/optimizer/graph/tflite_inputs_adjust_pass.h"
#include "tools/optimizer/graph/onnx_inputs_adjust_pass.h"
#include "tools/optimizer/graph/update_conv2d_param_pass.h"
#include "tools/optimizer/graph/unused_node_remove_pass.h"
#include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
#include "tools/optimizer/graph/unused_transpose_node_remove_pass.h"
#include "tools/optimizer/graph/infershape_pass.h"
@ -81,7 +85,7 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti
fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfliteLstmCellFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>());
fusion_pm->AddPass(std::make_shared<opt::BiDirectionTfGruCellFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfBidirectionGruFusion>());
}
if (config->fmk == lite::converter::FmkType_MS) {
auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();
@ -225,6 +229,23 @@ int AnfTransform::RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter
return RET_OK;
}
int AnfTransform::RunPrecedingPass(const FuncGraphPtr &old_graph, const converter::Flags &config) {
MS_ASSERT(old_graph != nullptr);
auto asylic_optimizer = std::make_shared<opt::GraphOptimizer>();
auto asylic_pm = std::make_shared<opt::PassManager>("asylic pass manager", false);
// fuse tf1.x bidirection_gru into GRU, must be placed here because graph is cyclic
asylic_pm->AddPass(std::make_shared<opt::TfBidirectionGruCfFusion>());
// remove remaining cyclic nodes
asylic_pm->AddPass(std::make_shared<opt::UnusedNodeRemovePass>());
asylic_optimizer->AddPassManager(asylic_pm);
if (!asylic_optimizer->Optimize(old_graph)) {
MS_LOG(ERROR) << "gru cf fusion pass failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
return RET_OK;
}
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config,
const FuncGraphPtr &new_graph) {
// quant
@ -266,7 +287,13 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
return old_graph;
}
auto status = RunAdjustPass(old_graph, config);
auto status = RunPrecedingPass(old_graph, *config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run Preceding pass failed.";
return nullptr;
}
status = RunAdjustPass(old_graph, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run Adjust pass failed.";
return nullptr;

View File

@ -50,6 +50,8 @@ class AnfTransform {
static int AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);
static int RunPrecedingPass(const FuncGraphPtr &old_graph, const converter::Flags &config);
static int RunAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
static int RunMindirAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);

View File

@ -1,37 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
using mindspore::ops::PrimitiveC;
namespace mindspore {
namespace lite {
constexpr auto kNameIf = "If";
class If : public PrimitiveC {
public:
If() : PrimitiveC(kNameIf) {}
~If() = default;
MS_DECLARE_PARENT(If, PrimitiveC);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_

View File

@ -1,39 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_LOOP_COND_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_LOOP_COND_H_
#include <vector>
#include <set>
#include <cmath>
#include "ops/primitive_c.h"
using mindspore::ops::PrimitiveC;
namespace mindspore {
namespace lite {
constexpr auto kNameLoopCond = "LoopCond";
class LoopCond : public PrimitiveC {
public:
LoopCond() : PrimitiveC(kNameLoopCond) {}
~LoopCond() = default;
MS_DECLARE_PARENT(LoopCond, PrimitiveC);
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_LOOP_COND_H_

View File

@ -17,16 +17,31 @@
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_
#include "schema/inner/model_generated.h"
#include "ops/primitive_c.h"
using mindspore::ops::PrimitiveC;
namespace mindspore {
namespace lite {
#define ADD_CONVERTER_ONLY_OP(name) \
constexpr auto kName##name = #name; \
class name : public PrimitiveC { \
public: \
name() : PrimitiveC(kName##name) {} \
~name() = default; \
MS_DECLARE_PARENT(name, PrimitiveC); \
};
enum ConverterPrimitiveType {
ConverterPrimitiveType_Enter = schema::PrimitiveType_MAX + 1,
ConverterPrimitiveType_LoopCond,
ConverterPrimitiveType_NextIteration,
ConverterPrimitiveType_Exit,
};
ADD_CONVERTER_ONLY_OP(Enter);
ADD_CONVERTER_ONLY_OP(Exit);
ADD_CONVERTER_ONLY_OP(If);
ADD_CONVERTER_ONLY_OP(LoopCond);
ADD_CONVERTER_ONLY_OP(NextIteration);
ADD_CONVERTER_ONLY_OP(TensorArrayGatherV3);
ADD_CONVERTER_ONLY_OP(TensorArrayReadV3);
ADD_CONVERTER_ONLY_OP(TensorArrayScatterV3);
ADD_CONVERTER_ONLY_OP(TensorArraySizeV3);
ADD_CONVERTER_ONLY_OP(TensorArrayV3);
ADD_CONVERTER_ONLY_OP(TensorArrayWriteV3);
} // namespace lite
} // namespace mindspore

View File

@ -17,7 +17,7 @@
#include "tools/converter/parser/onnx/onnx_if_parser.h"
#include <memory>
#include "tools/converter/parser/onnx/onnx_model_parser.h"
#include "tools/converter/ops/if.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {

View File

@ -19,7 +19,7 @@
#include <vector>
#include "tools/converter/parser/tf/tf_enter_parser.h"
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/enter.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {

View File

@ -18,7 +18,7 @@
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/exit.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {

View File

@ -19,7 +19,7 @@
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/if.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {

View File

@ -18,7 +18,7 @@
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/loop_cond.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {

View File

@ -28,7 +28,7 @@ ops::PrimitiveC *TFMergeParser::Parse(const tensorflow::NodeDef &tf_op,
std::vector<std::string> *inputs, int *output_size) {
auto prim = std::make_unique<ops::Merge>();
*output_size = tf_op.input_size();
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}

View File

@ -18,7 +18,7 @@
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/next_iteration.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {

View File

@ -28,7 +28,7 @@ ops::PrimitiveC *TFSwitchParser::Parse(const tensorflow::NodeDef &tf_op,
std::vector<std::string> *inputs, int *output_size) {
auto prim = std::make_unique<ops::Switch>();
*output_size = tf_op.input_size();
*output_size = 2;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}

View File

@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_tensor_array_gather_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *TFTensorArrayGatherParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF TensorArrayGatherParser";
if (inputs == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "inputs or output_size is nullptr";
return nullptr;
}
auto prim = std::make_unique<TensorArrayGatherV3>();
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is nullptr";
return nullptr;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return prim.release();
}
TFNodeRegistrar g_tfTensorArrayGatherParser("TensorArrayGatherV3", new TFTensorArrayGatherParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_GATHER_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_GATHER_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFTensorArrayGatherParser : public TFNodeParser {
public:
TFTensorArrayGatherParser() = default;
~TFTensorArrayGatherParser() override = default;
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_GATHER_PARSER_H_

View File

@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_tensor_array_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *TFTensorArrayParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF TensorArrayParser";
if (inputs == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "inputs or output_size is nullptr";
return nullptr;
}
auto prim = std::make_unique<TensorArrayV3>();
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is nullptr";
return nullptr;
}
*output_size = 2;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return prim.release();
}
TFNodeRegistrar g_tfTensorArrayParser("TensorArrayV3", new TFTensorArrayParser());
} // namespace lite
} // namespace mindspore

View File

@ -13,27 +13,25 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ENTER_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ENTER_H_
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include <set>
#include <cmath>
#include "ops/primitive_c.h"
using mindspore::ops::PrimitiveC;
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
constexpr auto kNameEnter = "Enter";
class Enter : public PrimitiveC {
class TFTensorArrayParser : public TFNodeParser {
public:
Enter() : PrimitiveC(kNameEnter) {}
~Enter() = default;
MS_DECLARE_PARENT(Enter, PrimitiveC);
TFTensorArrayParser() = default;
~TFTensorArrayParser() override = default;
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ENTER_H_
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_PARSER_H_

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_tensor_array_read_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *TFTensorArrayReadParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF TensorArrayReadParser";
if (inputs == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "inputs or output_size is nullptr";
return nullptr;
}
auto prim = std::make_unique<TensorArrayReadV3>();
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is nullptr";
return nullptr;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return prim.release();
}
TFNodeRegistrar g_tfTensorArrayReadParser("TensorArrayReadV3", new TFTensorArrayReadParser());
} // namespace lite
} // namespace mindspore

View File

@ -13,27 +13,25 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_EXIT_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_EXIT_H_
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_READ_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_READ_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include <set>
#include <cmath>
#include "ops/primitive_c.h"
using mindspore::ops::PrimitiveC;
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
constexpr auto kNameExit = "Exit";
class Exit : public PrimitiveC {
class TFTensorArrayReadParser : public TFNodeParser {
public:
Exit() : PrimitiveC(kNameExit) {}
~Exit() = default;
MS_DECLARE_PARENT(Exit, PrimitiveC);
TFTensorArrayReadParser() = default;
~TFTensorArrayReadParser() override = default;
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_EXIT_H_
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_READ_PARSER_H_

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_tensor_array_scatter_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *TFTensorArrayScatterParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF TensorArrayScatterParser";
if (inputs == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "inputs or output_size is nullptr";
return nullptr;
}
auto prim = std::make_unique<TensorArrayScatterV3>();
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is nullptr";
return nullptr;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return prim.release();
}
TFNodeRegistrar g_tfTensorArrayScatterParser("TensorArrayScatterV3", new TFTensorArrayScatterParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SCATTER_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SCATTER_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFTensorArrayScatterParser : public TFNodeParser {
public:
TFTensorArrayScatterParser() = default;
~TFTensorArrayScatterParser() override = default;
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SCATTER_PARSER_H_

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_tensor_array_size_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *TFTensorArraySizeParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF TensorArraySizeParser";
if (inputs == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "inputs or output_size is nullptr";
return nullptr;
}
auto prim = std::make_unique<TensorArraySizeV3>();
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is nullptr";
return nullptr;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return prim.release();
}
TFNodeRegistrar g_tfTensorArraySizeParser("TensorArraySizeV3", new TFTensorArraySizeParser());
} // namespace lite
} // namespace mindspore

View File

@ -13,27 +13,25 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_NEXT_ITERATION_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_NEXT_ITERATION_H_
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SIZE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SIZE_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include <set>
#include <cmath>
#include "ops/primitive_c.h"
using mindspore::ops::PrimitiveC;
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
constexpr auto kNameNextIteration = "NextIteration";
class NextIteration : public PrimitiveC {
class TFTensorArraySizeParser : public TFNodeParser {
public:
NextIteration() : PrimitiveC(kNameNextIteration) {}
~NextIteration() = default;
MS_DECLARE_PARENT(NextIteration, PrimitiveC);
TFTensorArraySizeParser() = default;
~TFTensorArraySizeParser() override = default;
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_NEXT_ITERATION_H_
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SIZE_PARSER_H_

View File

@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_tensor_array_write_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *TFTensorArrayWriteParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(DEBUG) << "TF TensorArrayWriteParser";
if (inputs == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "inputs or output_size is nullptr";
return nullptr;
}
auto prim = std::make_unique<TensorArrayWriteV3>();
if (prim == nullptr) {
MS_LOG(ERROR) << "prim is nullptr";
return nullptr;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return prim.release();
}
TFNodeRegistrar g_tfTensorArrayWriteParser("TensorArrayWriteV3", new TFTensorArrayWriteParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_WRITE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_WRITE_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFTensorArrayWriteParser : public TFNodeParser {
public:
TFTensorArrayWriteParser() = default;
~TFTensorArrayWriteParser() override = default;
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_WRITE_PARSER_H_

View File

@ -0,0 +1,214 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h"
#include <memory>
#include <set>
#include <functional>
#include "src/common/utils.h"
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h"
#include "tools/converter/ops/ops_def.h"
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kNumFwVars = 4;
constexpr size_t kNumBwVars = 4;
const auto &p1 = std::placeholders::_1;
BaseRef GetPrim(const PrimitivePtr &prim) { return std::make_shared<CondVar>(std::bind(IsOpType, p1, prim)); }
BaseRef GetPrim(const std::string &prim_name) {
auto prim = std::make_shared<Primitive>(prim_name);
return GetPrim(prim);
}
} // namespace
TfBidirectionGruCfFusion::TfBidirectionGruCfFusion(const std::string &name, bool multi_graph)
: TfBidirectionGruFusion(kNumFwVars, kNumBwVars, name, multi_graph) {
/*
* vars for fw/bw input
* fw:
* 0:kernel_gate 1:bias_gate 2:cand_kernel 3:cand_bias
* bw:
* 0:kernel_gate 1:bias_gate 2:cand_kernel 3:cand_bias
*/
}
BaseRef TfBidirectionGruCfFusion::DefineGruCellPattern(const BaseRef &in_ta_read, const BaseRef &switch3_true,
const std::vector<VarPtr> &vars) const {
auto concat = VectorRef({GetPrim(prim::kPrimConcat), in_ta_read, switch3_true});
auto matmul_enter = VectorRef({GetPrim(lite::kNameEnter), vars[0]}); // gate_kernel
auto matmul = VectorRef({GetPrim(prim::kPrimMatMul), concat, matmul_enter});
auto bias_enter = VectorRef({GetPrim(lite::kNameEnter), vars[1]}); // cand_bias
auto bias = VectorRef({GetPrim(prim::kPrimBiasAdd), matmul, bias_enter});
auto sigmoid = VectorRef({GetPrim(prim::kPrimActivation), bias});
auto split = VectorRef({GetPrim(prim::kPrimSplit), sigmoid});
auto rt = VectorRef({GetPrim(prim::kPrimTupleGetItem), split, std::make_shared<Var>()});
auto zt = VectorRef({GetPrim(prim::kPrimTupleGetItem), split, std::make_shared<Var>()});
auto mul = VectorRef({GetPrim(prim::kPrimMulFusion), rt, switch3_true});
auto concat1 = VectorRef({GetPrim(prim::kPrimConcat), in_ta_read, mul});
auto matmul1_enter = VectorRef({GetPrim(lite::kNameEnter), vars[2]}); // cand_kernel
auto matmul1 = VectorRef({GetPrim(prim::kPrimMatMul), concat1, matmul1_enter});
auto bias1_enter = VectorRef({GetPrim(lite::kNameEnter), vars[3]}); // cand_bias
auto bias1 = VectorRef({GetPrim(prim::kPrimBiasAdd), matmul1, bias1_enter});
auto tanh = VectorRef({GetPrim(prim::kPrimActivation), bias1});
auto sub = VectorRef({GetPrim(prim::kPrimSubFusion), std::make_shared<CondVar>(IsParameterNode), zt});
auto mul2 = VectorRef({GetPrim(prim::kPrimMulFusion), sub, tanh});
auto mul1 = VectorRef({GetPrim(prim::kPrimMulFusion), zt, switch3_true});
auto add = VectorRef({GetPrim(prim::kPrimAddFusion), mul1, mul2});
return add;
}
const BaseRef TfBidirectionGruCfFusion::DefineBidirectionRnnPattern(const BaseRef &input,
const std::vector<VarPtr> &vars,
const VarPtr &init_state) const {
// in order to match cyclic graph, some node in cycle is represented by SeqVar
auto fw_shape1 = VectorRef({GetPrim(prim::kPrimShape), input});
auto strided_slice = VectorRef({GetPrim(prim::kPrimStridedSlice), fw_shape1, std::make_shared<SeqVar>()});
auto fw_max = VectorRef({GetPrim(prim::kPrimReduceFusion), input_length_, std::make_shared<Var>()});
auto fw_maximum = VectorRef({GetPrim(prim::kPrimMaximum), std::make_shared<CondVar>(IsParameterNode), fw_max});
auto fw_minimum = VectorRef({GetPrim(prim::kPrimMinimum), strided_slice, fw_maximum});
auto fw_less1_enter = VectorRef({GetPrim(lite::kNameEnter), fw_minimum});
// SeqVar:counter_merge1
auto fw_less1 = VectorRef({GetPrim(prim::kPrimLess), std::make_shared<SeqVar>(), fw_less1_enter});
// SeqVar:fw_merge,loop_cond
auto fw_switch = VectorRef({GetPrim(prim::kPrimSwitch), std::make_shared<SeqVar>()});
auto fw_switch_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), fw_switch, std::make_shared<Var>()}); // identity
auto fw_add = VectorRef({GetPrim(prim::kPrimAddFusion), fw_switch_true, std::make_shared<CondVar>(IsParameterNode)});
auto fw_next_iter = VectorRef({GetPrim(lite::kNameNextIteration), fw_add});
auto fw_merge_enter = VectorRef({GetPrim(lite::kNameEnter), std::make_shared<CondVar>(IsParameterNode)});
auto fw_merge = VectorRef({GetPrim(prim::kPrimMerge), fw_merge_enter, fw_next_iter});
auto fw_less_enter = VectorRef({GetPrim(lite::kNameEnter), strided_slice});
auto fw_less = VectorRef({GetPrim(prim::kPrimLess), fw_merge, fw_less_enter});
auto fw_logical_and = VectorRef({GetPrim(prim::kPrimLogicalAnd), fw_less, fw_less1});
// SeqVar:fw_logical_and
auto loop_cond = VectorRef({GetPrim(lite::kNameLoopCond), fw_logical_and});
auto fw_shape = VectorRef({GetPrim(prim::kPrimShape), input});
auto fw_unstack_strided_slice = VectorRef({GetPrim(prim::kPrimStridedSlice), fw_shape, std::make_shared<SeqVar>()});
auto fw_unstack_range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<CondVar>(IsParameterNode),
fw_unstack_strided_slice, std::make_shared<CondVar>(IsParameterNode)});
// SeqVar:switch1_true
auto counter_add =
VectorRef({GetPrim(prim::kPrimAddFusion), std::make_shared<SeqVar>(), std::make_shared<CondVar>(IsParameterNode)});
auto counter_zero = VectorRef({GetPrim(lite::kNameEnter), std::make_shared<CondVar>(IsParameterNode)});
auto counter_next_iter = VectorRef({GetPrim(lite::kNameNextIteration), counter_add});
auto counter_merge1 = VectorRef({GetPrim(prim::kPrimMerge), counter_zero, counter_next_iter});
auto counter_switch1 = VectorRef({GetPrim(prim::kPrimSwitch), counter_merge1, loop_cond});
auto switch1_true =
VectorRef({GetPrim(prim::kPrimTupleGetItem), counter_switch1, std::make_shared<Var>()}); // identity1
auto in_ta = VectorRef({GetPrim(lite::kNameTensorArrayV3), strided_slice});
auto in_ta_handle = VectorRef({GetPrim(prim::kPrimTupleGetItem), in_ta, std::make_shared<Var>()});
auto in_ta_flow = VectorRef({GetPrim(prim::kPrimTupleGetItem), in_ta, std::make_shared<Var>()});
auto fw_unstack_ta_scatter =
VectorRef({GetPrim(lite::kNameTensorArrayScatterV3), in_ta_handle, fw_unstack_range, input, in_ta_flow});
auto in_ta_enter1 = VectorRef({GetPrim(lite::kNameEnter), fw_unstack_ta_scatter});
auto in_ta_enter = VectorRef({GetPrim(lite::kNameEnter), in_ta_handle});
auto in_ta_read = VectorRef({GetPrim(lite::kNameTensorArrayReadV3), in_ta_enter, switch1_true, in_ta_enter1});
auto greater_equal_enter = VectorRef({GetPrim(lite::kNameEnter), input_length_});
auto greater_equal = VectorRef({GetPrim(prim::kPrimGreaterEqual), switch1_true, greater_equal_enter});
auto select1 = VectorRef({GetPrim(prim::kPrimSelect), greater_equal, std::make_shared<SeqVar>()}); // select h
auto next_iteration3 = VectorRef({GetPrim(lite::kNameNextIteration), select1});
auto enter3 = VectorRef({GetPrim(lite::kNameEnter), init_state});
auto merge3 = VectorRef({GetPrim(prim::kPrimMerge), enter3, next_iteration3});
auto switch3 = VectorRef({GetPrim(prim::kPrimSwitch), merge3, loop_cond});
auto switch3_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), switch3, std::make_shared<Var>()}); // identity3
auto rnn_cell_out = DefineGruCellPattern(in_ta_read, switch3_true, vars);
auto out_ta = VectorRef({GetPrim(lite::kNameTensorArrayV3), strided_slice});
auto out_ta_handle = VectorRef({GetPrim(prim::kPrimTupleGetItem), out_ta, std::make_shared<Var>()});
auto out_ta_flow = VectorRef({GetPrim(prim::kPrimTupleGetItem), out_ta, std::make_shared<Var>()});
auto out_ta_enter = VectorRef({GetPrim(lite::kNameEnter), out_ta_handle});
auto switch2_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), std::make_shared<SeqVar>()}); // cycle
auto concat1 = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared<SeqVar>()});
auto zeros1 = VectorRef({GetPrim(prim::kPrimFill), std::make_shared<CondVar>(IsParameterNode), concat1});
auto select_enter = VectorRef({GetPrim(lite::kNameEnter), zeros1});
auto select = VectorRef({GetPrim(prim::kPrimSelect), greater_equal, select_enter, rnn_cell_out}); // select x
auto ta_write = VectorRef({GetPrim(lite::kNameTensorArrayWriteV3), out_ta_enter, switch1_true, select, switch2_true});
auto enter2 = VectorRef({GetPrim(lite::kNameEnter), out_ta_flow});
auto next_iter2 = VectorRef({GetPrim(lite::kNameNextIteration), ta_write});
auto merge2 = VectorRef({GetPrim(prim::kPrimMerge), enter2, next_iter2});
auto switch2 = VectorRef({GetPrim(prim::kPrimSwitch), merge2, loop_cond});
auto switch2_false = VectorRef({GetPrim(prim::kPrimTupleGetItem), switch2, std::make_shared<Var>()});
auto exit2 = VectorRef({GetPrim(lite::kNameExit), switch2_false});
auto ta_size = VectorRef({GetPrim(lite::kNameTensorArraySizeV3), out_ta_handle, exit2});
auto range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<Var>(), ta_size, std::make_shared<Var>()});
auto tensor_array_gather = VectorRef({GetPrim(lite::kNameTensorArrayGatherV3), out_ta_handle, range, exit2});
auto range1 = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<SeqVar>()});
auto concat2 = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared<CondVar>(IsParameterNode), range1});
auto fw_out_trans = VectorRef({GetPrim(prim::kPrimTranspose), tensor_array_gather, concat2});
return fw_out_trans;
}
const BaseRef TfBidirectionGruCfFusion::DefinePattern() const {
const auto fw_out_trans = DefineBidirectionRnnPattern(transpose_input_, fw_vars_, fw_init_state_);
auto bw_reverse_in = VectorRef({GetPrim(prim::kPrimReverseSequence), input_, input_length_});
auto bw_range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<SeqVar>()});
auto bw_concat = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared<CondVar>(IsParameterNode), bw_range});
auto bw_transpose = VectorRef({GetPrim(prim::kPrimTranspose), bw_reverse_in, bw_concat});
auto bw_out_trans = DefineBidirectionRnnPattern(bw_transpose, bw_vars_, bw_init_state_);
auto bw_reverse_out = VectorRef({GetPrim(prim::kPrimReverseSequence), bw_out_trans, input_length_});
auto concat = VectorRef({GetPrim(prim::kPrimConcat), fw_out_trans, bw_reverse_out});
return concat;
}
const AnfNodePtr TfBidirectionGruCfFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node,
const EquivPtr &equiv) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(concat_node != nullptr);
MS_LOG(DEBUG) << "bidirection tf gru fusion pass";
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(concat_node) != lite::RET_OK) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
auto transpose_input = utils::cast<AnfNodePtr>((*equiv)[transpose_input_]);
MS_ASSERT(transpose_input != nullptr);
const std::string gru_name = "gru_" + concat_node->fullname_with_scope();
auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, gru_name, 0);
if (gru_node == nullptr) {
return nullptr;
}
if (TfliteLstmCellFusion::SetAbstractTuple(gru_node, 2) != RET_OK) {
return nullptr;
}
auto get_item_node = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, gru_node, 0);
if (get_item_node == nullptr) {
return nullptr;
}
auto output_node = GetPostProcessNode(func_graph, get_item_node, gru_node->fullname_with_scope());
MS_LOG(INFO) << "gru node:" << gru_node->fullname_with_scope() << " fusion success";
return output_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,48 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_CF_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_CF_FUSION_H_
#include <vector>
#include <memory>
#include <string>
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
#include "schema/inner/model_generated.h"
#include "src/param_value_lite.h"
#include "backend/optimizer/common/optimizer.h"
#include "utils/utils.h"
#include "include/errorcode.h"
namespace mindspore {
namespace opt {
// fuse tf 1.x bidirection_gru into MSLITE GRU
class TfBidirectionGruCfFusion : public TfBidirectionGruFusion {
public:
explicit TfBidirectionGruCfFusion(const std::string &name = "tf_bidirection_gru_cf_fusion", bool multi_graph = true);
~TfBidirectionGruCfFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
BaseRef DefineGruCellPattern(const BaseRef &in_ta_read, const BaseRef &switch3_true,
const std::vector<VarPtr> &vars) const;
const BaseRef DefineBidirectionRnnPattern(const BaseRef &input, const std::vector<VarPtr> &vars,
const VarPtr &init_state) const;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_CF_FUSION_H_

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h"
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
#include <memory>
#include <functional>
#include "ops/concat.h"
@ -24,32 +24,21 @@
#include "ops/transpose.h"
#include "src/common/utils.h"
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h"
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kWhileUniqInputsLength = 6;
constexpr size_t kCondNodesNum = 12;
constexpr size_t kCondCNodesNum = 4;
constexpr size_t kBodyNodesNum = 69;
constexpr size_t kBodyCNodesNum = 25;
const auto &p1 = std::placeholders::_1;
bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); }
bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) {
if (utils::isa<AnfNodePtr>(n)) {
auto anf_node = utils::cast<AnfNodePtr>(n);
return CheckPrimitiveType(anf_node, prim);
}
return false;
}
} // namespace
BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, bool multigraph)
: PatternProcessPass(name, multigraph) {
TfBidirectionGruFusion::TfBidirectionGruFusion(int num_fw_vars, int num_bw_vars, const std::string &name,
bool multi_graph)
: PatternProcessPass(name, multi_graph) {
/*
* vars for while input
* fw_while_inputs:
@ -57,8 +46,10 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name,
* bw_while_inputs:
* 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias
*/
for (size_t i = 0; i < kWhileUniqInputsLength; ++i) {
for (int i = 0; i < num_fw_vars; ++i) {
fw_vars_.emplace_back(std::make_shared<Var>());
}
for (int i = 0; i < num_bw_vars; ++i) {
bw_vars_.emplace_back(std::make_shared<Var>());
}
input_ = std::make_shared<Var>();
@ -68,7 +59,7 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name,
bw_init_state_ = std::make_shared<Var>();
}
const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
const BaseRef TfBidirectionGruFusion::DefinePattern() const {
// forward
auto fw_reduce = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)),
input_length_, std::make_shared<CondVar>(IsParameterNode)});
@ -134,7 +125,7 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
return concat;
}
AnfNodePtr BiDirectionTfGruCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
AnfNodePtr TfBidirectionGruFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
auto is_parameter1 = std::make_shared<CondVar>(IsParameterNode);
auto is_parameter2 = std::make_shared<CondVar>(IsParameterNode);
auto is_parameter3 = std::make_shared<CondVar>(IsParameterNode);
@ -152,7 +143,7 @@ AnfNodePtr BiDirectionTfGruCellFusion::GetCondGraphPattern(const PrimitiveVarMap
return pattern;
}
AnfNodePtr BiDirectionTfGruCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
AnfNodePtr TfBidirectionGruFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
std::vector<CondVarPtr> placeholders;
for (int i = 0; i < 13; ++i) {
placeholders.emplace_back(std::make_shared<CondVar>(IsParameterNode));
@ -206,7 +197,7 @@ AnfNodePtr BiDirectionTfGruCellFusion::GetBodyGraphPattern(const PrimitiveVarMap
return pattern;
}
ParamValueLitePtr BiDirectionTfGruCellFusion::GetDefaultParamValue(const AnfNodePtr &parameter_anf) const {
ParamValueLitePtr TfBidirectionGruFusion::GetDefaultParamValue(const AnfNodePtr &parameter_anf) const {
MS_ASSERT(parameter_anf != nullptr);
if (!utils::isa<ParameterPtr>(parameter_anf)) {
MS_LOG(DEBUG) << "parameter_anf is not ParameterPtr";
@ -221,9 +212,9 @@ ParamValueLitePtr BiDirectionTfGruCellFusion::GetDefaultParamValue(const AnfNode
return param_value;
}
STATUS BiDirectionTfGruCellFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf,
const AnfNodePtr &bw_cand_kernel_anf, int *input_size,
int *hidden_size) const {
STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf,
const AnfNodePtr &bw_cand_kernel_anf, int *input_size,
int *hidden_size) const {
MS_ASSERT(fw_cand_kernel != nullptr);
MS_ASSERT(bw_cand_kernel != nullptr);
MS_ASSERT(input_size != nullptr);
@ -256,9 +247,9 @@ STATUS BiDirectionTfGruCellFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_ca
return RET_OK;
}
ParameterPtr BiDirectionTfGruCellFusion::AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name,
const std::vector<int> &shape, const TypeId type,
void **tensor_data) const {
ParameterPtr TfBidirectionGruFusion::AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name,
const std::vector<int> &shape, const TypeId type,
void **tensor_data) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(tensor_data != nullptr);
auto parameter = func_graph->add_parameter();
@ -300,9 +291,8 @@ ParameterPtr BiDirectionTfGruCellFusion::AddDefaultParameter(const FuncGraphPtr
return parameter;
}
void BiDirectionTfGruCellFusion::CopyFlattenMatData(const float *mat, const int R, const int C, const int r0,
const int r1, const int c0, const int c1, float *data,
bool t) const {
void TfBidirectionGruFusion::CopyFlattenMatData(const float *mat, const int R, const int C, const int r0, const int r1,
const int c0, const int c1, float *data, bool t) const {
MS_ASSERT(mat != nullptr);
MS_ASSERT(data != nullptr);
MS_ASSERT(0 <= r0 && r0 < r1 && r1 <= R);
@ -320,9 +310,9 @@ void BiDirectionTfGruCellFusion::CopyFlattenMatData(const float *mat, const int
}
}
STATUS BiDirectionTfGruCellFusion::ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight,
const int input_size, const int hidden_size,
float *gate_tensor_data, float *recu_tensor_data) const {
STATUS TfBidirectionGruFusion::ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight,
const int input_size, const int hidden_size, float *gate_tensor_data,
float *recu_tensor_data) const {
MS_ASSERT(gate_weight != nullptr);
MS_ASSERT(cand_weight != nullptr);
MS_ASSERT(gate_tensor_data != nullptr);
@ -375,8 +365,8 @@ STATUS BiDirectionTfGruCellFusion::ConvertWeightData(const AnfNodePtr &gate_weig
return RET_OK;
}
STATUS BiDirectionTfGruCellFusion::ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias,
const int hidden_size, float *tensor_data) const {
STATUS TfBidirectionGruFusion::ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias,
const int hidden_size, float *tensor_data) const {
MS_ASSERT(bias != nullptr);
MS_ASSERT(tensor_data != nullptr);
std::vector<int> gate_shape{hidden_size * 2};
@ -407,10 +397,9 @@ STATUS BiDirectionTfGruCellFusion::ConvertBiasData(const AnfNodePtr &gate_bias,
return RET_OK;
}
CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph,
const AnfNodePtr &fw_init_state,
const AnfNodePtr &bw_init_state,
const std::string base_name) const {
CNodePtr TfBidirectionGruFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &fw_init_state,
const AnfNodePtr &bw_init_state,
const std::string base_name) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(fw_init_state != nullptr);
MS_ASSERT(bw_init_state != nullptr);
@ -424,35 +413,32 @@ CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &f
return new_node;
}
CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
const EquivPtr &equiv, const EquivPtr &fw_body_equiv,
const EquivPtr &bw_body_equiv,
const std::string &base_name) const {
CNodePtr TfBidirectionGruFusion::CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
const EquivPtr &equiv, const std::string &base_name,
int var_offset) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(input != nullptr);
MS_ASSERT(equiv != nullptr);
MS_ASSERT(fw_body_equiv != nullptr);
MS_ASSERT(bw_body_equiv != nullptr);
auto gru_prim = std::make_shared<ops::GRU>();
gru_prim->set_bidirectional(true);
auto value_node = NewValueNode(gru_prim);
auto fw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[2]]);
auto fw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset]]);
MS_ASSERT(fw_gate_kernel != nullptr);
auto fw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[3]]);
auto fw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset + 1]]);
MS_ASSERT(fw_gate_bias != nullptr);
auto fw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[4]]);
auto fw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset + 2]]);
MS_ASSERT(fw_cand_kernel != nullptr);
auto fw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[5]]);
auto fw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset + 3]]);
MS_ASSERT(fw_cand_bias != nullptr);
auto bw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[2]]);
auto bw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset]]);
MS_ASSERT(bw_gate_kernel != nullptr);
auto bw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[3]]);
auto bw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset + 1]]);
MS_ASSERT(bw_gate_bias != nullptr);
auto bw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[4]]);
auto bw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset + 2]]);
MS_ASSERT(bw_cand_kernel != nullptr);
auto bw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[5]]);
auto bw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset + 3]]);
MS_ASSERT(bw_cand_bias != nullptr);
auto fw_init_state = utils::cast<AnfNodePtr>((*equiv)[fw_init_state_]);
@ -522,8 +508,8 @@ CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr
return new_node;
}
CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
const std::string base_name) const {
CNodePtr TfBidirectionGruFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
const std::string base_name) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(gru_output != nullptr);
auto split_prim = std::make_shared<ops::Split>();
@ -571,8 +557,8 @@ CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func
return transpose_new_node;
}
const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node,
const EquivPtr &equiv) const {
const AnfNodePtr TfBidirectionGruFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node,
const EquivPtr &equiv) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(concat_node != nullptr);
MS_LOG(DEBUG) << "bidirection tf gru fusion pass";
@ -628,7 +614,7 @@ const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_gr
}
const std::string gru_name = "gru_" + concat_node->fullname_with_scope();
auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, fw_body_equiv, bw_body_equiv, gru_name);
auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, gru_name, 2);
if (gru_node == nullptr) {
return nullptr;
}

View File

@ -13,12 +13,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_
#include <vector>
#include <memory>
#include <string>
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "schema/inner/model_generated.h"
#include "src/param_value_lite.h"
#include "backend/optimizer/common/optimizer.h"
@ -27,22 +28,26 @@
namespace mindspore {
namespace opt {
class BiDirectionTfGruCellFusion : public PatternProcessPass {
constexpr size_t kWhileUniqInputsLength = 6;
// fuse tf 2.x bidirection_gru into MSLITE GRU
class TfBidirectionGruFusion : public PatternProcessPass {
public:
explicit BiDirectionTfGruCellFusion(const std::string &name = "bidirection_tf_gru_cell_fusion",
bool multigraph = true);
~BiDirectionTfGruCellFusion() override = default;
explicit TfBidirectionGruFusion(int num_fw_vars = kWhileUniqInputsLength, int num_bw_vars = kWhileUniqInputsLength,
const std::string &name = "tf_bidirection_gru_fusion", bool multi_graph = true);
~TfBidirectionGruFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
protected:
virtual AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const;
CNodePtr CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const EquivPtr &equiv,
const std::string &base_name, int var_offset) const;
CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
const std::string base_name) const;
private:
AnfNodePtr GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const;
CNodePtr CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const EquivPtr &equiv,
const EquivPtr &fw_body_equiv, const EquivPtr &bw_body_equiv,
const std::string &base_name) const;
ParamValueLitePtr GetDefaultParamValue(const AnfNodePtr &parameter_anf) const;
lite::STATUS GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, const AnfNodePtr &bw_cand_kernel_anf,
int *input_size, int *hidden_size) const;
@ -56,10 +61,8 @@ class BiDirectionTfGruCellFusion : public PatternProcessPass {
const int c1, float *data, bool t = false) const;
CNodePtr GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &fw_init_state,
const AnfNodePtr &bw_init_state, const std::string base_name) const;
CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
const std::string base_name) const;
private:
protected:
std::vector<VarPtr> fw_vars_;
std::vector<VarPtr> bw_vars_;
VarPtr input_;
@ -68,7 +71,16 @@ class BiDirectionTfGruCellFusion : public PatternProcessPass {
VarPtr fw_init_state_;
VarPtr bw_init_state_;
};
inline bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); }
inline bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) {
if (utils::isa<AnfNodePtr>(n)) {
auto anf_node = utils::cast<AnfNodePtr>(n);
return CheckPrimitiveType(anf_node, prim);
}
return false;
}
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_

View File

@ -0,0 +1,249 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/functionalize_cond.h"
#include <algorithm>
#include <memory>
#include <deque>
#include <unordered_set>
#include "include/errorcode.h"
#include "ops/make_tuple.h"
#include "tools/converter/ops/ops_def.h"
#include "ops/return.h"
namespace mindspore::opt {
STATUS FunctionalizeCond::GetSwitchBranchType(const CNodePtr &switch_cnode, BranchType *branch_type) {
MS_ASSERT(switch_cnode != nullptr);
MS_ASSERT(branch_type != nullptr);
auto manager = fg_->manager();
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr";
return RET_ERROR;
}
auto node_users = manager->node_users()[switch_cnode];
if (node_users.size() != 1) { // only one output of switch is referenced in cond
MS_LOG(ERROR) << "switch's node users is not correct";
return RET_ERROR;
}
auto node_user = node_users.front();
auto tuple_get_item = node_user.first;
if (!utils::isa<CNodePtr>(tuple_get_item) || !CheckPrimitiveType(tuple_get_item, prim::kPrimTupleGetItem)) {
MS_LOG(ERROR) << "switch's node user is not TupleGetItem";
return RET_ERROR;
}
auto tuple_get_item_cnode = utils::cast<CNodePtr>(tuple_get_item);
auto idx = GetTupleGetItemOutIndex(tuple_get_item_cnode);
if (idx == 0) {
*branch_type = kElseBranch;
} else if (idx == 1) {
*branch_type = kThenBranch;
} else {
MS_LOG(ERROR) << "wrong tuple_get_item index";
return RET_ERROR;
}
return RET_OK;
}
STATUS FunctionalizeCond::BranchSubGraphAddNodes(const FuncGraphPtr &graph, const AnfNodePtr &root_node,
BranchType branch_type) {
std::deque<AnfNodePtr> q;
std::unordered_set<AnfNodePtr> vis;
q.push_back(root_node);
while (!q.empty()) {
auto node = q.front();
q.pop_front();
vis.insert(node);
if (FunctionalizeControlOpPass::IsSwitch(node)) {
auto cnode = utils::cast<CNodePtr>(node);
BranchType this_type;
if (GetSwitchBranchType(cnode, &this_type) != RET_OK || this_type != branch_type) {
MS_LOG(ERROR) << "switch node in branch " << branch_type << " is not correct";
return RET_ERROR;
}
continue;
}
if (utils::isa<ParameterPtr>(node)) {
graph->add_parameter(node->cast<ParameterPtr>());
} else {
graph->AddNode(node);
}
node->set_func_graph(graph);
if (utils::isa<CNodePtr>(node)) {
auto cnode = utils::cast<CNodePtr>(node);
for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto inputi = cnode->input(i);
if (vis.find(inputi) == vis.end()) {
q.push_back(cnode->input(i));
}
}
}
}
return RET_OK;
}
int FunctionalizeCond::PosInInputNodes(const CNodePtr &node) {
auto index = std::find(input_nodes_.begin(), input_nodes_.end(), node);
if (index == input_nodes_.end()) {
input_nodes_.push_back(node);
return input_nodes_.size() - 1;
}
return index - input_nodes_.begin();
}
STATUS FunctionalizeCond::IdentifySubgraphInput(const FuncGraphPtr &graph, std::string graph_name) {
std::vector<AnfNodePtr> nodes_need_drop{};
for (auto &cnode : graph->GetOrderedCnodes()) {
for (auto &input_node : cnode->inputs()) {
if (FunctionalizeControlOpPass::IsSwitch(input_node)) {
auto switch_node = input_node->cast<CNodePtr>();
auto switch_input = utils::cast<CNodePtr>(switch_node->input(1));
auto pos = PosInInputNodes(switch_input);
nodes_need_drop.push_back(cnode);
pred_nodes_.push_back(switch_node->input(2));
// set parameter
auto parameter = graph->add_parameter();
parameter->set_abstract(cnode->abstract());
// hardcode for subgraph input name
parameter->set_name(graph_name + "_input_" + std::to_string(pos) + "_parameter");
// replace switch
auto manager = fg_->manager();
auto node_users = manager->node_users()[cnode];
for (auto &node_user : node_users) {
if (graph->nodes().contains(node_user.first)) {
manager->SetEdge(node_user.first, node_user.second, parameter);
}
}
}
}
}
return RET_OK;
}
FuncGraphPtr FunctionalizeCond::CreateBranchGraph(const AnfNodePtr &node, std::string name, BranchType branch_type) {
auto graph = FunctionalizeControlOpPass::NewFuncGraph(name, mindspore::lite::converter::FmkType_TF);
if (graph == nullptr) {
MS_LOG(ERROR) << "new graph Partial Node return nullptr";
return nullptr;
}
graph->set_manager(fg_->manager());
auto status = BranchSubGraphAddNodes(graph, node, branch_type);
if (status != RET_OK) {
return nullptr;
}
if (!CheckPrimitiveType(node, prim::kPrimSwitch)) { // graph is not empty
auto return_prim_ptr = std::make_shared<ops::Return>();
if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
return nullptr;
}
auto value_node = NewValueNode(return_prim_ptr);
std::vector<AnfNodePtr> op_inputs{value_node, node}; // If subgraph only has one output tensor
auto return_cnode = graph->NewCNode(op_inputs);
return_cnode->set_fullname_with_scope(name + "-return");
return_cnode->set_func_graph(graph);
graph->set_return(return_cnode);
graph->output()->cast<CNodePtr>()->set_fullname_with_scope(name + "_output_0_cnode");
}
return graph;
}
CNodePtr FunctionalizeCond::CreateNewIf(const FuncGraphPtr &else_branch, const FuncGraphPtr &then_branch) {
MS_ASSERT(else_branch != nullptr);
MS_ASSERT(then_branch != nullptr);
auto if_primc = std::make_shared<mindspore::lite::If>();
if (if_primc == nullptr) {
MS_LOG(ERROR) << "new if_primitive failed";
return nullptr;
}
auto if_value_node = NewValueNode(if_primc);
if (if_value_node == nullptr) {
return nullptr;
}
auto then_value_node = NewValueNode(then_branch);
auto else_value_node = NewValueNode(else_branch);
std::vector<AnfNodePtr> if_op_inputs = {if_value_node, then_value_node, else_value_node, pred_node_};
std::copy(input_nodes_.begin(), input_nodes_.end(), std::back_inserter(if_op_inputs));
return fg_->NewCNode(if_op_inputs);
}
STATUS FunctionalizeCond::VerifyPredictNode() {
if (pred_nodes_.empty()) {
return RET_ERROR;
}
for (size_t i = 1; i < pred_nodes_.size(); ++i) {
if (pred_nodes_[i] != pred_nodes_[0]) {
return RET_ERROR;
}
}
if (!utils::isa<CNodePtr>(pred_nodes_[0])) {
return RET_ERROR;
}
pred_node_ = utils::cast<CNodePtr>(pred_nodes_[0]);
return RET_OK;
}
STATUS FunctionalizeCond::Process() {
if (fg_ == nullptr || merge_node_ == nullptr || merge_node_->inputs().size() != 3) {
MS_LOG(ERROR) << "fg or merge is not correct";
return RET_ERROR;
}
auto else_branch_name = merge_node_->fullname_with_scope() + "-partial-if-else";
auto then_branch_name = merge_node_->fullname_with_scope() + "-partial-then-else";
auto else_branch = CreateBranchGraph(merge_node_->input(1), else_branch_name, kElseBranch);
if (else_branch == nullptr) {
MS_LOG(ERROR) << "create else branch failed";
return RET_ERROR;
}
auto then_branch = CreateBranchGraph(merge_node_->input(2), then_branch_name, kThenBranch);
if (then_branch == nullptr) {
MS_LOG(ERROR) << "create then branch failed";
return RET_ERROR;
}
auto status = IdentifySubgraphInput(else_branch, else_branch_name);
if (status != RET_OK) {
return status;
}
status = IdentifySubgraphInput(then_branch, then_branch_name);
if (status != RET_OK) {
return status;
}
status = VerifyPredictNode();
if (status != RET_OK) {
return status;
}
auto if_node = CreateNewIf(else_branch, then_branch);
if (if_node == nullptr) {
MS_LOG(ERROR) << "create if node error";
return RET_ERROR;
}
if_node->set_abstract(merge_node_->abstract()->Clone());
auto manager = fg_->manager();
auto node_users = manager->node_users()[merge_node_];
for (auto &node_user : node_users) {
manager->SetEdge(node_user.first, node_user.second, if_node);
}
return RET_OK;
}
} // namespace mindspore::opt

View File

@ -0,0 +1,58 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_FUNCTIONALIZE_COND_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_FUNCTIONALIZE_COND_H_
#include <string>
#include <set>
#include <vector>
#include <map>
#include "backend/optimizer/common/pass.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/optimizer/graph/functionalize_control_op_pass.h"
using mindspore::lite::converter::FmkType;
namespace mindspore::opt {
typedef enum { kThenBranch = 0, kElseBranch = 1 } BranchType;
// Functionalize all the switch-merge nodes of a loop-free graph into single switch node.
// Precondition: While loops must have been functionalized.
class FunctionalizeCond {
public:
FunctionalizeCond(FuncGraphPtr fg, CNodePtr merge_node) : fg_(fg), merge_node_(merge_node) {}
STATUS Process();
private:
STATUS GetSwitchBranchType(const CNodePtr &switch_cnode, BranchType *branch_type);
STATUS BranchSubGraphAddNodes(const FuncGraphPtr &graph, const AnfNodePtr &root_node, BranchType branch_type);
FuncGraphPtr CreateBranchGraph(const AnfNodePtr &node, std::string name, BranchType branch_type);
int PosInInputNodes(const CNodePtr &node);
STATUS IdentifySubgraphInput(const FuncGraphPtr &graph, std::string graph_name);
CNodePtr CreateNewIf(const FuncGraphPtr &else_branch, const FuncGraphPtr &then_branch);
STATUS VerifyPredictNode();
FuncGraphPtr fg_ = nullptr;
CNodePtr merge_node_ = nullptr;
CNodePtr pred_node_ = nullptr;
std::vector<CNodePtr> input_nodes_{};
std::vector<AnfNodePtr> pred_nodes_{};
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_FUNCTIONALIZE_COND_H_

View File

@ -18,6 +18,7 @@
#include <algorithm>
#include <deque>
#include "tools/optimizer/graph/functionalize_while.h"
#include "tools/optimizer/graph/functionalize_cond.h"
#include "include/errorcode.h"
namespace mindspore::opt {
@ -100,6 +101,25 @@ STATUS FunctionalizeControlOpPass::BuildWhileSubgraph(const FuncGraphPtr &func_g
return ret;
}
STATUS FunctionalizeControlOpPass::BuildIfSubgraph(const FuncGraphPtr &func_graph) {
int ret = RET_OK;
auto nodes = func_graph->nodes();
for (auto &node : nodes) {
if (!IsMerge(node)) {
continue;
}
auto cnode = utils::cast<CNodePtr>(node);
FunctionalizeCond fc(func_graph, cnode);
ret = fc.Process();
if (ret != RET_OK) {
MS_LOG(ERROR) << "run functionalize cond failed, ret: " << ret;
return ret;
}
}
return ret;
}
bool FunctionalizeControlOpPass::Run(const FuncGraphPtr &func_graph) {
// use name to find the frame
InitNodeClusters(func_graph);
@ -107,6 +127,10 @@ bool FunctionalizeControlOpPass::Run(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "build while subgraph failed.";
return false;
}
if (BuildIfSubgraph(func_graph) != RET_OK) {
MS_LOG(ERROR) << "build while subgraph failed.";
return false;
}
return true;
}

View File

@ -23,10 +23,7 @@
#include <memory>
#include "backend/optimizer/common/pass.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/ops/enter.h"
#include "tools/converter/ops/exit.h"
#include "tools/converter/ops/loop_cond.h"
#include "tools/converter/ops/next_iteration.h"
#include "tools/converter/ops/ops_def.h"
#include "tools/optimizer/common/gllo_utils.h"
using mindspore::lite::converter::FmkType;
@ -70,6 +67,7 @@ class FunctionalizeControlOpPass : public Pass {
protected:
STATUS BuildWhileSubgraph(const FuncGraphPtr &func_graph);
STATUS BuildIfSubgraph(const FuncGraphPtr &func_graph);
std::vector<std::pair<std::string, std::vector<AnfNodePtr>>> node_clusters_{};
std::vector<CNodePtr> loop_cond_nodes_{};
};

View File

@ -0,0 +1,67 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/unused_node_remove_pass.h"
#include <deque>
#include <unordered_set>
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::opt {
STATUS UnusedNodeRemovePass::ProcessGraph(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto return_node = func_graph->get_return();
if (return_node == nullptr) {
return RET_OK;
}
std::unordered_set<AnfNodePtr> vis;
std::deque<AnfNodePtr> q;
q.push_back(return_node);
while (!q.empty()) {
auto node = q.front();
vis.insert(node);
q.pop_front();
if (utils::isa<CNodePtr>(node)) {
auto cnode = utils::cast<CNodePtr>(node);
for (auto &input : cnode->inputs()) {
if (vis.find(input) == vis.end()) {
q.push_back(input);
}
}
}
if (utils::isa<FuncGraphPtr>(node)) {
auto sub_graph = utils::cast<FuncGraphPtr>(node);
auto status = ProcessGraph(sub_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "process sub graph failed";
return RET_ERROR;
}
}
}
auto nodes = func_graph->nodes();
for (auto &node : nodes) {
if (vis.find(node) == vis.end()) {
func_graph->DropNode(node);
}
}
return RET_OK;
}
bool UnusedNodeRemovePass::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto status = ProcessGraph(func_graph);
return status == RET_OK;
}
} // namespace mindspore::opt

View File

@ -0,0 +1,37 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNUSED_NODE_REMOVE_PASS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNUSED_NODE_REMOVE_PASS_H_
#include <string>
#include "backend/optimizer/common/pass.h"
#include "tools/converter/converter_flags.h"
#include "mindspore/lite/include/errorcode.h"
using mindspore::lite::STATUS;
namespace mindspore::opt {
class UnusedNodeRemovePass : public Pass {
public:
UnusedNodeRemovePass() : Pass("remove_unused_node_pass") {}
~UnusedNodeRemovePass() override = default;
bool Run(const FuncGraphPtr &graph) override;
private:
STATUS ProcessGraph(const FuncGraphPtr &func_graph);
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNUSED_NODE_REMOVE_PASS_H_