forked from mindspore-Ecosystem/mindspore
cast_Matmul_fusion, when cast cannot fuse forward
This commit is contained in:
parent
ae91575346
commit
56390330ac
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit 5f7635e9a70e2d9fc34bf9499b6dcf5bf208c8e8
|
||||
Subproject commit e7a391c51e66975d46bacf6425ae8f27e1675f85
|
|
@ -39,20 +39,6 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool IsFusibleOp(const AnfNodePtr &node) {
|
||||
#if ENABLE_D
|
||||
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
|
||||
"LambNextMV", "LambUpdateWithLR"};
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto fg_attr = AnfAlgo::GetCNodeFuncGraphPtr(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
if (fg_attr != nullptr) {
|
||||
return graph_kernel_black_list.count(GetValue<std::string>(fg_attr)) == 0;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return IsBasicFuseOp(node) || AnfAlgo::IsGraphKernel(node);
|
||||
}
|
||||
|
||||
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
|
||||
if (cur_node == node) {
|
||||
return FOLLOW;
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/cast_matmul_fusion.h"
|
||||
#include <tuple>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
// Check if leaf is used by root
|
||||
bool HasPath(const AnfNodePtr &leaf, const AnfNodePtr &root, const FuncGraphManagerPtr &mng) {
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
bool result = false;
|
||||
auto IncludeUser = [&result, &root](const AnfNodePtr &node) {
|
||||
if (node == root) {
|
||||
result = true;
|
||||
}
|
||||
return result ? EXCLUDE : FOLLOW;
|
||||
};
|
||||
static_cast<void>(DeepLinkedGraphSearch(leaf, IncludeUser));
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/* MatMul supports fp32 bias, so remove the redundant cast if cast cannot fuse forword
|
||||
* case1, cast only used by MatMul
|
||||
*
|
||||
* bias_fp32 = depend(bias_fp32, u)
|
||||
* %0 = cast(bias_fp32, fp16)
|
||||
* %1 = MatMul(A_fp16, B_fp16, %0)
|
||||
* ------>
|
||||
* bias_fp32 = depend(bias_fp32, u)
|
||||
* %1 = MatMul(A_fp16, B_fp16, bias_fp32)
|
||||
*
|
||||
* case2, cast used by MatMul and UpdateStatus
|
||||
*
|
||||
* bias_fp32 = load(p, status)
|
||||
* %0 = cast(bias_fp32, fp16)
|
||||
* %1 = MatMul(A_fp16, B_fp16, %0)
|
||||
* %2 = UpstateStatus(status, %0)
|
||||
* ------>
|
||||
* bias_fp32 = load(p, status)
|
||||
* %1 = MatMul(A_fp16, B_fp16, bias_fp32)
|
||||
* %2 = UpstateStatus(status, %1)
|
||||
*/
|
||||
bool CastMatmulFusion::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
auto changed = false;
|
||||
auto nodes = TopoSort(func_graph->get_return());
|
||||
for (auto node : nodes) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimMatMul)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode->size() != 4) {
|
||||
continue;
|
||||
}
|
||||
auto cast_node = cnode->input(3);
|
||||
if (!IsPrimitiveCNode(cast_node, prim::kPrimCast)) {
|
||||
continue;
|
||||
}
|
||||
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_node, 0);
|
||||
auto cast_output_type = AnfAlgo::GetOutputDeviceDataType(cast_node, 0);
|
||||
if (cast_input_type != kNumberTypeFloat32 || cast_output_type != kNumberTypeFloat16) {
|
||||
continue;
|
||||
}
|
||||
// Cast cannot fuse with its input
|
||||
if (IsFusibleOp((cast_node->cast<CNodePtr>())->input(1))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto user_index_set = mng->node_users()[cast_node];
|
||||
// Case1 : Cast is only used by matmul
|
||||
if (user_index_set.size() == 1) {
|
||||
mng->Replace(cast_node, (cast_node->cast<CNodePtr>())->input(1));
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Case2 : Cast is used by matmul and Upstatus
|
||||
if (user_index_set.size() > 2) {
|
||||
continue;
|
||||
}
|
||||
for (auto user_index : user_index_set) {
|
||||
// Exclude when UpdateStatus-> ... ->matmul path is found
|
||||
if (IsPrimitiveCNode(user_index.first, prim::kPrimUpdateState) && !HasPath(user_index.first, node, mng)) {
|
||||
auto update_state = (user_index.first)->cast<CNodePtr>();
|
||||
update_state->set_input(2, node);
|
||||
cnode->set_input(4, (cast_node->cast<CNodePtr>())->input(1));
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_MATMUL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_MATMUL_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
@ -24,13 +24,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class OptimizeMatmul : public Pass {
|
||||
class CastMatmulFusion : public Pass {
|
||||
public:
|
||||
OptimizeMatmul() : Pass("optimize_matmul") {}
|
||||
~OptimizeMatmul() override = default;
|
||||
CastMatmulFusion() : Pass("cast_matmul_fusion") {}
|
||||
~CastMatmulFusion() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
};
|
||||
using OptimizeMatmulPtr = std::shared_ptr<OptimizeMatmul>;
|
||||
using OptimizeMatmulPtr = std::shared_ptr<CastMatmulFusion>;
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_OPTIMIZE_MATMUL_H_
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CAST_MATMUL_FUSION_H_
|
|
@ -627,6 +627,20 @@ bool IsBasicFuseOp(const AnfNodePtr &node) {
|
|||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||
}
|
||||
|
||||
bool IsFusibleOp(const AnfNodePtr &node) {
|
||||
#if ENABLE_D
|
||||
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
|
||||
"LambNextMV", "LambUpdateWithLR"};
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto fg_attr = AnfAlgo::GetCNodeFuncGraphPtr(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
if (fg_attr != nullptr) {
|
||||
return graph_kernel_black_list.count(GetValue<std::string>(fg_attr)) == 0;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return IsBasicFuseOp(node) || AnfAlgo::IsGraphKernel(node);
|
||||
}
|
||||
|
||||
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
|
|
@ -73,6 +73,7 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo
|
|||
std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = "");
|
||||
std::vector<PrimitivePtr> GetFusibleOpList();
|
||||
bool IsBasicFuseOp(const AnfNodePtr &node);
|
||||
bool IsFusibleOp(const AnfNodePtr &node);
|
||||
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
|
||||
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior);
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
#include "backend/optimizer/graph_kernel/tensor_promotion.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
|
||||
#include "backend/optimizer/graph_kernel/optimize_matmul.h"
|
||||
#include "backend/optimizer/graph_kernel/cast_matmul_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/raise_reduction_precision.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_cse.h"
|
||||
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
|
||||
|
@ -52,7 +52,7 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() {
|
|||
|
||||
if (is_ascend) {
|
||||
// Remove redundant Cast(bias, fp16) for Matmul input
|
||||
pm->AddPass(std::make_shared<OptimizeMatmul>());
|
||||
pm->AddPass(std::make_shared<CastMatmulFusion>());
|
||||
|
||||
// Reorder TransData-Cast to Cast-TransData
|
||||
pm->AddPass(std::make_shared<ReorderOps>());
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/optimize_matmul.h"
|
||||
#include <tuple>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
/* MatMul supports fp32 bias, so remove the redundant cast when cast only used by MatMul
|
||||
*
|
||||
* %0 = cast(bias_fp32, fp16)
|
||||
* %1 = MatMul(A_fp16, B_fp16, %0)
|
||||
* ------>
|
||||
* %1 = MatMul(A_fp16, B_fp16, bias_fp32)
|
||||
*/
|
||||
bool OptimizeMatmul::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
auto changed = false;
|
||||
auto nodes = TopoSort(func_graph->get_return());
|
||||
for (auto node : nodes) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimMatMul)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode->size() != 4) {
|
||||
continue;
|
||||
}
|
||||
auto cast_node = cnode->input(3);
|
||||
if (!IsPrimitiveCNode(cast_node, prim::kPrimCast)) {
|
||||
continue;
|
||||
}
|
||||
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_node, 0);
|
||||
auto cast_output_type = AnfAlgo::GetOutputDeviceDataType(cast_node, 0);
|
||||
if (cast_input_type == kNumberTypeFloat32 && cast_output_type == kNumberTypeFloat16 &&
|
||||
mng->node_users()[cast_node].size() == 1) {
|
||||
mng->Replace(cast_node, (cast_node->cast<CNodePtr>())->input(1));
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue