diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc index 111f29de65..ad3093cbc3 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc @@ -39,20 +39,6 @@ namespace mindspore { namespace opt { namespace { -bool IsFusibleOp(const AnfNodePtr &node) { -#if ENABLE_D - const std::set graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", - "LambNextMV", "LambUpdateWithLR"}; - if (AnfAlgo::IsGraphKernel(node)) { - auto fg_attr = AnfAlgo::GetCNodeFuncGraphPtr(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); - if (fg_attr != nullptr) { - return graph_kernel_black_list.count(GetValue(fg_attr)) == 0; - } - } -#endif - return IsBasicFuseOp(node) || AnfAlgo::IsGraphKernel(node); -} - IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) { if (cur_node == node) { return FOLLOW; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/cast_matmul_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/cast_matmul_fusion.cc new file mode 100644 index 0000000000..7c1d9913dc --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/cast_matmul_fusion.cc @@ -0,0 +1,120 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/graph_kernel/cast_matmul_fusion.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" + +namespace mindspore { +namespace opt { +namespace { +// Check if leaf is used by root +bool HasPath(const AnfNodePtr &leaf, const AnfNodePtr &root, const FuncGraphManagerPtr &mng) { + MS_EXCEPTION_IF_NULL(mng); + bool result = false; + auto IncludeUser = [&result, &root](const AnfNodePtr &node) { + if (node == root) { + result = true; + } + return result ? EXCLUDE : FOLLOW; + }; + static_cast(DeepLinkedGraphSearch(leaf, IncludeUser)); + return result; +} +} // namespace + +/* MatMul supports fp32 bias, so remove the redundant cast if cast cannot fuse forword + * case1, cast only used by MatMul + * + * bias_fp32 = depend(bias_fp32, u) + * %0 = cast(bias_fp32, fp16) + * %1 = MatMul(A_fp16, B_fp16, %0) + * ------> + * bias_fp32 = depend(bias_fp32, u) + * %1 = MatMul(A_fp16, B_fp16, bias_fp32) + * + * case2, cast used by MatMul and UpdateStatus + * + * bias_fp32 = load(p, status) + * %0 = cast(bias_fp32, fp16) + * %1 = MatMul(A_fp16, B_fp16, %0) + * %2 = UpstateStatus(status, %0) + * ------> + * bias_fp32 = load(p, status) + * %1 = MatMul(A_fp16, B_fp16, bias_fp32) + * %2 = UpstateStatus(status, %1) + */ +bool CastMatmulFusion::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + func_graph->set_manager(mng); + } + auto changed = false; + auto nodes = TopoSort(func_graph->get_return()); + for (auto node : nodes) { + if (!IsPrimitiveCNode(node, prim::kPrimMatMul)) { + continue; + } + auto cnode = node->cast(); + if (cnode->size() != 4) { + continue; + } + auto cast_node = cnode->input(3); + if (!IsPrimitiveCNode(cast_node, prim::kPrimCast)) { + continue; + } + auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_node, 0); + auto cast_output_type = AnfAlgo::GetOutputDeviceDataType(cast_node, 0); + if (cast_input_type != kNumberTypeFloat32 || cast_output_type != kNumberTypeFloat16) { + continue; + } + // Cast cannot fuse with its input + if (IsFusibleOp((cast_node->cast())->input(1))) { + continue; + } + + auto user_index_set = mng->node_users()[cast_node]; + // Case1 : Cast is only used by matmul + if (user_index_set.size() == 1) { + mng->Replace(cast_node, (cast_node->cast())->input(1)); + changed = true; + continue; + } + + // Case2 : Cast is used by matmul and Upstatus + if (user_index_set.size() > 2) { + continue; + } + for (auto user_index : user_index_set) { + // Exclude when UpdateStatus-> ... ->matmul path is found + if (IsPrimitiveCNode(user_index.first, prim::kPrimUpdateState) && !HasPath(user_index.first, node, mng)) { + auto update_state = (user_index.first)->cast(); + update_state->set_input(2, node); + cnode->set_input(4, (cast_node->cast())->input(1)); + mng->RemoveRoots(); + mng->KeepRoots({func_graph}); + changed = true; + } + } + } + + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_matmul.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/cast_matmul_fusion.h similarity index 64% rename from mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_matmul.h rename to mindspore/ccsrc/backend/optimizer/graph_kernel/cast_matmul_fusion.h index 6f607f1a26..15bc5b3ca4 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_matmul.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/cast_matmul_fusion.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_MATMUL_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_MATMUL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_ #include #include @@ -24,13 +24,13 @@ namespace mindspore { namespace opt { -class OptimizeMatmul : public Pass { +class CastMatmulFusion : public Pass { public: - OptimizeMatmul() : Pass("optimize_matmul") {} - ~OptimizeMatmul() override = default; + CastMatmulFusion() : Pass("cast_matmul_fusion") {} + ~CastMatmulFusion() override = default; bool Run(const FuncGraphPtr &graph) override; }; -using OptimizeMatmulPtr = std::shared_ptr; +using OptimizeMatmulPtr = std::shared_ptr; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_MATMUL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index edbb9ddfe0..9b6766346f 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -627,6 +627,20 @@ bool IsBasicFuseOp(const AnfNodePtr &node) { [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); } +bool IsFusibleOp(const AnfNodePtr &node) { +#if ENABLE_D + const std::set graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", + "LambNextMV", "LambUpdateWithLR"}; + if (AnfAlgo::IsGraphKernel(node)) { + auto fg_attr = AnfAlgo::GetCNodeFuncGraphPtr(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); + if (fg_attr != nullptr) { + return graph_kernel_black_list.count(GetValue(fg_attr)) == 0; + } + } +#endif + return IsBasicFuseOp(node) || AnfAlgo::IsGraphKernel(node); +} + void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index 7baa5571af..a8a4f521f8 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -73,6 +73,7 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector GetFusibleOpList(); bool IsBasicFuseOp(const AnfNodePtr &node); +bool IsFusibleOp(const AnfNodePtr &node); void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); void InitDependPrior(const std::vector &todos, std::multimap> *depend_prior); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc index a3374303b2..3eda671792 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc @@ -30,7 +30,7 @@ #include "backend/optimizer/graph_kernel/tensor_promotion.h" #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h" #include "backend/optimizer/graph_kernel/graph_kernel_expander.h" -#include "backend/optimizer/graph_kernel/optimize_matmul.h" +#include "backend/optimizer/graph_kernel/cast_matmul_fusion.h" #include "backend/optimizer/graph_kernel/raise_reduction_precision.h" #include "backend/optimizer/graph_kernel/graph_kernel_cse.h" #include "backend/optimizer/graph_kernel/shape_ops_splitter.h" @@ -52,7 +52,7 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() { if (is_ascend) { // Remove redundant Cast(bias, fp16) for Matmul input - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); // Reorder TransData-Cast to Cast-TransData pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_matmul.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_matmul.cc deleted file mode 100644 index 94106034d6..0000000000 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_matmul.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "backend/optimizer/graph_kernel/optimize_matmul.h" -#include -#include "backend/session/anf_runtime_algorithm.h" -#include "backend/kernel_compiler/common_utils.h" -#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" - -namespace mindspore { -namespace opt { -/* MatMul supports fp32 bias, so remove the redundant cast when cast only used by MatMul - * - * %0 = cast(bias_fp32, fp16) - * %1 = MatMul(A_fp16, B_fp16, %0) - * ------> - * %1 = MatMul(A_fp16, B_fp16, bias_fp32) - */ -bool OptimizeMatmul::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - auto mng = func_graph->manager(); - if (mng == nullptr) { - mng = Manage(func_graph, true); - func_graph->set_manager(mng); - } - auto changed = false; - auto nodes = TopoSort(func_graph->get_return()); - for (auto node : nodes) { - if (!IsPrimitiveCNode(node, prim::kPrimMatMul)) { - continue; - } - auto cnode = node->cast(); - if (cnode->size() != 4) { - continue; - } - auto cast_node = cnode->input(3); - if (!IsPrimitiveCNode(cast_node, prim::kPrimCast)) { - continue; - } - auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_node, 0); - auto cast_output_type = AnfAlgo::GetOutputDeviceDataType(cast_node, 0); - if (cast_input_type == kNumberTypeFloat32 && cast_output_type == kNumberTypeFloat16 && - mng->node_users()[cast_node].size() == 1) { - mng->Replace(cast_node, (cast_node->cast())->input(1)); - changed = true; - } - } - - return changed; -} -} // namespace opt -} // namespace mindspore diff --git a/tests/st/ops/graph_kernel/test_matmul_cast.py b/tests/st/ops/graph_kernel/test_cast_matmul_fusion.py similarity index 100% rename from tests/st/ops/graph_kernel/test_matmul_cast.py rename to tests/st/ops/graph_kernel/test_cast_matmul_fusion.py