forked from mindspore-Ecosystem/mindspore
speed up bert profermance in ascend for graph kernel
This commit is contained in:
parent
6339d5a1df
commit
94dda0c7c7
|
@ -24,6 +24,7 @@ class GraphSplitByPattern:
|
|||
"""Graph splitter"""
|
||||
class ReachTable:
|
||||
"""Reachable table"""
|
||||
|
||||
def __init__(self, size):
|
||||
self.map = []
|
||||
self.alive = set(range(size))
|
||||
|
@ -61,6 +62,7 @@ class GraphSplitByPattern:
|
|||
|
||||
class StitchInfo:
|
||||
"""StitchInfo"""
|
||||
|
||||
def __init__(self):
|
||||
self.stitch_ops = set()
|
||||
self.stitch_atomic_ops = set()
|
||||
|
@ -190,7 +192,7 @@ class GraphSplitByPattern:
|
|||
def cp_ops(self, area):
|
||||
"""copy recompute_ops in area to ops, self is area's user"""
|
||||
tail_tensor = area.recompute_ops[-1].output
|
||||
#copy tensors, all copied are Tensor.PARA_NONE
|
||||
# copy tensors, all copied are Tensor.PARA_NONE
|
||||
tensor_map = {}
|
||||
if area.recompute_ops[0].inputs:
|
||||
tensor_map[area.recompute_ops[0].inputs[0]] = area.recompute_ops[0].inputs[0]
|
||||
|
@ -198,7 +200,7 @@ class GraphSplitByPattern:
|
|||
orig_tensor = op.output
|
||||
cp_tensor = Tensor(orig_tensor.name, orig_tensor.shape, orig_tensor.dtype, orig_tensor.data_format)
|
||||
tensor_map[orig_tensor] = cp_tensor
|
||||
#copy ops
|
||||
# copy ops
|
||||
cp_ops = []
|
||||
for op in area.recompute_ops:
|
||||
inputs = [tensor_map[op.inputs[0]]] if op.inputs else []
|
||||
|
@ -206,14 +208,14 @@ class GraphSplitByPattern:
|
|||
cp_op.all_inputs = cp_op.inputs
|
||||
cp_ops.append(cp_op)
|
||||
area.ori_op_map[cp_op] = op
|
||||
#connect copied ops
|
||||
# connect copied ops
|
||||
for op in self.ops:
|
||||
if tail_tensor in op.inputs:
|
||||
op.inputs.remove(tail_tensor)
|
||||
op.inputs.append(tensor_map[tail_tensor])
|
||||
tail_tensor.to_ops.remove(op)
|
||||
tensor_map[tail_tensor].to_ops.append(op)
|
||||
#fill cp_ops in self.recompute_area
|
||||
# fill cp_ops in self.recompute_area
|
||||
cp_dom_op = None
|
||||
for cp, ori in area.ori_op_map.items():
|
||||
if ori == area.dom_op():
|
||||
|
@ -402,26 +404,26 @@ class GraphSplitByPattern:
|
|||
def set_recompute(self, dom_area, ops, user_area):
|
||||
"""set the recompute area and connect with other areas"""
|
||||
self.recom_area.recompute_ops.extend(ops)
|
||||
#recom_area: set dom_op and correct ops length
|
||||
# recom_area: set dom_op and correct ops length
|
||||
patterns = [PrimLib.iter_type(op) for op in ops]
|
||||
self.recom_area.pattern = max(patterns)
|
||||
for i, pat in enumerate(patterns):
|
||||
if pat == self.recom_area.pattern:
|
||||
self.recom_area.ops = [ops[i]] * len(ops)
|
||||
break
|
||||
#disconnect dom_area and user_area
|
||||
# disconnect dom_area and user_area
|
||||
self.dom_user_r = dom_area.out_relations[user_area]
|
||||
dom_area.out_relations.pop(user_area)
|
||||
user_area.in_relations.pop(dom_area)
|
||||
#connect recom_area and user_area
|
||||
# connect recom_area and user_area
|
||||
user_area.in_relations[self.recom_area] = self.dom_user_r
|
||||
self.recom_area.out_relations[user_area] = self.dom_user_r
|
||||
#connect recom_pre and recom_area
|
||||
# connect recom_pre and recom_area
|
||||
self.recom_pre = self.area_map[ops[0].inputs[0].op] if ops[0].inputs and ops[0].inputs[0].op else None
|
||||
if self.recom_pre is not None:
|
||||
self.recom_area.in_relations[self.recom_pre] = dom_area.in_relations[self.recom_pre]
|
||||
self.recom_pre.out_relations[self.recom_area] = dom_area.in_relations[self.recom_pre]
|
||||
#set related areas
|
||||
# set related areas
|
||||
self.recom_user = user_area
|
||||
self.recom_dom = dom_area
|
||||
self.recom_res = False
|
||||
|
@ -441,7 +443,6 @@ class GraphSplitByPattern:
|
|||
self.orig_op_map.update(self.recom_area.ori_op_map)
|
||||
self.recom_area.ori_op_map.clear()
|
||||
|
||||
|
||||
def to_subgraph(self, dom):
|
||||
"""Transform area to subgraphs"""
|
||||
ids = self.index_op()
|
||||
|
@ -461,14 +462,14 @@ class GraphSplitByPattern:
|
|||
region_ops.append(op)
|
||||
return False, None, weight, True
|
||||
if op.inputs[0] in inputs and len(op.inputs) == 1 and \
|
||||
PrimLib.iter_type(op) <= PrimLib.BROADCAST:
|
||||
PrimLib.iter_type(op) <= PrimLib.BROADCAST:
|
||||
region_ops.append(op)
|
||||
return False, None, weight, True
|
||||
#region fails to grow
|
||||
# region fails to grow
|
||||
MAX_WEIGHT = 20
|
||||
if weight > MAX_WEIGHT or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST:
|
||||
return False, None, weight, False
|
||||
#region grows successfully
|
||||
# region grows successfully
|
||||
weight = weight + 1
|
||||
region_ops.append(op)
|
||||
return True, op.inputs[0].op, weight, False
|
||||
|
@ -544,7 +545,7 @@ class GraphSplitByPattern:
|
|||
self.clear_recompute()
|
||||
if self.recom_res:
|
||||
recompute_suc = True
|
||||
#Copy region at most once for this dom
|
||||
# Copy region at most once for this dom
|
||||
dom_changed = True
|
||||
break
|
||||
if dom_changed:
|
||||
|
@ -555,6 +556,7 @@ class GraphSplitByPattern:
|
|||
while do_recompute_fuse():
|
||||
self.pattern_fuse()
|
||||
|
||||
|
||||
use_poly_reduce = True
|
||||
|
||||
|
||||
|
@ -808,8 +810,8 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
|
||||
def _fuse_once(fuse_func):
|
||||
if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \
|
||||
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
|
||||
fuse_func(_broadcast_width):
|
||||
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
|
||||
fuse_func(_broadcast_width):
|
||||
return
|
||||
if use_poly_reduce:
|
||||
if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)):
|
||||
|
@ -830,8 +832,17 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
REDUCE_FUSE_DEPTH = 10
|
||||
|
||||
def get_default_mode(self, op):
|
||||
if op.prim == "MatMul" or op.prim == "BatchMatMul":
|
||||
return self.Area.MODE_COMPOSITE if op.inputs[0].dtype == "float16" else self.Area.MODE_BASIC
|
||||
"""Get efault mode for op"""
|
||||
def _dtype_same(tensors):
|
||||
dtype = tensors[0].dtype
|
||||
for tensor_ in tensors:
|
||||
if tensor_.dtype != dtype:
|
||||
return False
|
||||
return True
|
||||
|
||||
if op.prim == "MatMul":
|
||||
if op.inputs[0].dtype == "float16" and not _dtype_same(op.inputs):
|
||||
return self.Area.MODE_COMPOSITE
|
||||
if op.prim in ("Tile", "BroadcastTo", "ExpandDims"):
|
||||
return self.Area.MODE_COMPOSITE
|
||||
return self.Area.MODE_BASIC
|
||||
|
@ -911,7 +922,7 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
if len(a.ops) > self.REDUCE_FUSE_DEPTH:
|
||||
return True
|
||||
if r == PrimLib.BROADCAST and _likely_multicore(dom) and \
|
||||
(dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH):
|
||||
(dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH):
|
||||
return True
|
||||
return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE
|
||||
|
||||
|
@ -937,7 +948,10 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
return None
|
||||
fused = []
|
||||
for a, _ in dom.out_relations.items():
|
||||
if a.pattern == PrimLib.ELEMWISE and a.check_acyclic(dom):
|
||||
if (((a.dom_op().prim == "AddN" or a.dom_op().prim == "Add" or a.dom_op().prim == "Cast")
|
||||
and dom.dom_op().prim == "MatMul")
|
||||
or (a.pattern == PrimLib.ELEMWISE and dom.dom_op().prim == "BatchMatMul")) \
|
||||
and a.check_acyclic(dom):
|
||||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
|
@ -1018,9 +1032,9 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
|
||||
def _fuse_once(fuse_func):
|
||||
if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \
|
||||
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
|
||||
fuse_func(_broadcast_width) or fuse_func(_matmul_depth) or fuse_func(_reduce_output) or \
|
||||
fuse_func(_transdata):
|
||||
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
|
||||
fuse_func(_broadcast_width) or fuse_func(_matmul_depth) or fuse_func(_reduce_output) or \
|
||||
fuse_func(_transdata):
|
||||
pass
|
||||
|
||||
if fuse_func is None:
|
||||
|
|
|
@ -454,9 +454,7 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
ub_fusion_pm->AddPass(std::make_shared<ConvSingleInFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseEltwiseFusionPass>(fusion_id_allocator));
|
||||
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
ub_fusion_pm->AddPass(std::make_shared<MatmulEltwiseFusionPass>(fusion_id_allocator));
|
||||
}
|
||||
ub_fusion_pm->AddPass(std::make_shared<MatmulEltwiseFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<ConvDoubleInFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<ReduceEltwiseFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<SegmentEltwiseFusionPass>(fusion_id_allocator));
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "debug/anf_ir_dump.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/optimizer/common/fusion_id_allocator.h"
|
||||
|
||||
|
@ -56,6 +57,14 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
|
|||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
|
||||
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimAddN)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
|
||||
|
|
|
@ -40,63 +40,63 @@ namespace {
|
|||
std::vector<PrimitivePtr> GetClusterableOpList() {
|
||||
std::vector<PrimitivePtr> clusterable_ops = {
|
||||
prim::kPrimAbs,
|
||||
prim::kPrimRound,
|
||||
prim::kPrimNeg,
|
||||
prim::kPrimExp,
|
||||
prim::kPrimAdd,
|
||||
prim::kPrimCast,
|
||||
prim::kPrimMul,
|
||||
prim::kPrimMinimum,
|
||||
prim::kPrimMaximum,
|
||||
prim::kPrimEqual,
|
||||
prim::kPrimExp,
|
||||
prim::kPrimInplaceAssign,
|
||||
prim::kPrimLog,
|
||||
prim::kPrimMaximum,
|
||||
prim::kPrimMinimum,
|
||||
prim::kPrimMul,
|
||||
prim::kPrimNeg,
|
||||
prim::kPrimPow,
|
||||
prim::kPrimSub,
|
||||
prim::kPrimRealDiv,
|
||||
prim::kPrimReciprocal,
|
||||
prim::kPrimReduceSum,
|
||||
prim::kPrimReshape,
|
||||
prim::kPrimRound,
|
||||
prim::kPrimRsqrt,
|
||||
prim::kPrimSqrt,
|
||||
prim::kPrimReciprocal,
|
||||
prim::kPrimSub,
|
||||
prim::kPrimTanh,
|
||||
prim::kPrimReshape,
|
||||
prim::kPrimTranspose,
|
||||
prim::kPrimRealDiv,
|
||||
prim::kPrimReduceSum,
|
||||
prim::kPrimEqual,
|
||||
prim::kPrimAssign,
|
||||
prim::kPrimInplaceAssign,
|
||||
#if ENABLE_D
|
||||
prim::kPrimMatMul,
|
||||
prim::KPrimTransData,
|
||||
prim::kPrimBatchMatMul,
|
||||
#elif ENABLE_GPU
|
||||
prim::kPrimSin,
|
||||
prim::kPrimCos,
|
||||
prim::kPrimAsin,
|
||||
prim::kPrimACos,
|
||||
prim::kPrimSign,
|
||||
prim::kPrimReduceMax,
|
||||
prim::kPrimReduceMin,
|
||||
prim::kPrimGreater,
|
||||
prim::kPrimLess,
|
||||
prim::kPrimGreaterEqual,
|
||||
prim::kPrimLessEqual,
|
||||
prim::kPrimSelect,
|
||||
prim::kPrimAcosh,
|
||||
prim::kPrimAsin,
|
||||
prim::kPrimAsinh,
|
||||
prim::kPrimAssign,
|
||||
prim::kPrimAtan,
|
||||
prim::kPrimAtan2,
|
||||
prim::kPrimExpm1,
|
||||
prim::kPrimAsinh,
|
||||
prim::kPrimAcosh,
|
||||
prim::kPrimCos,
|
||||
prim::kPrimDiv,
|
||||
prim::kPrimFloorDiv,
|
||||
prim::kPrimMod,
|
||||
prim::kPrimFloor,
|
||||
prim::kPrimFloorMod,
|
||||
prim::kPrimErf,
|
||||
prim::kPrimNotEqual,
|
||||
prim::kPrimExpm1,
|
||||
prim::kPrimFloor,
|
||||
prim::kPrimFloorDiv,
|
||||
prim::kPrimFloorMod,
|
||||
prim::kPrimGreater,
|
||||
prim::kPrimGreaterEqual,
|
||||
prim::kPrimIsFinite,
|
||||
prim::kPrimIsInf,
|
||||
prim::kPrimIsNan,
|
||||
prim::kPrimLess,
|
||||
prim::kPrimLessEqual,
|
||||
prim::kPrimLogicalAnd,
|
||||
prim::kPrimLogicalOr,
|
||||
prim::kPrimLogicalNot,
|
||||
prim::kPrimIsNan,
|
||||
prim::kPrimIsInf,
|
||||
prim::kPrimIsFinite,
|
||||
prim::kPrimMod,
|
||||
prim::kPrimNotEqual,
|
||||
prim::kPrimReduceMax,
|
||||
prim::kPrimReduceMin,
|
||||
prim::kPrimSelect,
|
||||
prim::kPrimSign,
|
||||
prim::kPrimSin,
|
||||
#endif
|
||||
};
|
||||
const auto &flags = context::GraphKernelFlags::GetInstance();
|
||||
|
|
|
@ -46,44 +46,43 @@ constexpr size_t kLambWeightInputIdx = 4;
|
|||
std::vector<PrimitivePtr> GetExpandOps() {
|
||||
std::vector<PrimitivePtr> expand_ops = {
|
||||
prim::kPrimAddN,
|
||||
prim::kPrimSquare,
|
||||
prim::kPrimGeLUGrad,
|
||||
prim::kPrimAssignAdd,
|
||||
prim::kPrimLayerNorm,
|
||||
prim::kPrimLayerNormGrad,
|
||||
prim::kPrimExpandDims,
|
||||
prim::kPrimBiasAddGrad,
|
||||
prim::kPrimGeLU,
|
||||
prim::kPrimSoftmax,
|
||||
prim::kPrimLogSoftmax,
|
||||
prim::kPrimLogSoftmaxGrad,
|
||||
prim::kPrimTile,
|
||||
prim::kPrimMatMul,
|
||||
prim::kPrimBatchMatMul,
|
||||
prim::kPrimErfc,
|
||||
prim::kPrimExpandDims,
|
||||
prim::kPrimGeLU,
|
||||
prim::kPrimGeLUGrad,
|
||||
prim::kPrimSquare,
|
||||
prim::kPrimTile,
|
||||
#if ENABLE_D
|
||||
prim::kPrimSqrtGrad,
|
||||
prim::kPrimClipByNormNoDivSum,
|
||||
prim::kLambApplyOptimizerAssign,
|
||||
prim::kLambApplyWeightAssign,
|
||||
prim::kPrimClipByNormNoDivSum,
|
||||
prim::kPrimSqrtGrad,
|
||||
prim::kSoftmaxGradExt,
|
||||
prim::kSquareSumV1,
|
||||
prim::kFusedMulAdd,
|
||||
#elif ENABLE_GPU
|
||||
prim::kPrimBatchMatMul,
|
||||
prim::kPrimBiasAdd,
|
||||
prim::kPrimFusedAdam,
|
||||
prim::kPrimFusedAdamWeightDecay,
|
||||
prim::kPrimReduceMean,
|
||||
prim::kPrimMaximumGrad,
|
||||
prim::kPrimMinimumGrad,
|
||||
prim::kPrimBiasAddGrad,
|
||||
prim::kPrimDropout,
|
||||
prim::kPrimDropoutGrad,
|
||||
prim::kPrimFusedAdam,
|
||||
prim::kPrimFusedAdamWeightDecay,
|
||||
prim::kPrimMaximumGrad,
|
||||
prim::kPrimMinimumGrad,
|
||||
prim::kPrimLayerNorm,
|
||||
prim::kPrimLayerNormGrad,
|
||||
prim::kPrimLogSoftmax,
|
||||
prim::kPrimLogSoftmaxGrad,
|
||||
prim::kPrimMatMul,
|
||||
prim::kPrimReduceMean,
|
||||
prim::kPrimRelu,
|
||||
prim::kPrimReluGrad,
|
||||
prim::kPrimSigmoid,
|
||||
prim::kPrimSigmoidGrad,
|
||||
prim::kPrimSigmoidCrossEntropyWithLogits,
|
||||
prim::kPrimSigmoidCrossEntropyWithLogitsGrad,
|
||||
prim::kPrimSoftmax,
|
||||
prim::kPrimSoftmaxCrossEntropyWithLogits,
|
||||
prim::kPrimSquaredDifference,
|
||||
prim::kPrimSqueeze,
|
||||
|
|
|
@ -61,7 +61,7 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() const {
|
|||
pm->AddPass(std::make_shared<CommonSubexpressionElimination>("cse1"), OptLevel_1);
|
||||
|
||||
// Change Assign(p, a, U) to Assign(Depend(p, U), a)
|
||||
pm->AddPass(std::make_shared<SplitAssign>(), OptLevel_1);
|
||||
pm->AddPass(std::make_shared<SplitAssign>(), OptLevel_1, is_gpu);
|
||||
|
||||
// Spread the MakeTuple input of UpdateState
|
||||
pm->AddPass(std::make_shared<SpreadUpdateState>(), OptLevel_1);
|
||||
|
|
|
@ -62,6 +62,9 @@ def _set_bert_all_reduce_split():
|
|||
context.set_auto_parallel_context(all_reduce_fusion_config=[30, 90, 150, 210, 270, 330, 390, 421])
|
||||
else:
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[38, 93, 148, 203, 258, 313, 368, 397])
|
||||
if device_target == 'Ascend' and enable_graph_kernel and device_num == 8:
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[
|
||||
0, 1, 2, 3, 4, 5, 10, 15, 20, 25, 30, 35, 40, 50, 70, 93, 148, 203, 258, 313, 368, 397])
|
||||
|
||||
|
||||
def _get_optimizer(args_opt, network):
|
||||
|
|
Loading…
Reference in New Issue