forked from mindspore-Ecosystem/mindspore
!13309 add functionalize_cond & tf_bidirection_gru_cf_fusion
From: @wangzhe128 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
bc7db669cb
|
@ -240,13 +240,15 @@ if(ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/group_depthwise_op_convert_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/tflite_inputs_adjust_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/update_conv2d_param_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/unused_node_remove_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/redundant_op_remove_pass.cc
|
||||
|
@ -258,6 +260,7 @@ if(ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/graph/if_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/functionalize_control_op_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/functionalize_while.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/functionalize_cond.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/inputs_adjust_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/primitive_adjust_pass.cc
|
||||
)
|
||||
|
|
|
@ -61,3 +61,4 @@ ml_noya_tts_melgan.pb 1;16,16,80
|
|||
ml_video_edit_oneclick_adaptis.pb 3
|
||||
# Q_hand_0812.pb is not suitable for float16. Out of float16 range.
|
||||
Q_hand_0812.pb
|
||||
tacotron_encoder_stf.pb 5;1:1,62:1,62:1,62:1,62
|
||||
|
|
|
@ -50,13 +50,15 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/fusion/conv_conv_fusion.cc
|
||||
../optimizer/fusion/tflite_lstm_cell_fusion.cc
|
||||
../optimizer/fusion/tf_lstm_cell_fusion.cc
|
||||
../optimizer/fusion/bidirection_tf_gru_cell_fusion.cc
|
||||
../optimizer/fusion/tf_bidirection_gru_fusion.cc
|
||||
../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
|
||||
../optimizer/graph/weight_format_transform_pass.cc
|
||||
../optimizer/graph/weight_format_hardcode_pass.cc
|
||||
../optimizer/graph/clip_convert_activation_pass.cc
|
||||
../optimizer/graph/group_depthwise_op_convert_pass.cc
|
||||
../optimizer/graph/tflite_inputs_adjust_pass.cc
|
||||
../optimizer/graph/update_conv2d_param_pass.cc
|
||||
../optimizer/graph/unused_node_remove_pass.cc
|
||||
../optimizer/graph/unused_cast_node_remove_pass.cc
|
||||
../optimizer/graph/unused_transpose_node_remove_pass.cc
|
||||
../optimizer/graph/redundant_op_remove_pass.cc
|
||||
|
@ -68,6 +70,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/graph/if_pass.cc
|
||||
../optimizer/graph/functionalize_control_op_pass.cc
|
||||
../optimizer/graph/functionalize_while.cc
|
||||
../optimizer/graph/functionalize_cond.cc
|
||||
../optimizer/graph/inputs_adjust_pass.cc
|
||||
../optimizer/graph/primitive_adjust_pass.cc
|
||||
)
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "mindspore/core/ir/primitive.h"
|
||||
#include "tools/optimizer/fusion/conv_biasadd_fusion.h"
|
||||
#include "tools/optimizer/fusion/conv_activation_fusion.h"
|
||||
#include "tools/optimizer/fusion/conv_tuple_activation_fusion.h"
|
||||
|
@ -31,7 +33,8 @@
|
|||
#include "tools/optimizer/fusion/conv_conv_fusion.h"
|
||||
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
|
||||
#include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
|
||||
#include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h"
|
||||
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
|
||||
#include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h"
|
||||
#include "tools/optimizer/graph/primitive_adjust_pass.h"
|
||||
#include "tools/optimizer/graph/mindir_adjust_pass.h"
|
||||
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
|
||||
|
@ -42,6 +45,7 @@
|
|||
#include "tools/optimizer/graph/tflite_inputs_adjust_pass.h"
|
||||
#include "tools/optimizer/graph/onnx_inputs_adjust_pass.h"
|
||||
#include "tools/optimizer/graph/update_conv2d_param_pass.h"
|
||||
#include "tools/optimizer/graph/unused_node_remove_pass.h"
|
||||
#include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
|
||||
#include "tools/optimizer/graph/unused_transpose_node_remove_pass.h"
|
||||
#include "tools/optimizer/graph/infershape_pass.h"
|
||||
|
@ -81,7 +85,7 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti
|
|||
fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TfliteLstmCellFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::BiDirectionTfGruCellFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TfBidirectionGruFusion>());
|
||||
}
|
||||
if (config->fmk == lite::converter::FmkType_MS) {
|
||||
auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();
|
||||
|
@ -225,6 +229,23 @@ int AnfTransform::RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int AnfTransform::RunPrecedingPass(const FuncGraphPtr &old_graph, const converter::Flags &config) {
|
||||
MS_ASSERT(old_graph != nullptr);
|
||||
auto asylic_optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto asylic_pm = std::make_shared<opt::PassManager>("asylic pass manager", false);
|
||||
// fuse tf1.x bidirection_gru into GRU, must be placed here because graph is cyclic
|
||||
asylic_pm->AddPass(std::make_shared<opt::TfBidirectionGruCfFusion>());
|
||||
// remove remaining cyclic nodes
|
||||
asylic_pm->AddPass(std::make_shared<opt::UnusedNodeRemovePass>());
|
||||
asylic_optimizer->AddPassManager(asylic_pm);
|
||||
if (!asylic_optimizer->Optimize(old_graph)) {
|
||||
MS_LOG(ERROR) << "gru cf fusion pass failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config,
|
||||
const FuncGraphPtr &new_graph) {
|
||||
// quant
|
||||
|
@ -266,7 +287,13 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
|
|||
return old_graph;
|
||||
}
|
||||
|
||||
auto status = RunAdjustPass(old_graph, config);
|
||||
auto status = RunPrecedingPass(old_graph, *config);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run Preceding pass failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = RunAdjustPass(old_graph, config);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run Adjust pass failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -50,6 +50,8 @@ class AnfTransform {
|
|||
|
||||
static int AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);
|
||||
|
||||
static int RunPrecedingPass(const FuncGraphPtr &old_graph, const converter::Flags &config);
|
||||
|
||||
static int RunAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
||||
static int RunMindirAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
|
|
@ -1,37 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_
|
||||
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
|
||||
using mindspore::ops::PrimitiveC;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
constexpr auto kNameIf = "If";
|
||||
class If : public PrimitiveC {
|
||||
public:
|
||||
If() : PrimitiveC(kNameIf) {}
|
||||
~If() = default;
|
||||
MS_DECLARE_PARENT(If, PrimitiveC);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_
|
|
@ -1,39 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_LOOP_COND_H_
|
||||
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_LOOP_COND_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ops/primitive_c.h"
|
||||
|
||||
using mindspore::ops::PrimitiveC;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
constexpr auto kNameLoopCond = "LoopCond";
|
||||
class LoopCond : public PrimitiveC {
|
||||
public:
|
||||
LoopCond() : PrimitiveC(kNameLoopCond) {}
|
||||
~LoopCond() = default;
|
||||
MS_DECLARE_PARENT(LoopCond, PrimitiveC);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_LOOP_COND_H_
|
|
@ -17,16 +17,31 @@
|
|||
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_
|
||||
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "ops/primitive_c.h"
|
||||
using mindspore::ops::PrimitiveC;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#define ADD_CONVERTER_ONLY_OP(name) \
|
||||
constexpr auto kName##name = #name; \
|
||||
class name : public PrimitiveC { \
|
||||
public: \
|
||||
name() : PrimitiveC(kName##name) {} \
|
||||
~name() = default; \
|
||||
MS_DECLARE_PARENT(name, PrimitiveC); \
|
||||
};
|
||||
|
||||
enum ConverterPrimitiveType {
|
||||
ConverterPrimitiveType_Enter = schema::PrimitiveType_MAX + 1,
|
||||
ConverterPrimitiveType_LoopCond,
|
||||
ConverterPrimitiveType_NextIteration,
|
||||
ConverterPrimitiveType_Exit,
|
||||
};
|
||||
ADD_CONVERTER_ONLY_OP(Enter);
|
||||
ADD_CONVERTER_ONLY_OP(Exit);
|
||||
ADD_CONVERTER_ONLY_OP(If);
|
||||
ADD_CONVERTER_ONLY_OP(LoopCond);
|
||||
ADD_CONVERTER_ONLY_OP(NextIteration);
|
||||
ADD_CONVERTER_ONLY_OP(TensorArrayGatherV3);
|
||||
ADD_CONVERTER_ONLY_OP(TensorArrayReadV3);
|
||||
ADD_CONVERTER_ONLY_OP(TensorArrayScatterV3);
|
||||
ADD_CONVERTER_ONLY_OP(TensorArraySizeV3);
|
||||
ADD_CONVERTER_ONLY_OP(TensorArrayV3);
|
||||
ADD_CONVERTER_ONLY_OP(TensorArrayWriteV3);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "tools/converter/parser/onnx/onnx_if_parser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
#include "tools/converter/ops/if.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_enter_parser.h"
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "tools/converter/ops/enter.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "tools/converter/ops/exit.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "tools/converter/ops/if.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "tools/converter/ops/loop_cond.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
|
|
@ -28,7 +28,7 @@ ops::PrimitiveC *TFMergeParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
std::vector<std::string> *inputs, int *output_size) {
|
||||
auto prim = std::make_unique<ops::Merge>();
|
||||
|
||||
*output_size = tf_op.input_size();
|
||||
*output_size = 1;
|
||||
for (int i = 0; i < tf_op.input_size(); i++) {
|
||||
inputs->emplace_back(tf_op.input(i));
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "tools/converter/ops/next_iteration.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
|
|
@ -28,7 +28,7 @@ ops::PrimitiveC *TFSwitchParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
std::vector<std::string> *inputs, int *output_size) {
|
||||
auto prim = std::make_unique<ops::Switch>();
|
||||
|
||||
*output_size = tf_op.input_size();
|
||||
*output_size = 2;
|
||||
for (int i = 0; i < tf_op.input_size(); i++) {
|
||||
inputs->emplace_back(tf_op.input(i));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_tensor_array_gather_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *TFTensorArrayGatherParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(DEBUG) << "TF TensorArrayGatherParser";
|
||||
if (inputs == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "inputs or output_size is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto prim = std::make_unique<TensorArrayGatherV3>();
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "prim is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
*output_size = 1;
|
||||
for (int i = 0; i < tf_op.input_size(); i++) {
|
||||
inputs->emplace_back(tf_op.input(i));
|
||||
}
|
||||
return prim.release();
|
||||
}
|
||||
TFNodeRegistrar g_tfTensorArrayGatherParser("TensorArrayGatherV3", new TFTensorArrayGatherParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_GATHER_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_GATHER_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFTensorArrayGatherParser : public TFNodeParser {
|
||||
public:
|
||||
TFTensorArrayGatherParser() = default;
|
||||
~TFTensorArrayGatherParser() override = default;
|
||||
|
||||
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_GATHER_PARSER_H_
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_tensor_array_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *TFTensorArrayParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(DEBUG) << "TF TensorArrayParser";
|
||||
if (inputs == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "inputs or output_size is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto prim = std::make_unique<TensorArrayV3>();
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "prim is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
*output_size = 2;
|
||||
for (int i = 0; i < tf_op.input_size(); i++) {
|
||||
inputs->emplace_back(tf_op.input(i));
|
||||
}
|
||||
|
||||
return prim.release();
|
||||
}
|
||||
TFNodeRegistrar g_tfTensorArrayParser("TensorArrayV3", new TFTensorArrayParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -13,27 +13,25 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ENTER_H_
|
||||
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ENTER_H_
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ops/primitive_c.h"
|
||||
|
||||
using mindspore::ops::PrimitiveC;
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
constexpr auto kNameEnter = "Enter";
|
||||
class Enter : public PrimitiveC {
|
||||
class TFTensorArrayParser : public TFNodeParser {
|
||||
public:
|
||||
Enter() : PrimitiveC(kNameEnter) {}
|
||||
~Enter() = default;
|
||||
MS_DECLARE_PARENT(Enter, PrimitiveC);
|
||||
TFTensorArrayParser() = default;
|
||||
~TFTensorArrayParser() override = default;
|
||||
|
||||
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ENTER_H_
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_PARSER_H_
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_tensor_array_read_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *TFTensorArrayReadParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(DEBUG) << "TF TensorArrayReadParser";
|
||||
if (inputs == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "inputs or output_size is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto prim = std::make_unique<TensorArrayReadV3>();
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "prim is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
for (int i = 0; i < tf_op.input_size(); i++) {
|
||||
inputs->emplace_back(tf_op.input(i));
|
||||
}
|
||||
return prim.release();
|
||||
}
|
||||
TFNodeRegistrar g_tfTensorArrayReadParser("TensorArrayReadV3", new TFTensorArrayReadParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -13,27 +13,25 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_EXIT_H_
|
||||
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_EXIT_H_
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_READ_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_READ_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ops/primitive_c.h"
|
||||
|
||||
using mindspore::ops::PrimitiveC;
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
constexpr auto kNameExit = "Exit";
|
||||
class Exit : public PrimitiveC {
|
||||
class TFTensorArrayReadParser : public TFNodeParser {
|
||||
public:
|
||||
Exit() : PrimitiveC(kNameExit) {}
|
||||
~Exit() = default;
|
||||
MS_DECLARE_PARENT(Exit, PrimitiveC);
|
||||
TFTensorArrayReadParser() = default;
|
||||
~TFTensorArrayReadParser() override = default;
|
||||
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_EXIT_H_
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_READ_PARSER_H_
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_tensor_array_scatter_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *TFTensorArrayScatterParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(DEBUG) << "TF TensorArrayScatterParser";
|
||||
if (inputs == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "inputs or output_size is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto prim = std::make_unique<TensorArrayScatterV3>();
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "prim is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
for (int i = 0; i < tf_op.input_size(); i++) {
|
||||
inputs->emplace_back(tf_op.input(i));
|
||||
}
|
||||
return prim.release();
|
||||
}
|
||||
TFNodeRegistrar g_tfTensorArrayScatterParser("TensorArrayScatterV3", new TFTensorArrayScatterParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SCATTER_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SCATTER_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFTensorArrayScatterParser : public TFNodeParser {
|
||||
public:
|
||||
TFTensorArrayScatterParser() = default;
|
||||
~TFTensorArrayScatterParser() override = default;
|
||||
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SCATTER_PARSER_H_
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_tensor_array_size_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *TFTensorArraySizeParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(DEBUG) << "TF TensorArraySizeParser";
|
||||
if (inputs == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "inputs or output_size is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto prim = std::make_unique<TensorArraySizeV3>();
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "prim is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
*output_size = 1;
|
||||
for (int i = 0; i < tf_op.input_size(); i++) {
|
||||
inputs->emplace_back(tf_op.input(i));
|
||||
}
|
||||
|
||||
return prim.release();
|
||||
}
|
||||
TFNodeRegistrar g_tfTensorArraySizeParser("TensorArraySizeV3", new TFTensorArraySizeParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -13,27 +13,25 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_NEXT_ITERATION_H_
|
||||
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_NEXT_ITERATION_H_
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SIZE_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SIZE_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <cmath>
|
||||
#include "ops/primitive_c.h"
|
||||
|
||||
using mindspore::ops::PrimitiveC;
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
constexpr auto kNameNextIteration = "NextIteration";
|
||||
class NextIteration : public PrimitiveC {
|
||||
class TFTensorArraySizeParser : public TFNodeParser {
|
||||
public:
|
||||
NextIteration() : PrimitiveC(kNameNextIteration) {}
|
||||
~NextIteration() = default;
|
||||
MS_DECLARE_PARENT(NextIteration, PrimitiveC);
|
||||
TFTensorArraySizeParser() = default;
|
||||
~TFTensorArraySizeParser() override = default;
|
||||
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_NEXT_ITERATION_H_
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_SIZE_PARSER_H_
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_tensor_array_write_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *TFTensorArrayWriteParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(DEBUG) << "TF TensorArrayWriteParser";
|
||||
if (inputs == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "inputs or output_size is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto prim = std::make_unique<TensorArrayWriteV3>();
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "prim is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
for (int i = 0; i < tf_op.input_size(); i++) {
|
||||
inputs->emplace_back(tf_op.input(i));
|
||||
}
|
||||
|
||||
return prim.release();
|
||||
}
|
||||
TFNodeRegistrar g_tfTensorArrayWriteParser("TensorArrayWriteV3", new TFTensorArrayWriteParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_WRITE_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_WRITE_PARSER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFTensorArrayWriteParser : public TFNodeParser {
|
||||
public:
|
||||
TFTensorArrayWriteParser() = default;
|
||||
~TFTensorArrayWriteParser() override = default;
|
||||
|
||||
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_ARRAY_WRITE_PARSER_H_
|
|
@ -0,0 +1,214 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <functional>
|
||||
#include "src/common/utils.h"
|
||||
#include "utils/utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kNumFwVars = 4;
|
||||
constexpr size_t kNumBwVars = 4;
|
||||
const auto &p1 = std::placeholders::_1;
|
||||
BaseRef GetPrim(const PrimitivePtr &prim) { return std::make_shared<CondVar>(std::bind(IsOpType, p1, prim)); }
|
||||
|
||||
BaseRef GetPrim(const std::string &prim_name) {
|
||||
auto prim = std::make_shared<Primitive>(prim_name);
|
||||
return GetPrim(prim);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TfBidirectionGruCfFusion::TfBidirectionGruCfFusion(const std::string &name, bool multi_graph)
|
||||
: TfBidirectionGruFusion(kNumFwVars, kNumBwVars, name, multi_graph) {
|
||||
/*
|
||||
* vars for fw/bw input
|
||||
* fw:
|
||||
* 0:kernel_gate 1:bias_gate 2:cand_kernel 3:cand_bias
|
||||
* bw:
|
||||
* 0:kernel_gate 1:bias_gate 2:cand_kernel 3:cand_bias
|
||||
*/
|
||||
}
|
||||
|
||||
BaseRef TfBidirectionGruCfFusion::DefineGruCellPattern(const BaseRef &in_ta_read, const BaseRef &switch3_true,
|
||||
const std::vector<VarPtr> &vars) const {
|
||||
auto concat = VectorRef({GetPrim(prim::kPrimConcat), in_ta_read, switch3_true});
|
||||
auto matmul_enter = VectorRef({GetPrim(lite::kNameEnter), vars[0]}); // gate_kernel
|
||||
auto matmul = VectorRef({GetPrim(prim::kPrimMatMul), concat, matmul_enter});
|
||||
auto bias_enter = VectorRef({GetPrim(lite::kNameEnter), vars[1]}); // cand_bias
|
||||
auto bias = VectorRef({GetPrim(prim::kPrimBiasAdd), matmul, bias_enter});
|
||||
auto sigmoid = VectorRef({GetPrim(prim::kPrimActivation), bias});
|
||||
auto split = VectorRef({GetPrim(prim::kPrimSplit), sigmoid});
|
||||
auto rt = VectorRef({GetPrim(prim::kPrimTupleGetItem), split, std::make_shared<Var>()});
|
||||
auto zt = VectorRef({GetPrim(prim::kPrimTupleGetItem), split, std::make_shared<Var>()});
|
||||
auto mul = VectorRef({GetPrim(prim::kPrimMulFusion), rt, switch3_true});
|
||||
auto concat1 = VectorRef({GetPrim(prim::kPrimConcat), in_ta_read, mul});
|
||||
auto matmul1_enter = VectorRef({GetPrim(lite::kNameEnter), vars[2]}); // cand_kernel
|
||||
auto matmul1 = VectorRef({GetPrim(prim::kPrimMatMul), concat1, matmul1_enter});
|
||||
auto bias1_enter = VectorRef({GetPrim(lite::kNameEnter), vars[3]}); // cand_bias
|
||||
auto bias1 = VectorRef({GetPrim(prim::kPrimBiasAdd), matmul1, bias1_enter});
|
||||
auto tanh = VectorRef({GetPrim(prim::kPrimActivation), bias1});
|
||||
auto sub = VectorRef({GetPrim(prim::kPrimSubFusion), std::make_shared<CondVar>(IsParameterNode), zt});
|
||||
auto mul2 = VectorRef({GetPrim(prim::kPrimMulFusion), sub, tanh});
|
||||
auto mul1 = VectorRef({GetPrim(prim::kPrimMulFusion), zt, switch3_true});
|
||||
auto add = VectorRef({GetPrim(prim::kPrimAddFusion), mul1, mul2});
|
||||
return add;
|
||||
}
|
||||
|
||||
const BaseRef TfBidirectionGruCfFusion::DefineBidirectionRnnPattern(const BaseRef &input,
|
||||
const std::vector<VarPtr> &vars,
|
||||
const VarPtr &init_state) const {
|
||||
// in order to match cyclic graph, some node in cycle is represented by SeqVar
|
||||
auto fw_shape1 = VectorRef({GetPrim(prim::kPrimShape), input});
|
||||
auto strided_slice = VectorRef({GetPrim(prim::kPrimStridedSlice), fw_shape1, std::make_shared<SeqVar>()});
|
||||
auto fw_max = VectorRef({GetPrim(prim::kPrimReduceFusion), input_length_, std::make_shared<Var>()});
|
||||
auto fw_maximum = VectorRef({GetPrim(prim::kPrimMaximum), std::make_shared<CondVar>(IsParameterNode), fw_max});
|
||||
auto fw_minimum = VectorRef({GetPrim(prim::kPrimMinimum), strided_slice, fw_maximum});
|
||||
auto fw_less1_enter = VectorRef({GetPrim(lite::kNameEnter), fw_minimum});
|
||||
// SeqVar:counter_merge1
|
||||
auto fw_less1 = VectorRef({GetPrim(prim::kPrimLess), std::make_shared<SeqVar>(), fw_less1_enter});
|
||||
|
||||
// SeqVar:fw_merge,loop_cond
|
||||
auto fw_switch = VectorRef({GetPrim(prim::kPrimSwitch), std::make_shared<SeqVar>()});
|
||||
auto fw_switch_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), fw_switch, std::make_shared<Var>()}); // identity
|
||||
auto fw_add = VectorRef({GetPrim(prim::kPrimAddFusion), fw_switch_true, std::make_shared<CondVar>(IsParameterNode)});
|
||||
auto fw_next_iter = VectorRef({GetPrim(lite::kNameNextIteration), fw_add});
|
||||
auto fw_merge_enter = VectorRef({GetPrim(lite::kNameEnter), std::make_shared<CondVar>(IsParameterNode)});
|
||||
auto fw_merge = VectorRef({GetPrim(prim::kPrimMerge), fw_merge_enter, fw_next_iter});
|
||||
auto fw_less_enter = VectorRef({GetPrim(lite::kNameEnter), strided_slice});
|
||||
auto fw_less = VectorRef({GetPrim(prim::kPrimLess), fw_merge, fw_less_enter});
|
||||
|
||||
auto fw_logical_and = VectorRef({GetPrim(prim::kPrimLogicalAnd), fw_less, fw_less1});
|
||||
// SeqVar:fw_logical_and
|
||||
auto loop_cond = VectorRef({GetPrim(lite::kNameLoopCond), fw_logical_and});
|
||||
|
||||
auto fw_shape = VectorRef({GetPrim(prim::kPrimShape), input});
|
||||
auto fw_unstack_strided_slice = VectorRef({GetPrim(prim::kPrimStridedSlice), fw_shape, std::make_shared<SeqVar>()});
|
||||
auto fw_unstack_range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<CondVar>(IsParameterNode),
|
||||
fw_unstack_strided_slice, std::make_shared<CondVar>(IsParameterNode)});
|
||||
|
||||
// SeqVar:switch1_true
|
||||
auto counter_add =
|
||||
VectorRef({GetPrim(prim::kPrimAddFusion), std::make_shared<SeqVar>(), std::make_shared<CondVar>(IsParameterNode)});
|
||||
auto counter_zero = VectorRef({GetPrim(lite::kNameEnter), std::make_shared<CondVar>(IsParameterNode)});
|
||||
auto counter_next_iter = VectorRef({GetPrim(lite::kNameNextIteration), counter_add});
|
||||
auto counter_merge1 = VectorRef({GetPrim(prim::kPrimMerge), counter_zero, counter_next_iter});
|
||||
auto counter_switch1 = VectorRef({GetPrim(prim::kPrimSwitch), counter_merge1, loop_cond});
|
||||
auto switch1_true =
|
||||
VectorRef({GetPrim(prim::kPrimTupleGetItem), counter_switch1, std::make_shared<Var>()}); // identity1
|
||||
|
||||
auto in_ta = VectorRef({GetPrim(lite::kNameTensorArrayV3), strided_slice});
|
||||
auto in_ta_handle = VectorRef({GetPrim(prim::kPrimTupleGetItem), in_ta, std::make_shared<Var>()});
|
||||
auto in_ta_flow = VectorRef({GetPrim(prim::kPrimTupleGetItem), in_ta, std::make_shared<Var>()});
|
||||
auto fw_unstack_ta_scatter =
|
||||
VectorRef({GetPrim(lite::kNameTensorArrayScatterV3), in_ta_handle, fw_unstack_range, input, in_ta_flow});
|
||||
auto in_ta_enter1 = VectorRef({GetPrim(lite::kNameEnter), fw_unstack_ta_scatter});
|
||||
auto in_ta_enter = VectorRef({GetPrim(lite::kNameEnter), in_ta_handle});
|
||||
auto in_ta_read = VectorRef({GetPrim(lite::kNameTensorArrayReadV3), in_ta_enter, switch1_true, in_ta_enter1});
|
||||
|
||||
auto greater_equal_enter = VectorRef({GetPrim(lite::kNameEnter), input_length_});
|
||||
auto greater_equal = VectorRef({GetPrim(prim::kPrimGreaterEqual), switch1_true, greater_equal_enter});
|
||||
auto select1 = VectorRef({GetPrim(prim::kPrimSelect), greater_equal, std::make_shared<SeqVar>()}); // select h
|
||||
|
||||
auto next_iteration3 = VectorRef({GetPrim(lite::kNameNextIteration), select1});
|
||||
auto enter3 = VectorRef({GetPrim(lite::kNameEnter), init_state});
|
||||
auto merge3 = VectorRef({GetPrim(prim::kPrimMerge), enter3, next_iteration3});
|
||||
auto switch3 = VectorRef({GetPrim(prim::kPrimSwitch), merge3, loop_cond});
|
||||
auto switch3_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), switch3, std::make_shared<Var>()}); // identity3
|
||||
|
||||
auto rnn_cell_out = DefineGruCellPattern(in_ta_read, switch3_true, vars);
|
||||
|
||||
auto out_ta = VectorRef({GetPrim(lite::kNameTensorArrayV3), strided_slice});
|
||||
auto out_ta_handle = VectorRef({GetPrim(prim::kPrimTupleGetItem), out_ta, std::make_shared<Var>()});
|
||||
auto out_ta_flow = VectorRef({GetPrim(prim::kPrimTupleGetItem), out_ta, std::make_shared<Var>()});
|
||||
auto out_ta_enter = VectorRef({GetPrim(lite::kNameEnter), out_ta_handle});
|
||||
|
||||
auto switch2_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), std::make_shared<SeqVar>()}); // cycle
|
||||
|
||||
auto concat1 = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared<SeqVar>()});
|
||||
auto zeros1 = VectorRef({GetPrim(prim::kPrimFill), std::make_shared<CondVar>(IsParameterNode), concat1});
|
||||
auto select_enter = VectorRef({GetPrim(lite::kNameEnter), zeros1});
|
||||
auto select = VectorRef({GetPrim(prim::kPrimSelect), greater_equal, select_enter, rnn_cell_out}); // select x
|
||||
auto ta_write = VectorRef({GetPrim(lite::kNameTensorArrayWriteV3), out_ta_enter, switch1_true, select, switch2_true});
|
||||
|
||||
auto enter2 = VectorRef({GetPrim(lite::kNameEnter), out_ta_flow});
|
||||
auto next_iter2 = VectorRef({GetPrim(lite::kNameNextIteration), ta_write});
|
||||
auto merge2 = VectorRef({GetPrim(prim::kPrimMerge), enter2, next_iter2});
|
||||
auto switch2 = VectorRef({GetPrim(prim::kPrimSwitch), merge2, loop_cond});
|
||||
auto switch2_false = VectorRef({GetPrim(prim::kPrimTupleGetItem), switch2, std::make_shared<Var>()});
|
||||
|
||||
auto exit2 = VectorRef({GetPrim(lite::kNameExit), switch2_false});
|
||||
auto ta_size = VectorRef({GetPrim(lite::kNameTensorArraySizeV3), out_ta_handle, exit2});
|
||||
auto range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<Var>(), ta_size, std::make_shared<Var>()});
|
||||
auto tensor_array_gather = VectorRef({GetPrim(lite::kNameTensorArrayGatherV3), out_ta_handle, range, exit2});
|
||||
auto range1 = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<SeqVar>()});
|
||||
auto concat2 = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared<CondVar>(IsParameterNode), range1});
|
||||
auto fw_out_trans = VectorRef({GetPrim(prim::kPrimTranspose), tensor_array_gather, concat2});
|
||||
return fw_out_trans;
|
||||
}
|
||||
|
||||
const BaseRef TfBidirectionGruCfFusion::DefinePattern() const {
|
||||
const auto fw_out_trans = DefineBidirectionRnnPattern(transpose_input_, fw_vars_, fw_init_state_);
|
||||
|
||||
auto bw_reverse_in = VectorRef({GetPrim(prim::kPrimReverseSequence), input_, input_length_});
|
||||
auto bw_range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<SeqVar>()});
|
||||
auto bw_concat = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared<CondVar>(IsParameterNode), bw_range});
|
||||
auto bw_transpose = VectorRef({GetPrim(prim::kPrimTranspose), bw_reverse_in, bw_concat});
|
||||
auto bw_out_trans = DefineBidirectionRnnPattern(bw_transpose, bw_vars_, bw_init_state_);
|
||||
auto bw_reverse_out = VectorRef({GetPrim(prim::kPrimReverseSequence), bw_out_trans, input_length_});
|
||||
auto concat = VectorRef({GetPrim(prim::kPrimConcat), fw_out_trans, bw_reverse_out});
|
||||
return concat;
|
||||
}
|
||||
|
||||
const AnfNodePtr TfBidirectionGruCfFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(concat_node != nullptr);
|
||||
MS_LOG(DEBUG) << "bidirection tf gru fusion pass";
|
||||
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(concat_node) != lite::RET_OK) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto transpose_input = utils::cast<AnfNodePtr>((*equiv)[transpose_input_]);
|
||||
MS_ASSERT(transpose_input != nullptr);
|
||||
|
||||
const std::string gru_name = "gru_" + concat_node->fullname_with_scope();
|
||||
auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, gru_name, 0);
|
||||
if (gru_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (TfliteLstmCellFusion::SetAbstractTuple(gru_node, 2) != RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto get_item_node = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, gru_node, 0);
|
||||
if (get_item_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto output_node = GetPostProcessNode(func_graph, get_item_node, gru_node->fullname_with_scope());
|
||||
MS_LOG(INFO) << "gru node:" << gru_node->fullname_with_scope() << " fusion success";
|
||||
return output_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_CF_FUSION_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_CF_FUSION_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
// fuse tf 1.x bidirection_gru into MSLITE GRU
|
||||
class TfBidirectionGruCfFusion : public TfBidirectionGruFusion {
|
||||
public:
|
||||
explicit TfBidirectionGruCfFusion(const std::string &name = "tf_bidirection_gru_cf_fusion", bool multi_graph = true);
|
||||
~TfBidirectionGruCfFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
BaseRef DefineGruCellPattern(const BaseRef &in_ta_read, const BaseRef &switch3_true,
|
||||
const std::vector<VarPtr> &vars) const;
|
||||
const BaseRef DefineBidirectionRnnPattern(const BaseRef &input, const std::vector<VarPtr> &vars,
|
||||
const VarPtr &init_state) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_CF_FUSION_H_
|
|
@ -13,7 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h"
|
||||
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include "ops/concat.h"
|
||||
|
@ -24,32 +24,21 @@
|
|||
#include "ops/transpose.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "utils/utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "securec/include/securec.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kWhileUniqInputsLength = 6;
|
||||
constexpr size_t kCondNodesNum = 12;
|
||||
constexpr size_t kCondCNodesNum = 4;
|
||||
constexpr size_t kBodyNodesNum = 69;
|
||||
constexpr size_t kBodyCNodesNum = 25;
|
||||
const auto &p1 = std::placeholders::_1;
|
||||
|
||||
bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); }
|
||||
|
||||
bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) {
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
auto anf_node = utils::cast<AnfNodePtr>(n);
|
||||
return CheckPrimitiveType(anf_node, prim);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, bool multigraph)
|
||||
: PatternProcessPass(name, multigraph) {
|
||||
TfBidirectionGruFusion::TfBidirectionGruFusion(int num_fw_vars, int num_bw_vars, const std::string &name,
|
||||
bool multi_graph)
|
||||
: PatternProcessPass(name, multi_graph) {
|
||||
/*
|
||||
* vars for while input
|
||||
* fw_while_inputs:
|
||||
|
@ -57,8 +46,10 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name,
|
|||
* bw_while_inputs:
|
||||
* 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias
|
||||
*/
|
||||
for (size_t i = 0; i < kWhileUniqInputsLength; ++i) {
|
||||
for (int i = 0; i < num_fw_vars; ++i) {
|
||||
fw_vars_.emplace_back(std::make_shared<Var>());
|
||||
}
|
||||
for (int i = 0; i < num_bw_vars; ++i) {
|
||||
bw_vars_.emplace_back(std::make_shared<Var>());
|
||||
}
|
||||
input_ = std::make_shared<Var>();
|
||||
|
@ -68,7 +59,7 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name,
|
|||
bw_init_state_ = std::make_shared<Var>();
|
||||
}
|
||||
|
||||
const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
|
||||
const BaseRef TfBidirectionGruFusion::DefinePattern() const {
|
||||
// forward
|
||||
auto fw_reduce = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)),
|
||||
input_length_, std::make_shared<CondVar>(IsParameterNode)});
|
||||
|
@ -134,7 +125,7 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
|
|||
return concat;
|
||||
}
|
||||
|
||||
AnfNodePtr BiDirectionTfGruCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
|
||||
AnfNodePtr TfBidirectionGruFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
|
||||
auto is_parameter1 = std::make_shared<CondVar>(IsParameterNode);
|
||||
auto is_parameter2 = std::make_shared<CondVar>(IsParameterNode);
|
||||
auto is_parameter3 = std::make_shared<CondVar>(IsParameterNode);
|
||||
|
@ -152,7 +143,7 @@ AnfNodePtr BiDirectionTfGruCellFusion::GetCondGraphPattern(const PrimitiveVarMap
|
|||
return pattern;
|
||||
}
|
||||
|
||||
AnfNodePtr BiDirectionTfGruCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
|
||||
AnfNodePtr TfBidirectionGruFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
|
||||
std::vector<CondVarPtr> placeholders;
|
||||
for (int i = 0; i < 13; ++i) {
|
||||
placeholders.emplace_back(std::make_shared<CondVar>(IsParameterNode));
|
||||
|
@ -206,7 +197,7 @@ AnfNodePtr BiDirectionTfGruCellFusion::GetBodyGraphPattern(const PrimitiveVarMap
|
|||
return pattern;
|
||||
}
|
||||
|
||||
ParamValueLitePtr BiDirectionTfGruCellFusion::GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const {
|
||||
ParamValueLitePtr TfBidirectionGruFusion::GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const {
|
||||
MS_ASSERT(parameter_anf != nullptr);
|
||||
if (!utils::isa<ParameterPtr>(parameter_anf)) {
|
||||
MS_LOG(DEBUG) << "parameter_anf is not ParameterPtr";
|
||||
|
@ -221,9 +212,9 @@ ParamValueLitePtr BiDirectionTfGruCellFusion::GetDefaultParamValue(const AnfNode
|
|||
return param_value;
|
||||
}
|
||||
|
||||
STATUS BiDirectionTfGruCellFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf,
|
||||
const AnfNodePtr &bw_cand_kernel_anf, int *input_size,
|
||||
int *hidden_size) const {
|
||||
STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf,
|
||||
const AnfNodePtr &bw_cand_kernel_anf, int *input_size,
|
||||
int *hidden_size) const {
|
||||
MS_ASSERT(fw_cand_kernel != nullptr);
|
||||
MS_ASSERT(bw_cand_kernel != nullptr);
|
||||
MS_ASSERT(input_size != nullptr);
|
||||
|
@ -256,9 +247,9 @@ STATUS BiDirectionTfGruCellFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_ca
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
ParameterPtr BiDirectionTfGruCellFusion::AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name,
|
||||
const std::vector<int> &shape, const TypeId type,
|
||||
void **tensor_data) const {
|
||||
ParameterPtr TfBidirectionGruFusion::AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name,
|
||||
const std::vector<int> &shape, const TypeId type,
|
||||
void **tensor_data) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(tensor_data != nullptr);
|
||||
auto parameter = func_graph->add_parameter();
|
||||
|
@ -300,9 +291,8 @@ ParameterPtr BiDirectionTfGruCellFusion::AddDefaultParameter(const FuncGraphPtr
|
|||
return parameter;
|
||||
}
|
||||
|
||||
void BiDirectionTfGruCellFusion::CopyFlattenMatData(const float *mat, const int R, const int C, const int r0,
|
||||
const int r1, const int c0, const int c1, float *data,
|
||||
bool t) const {
|
||||
void TfBidirectionGruFusion::CopyFlattenMatData(const float *mat, const int R, const int C, const int r0, const int r1,
|
||||
const int c0, const int c1, float *data, bool t) const {
|
||||
MS_ASSERT(mat != nullptr);
|
||||
MS_ASSERT(data != nullptr);
|
||||
MS_ASSERT(0 <= r0 && r0 < r1 && r1 <= R);
|
||||
|
@ -320,9 +310,9 @@ void BiDirectionTfGruCellFusion::CopyFlattenMatData(const float *mat, const int
|
|||
}
|
||||
}
|
||||
|
||||
STATUS BiDirectionTfGruCellFusion::ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight,
|
||||
const int input_size, const int hidden_size,
|
||||
float *gate_tensor_data, float *recu_tensor_data) const {
|
||||
STATUS TfBidirectionGruFusion::ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight,
|
||||
const int input_size, const int hidden_size, float *gate_tensor_data,
|
||||
float *recu_tensor_data) const {
|
||||
MS_ASSERT(gate_weight != nullptr);
|
||||
MS_ASSERT(cand_weight != nullptr);
|
||||
MS_ASSERT(gate_tensor_data != nullptr);
|
||||
|
@ -375,8 +365,8 @@ STATUS BiDirectionTfGruCellFusion::ConvertWeightData(const AnfNodePtr &gate_weig
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS BiDirectionTfGruCellFusion::ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias,
|
||||
const int hidden_size, float *tensor_data) const {
|
||||
STATUS TfBidirectionGruFusion::ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias,
|
||||
const int hidden_size, float *tensor_data) const {
|
||||
MS_ASSERT(bias != nullptr);
|
||||
MS_ASSERT(tensor_data != nullptr);
|
||||
std::vector<int> gate_shape{hidden_size * 2};
|
||||
|
@ -407,10 +397,9 @@ STATUS BiDirectionTfGruCellFusion::ConvertBiasData(const AnfNodePtr &gate_bias,
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph,
|
||||
const AnfNodePtr &fw_init_state,
|
||||
const AnfNodePtr &bw_init_state,
|
||||
const std::string base_name) const {
|
||||
CNodePtr TfBidirectionGruFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &fw_init_state,
|
||||
const AnfNodePtr &bw_init_state,
|
||||
const std::string base_name) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(fw_init_state != nullptr);
|
||||
MS_ASSERT(bw_init_state != nullptr);
|
||||
|
@ -424,35 +413,32 @@ CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &f
|
|||
return new_node;
|
||||
}
|
||||
|
||||
CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
||||
const EquivPtr &equiv, const EquivPtr &fw_body_equiv,
|
||||
const EquivPtr &bw_body_equiv,
|
||||
const std::string &base_name) const {
|
||||
CNodePtr TfBidirectionGruFusion::CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
||||
const EquivPtr &equiv, const std::string &base_name,
|
||||
int var_offset) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(input != nullptr);
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
MS_ASSERT(fw_body_equiv != nullptr);
|
||||
MS_ASSERT(bw_body_equiv != nullptr);
|
||||
auto gru_prim = std::make_shared<ops::GRU>();
|
||||
gru_prim->set_bidirectional(true);
|
||||
auto value_node = NewValueNode(gru_prim);
|
||||
|
||||
auto fw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[2]]);
|
||||
auto fw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset]]);
|
||||
MS_ASSERT(fw_gate_kernel != nullptr);
|
||||
auto fw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[3]]);
|
||||
auto fw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset + 1]]);
|
||||
MS_ASSERT(fw_gate_bias != nullptr);
|
||||
auto fw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[4]]);
|
||||
auto fw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset + 2]]);
|
||||
MS_ASSERT(fw_cand_kernel != nullptr);
|
||||
auto fw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[5]]);
|
||||
auto fw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset + 3]]);
|
||||
MS_ASSERT(fw_cand_bias != nullptr);
|
||||
|
||||
auto bw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[2]]);
|
||||
auto bw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset]]);
|
||||
MS_ASSERT(bw_gate_kernel != nullptr);
|
||||
auto bw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[3]]);
|
||||
auto bw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset + 1]]);
|
||||
MS_ASSERT(bw_gate_bias != nullptr);
|
||||
auto bw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[4]]);
|
||||
auto bw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset + 2]]);
|
||||
MS_ASSERT(bw_cand_kernel != nullptr);
|
||||
auto bw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[5]]);
|
||||
auto bw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset + 3]]);
|
||||
MS_ASSERT(bw_cand_bias != nullptr);
|
||||
|
||||
auto fw_init_state = utils::cast<AnfNodePtr>((*equiv)[fw_init_state_]);
|
||||
|
@ -522,8 +508,8 @@ CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr
|
|||
return new_node;
|
||||
}
|
||||
|
||||
CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
|
||||
const std::string base_name) const {
|
||||
CNodePtr TfBidirectionGruFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
|
||||
const std::string base_name) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(gru_output != nullptr);
|
||||
auto split_prim = std::make_shared<ops::Split>();
|
||||
|
@ -571,8 +557,8 @@ CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func
|
|||
return transpose_new_node;
|
||||
}
|
||||
|
||||
const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node,
|
||||
const EquivPtr &equiv) const {
|
||||
const AnfNodePtr TfBidirectionGruFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(concat_node != nullptr);
|
||||
MS_LOG(DEBUG) << "bidirection tf gru fusion pass";
|
||||
|
@ -628,7 +614,7 @@ const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_gr
|
|||
}
|
||||
|
||||
const std::string gru_name = "gru_" + concat_node->fullname_with_scope();
|
||||
auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, fw_body_equiv, bw_body_equiv, gru_name);
|
||||
auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, gru_name, 2);
|
||||
if (gru_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
|
@ -13,12 +13,13 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
@ -27,22 +28,26 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class BiDirectionTfGruCellFusion : public PatternProcessPass {
|
||||
constexpr size_t kWhileUniqInputsLength = 6;
|
||||
// fuse tf 2.x bidirection_gru into MSLITE GRU
|
||||
class TfBidirectionGruFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit BiDirectionTfGruCellFusion(const std::string &name = "bidirection_tf_gru_cell_fusion",
|
||||
bool multigraph = true);
|
||||
~BiDirectionTfGruCellFusion() override = default;
|
||||
explicit TfBidirectionGruFusion(int num_fw_vars = kWhileUniqInputsLength, int num_bw_vars = kWhileUniqInputsLength,
|
||||
const std::string &name = "tf_bidirection_gru_fusion", bool multi_graph = true);
|
||||
~TfBidirectionGruFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
protected:
|
||||
virtual AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const;
|
||||
CNodePtr CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const EquivPtr &equiv,
|
||||
const std::string &base_name, int var_offset) const;
|
||||
CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
|
||||
const std::string base_name) const;
|
||||
|
||||
private:
|
||||
AnfNodePtr GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const;
|
||||
CNodePtr CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const EquivPtr &equiv,
|
||||
const EquivPtr &fw_body_equiv, const EquivPtr &bw_body_equiv,
|
||||
const std::string &base_name) const;
|
||||
|
||||
ParamValueLitePtr GetDefaultParamValue(const AnfNodePtr ¶meter_anf) const;
|
||||
lite::STATUS GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, const AnfNodePtr &bw_cand_kernel_anf,
|
||||
int *input_size, int *hidden_size) const;
|
||||
|
@ -56,10 +61,8 @@ class BiDirectionTfGruCellFusion : public PatternProcessPass {
|
|||
const int c1, float *data, bool t = false) const;
|
||||
CNodePtr GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &fw_init_state,
|
||||
const AnfNodePtr &bw_init_state, const std::string base_name) const;
|
||||
CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
|
||||
const std::string base_name) const;
|
||||
|
||||
private:
|
||||
protected:
|
||||
std::vector<VarPtr> fw_vars_;
|
||||
std::vector<VarPtr> bw_vars_;
|
||||
VarPtr input_;
|
||||
|
@ -68,7 +71,16 @@ class BiDirectionTfGruCellFusion : public PatternProcessPass {
|
|||
VarPtr fw_init_state_;
|
||||
VarPtr bw_init_state_;
|
||||
};
|
||||
inline bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); }
|
||||
|
||||
inline bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) {
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
auto anf_node = utils::cast<AnfNodePtr>(n);
|
||||
return CheckPrimitiveType(anf_node, prim);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BIDIRECTION_TF_GRU_CELL_FUSION_H_
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_
|
|
@ -0,0 +1,249 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/graph/functionalize_cond.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <deque>
|
||||
#include <unordered_set>
|
||||
#include "include/errorcode.h"
|
||||
#include "ops/make_tuple.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "ops/return.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
|
||||
STATUS FunctionalizeCond::GetSwitchBranchType(const CNodePtr &switch_cnode, BranchType *branch_type) {
|
||||
MS_ASSERT(switch_cnode != nullptr);
|
||||
MS_ASSERT(branch_type != nullptr);
|
||||
auto manager = fg_->manager();
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "manager is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto node_users = manager->node_users()[switch_cnode];
|
||||
if (node_users.size() != 1) { // only one output of switch is referenced in cond
|
||||
MS_LOG(ERROR) << "switch's node users is not correct";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto node_user = node_users.front();
|
||||
auto tuple_get_item = node_user.first;
|
||||
if (!utils::isa<CNodePtr>(tuple_get_item) || !CheckPrimitiveType(tuple_get_item, prim::kPrimTupleGetItem)) {
|
||||
MS_LOG(ERROR) << "switch's node user is not TupleGetItem";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto tuple_get_item_cnode = utils::cast<CNodePtr>(tuple_get_item);
|
||||
auto idx = GetTupleGetItemOutIndex(tuple_get_item_cnode);
|
||||
if (idx == 0) {
|
||||
*branch_type = kElseBranch;
|
||||
} else if (idx == 1) {
|
||||
*branch_type = kThenBranch;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "wrong tuple_get_item index";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS FunctionalizeCond::BranchSubGraphAddNodes(const FuncGraphPtr &graph, const AnfNodePtr &root_node,
|
||||
BranchType branch_type) {
|
||||
std::deque<AnfNodePtr> q;
|
||||
std::unordered_set<AnfNodePtr> vis;
|
||||
q.push_back(root_node);
|
||||
while (!q.empty()) {
|
||||
auto node = q.front();
|
||||
q.pop_front();
|
||||
vis.insert(node);
|
||||
if (FunctionalizeControlOpPass::IsSwitch(node)) {
|
||||
auto cnode = utils::cast<CNodePtr>(node);
|
||||
BranchType this_type;
|
||||
if (GetSwitchBranchType(cnode, &this_type) != RET_OK || this_type != branch_type) {
|
||||
MS_LOG(ERROR) << "switch node in branch " << branch_type << " is not correct";
|
||||
return RET_ERROR;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (utils::isa<ParameterPtr>(node)) {
|
||||
graph->add_parameter(node->cast<ParameterPtr>());
|
||||
} else {
|
||||
graph->AddNode(node);
|
||||
}
|
||||
node->set_func_graph(graph);
|
||||
if (utils::isa<CNodePtr>(node)) {
|
||||
auto cnode = utils::cast<CNodePtr>(node);
|
||||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||
auto inputi = cnode->input(i);
|
||||
if (vis.find(inputi) == vis.end()) {
|
||||
q.push_back(cnode->input(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int FunctionalizeCond::PosInInputNodes(const CNodePtr &node) {
|
||||
auto index = std::find(input_nodes_.begin(), input_nodes_.end(), node);
|
||||
if (index == input_nodes_.end()) {
|
||||
input_nodes_.push_back(node);
|
||||
return input_nodes_.size() - 1;
|
||||
}
|
||||
return index - input_nodes_.begin();
|
||||
}
|
||||
|
||||
STATUS FunctionalizeCond::IdentifySubgraphInput(const FuncGraphPtr &graph, std::string graph_name) {
|
||||
std::vector<AnfNodePtr> nodes_need_drop{};
|
||||
for (auto &cnode : graph->GetOrderedCnodes()) {
|
||||
for (auto &input_node : cnode->inputs()) {
|
||||
if (FunctionalizeControlOpPass::IsSwitch(input_node)) {
|
||||
auto switch_node = input_node->cast<CNodePtr>();
|
||||
auto switch_input = utils::cast<CNodePtr>(switch_node->input(1));
|
||||
auto pos = PosInInputNodes(switch_input);
|
||||
nodes_need_drop.push_back(cnode);
|
||||
pred_nodes_.push_back(switch_node->input(2));
|
||||
// set parameter
|
||||
auto parameter = graph->add_parameter();
|
||||
parameter->set_abstract(cnode->abstract());
|
||||
// hardcode for subgraph input name
|
||||
parameter->set_name(graph_name + "_input_" + std::to_string(pos) + "_parameter");
|
||||
|
||||
// replace switch
|
||||
auto manager = fg_->manager();
|
||||
auto node_users = manager->node_users()[cnode];
|
||||
for (auto &node_user : node_users) {
|
||||
if (graph->nodes().contains(node_user.first)) {
|
||||
manager->SetEdge(node_user.first, node_user.second, parameter);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
FuncGraphPtr FunctionalizeCond::CreateBranchGraph(const AnfNodePtr &node, std::string name, BranchType branch_type) {
|
||||
auto graph = FunctionalizeControlOpPass::NewFuncGraph(name, mindspore::lite::converter::FmkType_TF);
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "new graph Partial Node return nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
graph->set_manager(fg_->manager());
|
||||
auto status = BranchSubGraphAddNodes(graph, node, branch_type);
|
||||
if (status != RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!CheckPrimitiveType(node, prim::kPrimSwitch)) { // graph is not empty
|
||||
auto return_prim_ptr = std::make_shared<ops::Return>();
|
||||
if (return_prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto value_node = NewValueNode(return_prim_ptr);
|
||||
std::vector<AnfNodePtr> op_inputs{value_node, node}; // If subgraph only has one output tensor
|
||||
auto return_cnode = graph->NewCNode(op_inputs);
|
||||
return_cnode->set_fullname_with_scope(name + "-return");
|
||||
return_cnode->set_func_graph(graph);
|
||||
graph->set_return(return_cnode);
|
||||
graph->output()->cast<CNodePtr>()->set_fullname_with_scope(name + "_output_0_cnode");
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
CNodePtr FunctionalizeCond::CreateNewIf(const FuncGraphPtr &else_branch, const FuncGraphPtr &then_branch) {
|
||||
MS_ASSERT(else_branch != nullptr);
|
||||
MS_ASSERT(then_branch != nullptr);
|
||||
|
||||
auto if_primc = std::make_shared<mindspore::lite::If>();
|
||||
if (if_primc == nullptr) {
|
||||
MS_LOG(ERROR) << "new if_primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto if_value_node = NewValueNode(if_primc);
|
||||
if (if_value_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto then_value_node = NewValueNode(then_branch);
|
||||
auto else_value_node = NewValueNode(else_branch);
|
||||
std::vector<AnfNodePtr> if_op_inputs = {if_value_node, then_value_node, else_value_node, pred_node_};
|
||||
std::copy(input_nodes_.begin(), input_nodes_.end(), std::back_inserter(if_op_inputs));
|
||||
return fg_->NewCNode(if_op_inputs);
|
||||
}
|
||||
|
||||
STATUS FunctionalizeCond::VerifyPredictNode() {
|
||||
if (pred_nodes_.empty()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 1; i < pred_nodes_.size(); ++i) {
|
||||
if (pred_nodes_[i] != pred_nodes_[0]) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (!utils::isa<CNodePtr>(pred_nodes_[0])) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
pred_node_ = utils::cast<CNodePtr>(pred_nodes_[0]);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS FunctionalizeCond::Process() {
|
||||
if (fg_ == nullptr || merge_node_ == nullptr || merge_node_->inputs().size() != 3) {
|
||||
MS_LOG(ERROR) << "fg or merge is not correct";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto else_branch_name = merge_node_->fullname_with_scope() + "-partial-if-else";
|
||||
auto then_branch_name = merge_node_->fullname_with_scope() + "-partial-then-else";
|
||||
|
||||
auto else_branch = CreateBranchGraph(merge_node_->input(1), else_branch_name, kElseBranch);
|
||||
if (else_branch == nullptr) {
|
||||
MS_LOG(ERROR) << "create else branch failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto then_branch = CreateBranchGraph(merge_node_->input(2), then_branch_name, kThenBranch);
|
||||
if (then_branch == nullptr) {
|
||||
MS_LOG(ERROR) << "create then branch failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto status = IdentifySubgraphInput(else_branch, else_branch_name);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
status = IdentifySubgraphInput(then_branch, then_branch_name);
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = VerifyPredictNode();
|
||||
if (status != RET_OK) {
|
||||
return status;
|
||||
}
|
||||
|
||||
auto if_node = CreateNewIf(else_branch, then_branch);
|
||||
if (if_node == nullptr) {
|
||||
MS_LOG(ERROR) << "create if node error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if_node->set_abstract(merge_node_->abstract()->Clone());
|
||||
auto manager = fg_->manager();
|
||||
auto node_users = manager->node_users()[merge_node_];
|
||||
for (auto &node_user : node_users) {
|
||||
manager->SetEdge(node_user.first, node_user.second, if_node);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_FUNCTIONALIZE_COND_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_FUNCTIONALIZE_COND_H_
|
||||
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/optimizer/graph/functionalize_control_op_pass.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
namespace mindspore::opt {
|
||||
|
||||
typedef enum { kThenBranch = 0, kElseBranch = 1 } BranchType;
|
||||
|
||||
// Functionalize all the switch-merge nodes of a loop-free graph into single switch node.
|
||||
// Precondition: While loops must have been functionalized.
|
||||
class FunctionalizeCond {
|
||||
public:
|
||||
FunctionalizeCond(FuncGraphPtr fg, CNodePtr merge_node) : fg_(fg), merge_node_(merge_node) {}
|
||||
|
||||
STATUS Process();
|
||||
|
||||
private:
|
||||
STATUS GetSwitchBranchType(const CNodePtr &switch_cnode, BranchType *branch_type);
|
||||
STATUS BranchSubGraphAddNodes(const FuncGraphPtr &graph, const AnfNodePtr &root_node, BranchType branch_type);
|
||||
FuncGraphPtr CreateBranchGraph(const AnfNodePtr &node, std::string name, BranchType branch_type);
|
||||
int PosInInputNodes(const CNodePtr &node);
|
||||
STATUS IdentifySubgraphInput(const FuncGraphPtr &graph, std::string graph_name);
|
||||
CNodePtr CreateNewIf(const FuncGraphPtr &else_branch, const FuncGraphPtr &then_branch);
|
||||
STATUS VerifyPredictNode();
|
||||
|
||||
FuncGraphPtr fg_ = nullptr;
|
||||
CNodePtr merge_node_ = nullptr;
|
||||
CNodePtr pred_node_ = nullptr;
|
||||
std::vector<CNodePtr> input_nodes_{};
|
||||
std::vector<AnfNodePtr> pred_nodes_{};
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_FUNCTIONALIZE_COND_H_
|
|
@ -18,6 +18,7 @@
|
|||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include "tools/optimizer/graph/functionalize_while.h"
|
||||
#include "tools/optimizer/graph/functionalize_cond.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
|
@ -100,6 +101,25 @@ STATUS FunctionalizeControlOpPass::BuildWhileSubgraph(const FuncGraphPtr &func_g
|
|||
return ret;
|
||||
}
|
||||
|
||||
STATUS FunctionalizeControlOpPass::BuildIfSubgraph(const FuncGraphPtr &func_graph) {
|
||||
int ret = RET_OK;
|
||||
auto nodes = func_graph->nodes();
|
||||
for (auto &node : nodes) {
|
||||
if (!IsMerge(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = utils::cast<CNodePtr>(node);
|
||||
FunctionalizeCond fc(func_graph, cnode);
|
||||
ret = fc.Process();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "run functionalize cond failed, ret: " << ret;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool FunctionalizeControlOpPass::Run(const FuncGraphPtr &func_graph) {
|
||||
// use name to find the frame
|
||||
InitNodeClusters(func_graph);
|
||||
|
@ -107,6 +127,10 @@ bool FunctionalizeControlOpPass::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "build while subgraph failed.";
|
||||
return false;
|
||||
}
|
||||
if (BuildIfSubgraph(func_graph) != RET_OK) {
|
||||
MS_LOG(ERROR) << "build while subgraph failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -23,10 +23,7 @@
|
|||
#include <memory>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/converter/ops/enter.h"
|
||||
#include "tools/converter/ops/exit.h"
|
||||
#include "tools/converter/ops/loop_cond.h"
|
||||
#include "tools/converter/ops/next_iteration.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
|
@ -70,6 +67,7 @@ class FunctionalizeControlOpPass : public Pass {
|
|||
|
||||
protected:
|
||||
STATUS BuildWhileSubgraph(const FuncGraphPtr &func_graph);
|
||||
STATUS BuildIfSubgraph(const FuncGraphPtr &func_graph);
|
||||
std::vector<std::pair<std::string, std::vector<AnfNodePtr>>> node_clusters_{};
|
||||
std::vector<CNodePtr> loop_cond_nodes_{};
|
||||
};
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/optimizer/graph/unused_node_remove_pass.h"
|
||||
#include <deque>
|
||||
#include <unordered_set>
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
|
||||
STATUS UnusedNodeRemovePass::ProcessGraph(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto return_node = func_graph->get_return();
|
||||
if (return_node == nullptr) {
|
||||
return RET_OK;
|
||||
}
|
||||
std::unordered_set<AnfNodePtr> vis;
|
||||
std::deque<AnfNodePtr> q;
|
||||
q.push_back(return_node);
|
||||
while (!q.empty()) {
|
||||
auto node = q.front();
|
||||
vis.insert(node);
|
||||
q.pop_front();
|
||||
if (utils::isa<CNodePtr>(node)) {
|
||||
auto cnode = utils::cast<CNodePtr>(node);
|
||||
for (auto &input : cnode->inputs()) {
|
||||
if (vis.find(input) == vis.end()) {
|
||||
q.push_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (utils::isa<FuncGraphPtr>(node)) {
|
||||
auto sub_graph = utils::cast<FuncGraphPtr>(node);
|
||||
auto status = ProcessGraph(sub_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "process sub graph failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto nodes = func_graph->nodes();
|
||||
for (auto &node : nodes) {
|
||||
if (vis.find(node) == vis.end()) {
|
||||
func_graph->DropNode(node);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool UnusedNodeRemovePass::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto status = ProcessGraph(func_graph);
|
||||
return status == RET_OK;
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNUSED_NODE_REMOVE_PASS_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNUSED_NODE_REMOVE_PASS_H_
|
||||
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "mindspore/lite/include/errorcode.h"
|
||||
|
||||
using mindspore::lite::STATUS;
|
||||
namespace mindspore::opt {
|
||||
class UnusedNodeRemovePass : public Pass {
|
||||
public:
|
||||
UnusedNodeRemovePass() : Pass("remove_unused_node_pass") {}
|
||||
~UnusedNodeRemovePass() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
STATUS ProcessGraph(const FuncGraphPtr &func_graph);
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_UNUSED_NODE_REMOVE_PASS_H_
|
Loading…
Reference in New Issue