From 94dda0c7c75e5f6a9a7278c734bd650b7ab05ca0 Mon Sep 17 00:00:00 2001 From: tronzhang Date: Sun, 11 Jul 2021 16:12:37 +0800 Subject: [PATCH] speed up bert profermance in ascend for graph kernel --- .../graph_kernel/model/graph_split.py | 60 ++++++++++------ .../ascend/ascend_backend_optimization.cc | 4 +- .../matmul_eltwise_fusion_pass.cc | 9 +++ .../graph_kernel/graph_kernel_cluster.cc | 72 +++++++++---------- .../graph_kernel/graph_kernel_expander.cc | 41 ++++++----- .../graph_kernel/graph_kernel_optimization.cc | 2 +- model_zoo/official/nlp/bert/run_pretrain.py | 3 + 7 files changed, 107 insertions(+), 84 deletions(-) diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index ad12de9c16a..46865422012 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -24,6 +24,7 @@ class GraphSplitByPattern: """Graph splitter""" class ReachTable: """Reachable table""" + def __init__(self, size): self.map = [] self.alive = set(range(size)) @@ -61,6 +62,7 @@ class GraphSplitByPattern: class StitchInfo: """StitchInfo""" + def __init__(self): self.stitch_ops = set() self.stitch_atomic_ops = set() @@ -190,7 +192,7 @@ class GraphSplitByPattern: def cp_ops(self, area): """copy recompute_ops in area to ops, self is area's user""" tail_tensor = area.recompute_ops[-1].output - #copy tensors, all copied are Tensor.PARA_NONE + # copy tensors, all copied are Tensor.PARA_NONE tensor_map = {} if area.recompute_ops[0].inputs: tensor_map[area.recompute_ops[0].inputs[0]] = area.recompute_ops[0].inputs[0] @@ -198,7 +200,7 @@ class GraphSplitByPattern: orig_tensor = op.output cp_tensor = Tensor(orig_tensor.name, orig_tensor.shape, orig_tensor.dtype, orig_tensor.data_format) tensor_map[orig_tensor] = cp_tensor - #copy ops + # copy ops cp_ops = [] for op in area.recompute_ops: inputs = [tensor_map[op.inputs[0]]] if op.inputs else [] @@ -206,14 +208,14 @@ class GraphSplitByPattern: cp_op.all_inputs = cp_op.inputs cp_ops.append(cp_op) area.ori_op_map[cp_op] = op - #connect copied ops + # connect copied ops for op in self.ops: if tail_tensor in op.inputs: op.inputs.remove(tail_tensor) op.inputs.append(tensor_map[tail_tensor]) tail_tensor.to_ops.remove(op) tensor_map[tail_tensor].to_ops.append(op) - #fill cp_ops in self.recompute_area + # fill cp_ops in self.recompute_area cp_dom_op = None for cp, ori in area.ori_op_map.items(): if ori == area.dom_op(): @@ -402,26 +404,26 @@ class GraphSplitByPattern: def set_recompute(self, dom_area, ops, user_area): """set the recompute area and connect with other areas""" self.recom_area.recompute_ops.extend(ops) - #recom_area: set dom_op and correct ops length + # recom_area: set dom_op and correct ops length patterns = [PrimLib.iter_type(op) for op in ops] self.recom_area.pattern = max(patterns) for i, pat in enumerate(patterns): if pat == self.recom_area.pattern: self.recom_area.ops = [ops[i]] * len(ops) break - #disconnect dom_area and user_area + # disconnect dom_area and user_area self.dom_user_r = dom_area.out_relations[user_area] dom_area.out_relations.pop(user_area) user_area.in_relations.pop(dom_area) - #connect recom_area and user_area + # connect recom_area and user_area user_area.in_relations[self.recom_area] = self.dom_user_r self.recom_area.out_relations[user_area] = self.dom_user_r - #connect recom_pre and recom_area + # connect recom_pre and recom_area self.recom_pre = self.area_map[ops[0].inputs[0].op] if ops[0].inputs and ops[0].inputs[0].op else None if self.recom_pre is not None: self.recom_area.in_relations[self.recom_pre] = dom_area.in_relations[self.recom_pre] self.recom_pre.out_relations[self.recom_area] = dom_area.in_relations[self.recom_pre] - #set related areas + # set related areas self.recom_user = user_area self.recom_dom = dom_area self.recom_res = False @@ -441,7 +443,6 @@ class GraphSplitByPattern: self.orig_op_map.update(self.recom_area.ori_op_map) self.recom_area.ori_op_map.clear() - def to_subgraph(self, dom): """Transform area to subgraphs""" ids = self.index_op() @@ -461,14 +462,14 @@ class GraphSplitByPattern: region_ops.append(op) return False, None, weight, True if op.inputs[0] in inputs and len(op.inputs) == 1 and \ - PrimLib.iter_type(op) <= PrimLib.BROADCAST: + PrimLib.iter_type(op) <= PrimLib.BROADCAST: region_ops.append(op) return False, None, weight, True - #region fails to grow + # region fails to grow MAX_WEIGHT = 20 if weight > MAX_WEIGHT or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST: return False, None, weight, False - #region grows successfully + # region grows successfully weight = weight + 1 region_ops.append(op) return True, op.inputs[0].op, weight, False @@ -544,7 +545,7 @@ class GraphSplitByPattern: self.clear_recompute() if self.recom_res: recompute_suc = True - #Copy region at most once for this dom + # Copy region at most once for this dom dom_changed = True break if dom_changed: @@ -555,6 +556,7 @@ class GraphSplitByPattern: while do_recompute_fuse(): self.pattern_fuse() + use_poly_reduce = True @@ -808,8 +810,8 @@ class GraphSplitGpu(GraphSplitByPattern): def _fuse_once(fuse_func): if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \ - fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \ - fuse_func(_broadcast_width): + fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \ + fuse_func(_broadcast_width): return if use_poly_reduce: if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)): @@ -830,8 +832,17 @@ class GraphSplitAscend(GraphSplitByPattern): REDUCE_FUSE_DEPTH = 10 def get_default_mode(self, op): - if op.prim == "MatMul" or op.prim == "BatchMatMul": - return self.Area.MODE_COMPOSITE if op.inputs[0].dtype == "float16" else self.Area.MODE_BASIC + """Get efault mode for op""" + def _dtype_same(tensors): + dtype = tensors[0].dtype + for tensor_ in tensors: + if tensor_.dtype != dtype: + return False + return True + + if op.prim == "MatMul": + if op.inputs[0].dtype == "float16" and not _dtype_same(op.inputs): + return self.Area.MODE_COMPOSITE if op.prim in ("Tile", "BroadcastTo", "ExpandDims"): return self.Area.MODE_COMPOSITE return self.Area.MODE_BASIC @@ -911,7 +922,7 @@ class GraphSplitAscend(GraphSplitByPattern): if len(a.ops) > self.REDUCE_FUSE_DEPTH: return True if r == PrimLib.BROADCAST and _likely_multicore(dom) and \ - (dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH): + (dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH): return True return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE @@ -937,7 +948,10 @@ class GraphSplitAscend(GraphSplitByPattern): return None fused = [] for a, _ in dom.out_relations.items(): - if a.pattern == PrimLib.ELEMWISE and a.check_acyclic(dom): + if (((a.dom_op().prim == "AddN" or a.dom_op().prim == "Add" or a.dom_op().prim == "Cast") + and dom.dom_op().prim == "MatMul") + or (a.pattern == PrimLib.ELEMWISE and dom.dom_op().prim == "BatchMatMul")) \ + and a.check_acyclic(dom): fused.append(a) return fused, False @@ -1018,9 +1032,9 @@ class GraphSplitAscend(GraphSplitByPattern): def _fuse_once(fuse_func): if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \ - fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \ - fuse_func(_broadcast_width) or fuse_func(_matmul_depth) or fuse_func(_reduce_output) or \ - fuse_func(_transdata): + fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \ + fuse_func(_broadcast_width) or fuse_func(_matmul_depth) or fuse_func(_reduce_output) or \ + fuse_func(_transdata): pass if fuse_func is None: diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 484d80b3f77..6b457d2bf53 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -454,9 +454,7 @@ void AscendBackendUBFusionOptimization(const std::shared_ptrAddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - } + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc index c72b71792b8..792937bf7d3 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc @@ -21,6 +21,7 @@ #include "debug/anf_ir_dump.h" #include "backend/session/anf_runtime_algorithm.h" #include "base/core_ops.h" +#include "utils/context/graph_kernel_flags.h" #include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" @@ -56,6 +57,14 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap continue; } auto cnode = node->cast(); + if (context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && + AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimAddN)) { + continue; + } + } + MS_EXCEPTION_IF_NULL(cnode); if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc index d594752b11a..0204c9b6669 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_cluster.cc @@ -40,63 +40,63 @@ namespace { std::vector GetClusterableOpList() { std::vector clusterable_ops = { prim::kPrimAbs, - prim::kPrimRound, - prim::kPrimNeg, - prim::kPrimExp, prim::kPrimAdd, prim::kPrimCast, - prim::kPrimMul, - prim::kPrimMinimum, - prim::kPrimMaximum, + prim::kPrimEqual, + prim::kPrimExp, + prim::kPrimInplaceAssign, prim::kPrimLog, + prim::kPrimMaximum, + prim::kPrimMinimum, + prim::kPrimMul, + prim::kPrimNeg, prim::kPrimPow, - prim::kPrimSub, + prim::kPrimRealDiv, + prim::kPrimReciprocal, + prim::kPrimReduceSum, + prim::kPrimReshape, + prim::kPrimRound, prim::kPrimRsqrt, prim::kPrimSqrt, - prim::kPrimReciprocal, + prim::kPrimSub, prim::kPrimTanh, - prim::kPrimReshape, prim::kPrimTranspose, - prim::kPrimRealDiv, - prim::kPrimReduceSum, - prim::kPrimEqual, - prim::kPrimAssign, - prim::kPrimInplaceAssign, #if ENABLE_D prim::kPrimMatMul, prim::KPrimTransData, prim::kPrimBatchMatMul, #elif ENABLE_GPU - prim::kPrimSin, - prim::kPrimCos, - prim::kPrimAsin, prim::kPrimACos, - prim::kPrimSign, - prim::kPrimReduceMax, - prim::kPrimReduceMin, - prim::kPrimGreater, - prim::kPrimLess, - prim::kPrimGreaterEqual, - prim::kPrimLessEqual, - prim::kPrimSelect, + prim::kPrimAcosh, + prim::kPrimAsin, + prim::kPrimAsinh, + prim::kPrimAssign, prim::kPrimAtan, prim::kPrimAtan2, - prim::kPrimExpm1, - prim::kPrimAsinh, - prim::kPrimAcosh, + prim::kPrimCos, prim::kPrimDiv, - prim::kPrimFloorDiv, - prim::kPrimMod, - prim::kPrimFloor, - prim::kPrimFloorMod, prim::kPrimErf, - prim::kPrimNotEqual, + prim::kPrimExpm1, + prim::kPrimFloor, + prim::kPrimFloorDiv, + prim::kPrimFloorMod, + prim::kPrimGreater, + prim::kPrimGreaterEqual, + prim::kPrimIsFinite, + prim::kPrimIsInf, + prim::kPrimIsNan, + prim::kPrimLess, + prim::kPrimLessEqual, prim::kPrimLogicalAnd, prim::kPrimLogicalOr, prim::kPrimLogicalNot, - prim::kPrimIsNan, - prim::kPrimIsInf, - prim::kPrimIsFinite, + prim::kPrimMod, + prim::kPrimNotEqual, + prim::kPrimReduceMax, + prim::kPrimReduceMin, + prim::kPrimSelect, + prim::kPrimSign, + prim::kPrimSin, #endif }; const auto &flags = context::GraphKernelFlags::GetInstance(); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index d7eac8960da..e2e41bbe687 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -46,44 +46,43 @@ constexpr size_t kLambWeightInputIdx = 4; std::vector GetExpandOps() { std::vector expand_ops = { prim::kPrimAddN, - prim::kPrimSquare, - prim::kPrimGeLUGrad, prim::kPrimAssignAdd, - prim::kPrimLayerNorm, - prim::kPrimLayerNormGrad, - prim::kPrimExpandDims, - prim::kPrimBiasAddGrad, - prim::kPrimGeLU, - prim::kPrimSoftmax, - prim::kPrimLogSoftmax, - prim::kPrimLogSoftmaxGrad, - prim::kPrimTile, - prim::kPrimMatMul, - prim::kPrimBatchMatMul, prim::kPrimErfc, + prim::kPrimExpandDims, + prim::kPrimGeLU, + prim::kPrimGeLUGrad, + prim::kPrimSquare, + prim::kPrimTile, #if ENABLE_D - prim::kPrimSqrtGrad, - prim::kPrimClipByNormNoDivSum, prim::kLambApplyOptimizerAssign, prim::kLambApplyWeightAssign, + prim::kPrimClipByNormNoDivSum, + prim::kPrimSqrtGrad, prim::kSoftmaxGradExt, - prim::kSquareSumV1, prim::kFusedMulAdd, #elif ENABLE_GPU + prim::kPrimBatchMatMul, prim::kPrimBiasAdd, - prim::kPrimFusedAdam, - prim::kPrimFusedAdamWeightDecay, - prim::kPrimReduceMean, - prim::kPrimMaximumGrad, - prim::kPrimMinimumGrad, + prim::kPrimBiasAddGrad, prim::kPrimDropout, prim::kPrimDropoutGrad, + prim::kPrimFusedAdam, + prim::kPrimFusedAdamWeightDecay, + prim::kPrimMaximumGrad, + prim::kPrimMinimumGrad, + prim::kPrimLayerNorm, + prim::kPrimLayerNormGrad, + prim::kPrimLogSoftmax, + prim::kPrimLogSoftmaxGrad, + prim::kPrimMatMul, + prim::kPrimReduceMean, prim::kPrimRelu, prim::kPrimReluGrad, prim::kPrimSigmoid, prim::kPrimSigmoidGrad, prim::kPrimSigmoidCrossEntropyWithLogits, prim::kPrimSigmoidCrossEntropyWithLogitsGrad, + prim::kPrimSoftmax, prim::kPrimSoftmaxCrossEntropyWithLogits, prim::kPrimSquaredDifference, prim::kPrimSqueeze, diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc index 5a1e37bcc4e..9ace0cb9a6b 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc @@ -61,7 +61,7 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() const { pm->AddPass(std::make_shared("cse1"), OptLevel_1); // Change Assign(p, a, U) to Assign(Depend(p, U), a) - pm->AddPass(std::make_shared(), OptLevel_1); + pm->AddPass(std::make_shared(), OptLevel_1, is_gpu); // Spread the MakeTuple input of UpdateState pm->AddPass(std::make_shared(), OptLevel_1); diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index 6e1e0d9018d..104a05cec6e 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -62,6 +62,9 @@ def _set_bert_all_reduce_split(): context.set_auto_parallel_context(all_reduce_fusion_config=[30, 90, 150, 210, 270, 330, 390, 421]) else: context.set_auto_parallel_context(all_reduce_fusion_config=[38, 93, 148, 203, 258, 313, 368, 397]) + if device_target == 'Ascend' and enable_graph_kernel and device_num == 8: + context.set_auto_parallel_context(all_reduce_fusion_config=[ + 0, 1, 2, 3, 4, 5, 10, 15, 20, 25, 30, 35, 40, 50, 70, 93, 148, 203, 258, 313, 368, 397]) def _get_optimizer(args_opt, network):