forked from mindspore-Ecosystem/mindspore
recompute_fuse
This commit is contained in:
parent
11d6b435c2
commit
a995bea507
|
@ -16,7 +16,7 @@
|
||||||
import os
|
import os
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from .model import PrimLib, Graph, Tensor
|
from .model import PrimLib, Graph, Tensor, Operator
|
||||||
from .model import DataFormat as DF
|
from .model import DataFormat as DF
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,13 +65,16 @@ class GraphSplitByPattern:
|
||||||
self.stitch_ops = set()
|
self.stitch_ops = set()
|
||||||
self.stitch_atomic_ops = set()
|
self.stitch_atomic_ops = set()
|
||||||
|
|
||||||
def __init__(self, init_op, is_output, unique_id, reach_tab):
|
def __init__(self, init_op, is_output, unique_id, reach_tab, recompute_ops=None):
|
||||||
self.pattern = PrimLib.iter_type(init_op)
|
self.pattern = PrimLib.iter_type(init_op) if init_op is not None else PrimLib.UNKNOWN
|
||||||
self.ops = [init_op]
|
self.ops = [] if init_op is None else [init_op]
|
||||||
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
|
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
|
||||||
self.out_relations = dict() # {area1: relation1, area2: relation2, ...}
|
self.out_relations = dict() # {area1: relation1, area2: relation2, ...}
|
||||||
self.mode = None
|
self.mode = None
|
||||||
self.stitch_info = self.StitchInfo()
|
self.stitch_info = self.StitchInfo()
|
||||||
|
self.recompute_ops = [] if recompute_ops is None else recompute_ops
|
||||||
|
self.ori_op_map = {}
|
||||||
|
self.is_recompute = False
|
||||||
self.is_output = is_output
|
self.is_output = is_output
|
||||||
self.output_excluded = set()
|
self.output_excluded = set()
|
||||||
if self.pattern == PrimLib.REDUCE:
|
if self.pattern == PrimLib.REDUCE:
|
||||||
|
@ -143,6 +146,8 @@ class GraphSplitByPattern:
|
||||||
r = rels.pop(area)
|
r = rels.pop(area)
|
||||||
_update_relation(rels, self, r)
|
_update_relation(rels, self, r)
|
||||||
|
|
||||||
|
if area.is_recompute:
|
||||||
|
self.cp_ops(area)
|
||||||
if self.pattern >= area.pattern:
|
if self.pattern >= area.pattern:
|
||||||
self.ops.extend(area.ops)
|
self.ops.extend(area.ops)
|
||||||
else:
|
else:
|
||||||
|
@ -161,7 +166,9 @@ class GraphSplitByPattern:
|
||||||
if area.output_excluded:
|
if area.output_excluded:
|
||||||
self.output_excluded.update(area.output_excluded)
|
self.output_excluded.update(area.output_excluded)
|
||||||
self.update_stitch_info(area.stitch_info)
|
self.update_stitch_info(area.stitch_info)
|
||||||
self.reach_tab.fuse(self.unique_id, area.unique_id)
|
if not area.is_recompute:
|
||||||
|
self.reach_tab.fuse(self.unique_id, area.unique_id)
|
||||||
|
self.recompute_ops.extend(area.recompute_ops)
|
||||||
|
|
||||||
def check_acyclic(self, to):
|
def check_acyclic(self, to):
|
||||||
"""Check circle. It returns false if circle exists"""
|
"""Check circle. It returns false if circle exists"""
|
||||||
|
@ -180,25 +187,73 @@ class GraphSplitByPattern:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def cp_ops(self, area):
|
||||||
|
"""copy recompute_ops in area to ops, self is area's user"""
|
||||||
|
tail_tensor = area.recompute_ops[-1].output
|
||||||
|
#copy tensors, all copied are Tensor.PARA_NONE
|
||||||
|
tensor_map = {}
|
||||||
|
tensor_map[area.recompute_ops[0].inputs[0]] = area.recompute_ops[0].inputs[0]
|
||||||
|
for op in area.recompute_ops:
|
||||||
|
orig_tensor = op.output
|
||||||
|
cp_tensor = Tensor(orig_tensor.name, orig_tensor.shape, orig_tensor.dtype, orig_tensor.data_format)
|
||||||
|
tensor_map[orig_tensor] = cp_tensor
|
||||||
|
#copy ops
|
||||||
|
cp_ops = []
|
||||||
|
for op in area.recompute_ops:
|
||||||
|
cp_op = Operator(op.prim, [tensor_map[op.inputs[0]]], tensor_map[op.output], op.attrs)
|
||||||
|
cp_op.all_inputs = cp_op.inputs
|
||||||
|
cp_ops.append(cp_op)
|
||||||
|
area.ori_op_map[cp_op] = op
|
||||||
|
#connect copied ops
|
||||||
|
for op in self.ops:
|
||||||
|
if tail_tensor in op.inputs:
|
||||||
|
op.inputs.remove(tail_tensor)
|
||||||
|
op.inputs.append(tensor_map[tail_tensor])
|
||||||
|
tail_tensor.to_ops.remove(op)
|
||||||
|
tensor_map[tail_tensor].to_ops.append(op)
|
||||||
|
#fill cp_ops in self.recompute_area
|
||||||
|
cp_dom_op = None
|
||||||
|
for cp, ori in area.ori_op_map.items():
|
||||||
|
if ori == area.dom_op():
|
||||||
|
cp_dom_op = cp
|
||||||
|
area.ops.clear()
|
||||||
|
area.ops.append(cp_dom_op)
|
||||||
|
area.ops.extend([op for op in cp_ops if op != cp_dom_op])
|
||||||
|
|
||||||
def __init__(self, graph, flags):
|
def __init__(self, graph, flags):
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.areas = []
|
self.areas = []
|
||||||
self.flags = flags
|
self.flags = flags
|
||||||
self.reach_tab = self.ReachTable(len(graph.ops))
|
self.enable_recompute = self.flags.get("enable_recompute_fusion", False)
|
||||||
area_map = {}
|
self.reach_tab = self.ReachTable(len(graph.ops) + 1 if self.enable_recompute else len(graph.ops))
|
||||||
|
self.area_map = {}
|
||||||
_, outputs = graph.deduce_parameters()
|
_, outputs = graph.deduce_parameters()
|
||||||
idx = 0
|
self.idx = 0
|
||||||
for op in graph.ops:
|
for op in graph.ops:
|
||||||
is_output = op.output in outputs
|
is_output = op.output in outputs
|
||||||
a = self.Area(op, is_output, idx, self.reach_tab)
|
a = self.Area(op, is_output, self.idx, self.reach_tab)
|
||||||
idx += 1
|
self.idx += 1
|
||||||
self.set_default_mode(a)
|
self.set_default_mode(a)
|
||||||
self.areas.append(a)
|
self.areas.append(a)
|
||||||
area_map[op] = a
|
self.set_area_map([op], a)
|
||||||
for a in self.areas:
|
for a in self.areas:
|
||||||
a.link_input(area_map)
|
a.link_input(self.area_map)
|
||||||
for i in range(len(self.areas)-1, -1, -1):
|
for i in range(len(self.areas)-1, -1, -1):
|
||||||
self.areas[i].link_output()
|
self.areas[i].link_output()
|
||||||
|
if self.enable_recompute:
|
||||||
|
self.recom_area = self.Area(None, False, self.idx, self.reach_tab)
|
||||||
|
self.recom_area.is_recompute = True
|
||||||
|
self.recom_pre = None
|
||||||
|
self.recom_user = None
|
||||||
|
self.recom_dom = None
|
||||||
|
self.dom_user_r = PrimLib.UNKNOWN
|
||||||
|
self.recom_res = False
|
||||||
|
self.orig_op_map = {}
|
||||||
|
|
||||||
|
def set_area_map(self, ops, area):
|
||||||
|
"""update area_map after op fused to area"""
|
||||||
|
for op in ops:
|
||||||
|
self.area_map[op] = area
|
||||||
|
|
||||||
def set_default_mode(self, area):
|
def set_default_mode(self, area):
|
||||||
area.mode = self.get_default_mode(area.ops[0])
|
area.mode = self.get_default_mode(area.ops[0])
|
||||||
|
@ -234,11 +289,13 @@ class GraphSplitByPattern:
|
||||||
if is_forward:
|
if is_forward:
|
||||||
for area in fuse_areas:
|
for area in fuse_areas:
|
||||||
dominant.fuse(area)
|
dominant.fuse(area)
|
||||||
|
self.set_area_map(area.ops, dominant)
|
||||||
self.areas.remove(area)
|
self.areas.remove(area)
|
||||||
else:
|
else:
|
||||||
forward_area = dominant
|
forward_area = dominant
|
||||||
for area in fuse_areas:
|
for area in fuse_areas:
|
||||||
area.fuse(forward_area)
|
area.fuse(forward_area)
|
||||||
|
self.set_area_map(forward_area.ops, area)
|
||||||
self.areas.remove(forward_area)
|
self.areas.remove(forward_area)
|
||||||
forward_area = area
|
forward_area = area
|
||||||
changed = True
|
changed = True
|
||||||
|
@ -246,16 +303,39 @@ class GraphSplitByPattern:
|
||||||
else:
|
else:
|
||||||
return changed
|
return changed
|
||||||
|
|
||||||
def to_subgraphs(self):
|
def fuse_recom(self, selector):
|
||||||
"""Transform op groups to subgraphs"""
|
"""Fuse recompute area to its user"""
|
||||||
|
for dominant in [self.recom_area, self.recom_user]:
|
||||||
|
result = selector(dominant)
|
||||||
|
if result is not None and result[0]:
|
||||||
|
fuse_areas, _ = result
|
||||||
|
fuse_areas = self.limit_area_size(dominant, fuse_areas)
|
||||||
|
if not fuse_areas:
|
||||||
|
continue
|
||||||
|
if fuse_areas[0] in [self.recom_area, self.recom_user]:
|
||||||
|
self.recom_user.fuse(self.recom_area)
|
||||||
|
self.recom_res = True
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def index_op(self):
|
||||||
|
"""index op by order, the copied op share id with original op, for topo-sort"""
|
||||||
ids = {}
|
ids = {}
|
||||||
for i, op in enumerate(self.graph.ops):
|
for i, op in enumerate(self.graph.ops):
|
||||||
ids[op] = i
|
ids[op] = i
|
||||||
|
if hasattr(self, 'orig_op_map'):
|
||||||
|
for k, v in self.orig_op_map.items():
|
||||||
|
ids[k] = ids[v]
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def to_subgraphs(self):
|
||||||
|
"""Transform op groups to subgraphs"""
|
||||||
|
ids = self.index_op()
|
||||||
subgraphs = []
|
subgraphs = []
|
||||||
graphmodes = []
|
graphmodes = []
|
||||||
for i, area in enumerate(self.areas):
|
for i, area in enumerate(self.areas):
|
||||||
area.ops.sort(key=lambda op: ids[op])
|
area.ops.sort(key=lambda op: ids[op])
|
||||||
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info))
|
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info, area.recompute_ops))
|
||||||
graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite")
|
graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite")
|
||||||
return subgraphs, graphmodes
|
return subgraphs, graphmodes
|
||||||
|
|
||||||
|
@ -274,13 +354,14 @@ class GraphSplitByPattern:
|
||||||
with os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), 'w+') as f:
|
with os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), 'w+') as f:
|
||||||
f.write(subgraphs_str)
|
f.write(subgraphs_str)
|
||||||
|
|
||||||
def do_split(self):
|
def pattern_fuse(self, select=None):
|
||||||
"""Split graph by pattern"""
|
"""fuse Areas by pattern repeatedly"""
|
||||||
raise Exception("do_split() is not implemented in {}".format(self.__class__.__name__))
|
raise Exception("pattern_fuse() is not implemented in {}".format(self.__class__.__name__))
|
||||||
|
|
||||||
def split(self):
|
def split(self):
|
||||||
"""Split graph by pattern"""
|
"""Split graph by pattern"""
|
||||||
self.do_split()
|
self.pattern_fuse()
|
||||||
|
self.recompute_fuse()
|
||||||
# The reshape should not be output node
|
# The reshape should not be output node
|
||||||
# Note: after this function, the input output relation is not maintained.
|
# Note: after this function, the input output relation is not maintained.
|
||||||
self.split_output_reshapes()
|
self.split_output_reshapes()
|
||||||
|
@ -316,6 +397,159 @@ class GraphSplitByPattern:
|
||||||
if new_areas:
|
if new_areas:
|
||||||
self.areas += new_areas
|
self.areas += new_areas
|
||||||
|
|
||||||
|
def set_recompute(self, dom_area, ops, user_area):
|
||||||
|
"""set the recompute area and connect with other areas"""
|
||||||
|
self.recom_area.recompute_ops.extend(ops)
|
||||||
|
#recom_area: set dom_op and correct ops length
|
||||||
|
patterns = [PrimLib.iter_type(op) for op in ops]
|
||||||
|
self.recom_area.pattern = max(patterns)
|
||||||
|
for i, pat in enumerate(patterns):
|
||||||
|
if pat == self.recom_area.pattern:
|
||||||
|
self.recom_area.ops = [ops[i]] * len(ops)
|
||||||
|
break
|
||||||
|
#disconnect dom_area and user_area
|
||||||
|
self.dom_user_r = dom_area.out_relations[user_area]
|
||||||
|
dom_area.out_relations.pop(user_area)
|
||||||
|
user_area.in_relations.pop(dom_area)
|
||||||
|
#connect recom_area and user_area
|
||||||
|
user_area.in_relations[self.recom_area] = self.dom_user_r
|
||||||
|
self.recom_area.out_relations[user_area] = self.dom_user_r
|
||||||
|
#connect recom_pre and recom_area
|
||||||
|
self.recom_pre = self.area_map[ops[0].inputs[0].op] if ops[0].inputs[0].op else None
|
||||||
|
if self.recom_pre is not None:
|
||||||
|
self.recom_area.in_relations[self.recom_pre] = dom_area.in_relations[self.recom_pre]
|
||||||
|
self.recom_pre.out_relations[self.recom_area] = dom_area.in_relations[self.recom_pre]
|
||||||
|
#set related areas
|
||||||
|
self.recom_user = user_area
|
||||||
|
self.recom_dom = dom_area
|
||||||
|
self.recom_res = False
|
||||||
|
|
||||||
|
def clear_recompute(self):
|
||||||
|
"""disconnect recom_area from other areas, and clear recom_area"""
|
||||||
|
self.recom_area.out_relations.clear()
|
||||||
|
self.recom_area.in_relations.clear()
|
||||||
|
if not self.recom_res:
|
||||||
|
self.recom_user.in_relations.pop(self.recom_area)
|
||||||
|
self.recom_user.in_relations[self.recom_dom] = self.dom_user_r
|
||||||
|
self.recom_dom.out_relations[self.recom_user] = self.dom_user_r
|
||||||
|
if self.recom_pre:
|
||||||
|
self.recom_pre.out_relations.pop(self.recom_area)
|
||||||
|
self.recom_area.ops.clear()
|
||||||
|
self.recom_area.recompute_ops.clear()
|
||||||
|
self.orig_op_map.update(self.recom_area.ori_op_map)
|
||||||
|
self.recom_area.ori_op_map.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def to_subgraph(self, dom):
|
||||||
|
"""Transform area to subgraphs"""
|
||||||
|
ids = self.index_op()
|
||||||
|
dom_ops = list()
|
||||||
|
dom_ops.extend(dom.ops)
|
||||||
|
dom_ops.sort(key=lambda op: ids[op])
|
||||||
|
subgraph = []
|
||||||
|
subgraph = Graph('{}_area'.format(self.graph.name), dom_ops)
|
||||||
|
return subgraph
|
||||||
|
|
||||||
|
def find_cheap_regions(self, dom):
|
||||||
|
"""extract all the cheap regions in dom area, toposort each region before return"""
|
||||||
|
def _grow_region(region_ops, op, weight, inputs):
|
||||||
|
"""include op to region_ops if region grow"""
|
||||||
|
# region successfully ends at input
|
||||||
|
if op.inputs[0] in inputs and len(op.inputs) == 1 and \
|
||||||
|
PrimLib.iter_type(op) <= PrimLib.BROADCAST:
|
||||||
|
region_ops.append(op)
|
||||||
|
return False, None, weight
|
||||||
|
#region fails to grow
|
||||||
|
MAX_WEIGHT = 20
|
||||||
|
if weight > MAX_WEIGHT or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST:
|
||||||
|
return False, None, weight
|
||||||
|
#region grows successfully
|
||||||
|
weight = weight + 1
|
||||||
|
region_ops.append(op)
|
||||||
|
return True, op.inputs[0].op, weight
|
||||||
|
|
||||||
|
def _find_cheap_regions(dom):
|
||||||
|
sub = self.to_subgraph(dom)
|
||||||
|
inputs, outputs = sub.deduce_parameters()
|
||||||
|
cheap_regions = []
|
||||||
|
for output in outputs:
|
||||||
|
# tensor should have user other than user_area to be fused
|
||||||
|
if output.para_type != Tensor.PARA_OUTPUT and len(output.to_ops) < 2:
|
||||||
|
continue
|
||||||
|
region_ops = []
|
||||||
|
grow = True
|
||||||
|
candidate_op = output.op
|
||||||
|
weight = 1
|
||||||
|
while grow:
|
||||||
|
grow, candidate_op, weight = _grow_region(region_ops, candidate_op, weight, inputs)
|
||||||
|
# region ends at input and not empty
|
||||||
|
if region_ops and region_ops[-1].inputs[0] in inputs:
|
||||||
|
region_ops.reverse()
|
||||||
|
# tensor size should equal or becomes larger(cast up, broadcast)
|
||||||
|
if region_ops[0].inputs[0].get_size() > region_ops[-1].output.get_size():
|
||||||
|
continue
|
||||||
|
cheap_regions.append(region_ops)
|
||||||
|
return cheap_regions
|
||||||
|
|
||||||
|
return _find_cheap_regions(dom)
|
||||||
|
|
||||||
|
def select_user_area(self, tail_tensor):
|
||||||
|
"""select the user area has only one edge to dom area"""
|
||||||
|
def _get_edge_num(dom_area, user_area):
|
||||||
|
"""get edge num between two areas"""
|
||||||
|
dom_graph = self.to_subgraph(dom_area)
|
||||||
|
_, dom_outputs = dom_graph.deduce_parameters()
|
||||||
|
user_graph = self.to_subgraph(user_area)
|
||||||
|
user_inputs, _ = user_graph.deduce_parameters()
|
||||||
|
edge = [t for t in dom_outputs if t in user_inputs]
|
||||||
|
return len(edge)
|
||||||
|
|
||||||
|
def _select_user_area(tail_tensor):
|
||||||
|
user_areas = []
|
||||||
|
for user_op in tail_tensor.to_ops:
|
||||||
|
user_area = self.area_map[user_op]
|
||||||
|
if len(user_area.ops) == 1 and user_area.pattern == PrimLib.RESHAPE:
|
||||||
|
continue
|
||||||
|
edge_num = _get_edge_num(self.area_map[tail_tensor.op], user_area)
|
||||||
|
if edge_num == 1 and not user_area in user_areas:
|
||||||
|
user_areas.append(user_area)
|
||||||
|
return user_areas
|
||||||
|
|
||||||
|
return _select_user_area(tail_tensor)
|
||||||
|
|
||||||
|
def recompute_fuse(self):
|
||||||
|
"""find recompute regions and copy them out to new Areas"""
|
||||||
|
def do_recompute_fuse():
|
||||||
|
"""split the unfusing pattern by add recompute area"""
|
||||||
|
recompute_suc = False
|
||||||
|
orig_areas = []
|
||||||
|
orig_areas.extend(self.areas)
|
||||||
|
for dom in orig_areas:
|
||||||
|
if dom not in self.areas or not dom.out_relations:
|
||||||
|
continue
|
||||||
|
cheap_regions = self.find_cheap_regions(dom)
|
||||||
|
dom_changed = False
|
||||||
|
for cheap_region in cheap_regions:
|
||||||
|
user_areas = self.select_user_area(cheap_region[-1].output)
|
||||||
|
if not user_areas:
|
||||||
|
continue
|
||||||
|
for user_area in user_areas:
|
||||||
|
self.set_recompute(dom, cheap_region, user_area)
|
||||||
|
self.pattern_fuse(self.fuse_recom)
|
||||||
|
self.clear_recompute()
|
||||||
|
if self.recom_res:
|
||||||
|
recompute_suc = True
|
||||||
|
#Copy region at most once for this dom
|
||||||
|
dom_changed = True
|
||||||
|
break
|
||||||
|
if dom_changed:
|
||||||
|
break
|
||||||
|
return recompute_suc
|
||||||
|
|
||||||
|
if self.enable_recompute:
|
||||||
|
while do_recompute_fuse():
|
||||||
|
self.pattern_fuse()
|
||||||
|
|
||||||
use_poly_reduce = True
|
use_poly_reduce = True
|
||||||
|
|
||||||
|
|
||||||
|
@ -331,8 +565,8 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
pattern = PrimLib.iter_type(op)
|
pattern = PrimLib.iter_type(op)
|
||||||
return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE
|
return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE
|
||||||
|
|
||||||
def do_split(self):
|
def pattern_fuse(self, fuse_func=None):
|
||||||
"""Split graph by pattern"""
|
"""fuse Areas by pattern"""
|
||||||
def _reshape(dom):
|
def _reshape(dom):
|
||||||
if dom.pattern != PrimLib.RESHAPE:
|
if dom.pattern != PrimLib.RESHAPE:
|
||||||
return None
|
return None
|
||||||
|
@ -551,21 +785,38 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
fused.append(a)
|
fused.append(a)
|
||||||
return fused, True
|
return fused, True
|
||||||
|
|
||||||
enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False)
|
def _fuse_loop():
|
||||||
changed = True
|
changed = True
|
||||||
while changed:
|
while changed:
|
||||||
changed = self.fuse(_reshape)
|
changed = self.fuse(_reshape)
|
||||||
changed = self.fuse(_elemwise_depth) or changed
|
changed = self.fuse(_elemwise_depth) or changed
|
||||||
changed = self.fuse(_elemwise_width) or changed
|
changed = self.fuse(_elemwise_width) or changed
|
||||||
changed = self.fuse(_reduce_depth) or changed
|
changed = self.fuse(_reduce_depth) or changed
|
||||||
changed = self.fuse(_reduce_width) or changed
|
changed = self.fuse(_reduce_width) or changed
|
||||||
changed = self.fuse(_broadcast_depth) or changed
|
changed = self.fuse(_broadcast_depth) or changed
|
||||||
changed = self.fuse(_broadcast_width) or changed
|
changed = self.fuse(_broadcast_width) or changed
|
||||||
|
if use_poly_reduce:
|
||||||
|
changed = self.fuse(_reduce_output) or changed
|
||||||
|
if enable_stitch_fusion:
|
||||||
|
changed = self.fuse(_reduce_stitch) or changed
|
||||||
|
self.fuse(_transpose)
|
||||||
|
|
||||||
|
def _fuse_once(fuse_func):
|
||||||
|
if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \
|
||||||
|
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
|
||||||
|
fuse_func(_broadcast_width):
|
||||||
|
return
|
||||||
if use_poly_reduce:
|
if use_poly_reduce:
|
||||||
changed = self.fuse(_reduce_output) or changed
|
if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)):
|
||||||
if enable_stitch_fusion:
|
return
|
||||||
changed = self.fuse(_reduce_stitch) or changed
|
fuse_func(_transpose)
|
||||||
self.fuse(_transpose)
|
return
|
||||||
|
|
||||||
|
enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False)
|
||||||
|
if fuse_func is None:
|
||||||
|
_fuse_loop()
|
||||||
|
else:
|
||||||
|
_fuse_once(fuse_func)
|
||||||
|
|
||||||
|
|
||||||
class GraphSplitAscend(GraphSplitByPattern):
|
class GraphSplitAscend(GraphSplitByPattern):
|
||||||
|
@ -580,8 +831,8 @@ class GraphSplitAscend(GraphSplitByPattern):
|
||||||
return self.Area.MODE_COMPOSITE
|
return self.Area.MODE_COMPOSITE
|
||||||
return self.Area.MODE_BASIC
|
return self.Area.MODE_BASIC
|
||||||
|
|
||||||
def do_split(self):
|
def pattern_fuse(self, fuse_func=None):
|
||||||
"""Split graph by pattern"""
|
"""fuse Areas by pattern"""
|
||||||
def _tensor_size(tensor):
|
def _tensor_size(tensor):
|
||||||
size = 1
|
size = 1
|
||||||
for i in tensor.shape:
|
for i in tensor.shape:
|
||||||
|
@ -685,6 +936,19 @@ class GraphSplitAscend(GraphSplitByPattern):
|
||||||
fused.append(a)
|
fused.append(a)
|
||||||
return fused, False
|
return fused, False
|
||||||
|
|
||||||
|
def _reduce_output(dom):
|
||||||
|
if dom.pattern != PrimLib.REDUCE:
|
||||||
|
return None
|
||||||
|
op_attrs = dom.dom_op().attrs
|
||||||
|
if not op_attrs.get('reduce_output_fuse'):
|
||||||
|
return None
|
||||||
|
fused = []
|
||||||
|
for a, r in dom.out_relations.items():
|
||||||
|
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \
|
||||||
|
dom.check_acyclic(a):
|
||||||
|
fused.append(a)
|
||||||
|
return fused, False
|
||||||
|
|
||||||
def _transdata_pattern_support(dom, a):
|
def _transdata_pattern_support(dom, a):
|
||||||
transdata_op = dom.dom_op()
|
transdata_op = dom.dom_op()
|
||||||
|
|
||||||
|
@ -733,32 +997,31 @@ class GraphSplitAscend(GraphSplitByPattern):
|
||||||
fused.append(a)
|
fused.append(a)
|
||||||
return fused, True
|
return fused, True
|
||||||
|
|
||||||
def _reduce_output(dom):
|
def _fuse_loop():
|
||||||
if dom.pattern != PrimLib.REDUCE:
|
changed = True
|
||||||
return None
|
while changed:
|
||||||
op_attrs = dom.dom_op().attrs
|
changed = self.fuse(_reshape)
|
||||||
if not op_attrs.get('reduce_output_fuse'):
|
changed = self.fuse(_elemwise_depth) or changed
|
||||||
return None
|
changed = self.fuse(_elemwise_width) or changed
|
||||||
fused = []
|
changed = self.fuse(_reduce_depth) or changed
|
||||||
for a, r in dom.out_relations.items():
|
changed = self.fuse(_reduce_width) or changed
|
||||||
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \
|
changed = self.fuse(_broadcast_depth) or changed
|
||||||
dom.check_acyclic(a):
|
changed = self.fuse(_broadcast_width) or changed
|
||||||
fused.append(a)
|
changed = self.fuse(_matmul_depth) or changed
|
||||||
return fused, False
|
changed = self.fuse(_reduce_output) or changed
|
||||||
|
self.fuse(_transdata)
|
||||||
|
|
||||||
changed = True
|
def _fuse_once(fuse_func):
|
||||||
while changed:
|
if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \
|
||||||
changed = self.fuse(_reshape)
|
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
|
||||||
changed = self.fuse(_elemwise_depth) or changed
|
fuse_func(_broadcast_width) or fuse_func(_matmul_depth) or fuse_func(_reduce_output) or \
|
||||||
changed = self.fuse(_elemwise_width) or changed
|
fuse_func(_transdata):
|
||||||
changed = self.fuse(_reduce_depth) or changed
|
pass
|
||||||
changed = self.fuse(_reduce_width) or changed
|
|
||||||
changed = self.fuse(_broadcast_depth) or changed
|
|
||||||
changed = self.fuse(_broadcast_width) or changed
|
|
||||||
changed = self.fuse(_matmul_depth) or changed
|
|
||||||
changed = self.fuse(_reduce_output) or changed
|
|
||||||
self.fuse(_transdata)
|
|
||||||
|
|
||||||
|
if fuse_func is None:
|
||||||
|
_fuse_loop()
|
||||||
|
else:
|
||||||
|
_fuse_once(fuse_func)
|
||||||
|
|
||||||
|
|
||||||
def split(graph, target, flags):
|
def split(graph, target, flags):
|
||||||
|
|
|
@ -320,8 +320,8 @@ class Operator:
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
args = ', '.join([str(t) for t in self.all_inputs])
|
args = ', '.join([str(t) for t in self.all_inputs])
|
||||||
expr = "%s = %s.%s(%s)" % (
|
expr = "%s = %s.%s(%s) id:%s" % (
|
||||||
str(self.output), self.prim, self.output.dtype, args)
|
str(self.output), self.prim, self.output.dtype, args, id(self))
|
||||||
return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
|
return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -331,12 +331,13 @@ class Operator:
|
||||||
class Graph:
|
class Graph:
|
||||||
"""Graph"""
|
"""Graph"""
|
||||||
|
|
||||||
def __init__(self, name, ops, stitch_info=None):
|
def __init__(self, name, ops, stitch_info=None, recompute_ops=None):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.ops = ops # in topo order, can not use set
|
self.ops = ops # in topo order, can not use set
|
||||||
self.inputs = []
|
self.inputs = []
|
||||||
self.outputs = []
|
self.outputs = []
|
||||||
self.stitch_info = stitch_info
|
self.stitch_info = stitch_info
|
||||||
|
self.recompute_ops = recompute_ops
|
||||||
|
|
||||||
def set_processor(self, processor):
|
def set_processor(self, processor):
|
||||||
"""Set processor"""
|
"""Set processor"""
|
||||||
|
|
|
@ -203,11 +203,13 @@ class CompositeGraph:
|
||||||
desc['buffer_stitch'] = buffer_stitch
|
desc['buffer_stitch'] = buffer_stitch
|
||||||
return desc
|
return desc
|
||||||
|
|
||||||
def dump(self, subgraph):
|
def add_recompute_ops(self, subgraph, desc):
|
||||||
"""Dump Graph to json"""
|
if subgraph.recompute_ops:
|
||||||
desc = {}
|
desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops]
|
||||||
inputs, outputs = subgraph.deduce_parameters()
|
return desc
|
||||||
graph_ops = set(subgraph.ops)
|
|
||||||
|
def _pre_dump(self, outputs):
|
||||||
|
"""restore name to before load"""
|
||||||
inplace_assign = {} # y_name, output_name
|
inplace_assign = {} # y_name, output_name
|
||||||
inplace_assign_z = None
|
inplace_assign_z = None
|
||||||
for op in self.desc['op_desc']:
|
for op in self.desc['op_desc']:
|
||||||
|
@ -217,6 +219,14 @@ class CompositeGraph:
|
||||||
for t in outputs:
|
for t in outputs:
|
||||||
if t.name not in inplace_assign:
|
if t.name not in inplace_assign:
|
||||||
inplace_assign_z = t
|
inplace_assign_z = t
|
||||||
|
return inplace_assign, inplace_assign_z
|
||||||
|
|
||||||
|
def dump(self, subgraph):
|
||||||
|
"""Dump Graph to json"""
|
||||||
|
desc = {}
|
||||||
|
inputs, outputs = subgraph.deduce_parameters()
|
||||||
|
graph_ops = set(subgraph.ops)
|
||||||
|
inplace_assign, inplace_assign_z = self._pre_dump(outputs)
|
||||||
for key in self.desc:
|
for key in self.desc:
|
||||||
if key == 'input_desc':
|
if key == 'input_desc':
|
||||||
desc[key] = [
|
desc[key] = [
|
||||||
|
@ -251,7 +261,7 @@ class CompositeGraph:
|
||||||
op_desc.append(inplace_desc)
|
op_desc.append(inplace_desc)
|
||||||
else:
|
else:
|
||||||
op = self.tensors[d['output_desc'][0]['tensor_name']].op
|
op = self.tensors[d['output_desc'][0]['tensor_name']].op
|
||||||
if op in graph_ops:
|
if op in graph_ops or op in subgraph.recompute_ops:
|
||||||
op_desc.append(d)
|
op_desc.append(d)
|
||||||
desc[key] = op_desc
|
desc[key] = op_desc
|
||||||
elif key == 'op':
|
elif key == 'op':
|
||||||
|
@ -260,6 +270,7 @@ class CompositeGraph:
|
||||||
desc[key] = self.desc[key]
|
desc[key] = self.desc[key]
|
||||||
|
|
||||||
desc = self.add_stitch_info(subgraph, desc)
|
desc = self.add_stitch_info(subgraph, desc)
|
||||||
|
desc = self.add_recompute_ops(subgraph, desc)
|
||||||
return desc
|
return desc
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -433,68 +433,5 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const std::string &kernel_js
|
||||||
auto kernel_json = nlohmann::json::parse(kernel_json_str);
|
auto kernel_json = nlohmann::json::parse(kernel_json_str);
|
||||||
return DecodeFusedNodes(kernel_json);
|
return DecodeFusedNodes(kernel_json);
|
||||||
}
|
}
|
||||||
|
|
||||||
StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json) const {
|
|
||||||
StitchInfo info;
|
|
||||||
if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) {
|
|
||||||
nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch];
|
|
||||||
if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) {
|
|
||||||
std::vector<std::string> stitch_ops = buffer_stitch[kJsonKeyStitchOp];
|
|
||||||
info.stitch_ops = stitch_ops;
|
|
||||||
}
|
|
||||||
if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) {
|
|
||||||
std::vector<std::string> stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp];
|
|
||||||
info.stitch_atomic_ops = stitch_atomic_ops;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return info;
|
|
||||||
}
|
|
||||||
|
|
||||||
void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info,
|
|
||||||
const CNodePtr &node) const {
|
|
||||||
std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
|
|
||||||
if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return;
|
|
||||||
std::string tensor_name = output_descs[0][kJsonKeyTensorName];
|
|
||||||
if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) {
|
|
||||||
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node);
|
|
||||||
MS_LOG(INFO) << "Enable common stitch fusion by " << node->fullname_with_scope();
|
|
||||||
}
|
|
||||||
if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) !=
|
|
||||||
info.stitch_atomic_ops.end()) {
|
|
||||||
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node);
|
|
||||||
MS_LOG(INFO) << "Enable atomic add stitch fusion by " << node->fullname_with_scope();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json,
|
|
||||||
const std::map<std::string, AnfNodePtr> &address_node_map,
|
|
||||||
AnfNodePtrList *res_graphs) {
|
|
||||||
MS_EXCEPTION_IF_NULL(res_graphs);
|
|
||||||
MS_LOG(DEBUG) << "start decode, " << kernel_json;
|
|
||||||
// decode cnodes in graph.
|
|
||||||
std::vector<nlohmann::json> op_node_descs = kernel_json[kJsonKeyOpDesc];
|
|
||||||
if (op_node_descs.empty()) {
|
|
||||||
MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
StitchInfo info = GetStitchInfo(kernel_json);
|
|
||||||
for (const auto &op_desc : op_node_descs) {
|
|
||||||
if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) {
|
|
||||||
MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string ptr_address = op_desc[kJsonKeyPtrAddress];
|
|
||||||
if (address_node_map.count(ptr_address) == 0) {
|
|
||||||
MS_LOG(ERROR) << "Decode failed, ptr_address not found in map.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto node = address_node_map.at(ptr_address)->cast<CNodePtr>();
|
|
||||||
SetStitchAttr(op_desc, info, node);
|
|
||||||
res_graphs->push_back(node);
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -26,10 +26,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
struct StitchInfo {
|
|
||||||
std::vector<std::string> stitch_ops;
|
|
||||||
std::vector<std::string> stitch_atomic_ops;
|
|
||||||
};
|
|
||||||
class AkgKernelJsonDecoder {
|
class AkgKernelJsonDecoder {
|
||||||
public:
|
public:
|
||||||
AkgKernelJsonDecoder() { nodes_map_.clear(); }
|
AkgKernelJsonDecoder() { nodes_map_.clear(); }
|
||||||
|
@ -37,15 +33,11 @@ class AkgKernelJsonDecoder {
|
||||||
|
|
||||||
FuncGraphPtr DecodeFusedNodes(const nlohmann::json &kernel_json);
|
FuncGraphPtr DecodeFusedNodes(const nlohmann::json &kernel_json);
|
||||||
FuncGraphPtr DecodeFusedNodes(const std::string &kernel_json_str);
|
FuncGraphPtr DecodeFusedNodes(const std::string &kernel_json_str);
|
||||||
bool DecodeSplitNodes(const nlohmann::json &kernel_json, const std::map<std::string, AnfNodePtr> &address_node_map,
|
|
||||||
AnfNodePtrList *res_graphs);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph);
|
ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph);
|
||||||
CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor);
|
CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor);
|
||||||
AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph);
|
AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph);
|
||||||
StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) const;
|
|
||||||
void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) const;
|
|
||||||
std::map<std::string, AnfNodePtr> nodes_map_;
|
std::map<std::string, AnfNodePtr> nodes_map_;
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
|
|
|
@ -54,6 +54,7 @@ constexpr auto kJsonKeyFusionType = "fusion_type";
|
||||||
constexpr auto kJsonKeySubGraph = "sub_graph";
|
constexpr auto kJsonKeySubGraph = "sub_graph";
|
||||||
constexpr auto kJsonKeyCoreNum = "core_num";
|
constexpr auto kJsonKeyCoreNum = "core_num";
|
||||||
constexpr auto kJsonKeyTypeInfo = "type_info";
|
constexpr auto kJsonKeyTypeInfo = "type_info";
|
||||||
|
constexpr auto kJsonKeyRecomputeOps = "recompute_ops";
|
||||||
constexpr auto kJsonKeyBufferStitch = "buffer_stitch";
|
constexpr auto kJsonKeyBufferStitch = "buffer_stitch";
|
||||||
constexpr auto kJsonKeyStitchOp = "stitch_op";
|
constexpr auto kJsonKeyStitchOp = "stitch_op";
|
||||||
constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op";
|
constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op";
|
||||||
|
|
|
@ -117,7 +117,7 @@ PassManagerPtr GraphKernelOptimizer::Split() const {
|
||||||
// which can avoid unnecessary input-output and get better performance.
|
// which can avoid unnecessary input-output and get better performance.
|
||||||
// preprocess for ShapeOpsSplitter
|
// preprocess for ShapeOpsSplitter
|
||||||
pm->AddPass(std::make_shared<ExtendOutputForUpdateState>(), OptLevel_1);
|
pm->AddPass(std::make_shared<ExtendOutputForUpdateState>(), OptLevel_1);
|
||||||
std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast};
|
std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape};
|
||||||
pm->AddPass(std::make_shared<ShapeOpsSplitter>(duplicated_ops), OptLevel_1);
|
pm->AddPass(std::make_shared<ShapeOpsSplitter>(duplicated_ops), OptLevel_1);
|
||||||
|
|
||||||
// Split kernel according to costmodel
|
// Split kernel according to costmodel
|
||||||
|
|
|
@ -32,6 +32,150 @@
|
||||||
#include "utils/context/graph_kernel_flags.h"
|
#include "utils/context/graph_kernel_flags.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
namespace {
|
||||||
|
StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) {
|
||||||
|
StitchInfo info;
|
||||||
|
if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) {
|
||||||
|
nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch];
|
||||||
|
if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) {
|
||||||
|
std::vector<std::string> stitch_ops = buffer_stitch[kJsonKeyStitchOp];
|
||||||
|
info.stitch_ops = stitch_ops;
|
||||||
|
}
|
||||||
|
if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) {
|
||||||
|
std::vector<std::string> stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp];
|
||||||
|
info.stitch_atomic_ops = stitch_atomic_ops;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return info;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::set<std::string> GetRecomputeOps(const nlohmann::json &kernel_json) {
|
||||||
|
if (kernel_json.find(kJsonKeyRecomputeOps) != kernel_json.end()) {
|
||||||
|
std::vector<std::string> recompute_ops = kernel_json[kJsonKeyRecomputeOps];
|
||||||
|
return std::set<std::string>(recompute_ops.begin(), recompute_ops.end());
|
||||||
|
}
|
||||||
|
return std::set<std::string>();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsRecomputeOp(const nlohmann::json &op_desc, const std::set<std::string> &recompute_ops) {
|
||||||
|
std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
|
||||||
|
if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::string tensor_name = output_descs[0][kJsonKeyTensorName];
|
||||||
|
if (recompute_ops.count(tensor_name)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
CNodePtr NewRecomputeNode(const AnfNodePtr &orig_node, std::map<AnfNodePtr, AnfNodePtr> *node_map) {
|
||||||
|
auto func_graph = orig_node->func_graph();
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
auto cnode = orig_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info()));
|
||||||
|
auto orig_inputs = cnode->inputs();
|
||||||
|
std::vector<AnfNodePtr> inputs;
|
||||||
|
for (auto inp : orig_inputs) {
|
||||||
|
if (node_map->find(inp) == node_map->end()) {
|
||||||
|
inputs.push_back(inp);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
inputs.push_back((*node_map)[inp]);
|
||||||
|
}
|
||||||
|
CNodePtr cp_node = func_graph->NewCNode(inputs);
|
||||||
|
func_graph->AddNode(cp_node);
|
||||||
|
cp_node->set_abstract(cnode->abstract());
|
||||||
|
cp_node->set_forward(cnode->forward().first, cnode->forward().second);
|
||||||
|
cp_node->set_inputs_value(cnode->inputs_value());
|
||||||
|
ScopePtr scope = (orig_node->scope() != kDefaultScope) ? orig_node->scope() : kDefaultScope;
|
||||||
|
cp_node->set_scope(scope);
|
||||||
|
cp_node->set_kernel_info(cnode->kernel_info_ptr());
|
||||||
|
(*node_map)[orig_node] = cp_node;
|
||||||
|
return cp_node->cast<CNodePtr>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) {
|
||||||
|
std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
|
||||||
|
if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return;
|
||||||
|
std::string tensor_name = output_descs[0][kJsonKeyTensorName];
|
||||||
|
if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) {
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node);
|
||||||
|
MS_LOG(INFO) << "Enable common stitch fusion by " << node->fullname_with_scope();
|
||||||
|
}
|
||||||
|
if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) !=
|
||||||
|
info.stitch_atomic_ops.end()) {
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node);
|
||||||
|
MS_LOG(INFO) << "Enable atomic add stitch fusion by " << node->fullname_with_scope();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// replace original region root op by its copy in this res_graphs
|
||||||
|
void ConnectRecomputeOps(AnfNodePtrList *res_graphs, const AnfNodePtr &orig_region_root,
|
||||||
|
const AnfNodePtr &cp_region_root) {
|
||||||
|
for (auto &node : *res_graphs) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
auto inputs = cnode->inputs();
|
||||||
|
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||||
|
if (inputs[i] != orig_region_root) continue;
|
||||||
|
cnode->set_input(i, cp_region_root);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool SplitNodesDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json,
|
||||||
|
const std::map<std::string, AnfNodePtr> &address_node_map,
|
||||||
|
AnfNodePtrList *res_graphs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(res_graphs);
|
||||||
|
MS_LOG(DEBUG) << "start decode, " << kernel_json;
|
||||||
|
// decode cnodes in graph.
|
||||||
|
std::vector<nlohmann::json> op_node_descs = kernel_json[kJsonKeyOpDesc];
|
||||||
|
if (op_node_descs.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
StitchInfo info = GetStitchInfo(kernel_json);
|
||||||
|
auto recompute_ops = GetRecomputeOps(kernel_json);
|
||||||
|
// key_value: original_copied
|
||||||
|
std::map<AnfNodePtr, AnfNodePtr> node_map;
|
||||||
|
// nodes would be copied
|
||||||
|
AnfNodePtrList orig_region_nodes;
|
||||||
|
// nodes would not be copied
|
||||||
|
AnfNodePtrList no_cp_nodes;
|
||||||
|
for (const auto &op_desc : op_node_descs) {
|
||||||
|
if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) {
|
||||||
|
MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ptr_address = op_desc[kJsonKeyPtrAddress];
|
||||||
|
if (address_node_map.count(ptr_address) == 0) {
|
||||||
|
MS_LOG(ERROR) << "Decode failed, ptr_address not found in map.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto node = address_node_map.at(ptr_address)->cast<CNodePtr>();
|
||||||
|
if (IsRecomputeOp(op_desc, recompute_ops)) {
|
||||||
|
auto cp_node = NewRecomputeNode(node, &node_map);
|
||||||
|
orig_region_nodes.push_back(node);
|
||||||
|
SetStitchAttr(op_desc, info, cp_node);
|
||||||
|
res_graphs->push_back(cp_node);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
SetStitchAttr(op_desc, info, node);
|
||||||
|
res_graphs->push_back(node);
|
||||||
|
no_cp_nodes.push_back(node);
|
||||||
|
}
|
||||||
|
for (auto orig_node : orig_region_nodes) {
|
||||||
|
ConnectRecomputeOps(&no_cp_nodes, orig_node, node_map[orig_node]);
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace kernel
|
||||||
|
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
void TraverseFuncGraphFromCNode(const CNodePtr &cnode, const std::function<void(AnfNodePtr &)> &callback) {
|
void TraverseFuncGraphFromCNode(const CNodePtr &cnode, const std::function<void(AnfNodePtr &)> &callback) {
|
||||||
|
@ -620,7 +764,7 @@ class CostModelSplitSchemer : public SplitSchemer {
|
||||||
split_plan_.clear();
|
split_plan_.clear();
|
||||||
for (const auto &graph_desc : graph_descs) {
|
for (const auto &graph_desc : graph_descs) {
|
||||||
AnfNodePtrList res_graph;
|
AnfNodePtrList res_graph;
|
||||||
if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) {
|
if (!kernel::SplitNodesDecoder::DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) {
|
||||||
MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc;
|
MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -731,6 +875,7 @@ class CostModelSplitSchemer : public SplitSchemer {
|
||||||
nlohmann::json flag_json;
|
nlohmann::json flag_json;
|
||||||
flag_json["dump_as_text"] = flags.dump_as_text;
|
flag_json["dump_as_text"] = flags.dump_as_text;
|
||||||
flag_json["enable_stitch_fusion"] = flags.enable_stitch_fusion;
|
flag_json["enable_stitch_fusion"] = flags.enable_stitch_fusion;
|
||||||
|
flag_json["enable_recompute_fusion"] = flags.enable_recompute_fusion;
|
||||||
return flag_json.dump();
|
return flag_json.dump();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,11 +15,31 @@
|
||||||
*/
|
*/
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_SPLITTER_H_
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_SPLITTER_H_
|
||||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_SPLITTER_H_
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_SPLITTER_H_
|
||||||
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "backend/optimizer/common/pass.h"
|
#include "backend/optimizer/common/pass.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
struct StitchInfo {
|
||||||
|
std::vector<std::string> stitch_ops;
|
||||||
|
std::vector<std::string> stitch_atomic_ops;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SplitNodesDecoder {
|
||||||
|
public:
|
||||||
|
SplitNodesDecoder() {}
|
||||||
|
~SplitNodesDecoder() = default;
|
||||||
|
static bool DecodeSplitNodes(const nlohmann::json &kernel_json,
|
||||||
|
const std::map<std::string, AnfNodePtr> &address_node_map, AnfNodePtrList *res_graphs);
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class GraphKernelSplitter : public Pass {
|
class GraphKernelSplitter : public Pass {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -181,6 +181,7 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
|
||||||
// Boolean flags
|
// Boolean flags
|
||||||
reg.AddFlag("dump_as_text", &dump_as_text);
|
reg.AddFlag("dump_as_text", &dump_as_text);
|
||||||
reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion, opt_level == OptLevel_3);
|
reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion, opt_level == OptLevel_3);
|
||||||
|
reg.AddFlag("enable_recompute_fusion", &enable_recompute_fusion, opt_level == OptLevel_2);
|
||||||
reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion, opt_level == OptLevel_3);
|
reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion, opt_level == OptLevel_3);
|
||||||
|
|
||||||
// Integer flags
|
// Integer flags
|
||||||
|
@ -203,6 +204,7 @@ std::string GraphKernelFlags::DumpAllFlags() const {
|
||||||
|
|
||||||
json["dump_as_text"] = dump_as_text;
|
json["dump_as_text"] = dump_as_text;
|
||||||
json["enable_stitch_fusion"] = enable_stitch_fusion;
|
json["enable_stitch_fusion"] = enable_stitch_fusion;
|
||||||
|
json["enable_recompute_fusion"] = enable_recompute_fusion;
|
||||||
json["enable_parallel_fusion"] = enable_parallel_fusion;
|
json["enable_parallel_fusion"] = enable_parallel_fusion;
|
||||||
|
|
||||||
json["opt_level"] = opt_level;
|
json["opt_level"] = opt_level;
|
||||||
|
|
|
@ -67,6 +67,11 @@ class GraphKernelFlags {
|
||||||
*/
|
*/
|
||||||
bool enable_stitch_fusion;
|
bool enable_stitch_fusion;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Enable recompute fusion in graph kernel fusion strategy.
|
||||||
|
*/
|
||||||
|
bool enable_recompute_fusion{true};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Enable parallel fusion in graph kernel fusion strategy.
|
* Enable parallel fusion in graph kernel fusion strategy.
|
||||||
*
|
*
|
||||||
|
|
|
@ -0,0 +1,223 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
import mindspore.ops.operations as P
|
||||||
|
|
||||||
|
#{cast} would be recompute and fused
|
||||||
|
class Net1(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net1, self).__init__()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.sum = P.ReduceSum(keep_dims=False)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
cast_res = self.cast(x, mstype.float32)
|
||||||
|
sum1_res = self.sum(cast_res, (0,))
|
||||||
|
sum2_res = self.sum(cast_res, (1,))
|
||||||
|
return sum1_res, sum2_res
|
||||||
|
|
||||||
|
#{sqrt} would be recompute on Ascend
|
||||||
|
class Net2(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net2, self).__init__()
|
||||||
|
self.sqrt = P.Sqrt()
|
||||||
|
self.sum = P.ReduceSum(keep_dims=True)
|
||||||
|
self.add = P.Add()
|
||||||
|
self.neg = P.Neg()
|
||||||
|
|
||||||
|
def construct(self, x0, x1):
|
||||||
|
sqrt_res = self.sqrt(x0)
|
||||||
|
neg_res = self.neg(sqrt_res)
|
||||||
|
add_res = self.add(x1, sqrt_res)
|
||||||
|
sum_res = self.sum(add_res, (0,))
|
||||||
|
return neg_res, sum_res
|
||||||
|
|
||||||
|
#{sqrt} would be recompute
|
||||||
|
class Net3(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net3, self).__init__()
|
||||||
|
self.sqrt = P.Sqrt()
|
||||||
|
self.add = P.Add()
|
||||||
|
self.neg = P.Neg()
|
||||||
|
|
||||||
|
def construct(self, x0, x1):
|
||||||
|
sqrt_res = self.sqrt(x0)
|
||||||
|
neg_res = self.neg(sqrt_res)
|
||||||
|
add_res = self.add(x1, sqrt_res)
|
||||||
|
return neg_res, add_res
|
||||||
|
|
||||||
|
#{sqrt neg} would be recompute
|
||||||
|
class Net4(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net4, self).__init__()
|
||||||
|
self.sqrt = P.Sqrt()
|
||||||
|
self.neg = P.Neg()
|
||||||
|
self.sum = P.ReduceSum(keep_dims=False)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
sqrt_res = self.sqrt(x)
|
||||||
|
neg_res = self.neg(sqrt_res)
|
||||||
|
sum1_res = self.sum(neg_res, (0,))
|
||||||
|
sum2_res = self.sum(neg_res, (1,))
|
||||||
|
return sum1_res, sum2_res
|
||||||
|
|
||||||
|
#{sqrt} would be recompute
|
||||||
|
class Net5(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net5, self).__init__()
|
||||||
|
self.sqrt = P.Sqrt()
|
||||||
|
self.add = P.Add()
|
||||||
|
|
||||||
|
def construct(self, x0, x1, x2):
|
||||||
|
sqrt_res = self.sqrt(x0)
|
||||||
|
add1_res = self.add(sqrt_res, x1)
|
||||||
|
add2_res = self.add(sqrt_res, x2)
|
||||||
|
return add1_res, add2_res
|
||||||
|
|
||||||
|
def test_basic1(net):
|
||||||
|
def get_output(i0, net, enable_graph_kernel=False):
|
||||||
|
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||||
|
net_obj = net()
|
||||||
|
output = net_obj(i0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
i0 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float16))
|
||||||
|
expect = get_output(i0, net, False)
|
||||||
|
output = get_output(i0, net, True)
|
||||||
|
expect0_np = expect[0].asnumpy().copy()
|
||||||
|
output0_np = output[0].asnumpy().copy()
|
||||||
|
expect1_np = expect[1].asnumpy().copy()
|
||||||
|
output1_np = output[1].asnumpy().copy()
|
||||||
|
assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3)
|
||||||
|
assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic2(net):
|
||||||
|
def get_output(i0, i1, net, enable_graph_kernel=False):
|
||||||
|
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||||
|
net_obj = net()
|
||||||
|
output = net_obj(i0, i1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
i0 = Tensor(np.random.uniform(1, 2, [1, 1024]).astype(np.float32))
|
||||||
|
i1 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float32))
|
||||||
|
expect = get_output(i0, i1, net, False)
|
||||||
|
output = get_output(i0, i1, net, True)
|
||||||
|
expect0_np = expect[0].asnumpy().copy()
|
||||||
|
output0_np = output[0].asnumpy().copy()
|
||||||
|
expect1_np = expect[1].asnumpy().copy()
|
||||||
|
output1_np = output[1].asnumpy().copy()
|
||||||
|
assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3)
|
||||||
|
assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3)
|
||||||
|
|
||||||
|
def test_basic3(net):
|
||||||
|
def get_output(i0, i1, i2, net, enable_graph_kernel=False):
|
||||||
|
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||||
|
net_obj = net()
|
||||||
|
output = net_obj(i0, i1, i2)
|
||||||
|
return output
|
||||||
|
|
||||||
|
i0 = Tensor(np.random.uniform(1, 2, [1, 1024]).astype(np.float16))
|
||||||
|
i1 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float16))
|
||||||
|
i2 = Tensor(np.random.uniform(1, 2, [2048, 1024]).astype(np.float16))
|
||||||
|
expect = get_output(i0, i1, i2, net, False)
|
||||||
|
output = get_output(i0, i1, i2, net, True)
|
||||||
|
expect0_np = expect[0].asnumpy().copy()
|
||||||
|
output0_np = output[0].asnumpy().copy()
|
||||||
|
expect1_np = expect[1].asnumpy().copy()
|
||||||
|
output1_np = output[1].asnumpy().copy()
|
||||||
|
assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3)
|
||||||
|
assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_1():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
test_basic1(Net1)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_2():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
test_basic2(Net2)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_3():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
test_basic2(Net3)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_4():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
test_basic1(Net4)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_gpu_5():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
test_basic3(Net5)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_ascend_1():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
test_basic1(Net1)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_ascend_2():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
test_basic2(Net2)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_ascend_3():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
test_basic2(Net3)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_ascend_4():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
test_basic1(Net4)
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_ascend_5():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
test_basic3(Net5)
|
Loading…
Reference in New Issue