forked from mindspore-Ecosystem/mindspore
enable GraphKernel for TransData
This commit is contained in:
parent
9fe10a19cd
commit
25505642ce
|
@ -584,6 +584,39 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
def _transdata_pattern_support(dom, a):
|
||||
transdata_op = dom.dom_op()
|
||||
|
||||
# Currently, if transdata has the pad, it is not used to fuse
|
||||
def _has_pad():
|
||||
res = False
|
||||
input_shape = transdata_op.inputs[0].shape
|
||||
output_shape = transdata_op.output.shape
|
||||
cube_size = 16
|
||||
for dim in input_shape[-2:]:
|
||||
if dim % cube_size != 0:
|
||||
res = True
|
||||
for dim in output_shape[-2:]:
|
||||
if dim % cube_size != 0:
|
||||
res = True
|
||||
return res
|
||||
has_pad = _has_pad()
|
||||
if has_pad:
|
||||
return False
|
||||
|
||||
if a.dom_op().prim == "MatMul" and len(dom.ops) == 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _transdata(dom):
|
||||
if dom.dom_op().prim != "TransData":
|
||||
return None
|
||||
fused = []
|
||||
for a, _ in dom.in_relations.items():
|
||||
if _transdata_pattern_support(dom, a) and a.check_acyclic(dom):
|
||||
fused.append(a)
|
||||
return fused, True
|
||||
|
||||
changed = True
|
||||
while changed:
|
||||
changed = self.fuse(_reshape)
|
||||
|
@ -594,6 +627,8 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
changed = self.fuse(_broadcast_depth) or changed
|
||||
changed = self.fuse(_broadcast_width) or changed
|
||||
changed = self.fuse(_matmul_depth) or changed
|
||||
self.fuse(_transdata)
|
||||
|
||||
|
||||
def split(graph, target, flags):
|
||||
"""Split graph"""
|
||||
|
|
|
@ -186,6 +186,7 @@ class PrimLib:
|
|||
'Tile': Prim(BROADCAST),
|
||||
'BroadcastTo': Prim(BROADCAST),
|
||||
'MatMul': Prim(OPAQUE),
|
||||
'TransData': Prim(OPAQUE),
|
||||
}
|
||||
|
||||
default_primtive = Prim(UNKNOWN)
|
||||
|
|
|
@ -596,11 +596,11 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
|
|||
std::vector<PrimitivePtr> GetFusibleOpList() {
|
||||
#if ENABLE_D
|
||||
std::vector<PrimitivePtr> fusible_basic_ops = {
|
||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
|
||||
prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
||||
prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum};
|
||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
|
||||
prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
||||
prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum, prim::KPrimTransData};
|
||||
#elif ENABLE_GPU
|
||||
std::vector<PrimitivePtr> fusible_basic_ops = {
|
||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
|
||||
|
|
Loading…
Reference in New Issue