!14553 [GraphKernel] refine cast matmul fusion

From: @lingyunli63
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
This commit is contained in:
mindspore-ci-bot 2021-04-06 11:07:55 +08:00 committed by Gitee
commit f324a9a760
8 changed files with 144 additions and 87 deletions

View File

@ -39,20 +39,6 @@
namespace mindspore {
namespace opt {
namespace {
bool IsFusibleOp(const AnfNodePtr &node) {
#if ENABLE_D
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
"LambNextMV", "LambUpdateWithLR"};
if (AnfAlgo::IsGraphKernel(node)) {
auto fg_attr = AnfAlgo::GetCNodeFuncGraphPtr(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (fg_attr != nullptr) {
return graph_kernel_black_list.count(GetValue<std::string>(fg_attr)) == 0;
}
}
#endif
return IsBasicFuseOp(node) || AnfAlgo::IsGraphKernel(node);
}
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
if (cur_node == node) {
return FOLLOW;

View File

@ -0,0 +1,120 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/cast_matmul_fusion.h"
#include <tuple>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
namespace mindspore {
namespace opt {
namespace {
// Check if leaf is used by root
bool HasPath(const AnfNodePtr &leaf, const AnfNodePtr &root, const FuncGraphManagerPtr &mng) {
MS_EXCEPTION_IF_NULL(mng);
bool result = false;
auto IncludeUser = [&result, &root](const AnfNodePtr &node) {
if (node == root) {
result = true;
}
return result ? EXCLUDE : FOLLOW;
};
static_cast<void>(DeepLinkedGraphSearch(leaf, IncludeUser));
return result;
}
} // namespace
/* MatMul supports fp32 bias, so remove the redundant cast if cast cannot fuse forword
* case1, cast only used by MatMul
*
* bias_fp32 = depend(bias_fp32, u)
* %0 = cast(bias_fp32, fp16)
* %1 = MatMul(A_fp16, B_fp16, %0)
* ------>
* bias_fp32 = depend(bias_fp32, u)
* %1 = MatMul(A_fp16, B_fp16, bias_fp32)
*
* case2, cast used by MatMul and UpdateStatus
*
* bias_fp32 = load(p, status)
* %0 = cast(bias_fp32, fp16)
* %1 = MatMul(A_fp16, B_fp16, %0)
* %2 = UpstateStatus(status, %0)
* ------>
* bias_fp32 = load(p, status)
* %1 = MatMul(A_fp16, B_fp16, bias_fp32)
* %2 = UpstateStatus(status, %1)
*/
bool CastMatmulFusion::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
auto changed = false;
auto nodes = TopoSort(func_graph->get_return());
for (auto node : nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimMatMul)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (cnode->size() != 4) {
continue;
}
auto cast_node = cnode->input(3);
if (!IsPrimitiveCNode(cast_node, prim::kPrimCast)) {
continue;
}
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_node, 0);
auto cast_output_type = AnfAlgo::GetOutputDeviceDataType(cast_node, 0);
if (cast_input_type != kNumberTypeFloat32 || cast_output_type != kNumberTypeFloat16) {
continue;
}
// Cast cannot fuse with its input
if (IsFusibleOp((cast_node->cast<CNodePtr>())->input(1))) {
continue;
}
auto user_index_set = mng->node_users()[cast_node];
// Case1 : Cast is only used by matmul
if (user_index_set.size() == 1) {
mng->Replace(cast_node, (cast_node->cast<CNodePtr>())->input(1));
changed = true;
continue;
}
// Case2 : Cast is used by matmul and Upstatus
if (user_index_set.size() > 2) {
continue;
}
for (auto user_index : user_index_set) {
// Exclude when UpdateStatus-> ... ->matmul path is found
if (IsPrimitiveCNode(user_index.first, prim::kPrimUpdateState) && !HasPath(user_index.first, node, mng)) {
auto update_state = (user_index.first)->cast<CNodePtr>();
update_state->set_input(2, node);
cnode->set_input(4, (cast_node->cast<CNodePtr>())->input(1));
mng->RemoveRoots();
mng->KeepRoots({func_graph});
changed = true;
}
}
}
return changed;
}
} // namespace opt
} // namespace mindspore

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_MATMUL_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_MATMUL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_
#include <map>
#include <memory>
@ -24,13 +24,13 @@
namespace mindspore {
namespace opt {
class OptimizeMatmul : public Pass {
class CastMatmulFusion : public Pass {
public:
OptimizeMatmul() : Pass("optimize_matmul") {}
~OptimizeMatmul() override = default;
CastMatmulFusion() : Pass("cast_matmul_fusion") {}
~CastMatmulFusion() override = default;
bool Run(const FuncGraphPtr &graph) override;
};
using OptimizeMatmulPtr = std::shared_ptr<OptimizeMatmul>;
using OptimizeMatmulPtr = std::shared_ptr<CastMatmulFusion>;
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_MATMUL_H_
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_

View File

@ -627,6 +627,20 @@ bool IsBasicFuseOp(const AnfNodePtr &node) {
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
}
bool IsFusibleOp(const AnfNodePtr &node) {
#if ENABLE_D
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
"LambNextMV", "LambUpdateWithLR"};
if (AnfAlgo::IsGraphKernel(node)) {
auto fg_attr = AnfAlgo::GetCNodeFuncGraphPtr(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (fg_attr != nullptr) {
return graph_kernel_black_list.count(GetValue<std::string>(fg_attr)) == 0;
}
}
#endif
return IsBasicFuseOp(node) || AnfAlgo::IsGraphKernel(node);
}
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);

View File

@ -73,6 +73,7 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo
std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = "");
std::vector<PrimitivePtr> GetFusibleOpList();
bool IsBasicFuseOp(const AnfNodePtr &node);
bool IsFusibleOp(const AnfNodePtr &node);
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior);

View File

@ -30,7 +30,7 @@
#include "backend/optimizer/graph_kernel/tensor_promotion.h"
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
#include "backend/optimizer/graph_kernel/optimize_matmul.h"
#include "backend/optimizer/graph_kernel/cast_matmul_fusion.h"
#include "backend/optimizer/graph_kernel/raise_reduction_precision.h"
#include "backend/optimizer/graph_kernel/graph_kernel_cse.h"
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
@ -52,7 +52,7 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() {
if (is_ascend) {
// Remove redundant Cast(bias, fp16) for Matmul input
pm->AddPass(std::make_shared<OptimizeMatmul>());
pm->AddPass(std::make_shared<CastMatmulFusion>());
// Reorder TransData-Cast to Cast-TransData
pm->AddPass(std::make_shared<ReorderOps>());

View File

@ -1,64 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/optimize_matmul.h"
#include <tuple>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
namespace mindspore {
namespace opt {
/* MatMul supports fp32 bias, so remove the redundant cast when cast only used by MatMul
*
* %0 = cast(bias_fp32, fp16)
* %1 = MatMul(A_fp16, B_fp16, %0)
* ------>
* %1 = MatMul(A_fp16, B_fp16, bias_fp32)
*/
bool OptimizeMatmul::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
auto changed = false;
auto nodes = TopoSort(func_graph->get_return());
for (auto node : nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimMatMul)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (cnode->size() != 4) {
continue;
}
auto cast_node = cnode->input(3);
if (!IsPrimitiveCNode(cast_node, prim::kPrimCast)) {
continue;
}
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_node, 0);
auto cast_output_type = AnfAlgo::GetOutputDeviceDataType(cast_node, 0);
if (cast_input_type == kNumberTypeFloat32 && cast_output_type == kNumberTypeFloat16 &&
mng->node_users()[cast_node].size() == 1) {
mng->Replace(cast_node, (cast_node->cast<CNodePtr>())->input(1));
changed = true;
}
}
return changed;
}
} // namespace opt
} // namespace mindspore