Fix precision problem
This commit is contained in:
parent
e1e8f1d429
commit
297f075dca
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit f8f4e60bf3c435cec41cbe48fe24901277ef9556
|
||||
Subproject commit 72b359ad457ed8f4f254c8a3bd2bde88967202fb
|
|
@ -33,7 +33,7 @@ class GraphSplitByPattern:
|
|||
self.out_relations = dict() # {area1: relation1, area2: relation2, ...}
|
||||
self.mode = self.MODE_BASIC
|
||||
if self.pattern == PrimLib.TRANSFORM or self.pattern == PrimLib.BROADCAST or \
|
||||
(use_poly_reduce and self.pattern == PrimLib.REDUCE):
|
||||
(use_poly_reduce and self.pattern == PrimLib.REDUCE):
|
||||
self.mode = self.MODE_COMPOSITE
|
||||
if init_op.prim == "AddN":
|
||||
self.mode = self.MODE_COMPOSITE
|
||||
|
@ -283,6 +283,9 @@ class GraphSplitByPattern:
|
|||
if _check_reduce_exclude(dom):
|
||||
return None
|
||||
a, r = list(dom.in_relations.items())[0]
|
||||
if a.is_output and len(a.ops) >= 10 and _is_atomic_add_available(dom):
|
||||
# to evade the precision problem in akg
|
||||
return None
|
||||
if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1:
|
||||
return None
|
||||
return [a], True
|
||||
|
@ -292,6 +295,8 @@ class GraphSplitByPattern:
|
|||
return None
|
||||
if _check_reduce_exclude(dom):
|
||||
return None
|
||||
if len(dom.ops) == 1:
|
||||
return None
|
||||
fused = []
|
||||
for a, r in dom.in_relations.items():
|
||||
if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom):
|
||||
|
@ -304,16 +309,17 @@ class GraphSplitByPattern:
|
|||
size *= i
|
||||
return size
|
||||
|
||||
def _is_atomic_add_available(dom):
|
||||
if any(["Reduce" in x.prim for x in dom.ops[1:]]):
|
||||
return False
|
||||
op = dom.ops[0]
|
||||
reduce_axis = op.attrs["reduce_axis"]
|
||||
if len(op.inputs[0].shape) - 1 in reduce_axis:
|
||||
reduce_size = reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis])
|
||||
return reduce_size >= 1024
|
||||
return True
|
||||
|
||||
def _reduce_output(dom):
|
||||
def _is_atomic_add_available(dom):
|
||||
if any(["Reduce" in x.prim for x in dom.ops[1:]]):
|
||||
return False
|
||||
op = dom.ops[0]
|
||||
reduce_axis = op.attrs["reduce_axis"]
|
||||
if len(op.inputs[0].shape) - 1 in reduce_axis:
|
||||
reduce_size = reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis])
|
||||
return reduce_size >= 1024
|
||||
return True
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
if _is_atomic_add_available(dom):
|
||||
|
|
Loading…
Reference in New Issue