From a995bea5072b8627f3786eb99a033467cac44922 Mon Sep 17 00:00:00 2001 From: lingyunli63 Date: Mon, 24 May 2021 09:31:43 +0800 Subject: [PATCH] recompute_fuse --- .../graph_kernel/model/graph_split.py | 385 +++++++++++++++--- .../_extends/graph_kernel/model/model.py | 7 +- .../graph_kernel/model/model_builder.py | 23 +- .../akg/akg_kernel_json_decoder.cc | 63 --- .../akg/akg_kernel_json_decoder.h | 8 - .../akg/akg_kernel_json_generator.h | 1 + .../graph_kernel/graph_kernel_optimization.cc | 2 +- .../graph_kernel/graph_kernel_splitter.cc | 147 ++++++- .../graph_kernel/graph_kernel_splitter.h | 20 + .../ccsrc/utils/context/graph_kernel_flags.cc | 2 + .../ccsrc/utils/context/graph_kernel_flags.h | 5 + tests/st/ops/graph_kernel/test_recompute.py | 223 ++++++++++ 12 files changed, 743 insertions(+), 143 deletions(-) create mode 100644 tests/st/ops/graph_kernel/test_recompute.py diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 84ad289e409..28c0aa09dde 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -16,7 +16,7 @@ import os from functools import reduce from mindspore import log as logger -from .model import PrimLib, Graph, Tensor +from .model import PrimLib, Graph, Tensor, Operator from .model import DataFormat as DF @@ -65,13 +65,16 @@ class GraphSplitByPattern: self.stitch_ops = set() self.stitch_atomic_ops = set() - def __init__(self, init_op, is_output, unique_id, reach_tab): - self.pattern = PrimLib.iter_type(init_op) - self.ops = [init_op] + def __init__(self, init_op, is_output, unique_id, reach_tab, recompute_ops=None): + self.pattern = PrimLib.iter_type(init_op) if init_op is not None else PrimLib.UNKNOWN + self.ops = [] if init_op is None else [init_op] self.in_relations = dict() # {area1: relation1, area2: relation2, ...} self.out_relations = dict() # {area1: relation1, area2: relation2, ...} self.mode = None self.stitch_info = self.StitchInfo() + self.recompute_ops = [] if recompute_ops is None else recompute_ops + self.ori_op_map = {} + self.is_recompute = False self.is_output = is_output self.output_excluded = set() if self.pattern == PrimLib.REDUCE: @@ -143,6 +146,8 @@ class GraphSplitByPattern: r = rels.pop(area) _update_relation(rels, self, r) + if area.is_recompute: + self.cp_ops(area) if self.pattern >= area.pattern: self.ops.extend(area.ops) else: @@ -161,7 +166,9 @@ class GraphSplitByPattern: if area.output_excluded: self.output_excluded.update(area.output_excluded) self.update_stitch_info(area.stitch_info) - self.reach_tab.fuse(self.unique_id, area.unique_id) + if not area.is_recompute: + self.reach_tab.fuse(self.unique_id, area.unique_id) + self.recompute_ops.extend(area.recompute_ops) def check_acyclic(self, to): """Check circle. It returns false if circle exists""" @@ -180,25 +187,73 @@ class GraphSplitByPattern: return True return False + def cp_ops(self, area): + """copy recompute_ops in area to ops, self is area's user""" + tail_tensor = area.recompute_ops[-1].output + #copy tensors, all copied are Tensor.PARA_NONE + tensor_map = {} + tensor_map[area.recompute_ops[0].inputs[0]] = area.recompute_ops[0].inputs[0] + for op in area.recompute_ops: + orig_tensor = op.output + cp_tensor = Tensor(orig_tensor.name, orig_tensor.shape, orig_tensor.dtype, orig_tensor.data_format) + tensor_map[orig_tensor] = cp_tensor + #copy ops + cp_ops = [] + for op in area.recompute_ops: + cp_op = Operator(op.prim, [tensor_map[op.inputs[0]]], tensor_map[op.output], op.attrs) + cp_op.all_inputs = cp_op.inputs + cp_ops.append(cp_op) + area.ori_op_map[cp_op] = op + #connect copied ops + for op in self.ops: + if tail_tensor in op.inputs: + op.inputs.remove(tail_tensor) + op.inputs.append(tensor_map[tail_tensor]) + tail_tensor.to_ops.remove(op) + tensor_map[tail_tensor].to_ops.append(op) + #fill cp_ops in self.recompute_area + cp_dom_op = None + for cp, ori in area.ori_op_map.items(): + if ori == area.dom_op(): + cp_dom_op = cp + area.ops.clear() + area.ops.append(cp_dom_op) + area.ops.extend([op for op in cp_ops if op != cp_dom_op]) + def __init__(self, graph, flags): self.graph = graph self.areas = [] self.flags = flags - self.reach_tab = self.ReachTable(len(graph.ops)) - area_map = {} + self.enable_recompute = self.flags.get("enable_recompute_fusion", False) + self.reach_tab = self.ReachTable(len(graph.ops) + 1 if self.enable_recompute else len(graph.ops)) + self.area_map = {} _, outputs = graph.deduce_parameters() - idx = 0 + self.idx = 0 for op in graph.ops: is_output = op.output in outputs - a = self.Area(op, is_output, idx, self.reach_tab) - idx += 1 + a = self.Area(op, is_output, self.idx, self.reach_tab) + self.idx += 1 self.set_default_mode(a) self.areas.append(a) - area_map[op] = a + self.set_area_map([op], a) for a in self.areas: - a.link_input(area_map) + a.link_input(self.area_map) for i in range(len(self.areas)-1, -1, -1): self.areas[i].link_output() + if self.enable_recompute: + self.recom_area = self.Area(None, False, self.idx, self.reach_tab) + self.recom_area.is_recompute = True + self.recom_pre = None + self.recom_user = None + self.recom_dom = None + self.dom_user_r = PrimLib.UNKNOWN + self.recom_res = False + self.orig_op_map = {} + + def set_area_map(self, ops, area): + """update area_map after op fused to area""" + for op in ops: + self.area_map[op] = area def set_default_mode(self, area): area.mode = self.get_default_mode(area.ops[0]) @@ -234,11 +289,13 @@ class GraphSplitByPattern: if is_forward: for area in fuse_areas: dominant.fuse(area) + self.set_area_map(area.ops, dominant) self.areas.remove(area) else: forward_area = dominant for area in fuse_areas: area.fuse(forward_area) + self.set_area_map(forward_area.ops, area) self.areas.remove(forward_area) forward_area = area changed = True @@ -246,16 +303,39 @@ class GraphSplitByPattern: else: return changed - def to_subgraphs(self): - """Transform op groups to subgraphs""" + def fuse_recom(self, selector): + """Fuse recompute area to its user""" + for dominant in [self.recom_area, self.recom_user]: + result = selector(dominant) + if result is not None and result[0]: + fuse_areas, _ = result + fuse_areas = self.limit_area_size(dominant, fuse_areas) + if not fuse_areas: + continue + if fuse_areas[0] in [self.recom_area, self.recom_user]: + self.recom_user.fuse(self.recom_area) + self.recom_res = True + return True + return False + + def index_op(self): + """index op by order, the copied op share id with original op, for topo-sort""" ids = {} for i, op in enumerate(self.graph.ops): ids[op] = i + if hasattr(self, 'orig_op_map'): + for k, v in self.orig_op_map.items(): + ids[k] = ids[v] + return ids + + def to_subgraphs(self): + """Transform op groups to subgraphs""" + ids = self.index_op() subgraphs = [] graphmodes = [] for i, area in enumerate(self.areas): area.ops.sort(key=lambda op: ids[op]) - subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info)) + subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info, area.recompute_ops)) graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite") return subgraphs, graphmodes @@ -274,13 +354,14 @@ class GraphSplitByPattern: with os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), 'w+') as f: f.write(subgraphs_str) - def do_split(self): - """Split graph by pattern""" - raise Exception("do_split() is not implemented in {}".format(self.__class__.__name__)) + def pattern_fuse(self, select=None): + """fuse Areas by pattern repeatedly""" + raise Exception("pattern_fuse() is not implemented in {}".format(self.__class__.__name__)) def split(self): """Split graph by pattern""" - self.do_split() + self.pattern_fuse() + self.recompute_fuse() # The reshape should not be output node # Note: after this function, the input output relation is not maintained. self.split_output_reshapes() @@ -316,6 +397,159 @@ class GraphSplitByPattern: if new_areas: self.areas += new_areas + def set_recompute(self, dom_area, ops, user_area): + """set the recompute area and connect with other areas""" + self.recom_area.recompute_ops.extend(ops) + #recom_area: set dom_op and correct ops length + patterns = [PrimLib.iter_type(op) for op in ops] + self.recom_area.pattern = max(patterns) + for i, pat in enumerate(patterns): + if pat == self.recom_area.pattern: + self.recom_area.ops = [ops[i]] * len(ops) + break + #disconnect dom_area and user_area + self.dom_user_r = dom_area.out_relations[user_area] + dom_area.out_relations.pop(user_area) + user_area.in_relations.pop(dom_area) + #connect recom_area and user_area + user_area.in_relations[self.recom_area] = self.dom_user_r + self.recom_area.out_relations[user_area] = self.dom_user_r + #connect recom_pre and recom_area + self.recom_pre = self.area_map[ops[0].inputs[0].op] if ops[0].inputs[0].op else None + if self.recom_pre is not None: + self.recom_area.in_relations[self.recom_pre] = dom_area.in_relations[self.recom_pre] + self.recom_pre.out_relations[self.recom_area] = dom_area.in_relations[self.recom_pre] + #set related areas + self.recom_user = user_area + self.recom_dom = dom_area + self.recom_res = False + + def clear_recompute(self): + """disconnect recom_area from other areas, and clear recom_area""" + self.recom_area.out_relations.clear() + self.recom_area.in_relations.clear() + if not self.recom_res: + self.recom_user.in_relations.pop(self.recom_area) + self.recom_user.in_relations[self.recom_dom] = self.dom_user_r + self.recom_dom.out_relations[self.recom_user] = self.dom_user_r + if self.recom_pre: + self.recom_pre.out_relations.pop(self.recom_area) + self.recom_area.ops.clear() + self.recom_area.recompute_ops.clear() + self.orig_op_map.update(self.recom_area.ori_op_map) + self.recom_area.ori_op_map.clear() + + + def to_subgraph(self, dom): + """Transform area to subgraphs""" + ids = self.index_op() + dom_ops = list() + dom_ops.extend(dom.ops) + dom_ops.sort(key=lambda op: ids[op]) + subgraph = [] + subgraph = Graph('{}_area'.format(self.graph.name), dom_ops) + return subgraph + + def find_cheap_regions(self, dom): + """extract all the cheap regions in dom area, toposort each region before return""" + def _grow_region(region_ops, op, weight, inputs): + """include op to region_ops if region grow""" + # region successfully ends at input + if op.inputs[0] in inputs and len(op.inputs) == 1 and \ + PrimLib.iter_type(op) <= PrimLib.BROADCAST: + region_ops.append(op) + return False, None, weight + #region fails to grow + MAX_WEIGHT = 20 + if weight > MAX_WEIGHT or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST: + return False, None, weight + #region grows successfully + weight = weight + 1 + region_ops.append(op) + return True, op.inputs[0].op, weight + + def _find_cheap_regions(dom): + sub = self.to_subgraph(dom) + inputs, outputs = sub.deduce_parameters() + cheap_regions = [] + for output in outputs: + # tensor should have user other than user_area to be fused + if output.para_type != Tensor.PARA_OUTPUT and len(output.to_ops) < 2: + continue + region_ops = [] + grow = True + candidate_op = output.op + weight = 1 + while grow: + grow, candidate_op, weight = _grow_region(region_ops, candidate_op, weight, inputs) + # region ends at input and not empty + if region_ops and region_ops[-1].inputs[0] in inputs: + region_ops.reverse() + # tensor size should equal or becomes larger(cast up, broadcast) + if region_ops[0].inputs[0].get_size() > region_ops[-1].output.get_size(): + continue + cheap_regions.append(region_ops) + return cheap_regions + + return _find_cheap_regions(dom) + + def select_user_area(self, tail_tensor): + """select the user area has only one edge to dom area""" + def _get_edge_num(dom_area, user_area): + """get edge num between two areas""" + dom_graph = self.to_subgraph(dom_area) + _, dom_outputs = dom_graph.deduce_parameters() + user_graph = self.to_subgraph(user_area) + user_inputs, _ = user_graph.deduce_parameters() + edge = [t for t in dom_outputs if t in user_inputs] + return len(edge) + + def _select_user_area(tail_tensor): + user_areas = [] + for user_op in tail_tensor.to_ops: + user_area = self.area_map[user_op] + if len(user_area.ops) == 1 and user_area.pattern == PrimLib.RESHAPE: + continue + edge_num = _get_edge_num(self.area_map[tail_tensor.op], user_area) + if edge_num == 1 and not user_area in user_areas: + user_areas.append(user_area) + return user_areas + + return _select_user_area(tail_tensor) + + def recompute_fuse(self): + """find recompute regions and copy them out to new Areas""" + def do_recompute_fuse(): + """split the unfusing pattern by add recompute area""" + recompute_suc = False + orig_areas = [] + orig_areas.extend(self.areas) + for dom in orig_areas: + if dom not in self.areas or not dom.out_relations: + continue + cheap_regions = self.find_cheap_regions(dom) + dom_changed = False + for cheap_region in cheap_regions: + user_areas = self.select_user_area(cheap_region[-1].output) + if not user_areas: + continue + for user_area in user_areas: + self.set_recompute(dom, cheap_region, user_area) + self.pattern_fuse(self.fuse_recom) + self.clear_recompute() + if self.recom_res: + recompute_suc = True + #Copy region at most once for this dom + dom_changed = True + break + if dom_changed: + break + return recompute_suc + + if self.enable_recompute: + while do_recompute_fuse(): + self.pattern_fuse() + use_poly_reduce = True @@ -331,8 +565,8 @@ class GraphSplitGpu(GraphSplitByPattern): pattern = PrimLib.iter_type(op) return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE - def do_split(self): - """Split graph by pattern""" + def pattern_fuse(self, fuse_func=None): + """fuse Areas by pattern""" def _reshape(dom): if dom.pattern != PrimLib.RESHAPE: return None @@ -551,21 +785,38 @@ class GraphSplitGpu(GraphSplitByPattern): fused.append(a) return fused, True - enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False) - changed = True - while changed: - changed = self.fuse(_reshape) - changed = self.fuse(_elemwise_depth) or changed - changed = self.fuse(_elemwise_width) or changed - changed = self.fuse(_reduce_depth) or changed - changed = self.fuse(_reduce_width) or changed - changed = self.fuse(_broadcast_depth) or changed - changed = self.fuse(_broadcast_width) or changed + def _fuse_loop(): + changed = True + while changed: + changed = self.fuse(_reshape) + changed = self.fuse(_elemwise_depth) or changed + changed = self.fuse(_elemwise_width) or changed + changed = self.fuse(_reduce_depth) or changed + changed = self.fuse(_reduce_width) or changed + changed = self.fuse(_broadcast_depth) or changed + changed = self.fuse(_broadcast_width) or changed + if use_poly_reduce: + changed = self.fuse(_reduce_output) or changed + if enable_stitch_fusion: + changed = self.fuse(_reduce_stitch) or changed + self.fuse(_transpose) + + def _fuse_once(fuse_func): + if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \ + fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \ + fuse_func(_broadcast_width): + return if use_poly_reduce: - changed = self.fuse(_reduce_output) or changed - if enable_stitch_fusion: - changed = self.fuse(_reduce_stitch) or changed - self.fuse(_transpose) + if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)): + return + fuse_func(_transpose) + return + + enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False) + if fuse_func is None: + _fuse_loop() + else: + _fuse_once(fuse_func) class GraphSplitAscend(GraphSplitByPattern): @@ -580,8 +831,8 @@ class GraphSplitAscend(GraphSplitByPattern): return self.Area.MODE_COMPOSITE return self.Area.MODE_BASIC - def do_split(self): - """Split graph by pattern""" + def pattern_fuse(self, fuse_func=None): + """fuse Areas by pattern""" def _tensor_size(tensor): size = 1 for i in tensor.shape: @@ -685,6 +936,19 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, False + def _reduce_output(dom): + if dom.pattern != PrimLib.REDUCE: + return None + op_attrs = dom.dom_op().attrs + if not op_attrs.get('reduce_output_fuse'): + return None + fused = [] + for a, r in dom.out_relations.items(): + if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ + dom.check_acyclic(a): + fused.append(a) + return fused, False + def _transdata_pattern_support(dom, a): transdata_op = dom.dom_op() @@ -733,32 +997,31 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, True - def _reduce_output(dom): - if dom.pattern != PrimLib.REDUCE: - return None - op_attrs = dom.dom_op().attrs - if not op_attrs.get('reduce_output_fuse'): - return None - fused = [] - for a, r in dom.out_relations.items(): - if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ - dom.check_acyclic(a): - fused.append(a) - return fused, False + def _fuse_loop(): + changed = True + while changed: + changed = self.fuse(_reshape) + changed = self.fuse(_elemwise_depth) or changed + changed = self.fuse(_elemwise_width) or changed + changed = self.fuse(_reduce_depth) or changed + changed = self.fuse(_reduce_width) or changed + changed = self.fuse(_broadcast_depth) or changed + changed = self.fuse(_broadcast_width) or changed + changed = self.fuse(_matmul_depth) or changed + changed = self.fuse(_reduce_output) or changed + self.fuse(_transdata) - changed = True - while changed: - changed = self.fuse(_reshape) - changed = self.fuse(_elemwise_depth) or changed - changed = self.fuse(_elemwise_width) or changed - changed = self.fuse(_reduce_depth) or changed - changed = self.fuse(_reduce_width) or changed - changed = self.fuse(_broadcast_depth) or changed - changed = self.fuse(_broadcast_width) or changed - changed = self.fuse(_matmul_depth) or changed - changed = self.fuse(_reduce_output) or changed - self.fuse(_transdata) + def _fuse_once(fuse_func): + if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \ + fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \ + fuse_func(_broadcast_width) or fuse_func(_matmul_depth) or fuse_func(_reduce_output) or \ + fuse_func(_transdata): + pass + if fuse_func is None: + _fuse_loop() + else: + _fuse_once(fuse_func) def split(graph, target, flags): diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index 4790cef17fe..9eb573a8cdd 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -320,8 +320,8 @@ class Operator: def __str__(self): args = ', '.join([str(t) for t in self.all_inputs]) - expr = "%s = %s.%s(%s)" % ( - str(self.output), self.prim, self.output.dtype, args) + expr = "%s = %s.%s(%s) id:%s" % ( + str(self.output), self.prim, self.output.dtype, args, id(self)) return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs)) def __repr__(self): @@ -331,12 +331,13 @@ class Operator: class Graph: """Graph""" - def __init__(self, name, ops, stitch_info=None): + def __init__(self, name, ops, stitch_info=None, recompute_ops=None): self.name = name self.ops = ops # in topo order, can not use set self.inputs = [] self.outputs = [] self.stitch_info = stitch_info + self.recompute_ops = recompute_ops def set_processor(self, processor): """Set processor""" diff --git a/mindspore/_extends/graph_kernel/model/model_builder.py b/mindspore/_extends/graph_kernel/model/model_builder.py index 0f0914cf893..68c6b0f7cf5 100644 --- a/mindspore/_extends/graph_kernel/model/model_builder.py +++ b/mindspore/_extends/graph_kernel/model/model_builder.py @@ -203,11 +203,13 @@ class CompositeGraph: desc['buffer_stitch'] = buffer_stitch return desc - def dump(self, subgraph): - """Dump Graph to json""" - desc = {} - inputs, outputs = subgraph.deduce_parameters() - graph_ops = set(subgraph.ops) + def add_recompute_ops(self, subgraph, desc): + if subgraph.recompute_ops: + desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops] + return desc + + def _pre_dump(self, outputs): + """restore name to before load""" inplace_assign = {} # y_name, output_name inplace_assign_z = None for op in self.desc['op_desc']: @@ -217,6 +219,14 @@ class CompositeGraph: for t in outputs: if t.name not in inplace_assign: inplace_assign_z = t + return inplace_assign, inplace_assign_z + + def dump(self, subgraph): + """Dump Graph to json""" + desc = {} + inputs, outputs = subgraph.deduce_parameters() + graph_ops = set(subgraph.ops) + inplace_assign, inplace_assign_z = self._pre_dump(outputs) for key in self.desc: if key == 'input_desc': desc[key] = [ @@ -251,7 +261,7 @@ class CompositeGraph: op_desc.append(inplace_desc) else: op = self.tensors[d['output_desc'][0]['tensor_name']].op - if op in graph_ops: + if op in graph_ops or op in subgraph.recompute_ops: op_desc.append(d) desc[key] = op_desc elif key == 'op': @@ -260,6 +270,7 @@ class CompositeGraph: desc[key] = self.desc[key] desc = self.add_stitch_info(subgraph, desc) + desc = self.add_recompute_ops(subgraph, desc) return desc diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc index 64973db201a..a58e4eba20e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc @@ -433,68 +433,5 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const std::string &kernel_js auto kernel_json = nlohmann::json::parse(kernel_json_str); return DecodeFusedNodes(kernel_json); } - -StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json) const { - StitchInfo info; - if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) { - nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch]; - if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) { - std::vector stitch_ops = buffer_stitch[kJsonKeyStitchOp]; - info.stitch_ops = stitch_ops; - } - if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) { - std::vector stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp]; - info.stitch_atomic_ops = stitch_atomic_ops; - } - } - return info; -} - -void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, - const CNodePtr &node) const { - std::vector output_descs = op_desc[kJsonKeyOutputDesc]; - if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return; - std::string tensor_name = output_descs[0][kJsonKeyTensorName]; - if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) { - AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node); - MS_LOG(INFO) << "Enable common stitch fusion by " << node->fullname_with_scope(); - } - if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) != - info.stitch_atomic_ops.end()) { - AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node); - MS_LOG(INFO) << "Enable atomic add stitch fusion by " << node->fullname_with_scope(); - } -} - -bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json, - const std::map &address_node_map, - AnfNodePtrList *res_graphs) { - MS_EXCEPTION_IF_NULL(res_graphs); - MS_LOG(DEBUG) << "start decode, " << kernel_json; - // decode cnodes in graph. - std::vector op_node_descs = kernel_json[kJsonKeyOpDesc]; - if (op_node_descs.empty()) { - MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json; - return false; - } - StitchInfo info = GetStitchInfo(kernel_json); - for (const auto &op_desc : op_node_descs) { - if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) { - MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc; - return false; - } - - std::string ptr_address = op_desc[kJsonKeyPtrAddress]; - if (address_node_map.count(ptr_address) == 0) { - MS_LOG(ERROR) << "Decode failed, ptr_address not found in map."; - return false; - } - auto node = address_node_map.at(ptr_address)->cast(); - SetStitchAttr(op_desc, info, node); - res_graphs->push_back(node); - } - MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size(); - return true; -} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.h index 585e946eb6e..98578c67e64 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.h @@ -26,10 +26,6 @@ namespace mindspore { namespace kernel { -struct StitchInfo { - std::vector stitch_ops; - std::vector stitch_atomic_ops; -}; class AkgKernelJsonDecoder { public: AkgKernelJsonDecoder() { nodes_map_.clear(); } @@ -37,15 +33,11 @@ class AkgKernelJsonDecoder { FuncGraphPtr DecodeFusedNodes(const nlohmann::json &kernel_json); FuncGraphPtr DecodeFusedNodes(const std::string &kernel_json_str); - bool DecodeSplitNodes(const nlohmann::json &kernel_json, const std::map &address_node_map, - AnfNodePtrList *res_graphs); private: ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph); CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor); AnfNodePtr DecodeOutput(const std::vector &output_descs, const FuncGraphPtr &func_graph); - StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) const; - void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) const; std::map nodes_map_; }; } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h index ce8463aa90c..193ac54dcb0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_generator.h @@ -54,6 +54,7 @@ constexpr auto kJsonKeyFusionType = "fusion_type"; constexpr auto kJsonKeySubGraph = "sub_graph"; constexpr auto kJsonKeyCoreNum = "core_num"; constexpr auto kJsonKeyTypeInfo = "type_info"; +constexpr auto kJsonKeyRecomputeOps = "recompute_ops"; constexpr auto kJsonKeyBufferStitch = "buffer_stitch"; constexpr auto kJsonKeyStitchOp = "stitch_op"; constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op"; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc index 050e65b6784..cbe161a4935 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc @@ -117,7 +117,7 @@ PassManagerPtr GraphKernelOptimizer::Split() const { // which can avoid unnecessary input-output and get better performance. // preprocess for ShapeOpsSplitter pm->AddPass(std::make_shared(), OptLevel_1); - std::vector duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast}; + std::vector duplicated_ops = {prim::kPrimReshape}; pm->AddPass(std::make_shared(duplicated_ops), OptLevel_1); // Split kernel according to costmodel diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc index 8b2c8031223..455cdc3b6ef 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc @@ -32,6 +32,150 @@ #include "utils/context/graph_kernel_flags.h" namespace mindspore { +namespace kernel { +namespace { +StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) { + StitchInfo info; + if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) { + nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch]; + if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) { + std::vector stitch_ops = buffer_stitch[kJsonKeyStitchOp]; + info.stitch_ops = stitch_ops; + } + if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) { + std::vector stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp]; + info.stitch_atomic_ops = stitch_atomic_ops; + } + } + return info; +} + +std::set GetRecomputeOps(const nlohmann::json &kernel_json) { + if (kernel_json.find(kJsonKeyRecomputeOps) != kernel_json.end()) { + std::vector recompute_ops = kernel_json[kJsonKeyRecomputeOps]; + return std::set(recompute_ops.begin(), recompute_ops.end()); + } + return std::set(); +} + +bool IsRecomputeOp(const nlohmann::json &op_desc, const std::set &recompute_ops) { + std::vector output_descs = op_desc[kJsonKeyOutputDesc]; + if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) { + return false; + } + std::string tensor_name = output_descs[0][kJsonKeyTensorName]; + if (recompute_ops.count(tensor_name)) { + return true; + } + return false; +} + +CNodePtr NewRecomputeNode(const AnfNodePtr &orig_node, std::map *node_map) { + auto func_graph = orig_node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + auto cnode = orig_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + TraceGuard guard(std::make_shared(cnode->debug_info())); + auto orig_inputs = cnode->inputs(); + std::vector inputs; + for (auto inp : orig_inputs) { + if (node_map->find(inp) == node_map->end()) { + inputs.push_back(inp); + continue; + } + inputs.push_back((*node_map)[inp]); + } + CNodePtr cp_node = func_graph->NewCNode(inputs); + func_graph->AddNode(cp_node); + cp_node->set_abstract(cnode->abstract()); + cp_node->set_forward(cnode->forward().first, cnode->forward().second); + cp_node->set_inputs_value(cnode->inputs_value()); + ScopePtr scope = (orig_node->scope() != kDefaultScope) ? orig_node->scope() : kDefaultScope; + cp_node->set_scope(scope); + cp_node->set_kernel_info(cnode->kernel_info_ptr()); + (*node_map)[orig_node] = cp_node; + return cp_node->cast(); +} + +void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) { + std::vector output_descs = op_desc[kJsonKeyOutputDesc]; + if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return; + std::string tensor_name = output_descs[0][kJsonKeyTensorName]; + if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) { + AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node); + MS_LOG(INFO) << "Enable common stitch fusion by " << node->fullname_with_scope(); + } + if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) != + info.stitch_atomic_ops.end()) { + AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node); + MS_LOG(INFO) << "Enable atomic add stitch fusion by " << node->fullname_with_scope(); + } +} + +// replace original region root op by its copy in this res_graphs +void ConnectRecomputeOps(AnfNodePtrList *res_graphs, const AnfNodePtr &orig_region_root, + const AnfNodePtr &cp_region_root) { + for (auto &node : *res_graphs) { + auto cnode = node->cast(); + auto inputs = cnode->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + if (inputs[i] != orig_region_root) continue; + cnode->set_input(i, cp_region_root); + } + } +} +} // namespace + +bool SplitNodesDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json, + const std::map &address_node_map, + AnfNodePtrList *res_graphs) { + MS_EXCEPTION_IF_NULL(res_graphs); + MS_LOG(DEBUG) << "start decode, " << kernel_json; + // decode cnodes in graph. + std::vector op_node_descs = kernel_json[kJsonKeyOpDesc]; + if (op_node_descs.empty()) { + MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json; + return false; + } + StitchInfo info = GetStitchInfo(kernel_json); + auto recompute_ops = GetRecomputeOps(kernel_json); + // key_value: original_copied + std::map node_map; + // nodes would be copied + AnfNodePtrList orig_region_nodes; + // nodes would not be copied + AnfNodePtrList no_cp_nodes; + for (const auto &op_desc : op_node_descs) { + if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) { + MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc; + return false; + } + + std::string ptr_address = op_desc[kJsonKeyPtrAddress]; + if (address_node_map.count(ptr_address) == 0) { + MS_LOG(ERROR) << "Decode failed, ptr_address not found in map."; + return false; + } + auto node = address_node_map.at(ptr_address)->cast(); + if (IsRecomputeOp(op_desc, recompute_ops)) { + auto cp_node = NewRecomputeNode(node, &node_map); + orig_region_nodes.push_back(node); + SetStitchAttr(op_desc, info, cp_node); + res_graphs->push_back(cp_node); + continue; + } + SetStitchAttr(op_desc, info, node); + res_graphs->push_back(node); + no_cp_nodes.push_back(node); + } + for (auto orig_node : orig_region_nodes) { + ConnectRecomputeOps(&no_cp_nodes, orig_node, node_map[orig_node]); + } + MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size(); + return true; +} +} // namespace kernel + namespace opt { namespace { void TraverseFuncGraphFromCNode(const CNodePtr &cnode, const std::function &callback) { @@ -620,7 +764,7 @@ class CostModelSplitSchemer : public SplitSchemer { split_plan_.clear(); for (const auto &graph_desc : graph_descs) { AnfNodePtrList res_graph; - if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) { + if (!kernel::SplitNodesDecoder::DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) { MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc; return false; } @@ -731,6 +875,7 @@ class CostModelSplitSchemer : public SplitSchemer { nlohmann::json flag_json; flag_json["dump_as_text"] = flags.dump_as_text; flag_json["enable_stitch_fusion"] = flags.enable_stitch_fusion; + flag_json["enable_recompute_fusion"] = flags.enable_recompute_fusion; return flag_json.dump(); } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.h index cab14e3376e..ebdfddaeee4 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.h @@ -15,11 +15,31 @@ */ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_SPLITTER_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_SPLITTER_H_ +#include #include +#include +#include +#include +#include #include "ir/func_graph.h" #include "backend/optimizer/common/pass.h" namespace mindspore { +namespace kernel { +struct StitchInfo { + std::vector stitch_ops; + std::vector stitch_atomic_ops; +}; + +class SplitNodesDecoder { + public: + SplitNodesDecoder() {} + ~SplitNodesDecoder() = default; + static bool DecodeSplitNodes(const nlohmann::json &kernel_json, + const std::map &address_node_map, AnfNodePtrList *res_graphs); +}; +} // namespace kernel + namespace opt { class GraphKernelSplitter : public Pass { public: diff --git a/mindspore/ccsrc/utils/context/graph_kernel_flags.cc b/mindspore/ccsrc/utils/context/graph_kernel_flags.cc index fd925f92df2..e6efa0ca9e6 100644 --- a/mindspore/ccsrc/utils/context/graph_kernel_flags.cc +++ b/mindspore/ccsrc/utils/context/graph_kernel_flags.cc @@ -181,6 +181,7 @@ void GraphKernelFlags::RegisterFlags(std::map *flag_ma // Boolean flags reg.AddFlag("dump_as_text", &dump_as_text); reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion, opt_level == OptLevel_3); + reg.AddFlag("enable_recompute_fusion", &enable_recompute_fusion, opt_level == OptLevel_2); reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion, opt_level == OptLevel_3); // Integer flags @@ -203,6 +204,7 @@ std::string GraphKernelFlags::DumpAllFlags() const { json["dump_as_text"] = dump_as_text; json["enable_stitch_fusion"] = enable_stitch_fusion; + json["enable_recompute_fusion"] = enable_recompute_fusion; json["enable_parallel_fusion"] = enable_parallel_fusion; json["opt_level"] = opt_level; diff --git a/mindspore/ccsrc/utils/context/graph_kernel_flags.h b/mindspore/ccsrc/utils/context/graph_kernel_flags.h index 25a0ab82e3b..b94a263549a 100644 --- a/mindspore/ccsrc/utils/context/graph_kernel_flags.h +++ b/mindspore/ccsrc/utils/context/graph_kernel_flags.h @@ -67,6 +67,11 @@ class GraphKernelFlags { */ bool enable_stitch_fusion; + /** + * Enable recompute fusion in graph kernel fusion strategy. + */ + bool enable_recompute_fusion{true}; + /** * Enable parallel fusion in graph kernel fusion strategy. * diff --git a/tests/st/ops/graph_kernel/test_recompute.py b/tests/st/ops/graph_kernel/test_recompute.py new file mode 100644 index 00000000000..762ee93213d --- /dev/null +++ b/tests/st/ops/graph_kernel/test_recompute.py @@ -0,0 +1,223 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.nn import Cell +import mindspore.ops.operations as P + +#{cast} would be recompute and fused +class Net1(Cell): + def __init__(self): + super(Net1, self).__init__() + self.cast = P.Cast() + self.sum = P.ReduceSum(keep_dims=False) + + def construct(self, x): + cast_res = self.cast(x, mstype.float32) + sum1_res = self.sum(cast_res, (0,)) + sum2_res = self.sum(cast_res, (1,)) + return sum1_res, sum2_res + +#{sqrt} would be recompute on Ascend +class Net2(Cell): + def __init__(self): + super(Net2, self).__init__() + self.sqrt = P.Sqrt() + self.sum = P.ReduceSum(keep_dims=True) + self.add = P.Add() + self.neg = P.Neg() + + def construct(self, x0, x1): + sqrt_res = self.sqrt(x0) + neg_res = self.neg(sqrt_res) + add_res = self.add(x1, sqrt_res) + sum_res = self.sum(add_res, (0,)) + return neg_res, sum_res + +#{sqrt} would be recompute +class Net3(Cell): + def __init__(self): + super(Net3, self).__init__() + self.sqrt = P.Sqrt() + self.add = P.Add() + self.neg = P.Neg() + + def construct(self, x0, x1): + sqrt_res = self.sqrt(x0) + neg_res = self.neg(sqrt_res) + add_res = self.add(x1, sqrt_res) + return neg_res, add_res + +#{sqrt neg} would be recompute +class Net4(Cell): + def __init__(self): + super(Net4, self).__init__() + self.sqrt = P.Sqrt() + self.neg = P.Neg() + self.sum = P.ReduceSum(keep_dims=False) + + def construct(self, x): + sqrt_res = self.sqrt(x) + neg_res = self.neg(sqrt_res) + sum1_res = self.sum(neg_res, (0,)) + sum2_res = self.sum(neg_res, (1,)) + return sum1_res, sum2_res + +#{sqrt} would be recompute +class Net5(Cell): + def __init__(self): + super(Net5, self).__init__() + self.sqrt = P.Sqrt() + self.add = P.Add() + + def construct(self, x0, x1, x2): + sqrt_res = self.sqrt(x0) + add1_res = self.add(sqrt_res, x1) + add2_res = self.add(sqrt_res, x2) + return add1_res, add2_res + +def test_basic1(net): + def get_output(i0, net, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) + net_obj = net() + output = net_obj(i0) + return output + + i0 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float16)) + expect = get_output(i0, net, False) + output = get_output(i0, net, True) + expect0_np = expect[0].asnumpy().copy() + output0_np = output[0].asnumpy().copy() + expect1_np = expect[1].asnumpy().copy() + output1_np = output[1].asnumpy().copy() + assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3) + assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3) + + +def test_basic2(net): + def get_output(i0, i1, net, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) + net_obj = net() + output = net_obj(i0, i1) + return output + + i0 = Tensor(np.random.uniform(1, 2, [1, 1024]).astype(np.float32)) + i1 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float32)) + expect = get_output(i0, i1, net, False) + output = get_output(i0, i1, net, True) + expect0_np = expect[0].asnumpy().copy() + output0_np = output[0].asnumpy().copy() + expect1_np = expect[1].asnumpy().copy() + output1_np = output[1].asnumpy().copy() + assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3) + assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3) + +def test_basic3(net): + def get_output(i0, i1, i2, net, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) + net_obj = net() + output = net_obj(i0, i1, i2) + return output + + i0 = Tensor(np.random.uniform(1, 2, [1, 1024]).astype(np.float16)) + i1 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float16)) + i2 = Tensor(np.random.uniform(1, 2, [2048, 1024]).astype(np.float16)) + expect = get_output(i0, i1, i2, net, False) + output = get_output(i0, i1, i2, net, True) + expect0_np = expect[0].asnumpy().copy() + output0_np = output[0].asnumpy().copy() + expect1_np = expect[1].asnumpy().copy() + output1_np = output[1].asnumpy().copy() + assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3) + assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gpu_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_basic1(Net1) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gpu_2(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_basic2(Net2) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gpu_3(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_basic2(Net3) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gpu_4(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_basic1(Net4) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gpu_5(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_basic3(Net5) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ascend_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_basic1(Net1) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ascend_2(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_basic2(Net2) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ascend_3(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_basic2(Net3) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ascend_4(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_basic1(Net4) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ascend_5(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_basic3(Net5)