!16766 Add recompute fuse

Merge pull request !16766 from lingyunli63/recompute_fuse
This commit is contained in:
i-robot 2021-06-23 07:54:12 +00:00 committed by Gitee
commit 0c360ea2d6
12 changed files with 743 additions and 143 deletions

View File

@ -16,7 +16,7 @@
import os
from functools import reduce
from mindspore import log as logger
from .model import PrimLib, Graph, Tensor
from .model import PrimLib, Graph, Tensor, Operator
from .model import DataFormat as DF
@ -65,13 +65,16 @@ class GraphSplitByPattern:
self.stitch_ops = set()
self.stitch_atomic_ops = set()
def __init__(self, init_op, is_output, unique_id, reach_tab):
self.pattern = PrimLib.iter_type(init_op)
self.ops = [init_op]
def __init__(self, init_op, is_output, unique_id, reach_tab, recompute_ops=None):
self.pattern = PrimLib.iter_type(init_op) if init_op is not None else PrimLib.UNKNOWN
self.ops = [] if init_op is None else [init_op]
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
self.out_relations = dict() # {area1: relation1, area2: relation2, ...}
self.mode = None
self.stitch_info = self.StitchInfo()
self.recompute_ops = [] if recompute_ops is None else recompute_ops
self.ori_op_map = {}
self.is_recompute = False
self.is_output = is_output
self.output_excluded = set()
if self.pattern == PrimLib.REDUCE:
@ -143,6 +146,8 @@ class GraphSplitByPattern:
r = rels.pop(area)
_update_relation(rels, self, r)
if area.is_recompute:
if self.pattern >= area.pattern:
@ -161,7 +166,9 @@ class GraphSplitByPattern:
if area.output_excluded:
self.reach_tab.fuse(self.unique_id, area.unique_id)
if not area.is_recompute:
self.reach_tab.fuse(self.unique_id, area.unique_id)
def check_acyclic(self, to):
"""Check circle. It returns false if circle exists"""
@ -180,25 +187,73 @@ class GraphSplitByPattern:
return True
return False
def cp_ops(self, area):
"""copy recompute_ops in area to ops, self is area's user"""
tail_tensor = area.recompute_ops[-1].output
#copy tensors, all copied are Tensor.PARA_NONE
tensor_map = {}
tensor_map[area.recompute_ops[0].inputs[0]] = area.recompute_ops[0].inputs[0]
for op in area.recompute_ops:
orig_tensor = op.output
cp_tensor = Tensor(orig_tensor.name, orig_tensor.shape, orig_tensor.dtype, orig_tensor.data_format)
tensor_map[orig_tensor] = cp_tensor
#copy ops
cp_ops = []
for op in area.recompute_ops:
cp_op = Operator(op.prim, [tensor_map[op.inputs[0]]], tensor_map[op.output], op.attrs)
cp_op.all_inputs = cp_op.inputs
area.ori_op_map[cp_op] = op
#connect copied ops
for op in self.ops:
if tail_tensor in op.inputs:
#fill cp_ops in self.recompute_area
cp_dom_op = None
for cp, ori in area.ori_op_map.items():
if ori == area.dom_op():
cp_dom_op = cp
area.ops.extend([op for op in cp_ops if op != cp_dom_op])
def __init__(self, graph, flags):
self.graph = graph
self.areas = []
self.flags = flags
self.reach_tab = self.ReachTable(len(graph.ops))
area_map = {}
self.enable_recompute = self.flags.get("enable_recompute_fusion", False)
self.reach_tab = self.ReachTable(len(graph.ops) + 1 if self.enable_recompute else len(graph.ops))
self.area_map = {}
_, outputs = graph.deduce_parameters()
idx = 0
self.idx = 0
for op in graph.ops:
is_output = op.output in outputs
a = self.Area(op, is_output, idx, self.reach_tab)
idx += 1
a = self.Area(op, is_output, self.idx, self.reach_tab)
self.idx += 1
area_map[op] = a
self.set_area_map([op], a)
for a in self.areas:
for i in range(len(self.areas)-1, -1, -1):
if self.enable_recompute:
self.recom_area = self.Area(None, False, self.idx, self.reach_tab)
self.recom_area.is_recompute = True
self.recom_pre = None
self.recom_user = None
self.recom_dom = None
self.dom_user_r = PrimLib.UNKNOWN
self.recom_res = False
self.orig_op_map = {}
def set_area_map(self, ops, area):
"""update area_map after op fused to area"""
for op in ops:
self.area_map[op] = area
def set_default_mode(self, area):
area.mode = self.get_default_mode(area.ops[0])
@ -234,11 +289,13 @@ class GraphSplitByPattern:
if is_forward:
for area in fuse_areas:
self.set_area_map(area.ops, dominant)
forward_area = dominant
for area in fuse_areas:
self.set_area_map(forward_area.ops, area)
forward_area = area
changed = True
@ -246,16 +303,39 @@ class GraphSplitByPattern:
return changed
def to_subgraphs(self):
"""Transform op groups to subgraphs"""
def fuse_recom(self, selector):
"""Fuse recompute area to its user"""
for dominant in [self.recom_area, self.recom_user]:
result = selector(dominant)
if result is not None and result[0]:
fuse_areas, _ = result
fuse_areas = self.limit_area_size(dominant, fuse_areas)
if not fuse_areas:
if fuse_areas[0] in [self.recom_area, self.recom_user]:
self.recom_res = True
return True
return False
def index_op(self):
"""index op by order, the copied op share id with original op, for topo-sort"""
ids = {}
for i, op in enumerate(self.graph.ops):
ids[op] = i
if hasattr(self, 'orig_op_map'):
for k, v in self.orig_op_map.items():
ids[k] = ids[v]
return ids
def to_subgraphs(self):
"""Transform op groups to subgraphs"""
ids = self.index_op()
subgraphs = []
graphmodes = []
for i, area in enumerate(self.areas):
area.ops.sort(key=lambda op: ids[op])
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info))
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info, area.recompute_ops))
graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite")
return subgraphs, graphmodes
@ -274,13 +354,14 @@ class GraphSplitByPattern:
with os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), 'w+') as f:
def do_split(self):
"""Split graph by pattern"""
raise Exception("do_split() is not implemented in {}".format(self.__class__.__name__))
def pattern_fuse(self, select=None):
"""fuse Areas by pattern repeatedly"""
raise Exception("pattern_fuse() is not implemented in {}".format(self.__class__.__name__))
def split(self):
"""Split graph by pattern"""
# The reshape should not be output node
# Note: after this function, the input output relation is not maintained.
@ -316,6 +397,159 @@ class GraphSplitByPattern:
if new_areas:
self.areas += new_areas
def set_recompute(self, dom_area, ops, user_area):
"""set the recompute area and connect with other areas"""
#recom_area: set dom_op and correct ops length
patterns = [PrimLib.iter_type(op) for op in ops]
self.recom_area.pattern = max(patterns)
for i, pat in enumerate(patterns):
if pat == self.recom_area.pattern:
self.recom_area.ops = [ops[i]] * len(ops)
#disconnect dom_area and user_area
self.dom_user_r = dom_area.out_relations[user_area]
#connect recom_area and user_area
user_area.in_relations[self.recom_area] = self.dom_user_r
self.recom_area.out_relations[user_area] = self.dom_user_r
#connect recom_pre and recom_area
self.recom_pre = self.area_map[ops[0].inputs[0].op] if ops[0].inputs[0].op else None
if self.recom_pre is not None:
self.recom_area.in_relations[self.recom_pre] = dom_area.in_relations[self.recom_pre]
self.recom_pre.out_relations[self.recom_area] = dom_area.in_relations[self.recom_pre]
#set related areas
self.recom_user = user_area
self.recom_dom = dom_area
self.recom_res = False
def clear_recompute(self):
"""disconnect recom_area from other areas, and clear recom_area"""
if not self.recom_res:
self.recom_user.in_relations[self.recom_dom] = self.dom_user_r
self.recom_dom.out_relations[self.recom_user] = self.dom_user_r
if self.recom_pre:
def to_subgraph(self, dom):
"""Transform area to subgraphs"""
ids = self.index_op()
dom_ops = list()
dom_ops.sort(key=lambda op: ids[op])
subgraph = []
subgraph = Graph('{}_area'.format(self.graph.name), dom_ops)
return subgraph
def find_cheap_regions(self, dom):
"""extract all the cheap regions in dom area, toposort each region before return"""
def _grow_region(region_ops, op, weight, inputs):
"""include op to region_ops if region grow"""
# region successfully ends at input
if op.inputs[0] in inputs and len(op.inputs) == 1 and \
PrimLib.iter_type(op) <= PrimLib.BROADCAST:
return False, None, weight
#region fails to grow
if weight > MAX_WEIGHT or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST:
return False, None, weight
#region grows successfully
weight = weight + 1
return True, op.inputs[0].op, weight
def _find_cheap_regions(dom):
sub = self.to_subgraph(dom)
inputs, outputs = sub.deduce_parameters()
cheap_regions = []
for output in outputs:
# tensor should have user other than user_area to be fused
if output.para_type != Tensor.PARA_OUTPUT and len(output.to_ops) < 2:
region_ops = []
grow = True
candidate_op = output.op
weight = 1
while grow:
grow, candidate_op, weight = _grow_region(region_ops, candidate_op, weight, inputs)
# region ends at input and not empty
if region_ops and region_ops[-1].inputs[0] in inputs:
# tensor size should equal or becomes larger(cast up, broadcast)
if region_ops[0].inputs[0].get_size() > region_ops[-1].output.get_size():
return cheap_regions
return _find_cheap_regions(dom)
def select_user_area(self, tail_tensor):
"""select the user area has only one edge to dom area"""
def _get_edge_num(dom_area, user_area):
"""get edge num between two areas"""
dom_graph = self.to_subgraph(dom_area)
_, dom_outputs = dom_graph.deduce_parameters()
user_graph = self.to_subgraph(user_area)
user_inputs, _ = user_graph.deduce_parameters()
edge = [t for t in dom_outputs if t in user_inputs]
return len(edge)
def _select_user_area(tail_tensor):
user_areas = []
for user_op in tail_tensor.to_ops:
user_area = self.area_map[user_op]
if len(user_area.ops) == 1 and user_area.pattern == PrimLib.RESHAPE:
edge_num = _get_edge_num(self.area_map[tail_tensor.op], user_area)
if edge_num == 1 and not user_area in user_areas:
return user_areas
return _select_user_area(tail_tensor)
def recompute_fuse(self):
"""find recompute regions and copy them out to new Areas"""
def do_recompute_fuse():
"""split the unfusing pattern by add recompute area"""
recompute_suc = False
orig_areas = []
for dom in orig_areas:
if dom not in self.areas or not dom.out_relations:
cheap_regions = self.find_cheap_regions(dom)
dom_changed = False
for cheap_region in cheap_regions:
user_areas = self.select_user_area(cheap_region[-1].output)
if not user_areas:
for user_area in user_areas:
self.set_recompute(dom, cheap_region, user_area)
if self.recom_res:
recompute_suc = True
#Copy region at most once for this dom
dom_changed = True
if dom_changed:
return recompute_suc
if self.enable_recompute:
while do_recompute_fuse():
use_poly_reduce = True
@ -331,8 +565,8 @@ class GraphSplitGpu(GraphSplitByPattern):
pattern = PrimLib.iter_type(op)
return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE
def do_split(self):
"""Split graph by pattern"""
def pattern_fuse(self, fuse_func=None):
"""fuse Areas by pattern"""
def _reshape(dom):
if dom.pattern != PrimLib.RESHAPE:
return None
@ -551,21 +785,38 @@ class GraphSplitGpu(GraphSplitByPattern):
return fused, True
enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False)
changed = True
while changed:
changed = self.fuse(_reshape)
changed = self.fuse(_elemwise_depth) or changed
changed = self.fuse(_elemwise_width) or changed
changed = self.fuse(_reduce_depth) or changed
changed = self.fuse(_reduce_width) or changed
changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed
def _fuse_loop():
changed = True
while changed:
changed = self.fuse(_reshape)
changed = self.fuse(_elemwise_depth) or changed
changed = self.fuse(_elemwise_width) or changed
changed = self.fuse(_reduce_depth) or changed
changed = self.fuse(_reduce_width) or changed
changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed
if use_poly_reduce:
changed = self.fuse(_reduce_output) or changed
if enable_stitch_fusion:
changed = self.fuse(_reduce_stitch) or changed
def _fuse_once(fuse_func):
if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
if use_poly_reduce:
changed = self.fuse(_reduce_output) or changed
if enable_stitch_fusion:
changed = self.fuse(_reduce_stitch) or changed
if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)):
enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False)
if fuse_func is None:
class GraphSplitAscend(GraphSplitByPattern):
@ -580,8 +831,8 @@ class GraphSplitAscend(GraphSplitByPattern):
return self.Area.MODE_COMPOSITE
return self.Area.MODE_BASIC
def do_split(self):
"""Split graph by pattern"""
def pattern_fuse(self, fuse_func=None):
"""fuse Areas by pattern"""
def _tensor_size(tensor):
size = 1
for i in tensor.shape:
@ -685,6 +936,19 @@ class GraphSplitAscend(GraphSplitByPattern):
return fused, False
def _reduce_output(dom):
if dom.pattern != PrimLib.REDUCE:
return None
op_attrs = dom.dom_op().attrs
if not op_attrs.get('reduce_output_fuse'):
return None
fused = []
for a, r in dom.out_relations.items():
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \
return fused, False
def _transdata_pattern_support(dom, a):
transdata_op = dom.dom_op()
@ -733,32 +997,31 @@ class GraphSplitAscend(GraphSplitByPattern):
return fused, True
def _reduce_output(dom):
if dom.pattern != PrimLib.REDUCE:
return None
op_attrs = dom.dom_op().attrs
if not op_attrs.get('reduce_output_fuse'):
return None
fused = []
for a, r in dom.out_relations.items():
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \
return fused, False
def _fuse_loop():
changed = True
while changed:
changed = self.fuse(_reshape)
changed = self.fuse(_elemwise_depth) or changed
changed = self.fuse(_elemwise_width) or changed
changed = self.fuse(_reduce_depth) or changed
changed = self.fuse(_reduce_width) or changed
changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed
changed = self.fuse(_matmul_depth) or changed
changed = self.fuse(_reduce_output) or changed
changed = True
while changed:
changed = self.fuse(_reshape)
changed = self.fuse(_elemwise_depth) or changed
changed = self.fuse(_elemwise_width) or changed
changed = self.fuse(_reduce_depth) or changed
changed = self.fuse(_reduce_width) or changed
changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed
changed = self.fuse(_matmul_depth) or changed
changed = self.fuse(_reduce_output) or changed
def _fuse_once(fuse_func):
if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
fuse_func(_broadcast_width) or fuse_func(_matmul_depth) or fuse_func(_reduce_output) or \
if fuse_func is None:
def split(graph, target, flags):

View File

@ -320,8 +320,8 @@ class Operator:
def __str__(self):
args = ', '.join([str(t) for t in self.all_inputs])
expr = "%s = %s.%s(%s)" % (
str(self.output), self.prim, self.output.dtype, args)
expr = "%s = %s.%s(%s) id:%s" % (
str(self.output), self.prim, self.output.dtype, args, id(self))
return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
def __repr__(self):
@ -331,12 +331,13 @@ class Operator:
class Graph:
def __init__(self, name, ops, stitch_info=None):
def __init__(self, name, ops, stitch_info=None, recompute_ops=None):
self.name = name
self.ops = ops # in topo order, can not use set
self.inputs = []
self.outputs = []
self.stitch_info = stitch_info
self.recompute_ops = recompute_ops
def set_processor(self, processor):
"""Set processor"""

View File

@ -203,11 +203,13 @@ class CompositeGraph:
desc['buffer_stitch'] = buffer_stitch
return desc
def dump(self, subgraph):
"""Dump Graph to json"""
desc = {}
inputs, outputs = subgraph.deduce_parameters()
graph_ops = set(subgraph.ops)
def add_recompute_ops(self, subgraph, desc):
if subgraph.recompute_ops:
desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops]
return desc
def _pre_dump(self, outputs):
"""restore name to before load"""
inplace_assign = {} # y_name, output_name
inplace_assign_z = None
for op in self.desc['op_desc']:
@ -217,6 +219,14 @@ class CompositeGraph:
for t in outputs:
if t.name not in inplace_assign:
inplace_assign_z = t
return inplace_assign, inplace_assign_z
def dump(self, subgraph):
"""Dump Graph to json"""
desc = {}
inputs, outputs = subgraph.deduce_parameters()
graph_ops = set(subgraph.ops)
inplace_assign, inplace_assign_z = self._pre_dump(outputs)
for key in self.desc:
if key == 'input_desc':
desc[key] = [
@ -251,7 +261,7 @@ class CompositeGraph:
op = self.tensors[d['output_desc'][0]['tensor_name']].op
if op in graph_ops:
if op in graph_ops or op in subgraph.recompute_ops:
desc[key] = op_desc
elif key == 'op':
@ -260,6 +270,7 @@ class CompositeGraph:
desc[key] = self.desc[key]
desc = self.add_stitch_info(subgraph, desc)
desc = self.add_recompute_ops(subgraph, desc)
return desc

View File

@ -433,68 +433,5 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const std::string &kernel_js
auto kernel_json = nlohmann::json::parse(kernel_json_str);
return DecodeFusedNodes(kernel_json);
StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json) const {
StitchInfo info;
if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) {
nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch];
if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) {
std::vector<std::string> stitch_ops = buffer_stitch[kJsonKeyStitchOp];
info.stitch_ops = stitch_ops;
if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) {
std::vector<std::string> stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp];
info.stitch_atomic_ops = stitch_atomic_ops;
return info;
void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info,
const CNodePtr &node) const {
std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return;
std::string tensor_name = output_descs[0][kJsonKeyTensorName];
if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) {
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node);
MS_LOG(INFO) << "Enable common stitch fusion by " << node->fullname_with_scope();
if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) !=
info.stitch_atomic_ops.end()) {
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node);
MS_LOG(INFO) << "Enable atomic add stitch fusion by " << node->fullname_with_scope();
bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json,
const std::map<std::string, AnfNodePtr> &address_node_map,
AnfNodePtrList *res_graphs) {
MS_LOG(DEBUG) << "start decode, " << kernel_json;
// decode cnodes in graph.
std::vector<nlohmann::json> op_node_descs = kernel_json[kJsonKeyOpDesc];
if (op_node_descs.empty()) {
MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json;
return false;
StitchInfo info = GetStitchInfo(kernel_json);
for (const auto &op_desc : op_node_descs) {
if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) {
MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc;
return false;
std::string ptr_address = op_desc[kJsonKeyPtrAddress];
if (address_node_map.count(ptr_address) == 0) {
MS_LOG(ERROR) << "Decode failed, ptr_address not found in map.";
return false;
auto node = address_node_map.at(ptr_address)->cast<CNodePtr>();
SetStitchAttr(op_desc, info, node);
MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size();
return true;
} // namespace kernel
} // namespace mindspore

View File

@ -26,10 +26,6 @@
namespace mindspore {
namespace kernel {
struct StitchInfo {
std::vector<std::string> stitch_ops;
std::vector<std::string> stitch_atomic_ops;
class AkgKernelJsonDecoder {
AkgKernelJsonDecoder() { nodes_map_.clear(); }
@ -37,15 +33,11 @@ class AkgKernelJsonDecoder {
FuncGraphPtr DecodeFusedNodes(const nlohmann::json &kernel_json);
FuncGraphPtr DecodeFusedNodes(const std::string &kernel_json_str);
bool DecodeSplitNodes(const nlohmann::json &kernel_json, const std::map<std::string, AnfNodePtr> &address_node_map,
AnfNodePtrList *res_graphs);
ParameterPtr DecodeParameter(const nlohmann::json &parameter_json, const FuncGraphPtr &func_graph);
CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor);
AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph);
StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) const;
void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) const;
std::map<std::string, AnfNodePtr> nodes_map_;
} // namespace kernel

View File

@ -54,6 +54,7 @@ constexpr auto kJsonKeyFusionType = "fusion_type";
constexpr auto kJsonKeySubGraph = "sub_graph";
constexpr auto kJsonKeyCoreNum = "core_num";
constexpr auto kJsonKeyTypeInfo = "type_info";
constexpr auto kJsonKeyRecomputeOps = "recompute_ops";
constexpr auto kJsonKeyBufferStitch = "buffer_stitch";
constexpr auto kJsonKeyStitchOp = "stitch_op";
constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op";

View File

@ -117,7 +117,7 @@ PassManagerPtr GraphKernelOptimizer::Split() const {
// which can avoid unnecessary input-output and get better performance.
// preprocess for ShapeOpsSplitter
pm->AddPass(std::make_shared<ExtendOutputForUpdateState>(), OptLevel_1);
std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast};
std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape};
pm->AddPass(std::make_shared<ShapeOpsSplitter>(duplicated_ops), OptLevel_1);
// Split kernel according to costmodel

View File

@ -32,6 +32,150 @@
#include "utils/context/graph_kernel_flags.h"
namespace mindspore {
namespace kernel {
namespace {
StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) {
StitchInfo info;
if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) {
nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch];
if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) {
std::vector<std::string> stitch_ops = buffer_stitch[kJsonKeyStitchOp];
info.stitch_ops = stitch_ops;
if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) {
std::vector<std::string> stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp];
info.stitch_atomic_ops = stitch_atomic_ops;
return info;
std::set<std::string> GetRecomputeOps(const nlohmann::json &kernel_json) {
if (kernel_json.find(kJsonKeyRecomputeOps) != kernel_json.end()) {
std::vector<std::string> recompute_ops = kernel_json[kJsonKeyRecomputeOps];
return std::set<std::string>(recompute_ops.begin(), recompute_ops.end());
return std::set<std::string>();
bool IsRecomputeOp(const nlohmann::json &op_desc, const std::set<std::string> &recompute_ops) {
std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) {
return false;
std::string tensor_name = output_descs[0][kJsonKeyTensorName];
if (recompute_ops.count(tensor_name)) {
return true;
return false;
CNodePtr NewRecomputeNode(const AnfNodePtr &orig_node, std::map<AnfNodePtr, AnfNodePtr> *node_map) {
auto func_graph = orig_node->func_graph();
auto cnode = orig_node->cast<CNodePtr>();
TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info()));
auto orig_inputs = cnode->inputs();
std::vector<AnfNodePtr> inputs;
for (auto inp : orig_inputs) {
if (node_map->find(inp) == node_map->end()) {
CNodePtr cp_node = func_graph->NewCNode(inputs);
cp_node->set_forward(cnode->forward().first, cnode->forward().second);
ScopePtr scope = (orig_node->scope() != kDefaultScope) ? orig_node->scope() : kDefaultScope;
(*node_map)[orig_node] = cp_node;
return cp_node->cast<CNodePtr>();
void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) {
std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return;
std::string tensor_name = output_descs[0][kJsonKeyTensorName];
if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) {
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node);
MS_LOG(INFO) << "Enable common stitch fusion by " << node->fullname_with_scope();
if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) !=
info.stitch_atomic_ops.end()) {
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node);
MS_LOG(INFO) << "Enable atomic add stitch fusion by " << node->fullname_with_scope();
// replace original region root op by its copy in this res_graphs
void ConnectRecomputeOps(AnfNodePtrList *res_graphs, const AnfNodePtr &orig_region_root,
const AnfNodePtr &cp_region_root) {
for (auto &node : *res_graphs) {
auto cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
if (inputs[i] != orig_region_root) continue;
cnode->set_input(i, cp_region_root);
} // namespace
bool SplitNodesDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json,
const std::map<std::string, AnfNodePtr> &address_node_map,
AnfNodePtrList *res_graphs) {
MS_LOG(DEBUG) << "start decode, " << kernel_json;
// decode cnodes in graph.
std::vector<nlohmann::json> op_node_descs = kernel_json[kJsonKeyOpDesc];
if (op_node_descs.empty()) {
MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json;
return false;
StitchInfo info = GetStitchInfo(kernel_json);
auto recompute_ops = GetRecomputeOps(kernel_json);
// key_value: original_copied
std::map<AnfNodePtr, AnfNodePtr> node_map;
// nodes would be copied
AnfNodePtrList orig_region_nodes;
// nodes would not be copied
AnfNodePtrList no_cp_nodes;
for (const auto &op_desc : op_node_descs) {
if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) {
MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc;
return false;
std::string ptr_address = op_desc[kJsonKeyPtrAddress];
if (address_node_map.count(ptr_address) == 0) {
MS_LOG(ERROR) << "Decode failed, ptr_address not found in map.";
return false;
auto node = address_node_map.at(ptr_address)->cast<CNodePtr>();
if (IsRecomputeOp(op_desc, recompute_ops)) {
auto cp_node = NewRecomputeNode(node, &node_map);
SetStitchAttr(op_desc, info, cp_node);
SetStitchAttr(op_desc, info, node);
for (auto orig_node : orig_region_nodes) {
ConnectRecomputeOps(&no_cp_nodes, orig_node, node_map[orig_node]);
MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size();
return true;
} // namespace kernel
namespace opt {
namespace {
void TraverseFuncGraphFromCNode(const CNodePtr &cnode, const std::function<void(AnfNodePtr &)> &callback) {
@ -620,7 +764,7 @@ class CostModelSplitSchemer : public SplitSchemer {
for (const auto &graph_desc : graph_descs) {
AnfNodePtrList res_graph;
if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) {
if (!kernel::SplitNodesDecoder::DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) {
MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc;
return false;
@ -731,6 +875,7 @@ class CostModelSplitSchemer : public SplitSchemer {
nlohmann::json flag_json;
flag_json["dump_as_text"] = flags.dump_as_text;
flag_json["enable_stitch_fusion"] = flags.enable_stitch_fusion;
flag_json["enable_recompute_fusion"] = flags.enable_recompute_fusion;
return flag_json.dump();

View File

@ -15,11 +15,31 @@
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include <nlohmann/json.hpp>
#include "ir/func_graph.h"
#include "backend/optimizer/common/pass.h"
namespace mindspore {
namespace kernel {
struct StitchInfo {
std::vector<std::string> stitch_ops;
std::vector<std::string> stitch_atomic_ops;
class SplitNodesDecoder {
SplitNodesDecoder() {}
~SplitNodesDecoder() = default;
static bool DecodeSplitNodes(const nlohmann::json &kernel_json,
const std::map<std::string, AnfNodePtr> &address_node_map, AnfNodePtrList *res_graphs);
} // namespace kernel
namespace opt {
class GraphKernelSplitter : public Pass {

View File

@ -181,6 +181,7 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
// Boolean flags
reg.AddFlag("dump_as_text", &dump_as_text);
reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion, opt_level == OptLevel_3);
reg.AddFlag("enable_recompute_fusion", &enable_recompute_fusion, opt_level == OptLevel_2);
reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion, opt_level == OptLevel_3);
// Integer flags
@ -203,6 +204,7 @@ std::string GraphKernelFlags::DumpAllFlags() const {
json["dump_as_text"] = dump_as_text;
json["enable_stitch_fusion"] = enable_stitch_fusion;
json["enable_recompute_fusion"] = enable_recompute_fusion;
json["enable_parallel_fusion"] = enable_parallel_fusion;
json["opt_level"] = opt_level;

View File

@ -67,6 +67,11 @@ class GraphKernelFlags {
bool enable_stitch_fusion;
* Enable recompute fusion in graph kernel fusion strategy.
bool enable_recompute_fusion{true};
* Enable parallel fusion in graph kernel fusion strategy.

View File

@ -0,0 +1,223 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn import Cell
import mindspore.ops.operations as P
#{cast} would be recompute and fused
class Net1(Cell):
def __init__(self):
super(Net1, self).__init__()
self.cast = P.Cast()
self.sum = P.ReduceSum(keep_dims=False)
def construct(self, x):
cast_res = self.cast(x, mstype.float32)
sum1_res = self.sum(cast_res, (0,))
sum2_res = self.sum(cast_res, (1,))
return sum1_res, sum2_res
#{sqrt} would be recompute on Ascend
class Net2(Cell):
def __init__(self):
super(Net2, self).__init__()
self.sqrt = P.Sqrt()
self.sum = P.ReduceSum(keep_dims=True)
self.add = P.Add()
self.neg = P.Neg()
def construct(self, x0, x1):
sqrt_res = self.sqrt(x0)
neg_res = self.neg(sqrt_res)
add_res = self.add(x1, sqrt_res)
sum_res = self.sum(add_res, (0,))
return neg_res, sum_res
#{sqrt} would be recompute
class Net3(Cell):
def __init__(self):
super(Net3, self).__init__()
self.sqrt = P.Sqrt()
self.add = P.Add()
self.neg = P.Neg()
def construct(self, x0, x1):
sqrt_res = self.sqrt(x0)
neg_res = self.neg(sqrt_res)
add_res = self.add(x1, sqrt_res)
return neg_res, add_res
#{sqrt neg} would be recompute
class Net4(Cell):
def __init__(self):
super(Net4, self).__init__()
self.sqrt = P.Sqrt()
self.neg = P.Neg()
self.sum = P.ReduceSum(keep_dims=False)
def construct(self, x):
sqrt_res = self.sqrt(x)
neg_res = self.neg(sqrt_res)
sum1_res = self.sum(neg_res, (0,))
sum2_res = self.sum(neg_res, (1,))
return sum1_res, sum2_res
#{sqrt} would be recompute
class Net5(Cell):
def __init__(self):
super(Net5, self).__init__()
self.sqrt = P.Sqrt()
self.add = P.Add()
def construct(self, x0, x1, x2):
sqrt_res = self.sqrt(x0)
add1_res = self.add(sqrt_res, x1)
add2_res = self.add(sqrt_res, x2)
return add1_res, add2_res
def test_basic1(net):
def get_output(i0, net, enable_graph_kernel=False):
net_obj = net()
output = net_obj(i0)
return output
i0 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float16))
expect = get_output(i0, net, False)
output = get_output(i0, net, True)
expect0_np = expect[0].asnumpy().copy()
output0_np = output[0].asnumpy().copy()
expect1_np = expect[1].asnumpy().copy()
output1_np = output[1].asnumpy().copy()
assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3)
assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3)
def test_basic2(net):
def get_output(i0, i1, net, enable_graph_kernel=False):
net_obj = net()
output = net_obj(i0, i1)
return output
i0 = Tensor(np.random.uniform(1, 2, [1, 1024]).astype(np.float32))
i1 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float32))
expect = get_output(i0, i1, net, False)
output = get_output(i0, i1, net, True)
expect0_np = expect[0].asnumpy().copy()
output0_np = output[0].asnumpy().copy()
expect1_np = expect[1].asnumpy().copy()
output1_np = output[1].asnumpy().copy()
assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3)
assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3)
def test_basic3(net):
def get_output(i0, i1, i2, net, enable_graph_kernel=False):
net_obj = net()
output = net_obj(i0, i1, i2)
return output
i0 = Tensor(np.random.uniform(1, 2, [1, 1024]).astype(np.float16))
i1 = Tensor(np.random.uniform(1, 2, [1024, 1024]).astype(np.float16))
i2 = Tensor(np.random.uniform(1, 2, [2048, 1024]).astype(np.float16))
expect = get_output(i0, i1, i2, net, False)
output = get_output(i0, i1, i2, net, True)
expect0_np = expect[0].asnumpy().copy()
output0_np = output[0].asnumpy().copy()
expect1_np = expect[1].asnumpy().copy()
output1_np = output[1].asnumpy().copy()
assert np.allclose(expect0_np, output0_np, 1.e-3, 1.e-3)
assert np.allclose(expect1_np, output1_np, 1.e-3, 1.e-3)
def test_gpu_1():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
def test_gpu_2():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
def test_gpu_3():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
def test_gpu_4():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
def test_gpu_5():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
def test_ascend_1():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def test_ascend_2():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def test_ascend_3():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def test_ascend_4():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def test_ascend_5():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")