enable GraphKernel for TransData

This commit is contained in:
hanhuifeng2020 2021-04-01 09:44:14 +08:00
parent 9fe10a19cd
commit 25505642ce
3 changed files with 41 additions and 5 deletions

View File

@ -584,6 +584,39 @@ class GraphSplitAscend(GraphSplitByPattern):
fused.append(a)
return fused, False
def _transdata_pattern_support(dom, a):
transdata_op = dom.dom_op()
# Currently, if transdata has the pad, it is not used to fuse
def _has_pad():
res = False
input_shape = transdata_op.inputs[0].shape
output_shape = transdata_op.output.shape
cube_size = 16
for dim in input_shape[-2:]:
if dim % cube_size != 0:
res = True
for dim in output_shape[-2:]:
if dim % cube_size != 0:
res = True
return res
has_pad = _has_pad()
if has_pad:
return False
if a.dom_op().prim == "MatMul" and len(dom.ops) == 1:
return True
return False
def _transdata(dom):
if dom.dom_op().prim != "TransData":
return None
fused = []
for a, _ in dom.in_relations.items():
if _transdata_pattern_support(dom, a) and a.check_acyclic(dom):
fused.append(a)
return fused, True
changed = True
while changed:
changed = self.fuse(_reshape)
@ -594,6 +627,8 @@ class GraphSplitAscend(GraphSplitByPattern):
changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed
changed = self.fuse(_matmul_depth) or changed
self.fuse(_transdata)
def split(graph, target, flags):
"""Split graph"""

View File

@ -186,6 +186,7 @@ class PrimLib:
'Tile': Prim(BROADCAST),
'BroadcastTo': Prim(BROADCAST),
'MatMul': Prim(OPAQUE),
'TransData': Prim(OPAQUE),
}
default_primtive = Prim(UNKNOWN)

View File

@ -596,11 +596,11 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
std::vector<PrimitivePtr> GetFusibleOpList() {
#if ENABLE_D
std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum};
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum, prim::KPrimTransData};
#elif ENABLE_GPU
std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,