speed up bert profermance in ascend for graph kernel

This commit is contained in:
tronzhang 2021-07-11 16:12:37 +08:00
parent 6339d5a1df
commit 94dda0c7c7
7 changed files with 107 additions and 84 deletions

View File

@ -24,6 +24,7 @@ class GraphSplitByPattern:
"""Graph splitter"""
class ReachTable:
"""Reachable table"""
def __init__(self, size):
self.map = []
self.alive = set(range(size))
@ -61,6 +62,7 @@ class GraphSplitByPattern:
class StitchInfo:
"""StitchInfo"""
def __init__(self):
self.stitch_ops = set()
self.stitch_atomic_ops = set()
@ -190,7 +192,7 @@ class GraphSplitByPattern:
def cp_ops(self, area):
"""copy recompute_ops in area to ops, self is area's user"""
tail_tensor = area.recompute_ops[-1].output
#copy tensors, all copied are Tensor.PARA_NONE
# copy tensors, all copied are Tensor.PARA_NONE
tensor_map = {}
if area.recompute_ops[0].inputs:
tensor_map[area.recompute_ops[0].inputs[0]] = area.recompute_ops[0].inputs[0]
@ -198,7 +200,7 @@ class GraphSplitByPattern:
orig_tensor = op.output
cp_tensor = Tensor(orig_tensor.name, orig_tensor.shape, orig_tensor.dtype, orig_tensor.data_format)
tensor_map[orig_tensor] = cp_tensor
#copy ops
# copy ops
cp_ops = []
for op in area.recompute_ops:
inputs = [tensor_map[op.inputs[0]]] if op.inputs else []
@ -206,14 +208,14 @@ class GraphSplitByPattern:
cp_op.all_inputs = cp_op.inputs
cp_ops.append(cp_op)
area.ori_op_map[cp_op] = op
#connect copied ops
# connect copied ops
for op in self.ops:
if tail_tensor in op.inputs:
op.inputs.remove(tail_tensor)
op.inputs.append(tensor_map[tail_tensor])
tail_tensor.to_ops.remove(op)
tensor_map[tail_tensor].to_ops.append(op)
#fill cp_ops in self.recompute_area
# fill cp_ops in self.recompute_area
cp_dom_op = None
for cp, ori in area.ori_op_map.items():
if ori == area.dom_op():
@ -402,26 +404,26 @@ class GraphSplitByPattern:
def set_recompute(self, dom_area, ops, user_area):
"""set the recompute area and connect with other areas"""
self.recom_area.recompute_ops.extend(ops)
#recom_area: set dom_op and correct ops length
# recom_area: set dom_op and correct ops length
patterns = [PrimLib.iter_type(op) for op in ops]
self.recom_area.pattern = max(patterns)
for i, pat in enumerate(patterns):
if pat == self.recom_area.pattern:
self.recom_area.ops = [ops[i]] * len(ops)
break
#disconnect dom_area and user_area
# disconnect dom_area and user_area
self.dom_user_r = dom_area.out_relations[user_area]
dom_area.out_relations.pop(user_area)
user_area.in_relations.pop(dom_area)
#connect recom_area and user_area
# connect recom_area and user_area
user_area.in_relations[self.recom_area] = self.dom_user_r
self.recom_area.out_relations[user_area] = self.dom_user_r
#connect recom_pre and recom_area
# connect recom_pre and recom_area
self.recom_pre = self.area_map[ops[0].inputs[0].op] if ops[0].inputs and ops[0].inputs[0].op else None
if self.recom_pre is not None:
self.recom_area.in_relations[self.recom_pre] = dom_area.in_relations[self.recom_pre]
self.recom_pre.out_relations[self.recom_area] = dom_area.in_relations[self.recom_pre]
#set related areas
# set related areas
self.recom_user = user_area
self.recom_dom = dom_area
self.recom_res = False
@ -441,7 +443,6 @@ class GraphSplitByPattern:
self.orig_op_map.update(self.recom_area.ori_op_map)
self.recom_area.ori_op_map.clear()
def to_subgraph(self, dom):
"""Transform area to subgraphs"""
ids = self.index_op()
@ -461,14 +462,14 @@ class GraphSplitByPattern:
region_ops.append(op)
return False, None, weight, True
if op.inputs[0] in inputs and len(op.inputs) == 1 and \
PrimLib.iter_type(op) <= PrimLib.BROADCAST:
PrimLib.iter_type(op) <= PrimLib.BROADCAST:
region_ops.append(op)
return False, None, weight, True
#region fails to grow
# region fails to grow
MAX_WEIGHT = 20
if weight > MAX_WEIGHT or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST:
return False, None, weight, False
#region grows successfully
# region grows successfully
weight = weight + 1
region_ops.append(op)
return True, op.inputs[0].op, weight, False
@ -544,7 +545,7 @@ class GraphSplitByPattern:
self.clear_recompute()
if self.recom_res:
recompute_suc = True
#Copy region at most once for this dom
# Copy region at most once for this dom
dom_changed = True
break
if dom_changed:
@ -555,6 +556,7 @@ class GraphSplitByPattern:
while do_recompute_fuse():
self.pattern_fuse()
use_poly_reduce = True
@ -808,8 +810,8 @@ class GraphSplitGpu(GraphSplitByPattern):
def _fuse_once(fuse_func):
if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
fuse_func(_broadcast_width):
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
fuse_func(_broadcast_width):
return
if use_poly_reduce:
if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)):
@ -830,8 +832,17 @@ class GraphSplitAscend(GraphSplitByPattern):
REDUCE_FUSE_DEPTH = 10
def get_default_mode(self, op):
if op.prim == "MatMul" or op.prim == "BatchMatMul":
return self.Area.MODE_COMPOSITE if op.inputs[0].dtype == "float16" else self.Area.MODE_BASIC
"""Get efault mode for op"""
def _dtype_same(tensors):
dtype = tensors[0].dtype
for tensor_ in tensors:
if tensor_.dtype != dtype:
return False
return True
if op.prim == "MatMul":
if op.inputs[0].dtype == "float16" and not _dtype_same(op.inputs):
return self.Area.MODE_COMPOSITE
if op.prim in ("Tile", "BroadcastTo", "ExpandDims"):
return self.Area.MODE_COMPOSITE
return self.Area.MODE_BASIC
@ -911,7 +922,7 @@ class GraphSplitAscend(GraphSplitByPattern):
if len(a.ops) > self.REDUCE_FUSE_DEPTH:
return True
if r == PrimLib.BROADCAST and _likely_multicore(dom) and \
(dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH):
(dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH):
return True
return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE
@ -937,7 +948,10 @@ class GraphSplitAscend(GraphSplitByPattern):
return None
fused = []
for a, _ in dom.out_relations.items():
if a.pattern == PrimLib.ELEMWISE and a.check_acyclic(dom):
if (((a.dom_op().prim == "AddN" or a.dom_op().prim == "Add" or a.dom_op().prim == "Cast")
and dom.dom_op().prim == "MatMul")
or (a.pattern == PrimLib.ELEMWISE and dom.dom_op().prim == "BatchMatMul")) \
and a.check_acyclic(dom):
fused.append(a)
return fused, False
@ -1018,9 +1032,9 @@ class GraphSplitAscend(GraphSplitByPattern):
def _fuse_once(fuse_func):
if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
fuse_func(_broadcast_width) or fuse_func(_matmul_depth) or fuse_func(_reduce_output) or \
fuse_func(_transdata):
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
fuse_func(_broadcast_width) or fuse_func(_matmul_depth) or fuse_func(_reduce_output) or \
fuse_func(_transdata):
pass
if fuse_func is None:

View File

@ -454,9 +454,7 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap
ub_fusion_pm->AddPass(std::make_shared<ConvSingleInFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseEltwiseFusionPass>(fusion_id_allocator));
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
ub_fusion_pm->AddPass(std::make_shared<MatmulEltwiseFusionPass>(fusion_id_allocator));
}
ub_fusion_pm->AddPass(std::make_shared<MatmulEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<ConvDoubleInFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<ReduceEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<SegmentEltwiseFusionPass>(fusion_id_allocator));

View File

@ -21,6 +21,7 @@
#include "debug/anf_ir_dump.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "base/core_ops.h"
#include "utils/context/graph_kernel_flags.h"
#include "utils/ms_context.h"
#include "backend/optimizer/common/fusion_id_allocator.h"
@ -56,6 +57,14 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
continue;
}
auto cnode = node->cast<CNodePtr>();
if (context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimAddN)) {
continue;
}
}
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&

View File

@ -40,63 +40,63 @@ namespace {
std::vector<PrimitivePtr> GetClusterableOpList() {
std::vector<PrimitivePtr> clusterable_ops = {
prim::kPrimAbs,
prim::kPrimRound,
prim::kPrimNeg,
prim::kPrimExp,
prim::kPrimAdd,
prim::kPrimCast,
prim::kPrimMul,
prim::kPrimMinimum,
prim::kPrimMaximum,
prim::kPrimEqual,
prim::kPrimExp,
prim::kPrimInplaceAssign,
prim::kPrimLog,
prim::kPrimMaximum,
prim::kPrimMinimum,
prim::kPrimMul,
prim::kPrimNeg,
prim::kPrimPow,
prim::kPrimSub,
prim::kPrimRealDiv,
prim::kPrimReciprocal,
prim::kPrimReduceSum,
prim::kPrimReshape,
prim::kPrimRound,
prim::kPrimRsqrt,
prim::kPrimSqrt,
prim::kPrimReciprocal,
prim::kPrimSub,
prim::kPrimTanh,
prim::kPrimReshape,
prim::kPrimTranspose,
prim::kPrimRealDiv,
prim::kPrimReduceSum,
prim::kPrimEqual,
prim::kPrimAssign,
prim::kPrimInplaceAssign,
#if ENABLE_D
prim::kPrimMatMul,
prim::KPrimTransData,
prim::kPrimBatchMatMul,
#elif ENABLE_GPU
prim::kPrimSin,
prim::kPrimCos,
prim::kPrimAsin,
prim::kPrimACos,
prim::kPrimSign,
prim::kPrimReduceMax,
prim::kPrimReduceMin,
prim::kPrimGreater,
prim::kPrimLess,
prim::kPrimGreaterEqual,
prim::kPrimLessEqual,
prim::kPrimSelect,
prim::kPrimAcosh,
prim::kPrimAsin,
prim::kPrimAsinh,
prim::kPrimAssign,
prim::kPrimAtan,
prim::kPrimAtan2,
prim::kPrimExpm1,
prim::kPrimAsinh,
prim::kPrimAcosh,
prim::kPrimCos,
prim::kPrimDiv,
prim::kPrimFloorDiv,
prim::kPrimMod,
prim::kPrimFloor,
prim::kPrimFloorMod,
prim::kPrimErf,
prim::kPrimNotEqual,
prim::kPrimExpm1,
prim::kPrimFloor,
prim::kPrimFloorDiv,
prim::kPrimFloorMod,
prim::kPrimGreater,
prim::kPrimGreaterEqual,
prim::kPrimIsFinite,
prim::kPrimIsInf,
prim::kPrimIsNan,
prim::kPrimLess,
prim::kPrimLessEqual,
prim::kPrimLogicalAnd,
prim::kPrimLogicalOr,
prim::kPrimLogicalNot,
prim::kPrimIsNan,
prim::kPrimIsInf,
prim::kPrimIsFinite,
prim::kPrimMod,
prim::kPrimNotEqual,
prim::kPrimReduceMax,
prim::kPrimReduceMin,
prim::kPrimSelect,
prim::kPrimSign,
prim::kPrimSin,
#endif
};
const auto &flags = context::GraphKernelFlags::GetInstance();

View File

@ -46,44 +46,43 @@ constexpr size_t kLambWeightInputIdx = 4;
std::vector<PrimitivePtr> GetExpandOps() {
std::vector<PrimitivePtr> expand_ops = {
prim::kPrimAddN,
prim::kPrimSquare,
prim::kPrimGeLUGrad,
prim::kPrimAssignAdd,
prim::kPrimLayerNorm,
prim::kPrimLayerNormGrad,
prim::kPrimExpandDims,
prim::kPrimBiasAddGrad,
prim::kPrimGeLU,
prim::kPrimSoftmax,
prim::kPrimLogSoftmax,
prim::kPrimLogSoftmaxGrad,
prim::kPrimTile,
prim::kPrimMatMul,
prim::kPrimBatchMatMul,
prim::kPrimErfc,
prim::kPrimExpandDims,
prim::kPrimGeLU,
prim::kPrimGeLUGrad,
prim::kPrimSquare,
prim::kPrimTile,
#if ENABLE_D
prim::kPrimSqrtGrad,
prim::kPrimClipByNormNoDivSum,
prim::kLambApplyOptimizerAssign,
prim::kLambApplyWeightAssign,
prim::kPrimClipByNormNoDivSum,
prim::kPrimSqrtGrad,
prim::kSoftmaxGradExt,
prim::kSquareSumV1,
prim::kFusedMulAdd,
#elif ENABLE_GPU
prim::kPrimBatchMatMul,
prim::kPrimBiasAdd,
prim::kPrimFusedAdam,
prim::kPrimFusedAdamWeightDecay,
prim::kPrimReduceMean,
prim::kPrimMaximumGrad,
prim::kPrimMinimumGrad,
prim::kPrimBiasAddGrad,
prim::kPrimDropout,
prim::kPrimDropoutGrad,
prim::kPrimFusedAdam,
prim::kPrimFusedAdamWeightDecay,
prim::kPrimMaximumGrad,
prim::kPrimMinimumGrad,
prim::kPrimLayerNorm,
prim::kPrimLayerNormGrad,
prim::kPrimLogSoftmax,
prim::kPrimLogSoftmaxGrad,
prim::kPrimMatMul,
prim::kPrimReduceMean,
prim::kPrimRelu,
prim::kPrimReluGrad,
prim::kPrimSigmoid,
prim::kPrimSigmoidGrad,
prim::kPrimSigmoidCrossEntropyWithLogits,
prim::kPrimSigmoidCrossEntropyWithLogitsGrad,
prim::kPrimSoftmax,
prim::kPrimSoftmaxCrossEntropyWithLogits,
prim::kPrimSquaredDifference,
prim::kPrimSqueeze,

View File

@ -61,7 +61,7 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() const {
pm->AddPass(std::make_shared<CommonSubexpressionElimination>("cse1"), OptLevel_1);
// Change Assign(p, a, U) to Assign(Depend(p, U), a)
pm->AddPass(std::make_shared<SplitAssign>(), OptLevel_1);
pm->AddPass(std::make_shared<SplitAssign>(), OptLevel_1, is_gpu);
// Spread the MakeTuple input of UpdateState
pm->AddPass(std::make_shared<SpreadUpdateState>(), OptLevel_1);

View File

@ -62,6 +62,9 @@ def _set_bert_all_reduce_split():
context.set_auto_parallel_context(all_reduce_fusion_config=[30, 90, 150, 210, 270, 330, 390, 421])
else:
context.set_auto_parallel_context(all_reduce_fusion_config=[38, 93, 148, 203, 258, 313, 368, 397])
if device_target == 'Ascend' and enable_graph_kernel and device_num == 8:
context.set_auto_parallel_context(all_reduce_fusion_config=[
0, 1, 2, 3, 4, 5, 10, 15, 20, 25, 30, 35, 40, 50, 70, 93, 148, 203, 258, 313, 368, 397])
def _get_optimizer(args_opt, network):