diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 7a35627e25e..6c245d7548c 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -45,6 +45,7 @@ #include "pre_activate/ascend/ir_fusion/mul_add_fusion.h" #include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h" #include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h" +#include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" #include "pre_activate/ascend/format_type/insert_trans_op.h" #include "pre_activate/pass/getitem_tuple.h" #include "pre_activate/pass/optimize_dependence.h" @@ -113,6 +114,7 @@ void AscendDataLayout(const std::shared_ptr &kernel_graph) data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc new file mode 100644 index 00000000000..5e265f2cf19 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" +#include +#include "session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "operator/ops.h" + +namespace mindspore { +namespace opt { +const BaseRef RemoveReshapePair::DefinePattern() const { + const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); + VectorRef reshape({prim_reshape, input_varptr_}); + + return VectorRef({prim::kPrimReshape, reshape}); +} + +const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(reshape_op_1); + // If reshape operator used by more than one other operators, reshape operator cant not be deleted directly + auto users = manager->node_users()[reshape_op_1]; + if (users.size() > 1) { + return nullptr; + } + auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(reshape_op_2); + users = manager->node_users()[reshape_op_2]; + if (users.size() > 1) { + return nullptr; + } + auto input_node = reshape_op_2->input(1); + return input_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h new file mode 100644 index 00000000000..a284f4eaa95 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ + +#include +#include +#include "ir/anf.h" +#include "pre_activate/common/pattern_engine.h" +#include "pre_activate/common/helper.h" +#include "pre_activate/common/optimizer.h" + +namespace mindspore { +namespace opt { +class RemoveReshapePair : public PatternProcessPass { + public: + explicit RemoveReshapePair(bool multigraph = true) : PatternProcessPass("remove_reshape_pair", multigraph) { + input_varptr_ = std::make_shared(); + } + ~RemoveReshapePair() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input_varptr_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_