forked from mindspore-Ecosystem/mindspore
!16766 Add recompute fuse
Merge pull request !16766 from lingyunli63/recompute_fuse
This commit is contained in:
commit
0c360ea2d6
|
@ -16,7 +16,7 @@
|
|||
import os
|
||||
from functools import reduce
|
||||
from mindspore import log as logger
|
||||
from .model import PrimLib, Graph, Tensor
|
||||
from .model import PrimLib, Graph, Tensor, Operator
|
||||
from .model import DataFormat as DF
|
||||
|
||||
|
||||
|
@ -65,13 +65,16 @@ class GraphSplitByPattern:
|
|||
self.stitch_ops = set()
|
||||
self.stitch_atomic_ops = set()
|
||||
|
||||
def __init__(self, init_op, is_output, unique_id, reach_tab):
|
||||
self.pattern = PrimLib.iter_type(init_op)
|
||||
self.ops = [init_op]
|
||||
def __init__(self, init_op, is_output, unique_id, reach_tab, recompute_ops=None):
|
||||
self.pattern = PrimLib.iter_type(init_op) if init_op is not None else PrimLib.UNKNOWN
|
||||
self.ops = [] if init_op is None else [init_op]
|
||||
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
|
||||
self.out_relations = dict() # {area1: relation1, area2: relation2, ...}
|
||||
self.mode = None
|
||||
self.stitch_info = self.StitchInfo()
|
||||
self.recompute_ops = [] if recompute_ops is None else recompute_ops
|
||||
self.ori_op_map = {}
|
||||
self.is_recompute = False
|
||||
self.is_output = is_output
|
||||
self.output_excluded = set()
|
||||
if self.pattern == PrimLib.REDUCE:
|
||||
|
@ -143,6 +146,8 @@ class GraphSplitByPattern:
|
|||
r = rels.pop(area)
|
||||
_update_relation(rels, self, r)
|
||||
|
||||
if area.is_recompute:
|
||||
self.cp_ops(area)
|
||||
if self.pattern >= area.pattern:
|
||||
self.ops.extend(area.ops)
|
||||
else:
|
||||
|
@ -161,7 +166,9 @@ class GraphSplitByPattern:
|
|||
if area.output_excluded:
|
||||
self.output_excluded.update(area.output_excluded)
|
||||
self.update_stitch_info(area.stitch_info)
|
||||
self.reach_tab.fuse(self.unique_id, area.unique_id)
|
||||
if not area.is_recompute:
|
||||
self.reach_tab.fuse(self.unique_id, area.unique_id)
|
||||
self.recompute_ops.extend(area.recompute_ops)
|
||||
|
||||
def check_acyclic(self, to):
|
||||
"""Check circle. It returns false if circle exists"""
|
||||
|
@ -180,25 +187,73 @@ class GraphSplitByPattern:
|
|||
return True
|
||||
return False
|
||||
|
||||
def cp_ops(self, area):
|
||||
"""copy recompute_ops in area to ops, self is area's user"""
|
||||
tail_tensor = area.recompute_ops[-1].output
|
||||
#copy tensors, all copied are Tensor.PARA_NONE
|
||||
tensor_map = {}
|
||||
tensor_map[area.recompute_ops[0].inputs[0]] = area.recompute_ops[0].inputs[0]
|
||||
for op in area.recompute_ops:
|
||||
orig_tensor = op.output
|
||||
cp_tensor = Tensor(orig_tensor.name, orig_tensor.shape, orig_tensor.dtype, orig_tensor.data_format)
|
||||
tensor_map[orig_tensor] = cp_tensor
|
||||
#copy ops
|
||||
cp_ops = []
|
||||
for op in area.recompute_ops:
|
||||
cp_op = Operator(op.prim, [tensor_map[op.inputs[0]]], tensor_map[op.output], op.attrs)
|
||||
cp_op.all_inputs = cp_op.inputs
|
||||
cp_ops.append(cp_op)
|
||||
area.ori_op_map[cp_op] = op
|
||||
#connect copied ops
|
||||
for op in self.ops:
|
||||
if tail_tensor in op.inputs:
|
||||
op.inputs.remove(tail_tensor)
|
||||
op.inputs.append(tensor_map[tail_tensor])
|
||||
tail_tensor.to_ops.remove(op)
|
||||
tensor_map[tail_tensor].to_ops.append(op)
|
||||
#fill cp_ops in self.recompute_area
|
||||
cp_dom_op = None
|
||||
for cp, ori in area.ori_op_map.items():
|
||||
if ori == area.dom_op():
|
||||
cp_dom_op = cp
|
||||
area.ops.clear()
|
||||
area.ops.append(cp_dom_op)
|
||||
area.ops.extend([op for op in cp_ops if op != cp_dom_op])
|
||||
|
||||
def __init__(self, graph, flags):
|
||||
self.graph = graph
|
||||
self.areas = []
|
||||
self.flags = flags
|
||||
self.reach_tab = self.ReachTable(len(graph.ops))
|
||||
area_map = {}
|
||||
self.enable_recompute = self.flags.get("enable_recompute_fusion", False)
|
||||
self.reach_tab = self.ReachTable(len(graph.ops) + 1 if self.enable_recompute else len(graph.ops))
|
||||
self.area_map = {}
|
||||
_, outputs = graph.deduce_parameters()
|
||||
idx = 0
|
||||
self.idx = 0
|
||||
for op in graph.ops:
|
||||
is_output = op.output in outputs
|
||||
a = self.Area(op, is_output, idx, self.reach_tab)
|
||||
idx += 1
|
||||
a = self.Area(op, is_output, self.idx, self.reach_tab)
|
||||
self.idx += 1
|
||||
self.set_default_mode(a)
|
||||
self.areas.append(a)
|
||||
area_map[op] = a
|
||||
self.set_area_map([op], a)
|
||||
for a in self.areas:
|
||||
a.link_input(area_map)
|
||||
a.link_input(self.area_map)
|
||||
for i in range(len(self.areas)-1, -1, -1):
|
||||
self.areas[i].link_output()
|
||||
if self.enable_recompute:
|
||||
self.recom_area = self.Area(None, False, self.idx, self.reach_tab)
|
||||
self.recom_area.is_recompute = True
|
||||
self.recom_pre = None
|
||||
self.recom_user = None
|
||||
self.recom_dom = None
|
||||
self.dom_user_r = PrimLib.UNKNOWN
|
||||
self.recom_res = False
|
||||
self.orig_op_map = {}
|
||||
|
||||
def set_area_map(self, ops, area):
|
||||
"""update area_map after op fused to area"""
|
||||
for op in ops:
|
||||
self.area_map[op] = area
|
||||
|
||||
def set_default_mode(self, area):
|
||||
area.mode = self.get_default_mode(area.ops[0])
|
||||
|
@ -234,11 +289,13 @@ class GraphSplitByPattern:
|
|||
if is_forward:
|
||||
for area in fuse_areas:
|
||||
dominant.fuse(area)
|
||||
self.set_area_map(area.ops, dominant)
|
||||
self.areas.remove(area)
|
||||
else:
|
||||
forward_area = dominant
|
||||
for area in fuse_areas:
|
||||
area.fuse(forward_area)
|
||||
self.set_area_map(forward_area.ops, area)
|
||||
self.areas.remove(forward_area)
|
||||
forward_area = area
|
||||
changed = True
|
||||
|
@ -246,16 +303,39 @@ class GraphSplitByPattern:
|
|||
else:
|
||||
return changed
|
||||
|
||||
def to_subgraphs(self):
|
||||
"""Transform op groups to subgraphs"""
|
||||
def fuse_recom(self, selector):
|
||||
"""Fuse recompute area to its user"""
|
||||
for dominant in [self.recom_area, self.recom_user]:
|
||||
result = selector(dominant)
|
||||
if result is not None and result[0]:
|
||||
fuse_areas, _ = result
|
||||
fuse_areas = self.limit_area_size(dominant, fuse_areas)
|
||||
if not fuse_areas:
|
||||
continue
|
||||
if fuse_areas[0] in [self.recom_area, self.recom_user]:
|
||||
self.recom_user.fuse(self.recom_area)
|
||||
self.recom_res = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def index_op(self):
|
||||
"""index op by order, the copied op share id with original op, for topo-sort"""
|
||||
ids = {}
|
||||
for i, op in enumerate(self.graph.ops):
|
||||
ids[op] = i
|
||||
if hasattr(self, 'orig_op_map'):
|
||||
for k, v in self.orig_op_map.items():
|
||||
ids[k] = ids[v]
|
||||
return ids
|
||||
|
||||
def to_subgraphs(self):
|
||||
"""Transform op groups to subgraphs"""
|
||||
ids = self.index_op()
|
||||
subgraphs = []
|
||||
graphmodes = []
|
||||
for i, area in enumerate(self.areas):
|
||||
area.ops.sort(key=lambda op: ids[op])
|
||||
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info))
|
||||
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info, area.recompute_ops))
|
||||
graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite")
|
||||
return subgraphs, graphmodes
|
||||
|
||||
|
@ -274,13 +354,14 @@ class GraphSplitByPattern:
|
|||
with os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), 'w+') as f:
|
||||
f.write(subgraphs_str)
|
||||
|
||||
def do_split(self):
|
||||
"""Split graph by pattern"""
|
||||
raise Exception("do_split() is not implemented in {}".format(self.__class__.__name__))
|
||||
def pattern_fuse(self, select=None):
|
||||
"""fuse Areas by pattern repeatedly"""
|
||||
raise Exception("pattern_fuse() is not implemented in {}".format(self.__class__.__name__))
|
||||
|
||||
def split(self):
|
||||
"""Split graph by pattern"""
|
||||
self.do_split()
|
||||
self.pattern_fuse()
|
||||
self.recompute_fuse()
|
||||
# The reshape should not be output node
|
||||
# Note: after this function, the input output relation is not maintained.
|
||||
self.split_output_reshapes()
|
||||
|
@ -316,6 +397,159 @@ class GraphSplitByPattern:
|
|||
if new_areas:
|
||||
self.areas += new_areas
|
||||
|
||||
def set_recompute(self, dom_area, ops, user_area):
|
||||
"""set the recompute area and connect with other areas"""
|
||||
self.recom_area.recompute_ops.extend(ops)
|
||||
#recom_area: set dom_op and correct ops length
|
||||
patterns = [PrimLib.iter_type(op) for op in ops]
|
||||
self.recom_area.pattern = max(patterns)
|
||||
for i, pat in enumerate(patterns):
|
||||
if pat == self.recom_area.pattern:
|
||||
self.recom_area.ops = [ops[i]] * len(ops)
|
||||
break
|
||||
#disconnect dom_area and user_area
|
||||
self.dom_user_r = dom_area.out_relations[user_area]
|
||||
dom_area.out_relations.pop(user_area)
|
||||
user_area.in_relations.pop(dom_area)
|
||||
#connect recom_area and user_area
|
||||
user_area.in_relations[self.recom_area] = self.dom_user_r
|
||||
self.recom_area.out_relations[user_area] = self.dom_user_r
|
||||
#connect recom_pre and recom_area
|
||||
self.recom_pre = self.area_map[ops[0].inputs[0].op] if ops[0].inputs[0].op else None
|
||||
if self.recom_pre is not None:
|
||||
self.recom_area.in_relations[self.recom_pre] = dom_area.in_relations[self.recom_pre]
|
||||
self.recom_pre.out_relations[self.recom_area] = dom_area.in_relations[self.recom_pre]
|
||||
#set related areas
|
||||
self.recom_user = user_area
|
||||
self.recom_dom = dom_area
|
||||
self.recom_res = False
|
||||
|
||||
def clear_recompute(self):
|
||||
"""disconnect recom_area from other areas, and clear recom_area"""
|
||||
self.recom_area.out_relations.clear()
|
||||
self.recom_area.in_relations.clear()
|
||||
if not self.recom_res:
|
||||
self.recom_user.in_relations.pop(self.recom_area)
|
||||
self.recom_user.in_relations[self.recom_dom] = self.dom_user_r
|
||||
self.recom_dom.out_relations[self.recom_user] = self.dom_user_r
|
||||
if self.recom_pre:
|
||||
self.recom_pre.out_relations.pop(self.recom_area)
|
||||
self.recom_area.ops.clear()
|
||||
self.recom_area.recompute_ops.clear()
|
||||
self.orig_op_map.update(self.recom_area.ori_op_map)
|
||||
self.recom_area.ori_op_map.clear()
|
||||
|
||||
|
||||
def to_subgraph(self, dom):
|
||||
"""Transform area to subgraphs"""
|
||||
ids = self.index_op()
|
||||
dom_ops = list()
|
||||
dom_ops.extend(dom.ops)
|
||||
dom_ops.sort(key=lambda op: ids[op])
|
||||
subgraph = []
|
||||
subgraph = Graph('{}_area'.format(self.graph.name), dom_ops)
|
||||
return subgraph
|
||||
|
||||
def find_cheap_regions(self, dom):
|
||||
"""extract all the cheap regions in dom area, toposort each region before return"""
|
||||
def _grow_region(region_ops, op, weight, inputs):
|
||||
"""include op to region_ops if region grow"""
|
||||
# region successfully ends at input
|
||||
if op.inputs[0] in inputs and len(op.inputs) == 1 and \
|
||||
PrimLib.iter_type(op) <= PrimLib.BROADCAST:
|
||||
region_ops.append(op)
|
||||
return False, None, weight
|
||||
#region fails to grow
|
||||
MAX_WEIGHT = 20
|
||||
if weight > MAX_WEIGHT or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST:
|
||||
return False, None, weight
|
||||
#region grows successfully
|
||||
weight = weight + 1
|
||||
region_ops.append(op)
|
||||
return True, op.inputs[0].op, weight
|
||||
|
||||
def _find_cheap_regions(dom):
|
||||
sub = self.to_subgraph(dom)
|
||||
inputs, outputs = sub.deduce_parameters()
|
||||
cheap_regions = []
|
||||
for output in outputs:
|
||||
# tensor should have user other than user_area to be fused
|
||||
if output.para_type != Tensor.PARA_OUTPUT and len(output.to_ops) < 2:
|
||||
continue
|
||||
region_ops = []
|
||||
grow = True
|
||||
candidate_op = output.op
|
||||
weight = 1
|
||||
while grow:
|
||||
grow, candidate_op, weight = _grow_region(region_ops, candidate_op, weight, inputs)
|
||||
# region ends at input and not empty
|
||||
if region_ops and region_ops[-1].inputs[0] in inputs:
|
||||
region_ops.reverse()
|
||||
# tensor size should equal or becomes larger(cast up, broadcast)
|
||||
if region_ops[0].inputs[0].get_size() > region_ops[-1].output.get_size():
|
||||
continue
|
||||
cheap_regions.append(region_ops)
|
||||
return cheap_regions
|
||||
|
||||
return _find_cheap_regions(dom)
|
||||
|
||||
def select_user_area(self, tail_tensor):
|
||||
"""select the user area has only one edge to dom area"""
|
||||
def _get_edge_num(dom_area, user_area):
|
||||
"""get edge num between two areas"""
|
||||
dom_graph = self.to_subgraph(dom_area)
|
||||
_, dom_outputs = dom_graph.deduce_parameters()
|
||||
user_graph = self.to_subgraph(user_area)
|
||||
user_inputs, _ = user_graph.deduce_parameters()
|
||||
edge = [t for t in dom_outputs if t in user_inputs]
|
||||
return len(edge)
|
||||
|
||||
def _select_user_area(tail_tensor):
|
||||
user_areas = []
|
||||
for user_op in tail_tensor.to_ops:
|
||||
user_area = self.area_map[user_op]
|
||||
if len(user_area.ops) == 1 and user_area.pattern == PrimLib.RESHAPE:
|
||||
continue
|
||||
edge_num = _get_edge_num(self.area_map[tail_tensor.op], user_area)
|
||||
if edge_num == 1 and not user_area in user_areas:
|
||||
user_areas.append(user_area)
|
||||
return user_areas
|
||||
|
||||
return _select_user_area(tail_tensor)
|
||||
|
||||
def recompute_fuse(self):
|
||||
"""find recompute regions and copy them out to new Areas"""
|
||||
def do_recompute_fuse():
|
||||
"""split the unfusing pattern by add recompute area"""
|
||||
recompute_suc = False
|
||||
orig_areas = []
|
||||
orig_areas.extend(self.areas)
|
||||
for dom in orig_areas:
|
||||
if dom not in self.areas or not dom.out_relations:
|
||||
continue
|
||||
cheap_regions = self.find_cheap_regions(dom)
|
||||
dom_changed = False
|
||||
for cheap_region in cheap_regions:
|
||||
user_areas = self.select_user_area(cheap_region[-1].output)
|
||||
if not user_areas:
|
||||
continue
|
||||
for user_area in user_areas:
|
||||
self.set_recompute(dom, cheap_region, user_area)
|
||||
self.pattern_fuse(self.fuse_recom)
|
||||
self.clear_recompute()
|
||||
if self.recom_res:
|
||||
recompute_suc = True
|
||||
#Copy region at most once for this dom
|
||||
dom_changed = True
|
||||
break
|
||||
if dom_changed:
|
||||
break
|
||||
return recompute_suc
|
||||
|
||||
if self.enable_recompute:
|
||||
while do_recompute_fuse():
|
||||
self.pattern_fuse()
|
||||
|
||||
use_poly_reduce = True
|
||||
|
||||
|
||||
|
@ -331,8 +565,8 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
pattern = PrimLib.iter_type(op)
|
||||
return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE
|
||||
|
||||
def do_split(self):
|
||||
"""Split graph by pattern"""
|
||||
def pattern_fuse(self, fuse_func=None):
|
||||
"""fuse Areas by pattern"""
|
||||
def _reshape(dom):
|
||||
if dom.pattern != PrimLib.RESHAPE:
|
||||
return None
|
||||
|
@ -551,21 +785,38 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
fused.append(a)
|
||||
return fused, True
|
||||
|
||||
enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False)
|
||||
changed = True
|
||||
while changed:
|
||||
changed = self.fuse(_reshape)
|
||||
changed = self.fuse(_elemwise_depth) or changed
|
||||
changed = self.fuse(_elemwise_width) or changed
|
||||
changed = self.fuse(_reduce_depth) or changed
|
||||
changed = self.fuse(_reduce_width) or changed
|
||||
changed = self.fuse(_broadcast_depth) or changed
|
||||
changed = self.fuse(_broadcast_width) or changed
|
||||
def _fuse_loop():
|
||||
changed = True
|
||||
while changed:
|
||||
changed = self.fuse(_reshape)
|
||||
changed = self.fuse(_elemwise_depth) or changed
|
||||
changed = self.fuse(_elemwise_width) or changed
|
||||
changed = self.fuse(_reduce_depth) or changed
|
||||
changed = self.fuse(_reduce_width) or changed
|
||||
changed = self.fuse(_broadcast_depth) or changed
|
||||
changed = self.fuse(_broadcast_width) or changed
|
||||
if use_poly_reduce:
|
||||
changed = self.fuse(_reduce_output) or changed
|
||||
if enable_stitch_fusion:
|
||||
changed = self.fuse(_reduce_stitch) or changed
|
||||
self.fuse(_transpose)
|
||||
|
||||
def _fuse_once(fuse_func):
|
||||
if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \
|
||||
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
|
||||
fuse_func(_broadcast_width):
|
||||
return
|
||||
if use_poly_reduce:
|
||||
changed = self.fuse(_reduce_output) or changed
|
||||
if enable_stitch_fusion:
|
||||
changed = self.fuse(_reduce_stitch) or changed
|
||||
self.fuse(_transpose)
|
||||
if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)):
|
||||
return
|
||||
fuse_func(_transpose)
|
||||
return
|
||||
|
||||
enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False)
|
||||
if fuse_func is None:
|
||||
_fuse_loop()
|
||||
else:
|
||||
_fuse_once(fuse_func)
|
||||
|
||||
|
||||
class GraphSplitAscend(GraphSplitByPattern):
|
||||
|
@ -580,8 +831,8 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
return self.Area.MODE_COMPOSITE
|
||||
return self.Area.MODE_BASIC
|
||||
|
||||
def do_split(self):
|
||||
"""Split graph by pattern"""
|
||||
def pattern_fuse(self, fuse_func=None):
|
||||
"""fuse Areas by pattern"""
|
||||
def _tensor_size(tensor):
|
||||
size = 1
|
||||
for i in tensor.shape:
|
||||
|
@ -685,6 +936,19 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
def _reduce_output(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
op_attrs = dom.dom_op().attrs
|
||||
if not op_attrs.get('reduce_output_fuse'):
|
||||
return None
|
||||
fused = []
|
||||
for a, r in dom.out_relations.items():
|
||||
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \
|
||||
dom.check_acyclic(a):
|
||||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
def _transdata_pattern_support(dom, a):
|
||||
transdata_op = dom.dom_op()
|
||||
|
||||
|
@ -733,32 +997,31 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
fused.append(a)
|
||||
return fused, True
|
||||
|
||||
def _reduce_output(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
op_attrs = dom.dom_op().attrs
|
||||
if not op_attrs.get('reduce_output_fuse'):
|
||||
return None
|
||||
fused = []
|
||||
for a, r in dom.out_relations.items():
|
||||
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \
|
||||
dom.check_acyclic(a):
|
||||
fused.append(a)
|
||||
return fused, False
|
||||
def _fuse_loop():
|
||||
changed = True
|
||||
while changed:
|
||||
changed = self.fuse(_reshape)
|
||||
changed = self.fuse(_elemwise_depth) or changed
|
||||
changed = self.fuse(_elemwise_width) or changed
|
||||
changed = self.fuse(_reduce_depth) or changed
|
||||
changed = self.fuse(_reduce_width) or changed
|
||||
changed = self.fuse(_broadcast_depth) or changed
|
||||
changed = self.fuse(_broadcast_width) or changed
|
||||
changed = self.fuse(_matmul_depth) or changed
|
||||
changed = self.fuse(_reduce_output) or changed
|
||||
self.fuse(_transdata)
|
||||
|
||||
changed = True
|
||||
while changed:
|
||||
changed = self.fuse(_reshape)
|
||||
changed = self.fuse(_elemwise_depth) or changed
|
||||
changed = self.fuse(_elemwise_width) or changed
|
||||
changed = self.fuse(_reduce_depth) or changed
|
||||
changed = self.fuse(_reduce_width) or changed
|
||||
changed = self.fuse(_broadcast_depth) or changed
|
||||
changed = self.fuse(_broadcast_width) or changed
|
||||
changed = self.fuse(_matmul_depth) or changed
|
||||
changed = self.fuse(_reduce_output) or changed
|
||||
self.fuse(_transdata)
|
||||
def _fuse_once(fuse_func):
|
||||
if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \
|
||||
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
|
||||
fuse_func(_broadcast_width) or fuse_func(_matmul_depth) or fuse_func(_reduce_output) or \
|
||||
fuse_func(_transdata):
|
||||
pass
|
||||
|
||||
if fuse_func is None:
|
||||
_fuse_loop()
|
||||
else:
|
||||
_fuse_once(fuse_func)
|
||||
|
||||
|
||||
def split(graph, target, flags):
|
||||
|
|
|
@ -320,8 +320,8 @@ class Operator:
|
|||
|
||||
def __str__(self):
|
||||
args = ', '.join([str(t) for t in self.all_inputs])
|
||||
expr = "%s = %s.%s(%s)" % (
|
||||
str(self.output), self.prim, self.output.dtype, args)
|
||||
expr = "%s = %s.%s(%s) id:%s" % (
|
||||
str(self.output), self.prim, self.output.dtype, args, id(self))
|
||||
return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -331,12 +331,13 @@ class Operator:
|
|||
class Graph:
|
||||
"""Graph"""
|
||||
|
||||
def __init__(self, name, ops, stitch_info=None):
|
||||
def __init__(self, name, ops, stitch_info=None, recompute_ops=None):
|
||||
self.name = name
|
||||
self.ops = ops # in topo order, can not use set
|
||||
self.inputs = []
|
||||
self.outputs = []
|
||||
self.stitch_info = stitch_info
|
||||
self.recompute_ops = recompute_ops
|
||||
|
||||
def set_processor(self, processor):
|
||||
"""Set processor"""
|
||||
|
|
|
@ -203,11 +203,13 @@ class CompositeGraph:
|
|||
desc['buffer_stitch'] = buffer_stitch
|
||||
return desc
|
||||
|
||||
def dump(self, subgraph):
|
||||
"""Dump Graph to json"""
|
||||
desc = {}
|
||||
inputs, outputs = subgraph.deduce_parameters()
|
||||
graph_ops = set(subgraph.ops)
|
||||
def add_recompute_ops(self, subgraph, desc):
|
||||
if subgraph.recompute_ops:
|
||||
desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops]
|
||||
return desc
|
||||
|
||||
def _pre_dump(self, outputs):
|
||||
"""restore name to before load"""
|
||||
inplace_assign = {} # y_name, output_name
|
||||
inplace_assign_z = None
|
||||
for op in self.desc['op_desc']:
|
||||
|
@ -217,6 +219,14 @@ class CompositeGraph:
|
|||
for t in outputs:
|
||||
if t.name not in inplace_assign:
|
||||
inplace_assign_z = t
|
||||
return inplace_assign, inplace_assign_z
|
||||
|
||||
def dump(self, subgraph):
|
||||
"""Dump Graph to json"""
|
||||
desc = {}
|
||||
inputs, outputs = subgraph.deduce_parameters()
|
||||
graph_ops = set(subgraph.ops)
|
||||
inplace_assign, inplace_assign_z = self._pre_dump(outputs)
|
||||
for key in self.desc:
|
||||
if key == 'input_desc':
|
||||
desc[key] = [
|
||||
|
@ -251,7 +261,7 @@ class CompositeGraph:
|
|||
op_desc.append(inplace_desc)
|
||||
else:
|
||||
op = self.tensors[d['output_desc'][0]['tensor_name']].op
|
||||
if op in graph_ops:
|
||||
if op in graph_ops or op in subgraph.recompute_ops:
|
||||
op_desc.append(d)
|
||||
desc[key] = op_desc
|
||||
elif key == 'op':
|
||||
|
@ -260,6 +270,7 @@ class CompositeGraph:
|
|||
desc[key] = self.desc[key]
|
||||
|
||||
desc = self.add_stitch_info(subgraph, desc)
|
||||
desc = self.add_recompute_ops(subgraph, desc)
|
||||
return desc
|
||||
|
||||
|
||||
|
|
|
@ -433,68 +433,5 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const std::string &kernel_js
|
|||
auto kernel_json = nlohmann::json::parse(kernel_json_str);
|
||||
return DecodeFusedNodes(kernel_json);
|
||||
}
|
||||
|
||||
StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json) const {
|
||||
StitchInfo info;
|
||||
if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) {
|
||||
nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch];
|
||||
if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) {
|
||||
std::vector<std::string> stitch_ops = buffer_stitch[kJsonKeyStitchOp];
|
||||
info.stitch_ops = stitch_ops;
|
||||
}
|
||||
if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) {
|
||||
std::vector<std::string> stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp];
|
||||
info.stitch_atomic_ops = stitch_atomic_ops;
|
||||
}
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info,
|
||||
const CNodePtr &node) const {
|
||||
std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
|
||||
if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return;
|
||||
std::string tensor_name = output_descs[0][kJsonKeyTensorName];
|
||||
if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) {
|
||||
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node);
|
||||
MS_LOG(INFO) << "Enable common stitch fusion by " << node->fullname_with_scope();
|
||||
}
|
||||
if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) !=
|
||||
info.stitch_atomic_ops.end()) {
|
||||
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node);
|
||||
MS_LOG(INFO) << "Enable atomic add stitch fusion by " << node->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
|
||||
bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json,
|
||||
const std::map<std::string, AnfNodePtr> &address_node_map,
|
||||
AnfNodePtrList *res_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(res_graphs);
|
||||
MS_LOG(DEBUG) << "start decode, " << kernel_json;
|
||||
// decode cnodes in graph.
|
||||
std::vector<nlohmann::json> op_node_descs = kernel_json[kJsonKeyOpDesc];
|
||||
if (op_node_descs.empty()) {
|
||||
MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json;
|
||||
return false;
|
||||
}
|
||||
StitchInfo info = GetStitchInfo(kernel_json);
|
||||
for (const auto &op_desc : op_node_descs) {
|
||||
if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) {
|
||||
MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string ptr_address = op_desc[kJsonKeyPtrAddress];
|
||||
if (address_node_map.count(ptr_address) == 0) {
|
||||
MS_LOG(ERROR) << "Decode failed, ptr_address not found in map.";
|
||||
return false;
|
||||
}
|
||||
auto node = address_node_map.at(ptr_address)->cast<CNodePtr>();
|
||||
SetStitchAttr(op_desc, info, node);
|
||||
res_graphs->push_back(node);
|
||||
}
|
||||
MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size();
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,10 +26,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
struct StitchInfo {
|
||||
std::vector<std::string> stitch_ops;
|
||||
std::vector<std::string> stitch_atomic_ops;
|
||||
};
|
||||
class AkgKernelJsonDecoder {
|
||||
public:
|
||||
AkgKernelJsonDecoder() { nodes_map_.clear(); }
|
||||
|
@ -37,15 +33,11 @@ class AkgKernelJsonDecoder {
|
|||
|
||||
FuncGraphPtr DecodeFusedNodes(const nlohmann::json &kernel_json);
|
||||
FuncGraphPtr DecodeFusedNodes(const std::string &kernel_json_str);
|
||||
bool DecodeSplitNodes(const nlohmann::json &kernel_json, const std::map<std::string, AnfNodePtr> &address_node_map,
|
||||
AnfNodePtrList *res_graphs);
|
||||
|
||||
private:
|
||||
ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph);
|
||||
CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor);
|
||||
AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph);
|
||||
StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) const;
|
||||
void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) const;
|
||||
std::map<std::string, AnfNodePtr> nodes_map_;
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -54,6 +54,7 @@ constexpr auto kJsonKeyFusionType = "fusion_type";
|
|||
constexpr auto kJsonKeySubGraph = "sub_graph";
|
||||
constexpr auto kJsonKeyCoreNum = "core_num";
|
||||
constexpr auto kJsonKeyTypeInfo = "type_info";
|
||||
constexpr auto kJsonKeyRecomputeOps = "recompute_ops";
|
||||
constexpr auto kJsonKeyBufferStitch = "buffer_stitch";
|
||||
constexpr auto kJsonKeyStitchOp = "stitch_op";
|
||||
constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op";
|
||||
|
|
|
@ -117,7 +117,7 @@ PassManagerPtr GraphKernelOptimizer::Split() const {
|
|||
// which can avoid unnecessary input-output and get better performance.
|
||||
// preprocess for ShapeOpsSplitter
|
||||
pm->AddPass(std::make_shared<ExtendOutputForUpdateState>(), OptLevel_1);
|
||||
std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast};
|
||||
std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape};
|
||||
pm->AddPass(std::make_shared<ShapeOpsSplitter>(duplicated_ops), OptLevel_1);
|
||||
|
||||
// Split kernel according to costmodel
|
||||
|
|
|
@ -32,6 +32,150 @@
|
|||
#include "utils/context/graph_kernel_flags.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) {
|
||||
StitchInfo info;
|
||||
if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) {
|
||||
nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch];
|
||||
if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) {
|
||||
std::vector<std::string> stitch_ops = buffer_stitch[kJsonKeyStitchOp];
|
||||
info.stitch_ops = stitch_ops;
|
||||
}
|
||||
if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) {
|
||||
std::vector<std::string> stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp];
|
||||
info.stitch_atomic_ops = stitch_atomic_ops;
|
||||
}
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
std::set<std::string> GetRecomputeOps(const nlohmann::json &kernel_json) {
|
||||
if (kernel_json.find(kJsonKeyRecomputeOps) != kernel_json.end()) {
|
||||
std::vector<std::string> recompute_ops = kernel_json[kJsonKeyRecomputeOps];
|
||||
return std::set<std::string>(recompute_ops.begin(), recompute_ops.end());
|
||||
}
|
||||
return std::set<std::string>();
|
||||
}
|
||||
|
||||
bool IsRecomputeOp(const nlohmann::json &op_desc, const std::set<std::string> &recompute_ops) {
|
||||
std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
|
||||
if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) {
|
||||
return false;
|
||||
}
|
||||
std::string tensor_name = output_descs[0][kJsonKeyTensorName];
|
||||
if (recompute_ops.count(tensor_name)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
CNodePtr NewRecomputeNode(const AnfNodePtr &orig_node, std::map<AnfNodePtr, AnfNodePtr> *node_map) {
|
||||
auto func_graph = orig_node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto cnode = orig_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info()));
|
||||
auto orig_inputs = cnode->inputs();
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
for (auto inp : orig_inputs) {
|
||||
if (node_map->find(inp) == node_map->end()) {
|
||||
inputs.push_back(inp);
|
||||
continue;
|
||||
}
|
||||
inputs.push_back((*node_map)[inp]);
|
||||
}
|
||||
CNodePtr cp_node = func_graph->NewCNode(inputs);
|
||||
func_graph->AddNode(cp_node);
|
||||
cp_node->set_abstract(cnode->abstract());
|
||||
cp_node->set_forward(cnode->forward().first, cnode->forward().second);
|
||||
cp_node->set_inputs_value(cnode->inputs_value());
|
||||
ScopePtr scope = (orig_node->scope() != kDefaultScope) ? orig_node->scope() : kDefaultScope;
|
||||
cp_node->set_scope(scope);
|
||||
cp_node->set_kernel_info(cnode->kernel_info_ptr());
|
||||
(*node_map)[orig_node] = cp_node;
|
||||
return cp_node->cast<CNodePtr>();
|
||||
}
|
||||
|
||||
void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) {
|
||||
std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
|
||||
if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return;
|
||||
std::string tensor_name = output_descs[0][kJsonKeyTensorName];
|
||||
if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) {
|
||||
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node);
|
||||
MS_LOG(INFO) << "Enable common stitch fusion by " << node->fullname_with_scope();
|
||||
}
|
||||
if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) !=
|
||||
info.stitch_atomic_ops.end()) {
|
||||
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node);
|
||||
MS_LOG(INFO) << "Enable atomic add stitch fusion by " << node->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
|
||||
// replace original region root op by its copy in this res_graphs
|
||||
void ConnectRecomputeOps(AnfNodePtrList *res_graphs, const AnfNodePtr &orig_region_root,
|
||||
const AnfNodePtr &cp_region_root) {
|
||||
for (auto &node : *res_graphs) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
if (inputs[i] != orig_region_root) continue;
|
||||
cnode->set_input(i, cp_region_root);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool SplitNodesDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json,
|
||||
const std::map<std::string, AnfNodePtr> &address_node_map,
|
||||
AnfNodePtrList *res_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(res_graphs);
|
||||
MS_LOG(DEBUG) << "start decode, " << kernel_json;
|
||||
// decode cnodes in graph.
|
||||
std::vector<nlohmann::json> op_node_descs = kernel_json[kJsonKeyOpDesc];
|
||||
if (op_node_descs.empty()) {
|
||||
MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json;
|
||||
return false;
|
||||
}
|
||||
StitchInfo info = GetStitchInfo(kernel_json);
|
||||
auto recompute_ops = GetRecomputeOps(kernel_json);
|
||||
// key_value: original_copied
|
||||
std::map<AnfNodePtr, AnfNodePtr> node_map;
|
||||
// nodes would be copied
|
||||
AnfNodePtrList orig_region_nodes;
|
||||
// nodes would not be copied
|
||||
AnfNodePtrList no_cp_nodes;
|
||||
for (const auto &op_desc : op_node_descs) {
|
||||
if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) {
|
||||
MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string ptr_address = op_desc[kJsonKeyPtrAddress];
|
||||
if (address_node_map.count(ptr_address) == 0) {
|
||||
MS_LOG(ERROR) << "Decode failed, ptr_address not found in map.";
|
||||
return false;
|
||||
}
|
||||
auto node = address_node_map.at(ptr_address)->cast<CNodePtr>();
|
||||
if (IsRecomputeOp(op_desc, recompute_ops)) {
|
||||
auto cp_node = NewRecomputeNode(node, &node_map);
|
||||
orig_region_nodes.push_back(node);
|
||||
SetStitchAttr(op_desc, info, cp_node);
|
||||
res_graphs->push_back(cp_node);
|
||||
continue;
|
||||
}
|
||||
SetStitchAttr(op_desc, info, node);
|
||||
res_graphs->push_back(node);
|
||||
no_cp_nodes.push_back(node);
|
||||
}
|
||||
for (auto orig_node : orig_region_nodes) {
|
||||
ConnectRecomputeOps(&no_cp_nodes, orig_node, node_map[orig_node]);
|
||||
}
|
||||
MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size();
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
|
||||
namespace opt {
|
||||
namespace {
|
||||
void TraverseFuncGraphFromCNode(const CNodePtr &cnode, const std::function<void(AnfNodePtr &)> &callback) {
|
||||
|
@ -620,7 +764,7 @@ class CostModelSplitSchemer : public SplitSchemer {
|
|||
split_plan_.clear();
|
||||
for (const auto &graph_desc : graph_descs) {
|
||||
AnfNodePtrList res_graph;
|
||||
if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) {
|
||||
if (!kernel::SplitNodesDecoder::DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) {
|
||||
MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc;
|
||||
return false;
|
||||
}
|
||||
|
@ -731,6 +875,7 @@ class CostModelSplitSchemer : public SplitSchemer {
|
|||
nlohmann::json flag_json;
|
||||
flag_json["dump_as_text"] = flags.dump_as_text;
|
||||
flag_json["enable_stitch_fusion"] = flags.enable_stitch_fusion;
|
||||
flag_json["enable_recompute_fusion"] = flags.enable_recompute_fusion;
|
||||
return flag_json.dump();
|
||||
}
|
||||
|
||||
|
|
|
@ -15,11 +15,31 @@
|
|||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_SPLITTER_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_SPLITTER_H_
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
struct StitchInfo {
|
||||
std::vector<std::string> stitch_ops;
|
||||
std::vector<std::string> stitch_atomic_ops;
|
||||
};
|
||||
|
||||
class SplitNodesDecoder {
|
||||
public:
|
||||
SplitNodesDecoder() {}
|
||||
~SplitNodesDecoder() = default;
|
||||
static bool DecodeSplitNodes(const nlohmann::json &kernel_json,
|
||||
const std::map<std::string, AnfNodePtr> &address_node_map, AnfNodePtrList *res_graphs);
|
||||
};
|
||||
} // namespace kernel
|
||||
|
||||
namespace opt {
|
||||
class GraphKernelSplitter : public Pass {
|
||||
public:
|
||||
|
|
|
@ -181,6 +181,7 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
|
|||
// Boolean flags
|
||||
reg.AddFlag("dump_as_text", &dump_as_text);
|
||||
reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion, opt_level == OptLevel_3);
|
||||
reg.AddFlag("enable_recompute_fusion", &enable_recompute_fusion, opt_level == OptLevel_2);
|
||||
reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion, opt_level == OptLevel_3);
|
||||
|
||||
// Integer flags
|
||||
|
@ -203,6 +204,7 @@ std::string GraphKernelFlags::DumpAllFlags() const {
|
|||
|
||||
json["dump_as_text"] = dump_as_text;
|
||||
json["enable_stitch_fusion"] = enable_stitch_fusion;
|
||||
json["enable_recompute_fusion"] = enable_recompute_fusion;
|
||||
json["enable_parallel_fusion"] = enable_parallel_fusion;
|
||||
|
||||
json["opt_level"] = opt_level;
|
||||
|
|
|
@ -67,6 +67,11 @@ class GraphKernelFlags {
|
|||
*/
|
||||
bool enable_stitch_fusion;
|
||||
|
||||
/**
|
||||
* Enable recompute fusion in graph kernel fusion strategy.
|
||||
*/
|
||||
bool enable_recompute_fusion{true};
|
||||
|
||||
/**
|
||||
* Enable parallel fusion in graph kernel fusion strategy.
|
||||
*
|
||||
|
|
|
@ -0,0 +1,223 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.ops.operations as P
|
||||
|
||||
#{cast} would be recompute and fused
|
||||
class Net1(Cell):
|
||||
def __init__(self):
|
||||
super(Net1, self).__init__()
|
||||
self.cast = P.Cast()
|
||||
self.sum = P.ReduceSum(keep_dims=False)
|
||||
|
||||
def construct(self, x):
|
||||
cast_res = self.cast(x, mstype.float32)
|
||||
sum1_res = self.sum(cast_res, (0,))
|
||||
sum2_res = self.sum(cast_res, (1,))
|
||||
return sum1_res, sum2_res
|
||||
|
||||
#{sqrt} would be recompute on Ascend
|
||||
class Net2(Cell):
|
||||
def __init__(self):
|
||||
super(Net2, self).__init__()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.sum = P.ReduceSum(keep_dims=True)
|
||||
self.add = P.Add()
|
||||
self.neg = P.Neg()
|
||||
|
||||
def construct(self, x0, x1):
|
||||
sqrt_res = self.sqrt(x0)
|
||||
neg_res = self.neg(sqrt_res)
|
||||
add_res = self.add(x1, sqrt_res)
|
||||
sum_res = self.sum(add_res, (0,))
|
||||
return neg_res, sum_res
|
||||
|
||||
#{sqrt} would be recompute
|
||||
class Net3(Cell):
|
||||
def __init__(self):
|
||||
super(Net3, self).__init__()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.add = P.Add()
|
||||
self.neg = P.Neg()
|
||||
|
||||
def construct(self, x0, x1):
|
||||
sqrt_res = self.sqrt(x0)
|
||||
neg_res = self.neg(sqrt_res)
|
||||
add_res = self.add(x1, sqrt_res)
|
||||
return neg_res, add_res
|
||||
|
||||
#{sqrt neg} would be recompute
|
||||
class Net4(Cell):
|
||||
def __init__(self):
|
||||
super(Net4, self).__init__()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.neg = P.Neg()
|
||||
self.sum = P.ReduceSum(keep_dims=False)
|
||||
|
||||
def construct(self, x):
|
||||
sqrt_res = self.sqrt(x)
|
||||
neg_res = self.neg(sqrt_res)
|
||||
sum1_res = self.sum(neg_res, (0,))
|
||||
sum2_res = self.sum(neg_res, (1,))
|
||||
return sum1_res, sum2_res
|
||||
|
||||
#{sqrt} would be recompute
|
||||
class Net5(Cell):
|
||||
def __init__(self):
|
||||
super(Net5, self).__init__()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x0, x1, x2):
|
||||
sqrt_res = self.sqrt(x0)
|
||||
add1_res = self.add(sqrt_res, x1)
|
||||
add2_res = self.add(sqrt_res, x2)
|
||||
return add1_res, add2_res
|
||||
|
||||
def test_basic1(net):
|
||||
def get_output(i0, net, enable_graph_kernel=False):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
net_obj = net()
|
||||
output = net_obj(i0)
|
||||
return output
|
||||
|
||||
i0 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float16))
|
||||
expect = get_output(i0, net, False)
|
||||
output = get_output(i0, net, True)
|
||||
expect0_np = expect[0].asnumpy().copy()
|
||||
output0_np = output[0].asnumpy().copy()
|
||||
expect1_np = expect[1].asnumpy().copy()
|
||||
output1_np = output[1].asnumpy().copy()
|
||||
assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3)
|
||||
assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3)
|
||||
|
||||
|
||||
def test_basic2(net):
|
||||
def get_output(i0, i1, net, enable_graph_kernel=False):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
net_obj = net()
|
||||
output = net_obj(i0, i1)
|
||||
return output
|
||||
|
||||
i0 = Tensor(np.random.uniform(1, 2, [1, 1024]).astype(np.float32))
|
||||
i1 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float32))
|
||||
expect = get_output(i0, i1, net, False)
|
||||
output = get_output(i0, i1, net, True)
|
||||
expect0_np = expect[0].asnumpy().copy()
|
||||
output0_np = output[0].asnumpy().copy()
|
||||
expect1_np = expect[1].asnumpy().copy()
|
||||
output1_np = output[1].asnumpy().copy()
|
||||
assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3)
|
||||
assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3)
|
||||
|
||||
def test_basic3(net):
|
||||
def get_output(i0, i1, i2, net, enable_graph_kernel=False):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
net_obj = net()
|
||||
output = net_obj(i0, i1, i2)
|
||||
return output
|
||||
|
||||
i0 = Tensor(np.random.uniform(1, 2, [1, 1024]).astype(np.float16))
|
||||
i1 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float16))
|
||||
i2 = Tensor(np.random.uniform(1, 2, [2048, 1024]).astype(np.float16))
|
||||
expect = get_output(i0, i1, i2, net, False)
|
||||
output = get_output(i0, i1, i2, net, True)
|
||||
expect0_np = expect[0].asnumpy().copy()
|
||||
output0_np = output[0].asnumpy().copy()
|
||||
expect1_np = expect[1].asnumpy().copy()
|
||||
output1_np = output[1].asnumpy().copy()
|
||||
assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3)
|
||||
assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gpu_1():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_basic1(Net1)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gpu_2():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_basic2(Net2)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gpu_3():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_basic2(Net3)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gpu_4():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_basic1(Net4)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gpu_5():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_basic3(Net5)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ascend_1():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_basic1(Net1)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ascend_2():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_basic2(Net2)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ascend_3():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_basic2(Net3)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ascend_4():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_basic1(Net4)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ascend_5():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_basic3(Net5)
|
Loading…
Reference in New Issue