!21603 Fix static-check warnings

Merge pull request !21603 from DeshiChen/0809_clean
This commit is contained in:
i-robot 2021-08-16 06:38:45 +00:00 committed by Gitee
commit ff07de80b4
14 changed files with 272 additions and 335 deletions

View File

@ -24,6 +24,7 @@ from .expand_dims import ExpandDims
@VLD.check_attrs('is_training', 'momentum', 'epsilon')
class BatchNorm(Expander):
"""BatchNorm expander"""
def _expand(self, graph_builder):
# get op info
input_x = self.inputs[0]
@ -42,81 +43,8 @@ class BatchNorm(Expander):
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})
if self.attrs['is_training']:
reduce_axis = ()
shape_x = input_x.shape
if input_x.data_format == DF.NHWC:
reduce_axis = (0, 1, 2)
num = shape_x[0] * shape_x[1] * shape_x[2]
else:
reduce_axis = (0, 2, 3)
num = shape_x[0] * shape_x[2] * shape_x[3]
num_rec = 1.0 / num
num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
# compute mean value of input_x
mean_sum = graph_builder.emit(
'ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
mean_muls = graph_builder.emit('Mul', [mean_sum, num_rec_v])
# compute variance of input_x
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
mean_muls_expand = graph_builder.emit(
'Reshape', [mean_muls], attrs={'shape': ExpandDims.infer_shape(mean_muls.shape, [-1, -1])})
else:
mean_muls_expand = mean_muls
var_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
var_mul = graph_builder.emit('Mul', [var_sub, var_sub])
var_sum = graph_builder.emit('ReduceSum', [var_mul], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
var_mul = graph_builder.emit('Mul', [var_sum, num_rec_v])
# y_sqrt_rec means 1 / sqrt(variance + epsilon), which is calculated in backward pass
scalar_one = 1.0
scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
y_add = graph_builder.emit('Add', [var_mul, epsilon_v])
y_sqrt = graph_builder.emit('Sqrt', [y_add])
y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt])
# compute res_y
tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
y_sqrt_rec_expand = graph_builder.emit(
'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])})
else:
y_sqrt_rec_expand = y_sqrt_rec
y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand])
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_scale_expand = graph_builder.emit(
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
else:
input_scale_expand = input_scale
res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm])
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_offset_expand = graph_builder.emit(
'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
else:
input_offset_expand = input_offset
res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand])
# compute mean_res
momentum_sub = scalar_one - self.attrs['momentum']
momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub)
new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean])
momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'])
current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls])
updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp])
mean_res = graph_builder.emit(
'InplaceAssign', [input_mean, updated_moving_mean, updated_moving_mean], attrs={'fake_output': True})
# variance_res is calculated by sample variance, and need to multiply by num / (num - 1)
var_num = float(num) / (num - 1)
var_num_v = graph_builder.value(input_scale.dtype, var_num)
var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul])
new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance])
current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update])
updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp])
variance_res = graph_builder.emit(
'InplaceAssign', [input_variance, updated_moving_variance, updated_moving_variance],
attrs={'fake_output': True})
self.inputs[0] = input_x
res_y, mean_res, variance_res, mean_muls, y_sqrt_rec = self._bn_train(graph_builder)
if input_x_new_type != input_x_ori_type:
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
@ -140,3 +68,88 @@ class BatchNorm(Expander):
if input_x_new_type != input_x_ori_type:
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
return res_y, var_add, var_add, var_add, var_add
def _bn_train(self, graph_builder):
"""expand BatchNorm for training mode"""
input_x = self.inputs[0]
input_scale = self.inputs[1]
input_offset = self.inputs[2]
input_mean = self.inputs[3]
input_variance = self.inputs[4]
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
reduce_axis = ()
shape_x = input_x.shape
if input_x.data_format == DF.NHWC:
reduce_axis = (0, 1, 2)
num = shape_x[0] * shape_x[1] * shape_x[2]
else:
reduce_axis = (0, 2, 3)
num = shape_x[0] * shape_x[2] * shape_x[3]
num_rec = 1.0 / num
num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
# compute mean value of input_x
mean_sum = graph_builder.emit(
'ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
mean_muls = graph_builder.emit('Mul', [mean_sum, num_rec_v])
# compute variance of input_x
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
mean_muls_expand = graph_builder.emit(
'Reshape', [mean_muls], attrs={'shape': ExpandDims.infer_shape(mean_muls.shape, [-1, -1])})
else:
mean_muls_expand = mean_muls
var_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
var_mul = graph_builder.emit('Mul', [var_sub, var_sub])
var_sum = graph_builder.emit('ReduceSum', [var_mul], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
var_mul = graph_builder.emit('Mul', [var_sum, num_rec_v])
# y_sqrt_rec means 1 / sqrt(variance + epsilon), which is calculated in backward pass
scalar_one = 1.0
scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
y_add = graph_builder.emit('Add', [var_mul, epsilon_v])
y_sqrt = graph_builder.emit('Sqrt', [y_add])
y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt])
# compute res_y
tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
y_sqrt_rec_expand = graph_builder.emit(
'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])})
else:
y_sqrt_rec_expand = y_sqrt_rec
y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand])
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_scale_expand = graph_builder.emit(
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
else:
input_scale_expand = input_scale
res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm])
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_offset_expand = graph_builder.emit(
'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
else:
input_offset_expand = input_offset
res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand])
# compute mean_res
momentum_sub = scalar_one - self.attrs['momentum']
momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub)
new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean])
momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'])
current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls])
updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp])
mean_res = graph_builder.emit(
'InplaceAssign', [input_mean, updated_moving_mean, updated_moving_mean], attrs={'fake_output': True})
# variance_res is calculated by sample variance, and need to multiply by num / (num - 1)
var_num = float(num) / (num - 1)
var_num_v = graph_builder.value(input_scale.dtype, var_num)
var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul])
new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance])
current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update])
updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp])
variance_res = graph_builder.emit(
'InplaceAssign', [input_variance, updated_moving_variance, updated_moving_variance],
attrs={'fake_output': True})
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec

View File

@ -48,24 +48,22 @@ class MatMul(Expander):
if input_num < 2:
raise GKException("matul inputs number should bigger than 1, but got {}.".format(input_num))
@staticmethod
def _trans_shape(shape):
trans_shape = list(shape)
trans_shape[-2] = shape[-1]
trans_shape[-1] = shape[-2]
return trans_shape
def _expand(self, graph_builder):
def transpose(shape):
trans_shape = list(shape)
trans_shape[-2] = shape[-1]
trans_shape[-1] = shape[-2]
return trans_shape
if not self._optimize_to_mul():
raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul")
# Matmul is replaced by Mul([b m k], [b k n]) when k==1
input_a = self.inputs[0]
input_b = self.inputs[1]
if self.transpose_a:
shape_a_trans = self._trans_shape(self.shape_a)
shape_a_trans = transpose(self.shape_a)
input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans})
if self.transpose_b:
shape_b_trans = self._trans_shape(self.shape_b)
shape_b_trans = transpose(self.shape_b)
input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans})
result = graph_builder.emit('Mul', [input_a, input_b])
if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']:

View File

@ -24,7 +24,7 @@ class MinimumGrad(Expander):
def _check(self):
if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
raise GKException("both grad_x and grad_y are False.")
return super()._check()
return super(MinimumGrad, self)._check()
def _expand(self, graph_builder):
input_x, input_y, input_dout = self.inputs
@ -34,7 +34,8 @@ class MinimumGrad(Expander):
dx = graph_builder.emit('Mul', [le_result, input_dout])
dy = graph_builder.emit('Sub', [input_dout, dx])
# for minimumgrad op, output_shape should be equal to input_shape, but some elementwise operating may broadcast input_shape
# for minimumgrad op, output_shape should be equal to input_shape,
# but some elementwise operating may broadcast input_shape
# then output_shape not equal to original input_shape, so need to reduce output to let them equal
reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape)
reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape)

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ===========================================================================
"""Cost model splitter"""
import os
from functools import reduce as prod_reduce
from mindspore import log as logger
from .model import PrimLib, Graph, Tensor, Operator
@ -39,20 +38,24 @@ class GraphSplitByPattern:
def sync(self, x, y):
"""sync from y to x"""
for i in self.alive:
if self.map[y][i] and not self.map[x][i]:
self.map[x][i] = True
self._link(self.map[y][i], x, i)
def _link(self, cond, f, t):
"""link from `f` to `t`"""
if cond:
self.map[f][t] = True
def fuse(self, x, y):
"""fuse y to x"""
for i in self.alive:
# i is the succeeding node of y, links the x's previous nodes to i
if self.map[y][i] and not self.map[x][i]:
for pre in self.alive:
if self.map[pre][x] and not self.map[pre][i]:
self.map[pre][i] = True
self._link(self.map[pre][x], pre, i)
# i is the previous node of y, link i to x's succeeding nodes
if self.map[i][y] and not self.map[i][x]:
for suc in self.alive:
if self.map[x][suc] and not self.map[i][suc]:
self.map[i][suc] = True
self._link(self.map[x][suc], i, suc)
self.alive.remove(y)
class Area:
@ -67,6 +70,10 @@ class GraphSplitByPattern:
self.stitch_ops = set()
self.stitch_atomic_ops = set()
def has_stitch_op(self):
"""check stitch_op exists"""
return self.stitch_ops or self.stitch_atomic_ops
def __init__(self, init_op, is_output, unique_id, reach_tab, recompute_ops=None):
self.pattern = PrimLib.iter_type(init_op) if init_op is not None else PrimLib.UNKNOWN
self.ops = [] if init_op is None else [init_op]
@ -286,31 +293,35 @@ class GraphSplitByPattern:
def fuse(self, selector):
"""Fuse areas"""
changed = False
while True:
def _fuse_area():
for dominant in self.areas:
result = selector(dominant)
if result is not None and result[0]:
fuse_areas, is_forward = result
fuse_areas = self.limit_area_size(dominant, fuse_areas)
if not fuse_areas:
continue
if is_forward:
for area in fuse_areas:
dominant.fuse(area)
self.set_area_map(area.ops, dominant)
self.areas.remove(area)
else:
forward_area = dominant
for area in fuse_areas:
area.fuse(forward_area)
self.set_area_map(forward_area.ops, area)
self.areas.remove(forward_area)
forward_area = area
changed = True
break
else:
return changed
if result is None or not result[0]:
continue
fuse_areas, is_forward = result
fuse_areas = self.limit_area_size(dominant, fuse_areas)
if not fuse_areas:
continue
if is_forward:
for area in fuse_areas:
dominant.fuse(area)
self.set_area_map(area.ops, dominant)
self.areas.remove(area)
else:
forward_area = dominant
for area in fuse_areas:
area.fuse(forward_area)
self.set_area_map(forward_area.ops, area)
self.areas.remove(forward_area)
forward_area = area
return True
return False
changed, do_again = False, True
while do_again:
do_again = _fuse_area()
changed = changed or do_again
return changed
def fuse_recom(self, selector):
"""Fuse recompute area to its user"""
@ -348,21 +359,6 @@ class GraphSplitByPattern:
graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite")
return subgraphs, graphmodes
def dump_subgraphs(self, subgraphs):
"""Dump subgraphs"""
if os.environ.get("ENABLE_SUBGRAPHS", "off") == "on":
subgraphs_str = "subgraphs:\nlen: " + str(len(subgraphs)) + "\n"
for i, sub in enumerate(subgraphs):
subgraphs_str += str("============") + str(i) + "\n"
subgraphs_str += str(sub)
dirname = 'subgraphs'
if not os.path.exists(dirname):
os.makedirs(dirname)
graphname = self.graph.name
filename = dirname + '/' + graphname + '.log'
with os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), 'w+') as f:
f.write(subgraphs_str)
def pattern_fuse(self, fuse_func=None):
"""fuse Areas by pattern repeatedly"""
del fuse_func
@ -376,34 +372,38 @@ class GraphSplitByPattern:
# Note: after this function, the input output relation is not maintained.
self.split_output_reshapes()
subgraphs, graphmodes = self.to_subgraphs()
self.dump_subgraphs(subgraphs)
return subgraphs, graphmodes
def split_output_reshapes(self):
"""Force split the output reshapes into other new """
"""Force split the output Reshapes into other new area"""
def _remove_output_reshape(reshape_ops, other_ops):
def _run():
for op in reshape_ops:
if any([to_op in other_ops for to_op in op.output.to_ops]):
reshape_ops.remove(op)
other_ops.append(op)
return True
return False
while _run():
pass
new_areas = []
for area in self.areas:
out_reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE]
remain_ops = [op for op in area.ops if op not in out_reshape_ops]
if not remain_ops or not out_reshape_ops:
reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE]
other_ops = [op for op in area.ops if op not in reshape_ops]
if not other_ops or not reshape_ops:
continue
changed = True
while changed:
changed = False
for op in out_reshape_ops:
if any([to_op in remain_ops for to_op in op.output.to_ops]):
out_reshape_ops.remove(op)
remain_ops.append(op)
changed = True
break
if out_reshape_ops:
for op in out_reshape_ops:
a = self.Area(op, False, 0, self.reach_tab)
self.set_default_mode(a)
new_areas.append(a)
area.ops = remain_ops
if len(remain_ops) == 1:
self.set_default_mode(area)
# remove the output reshape from "reshape_ops" and add it into "other_ops"
_remove_output_reshape(reshape_ops, other_ops)
if not reshape_ops:
continue
for op in reshape_ops:
a = self.Area(op, False, 0, self.reach_tab)
self.set_default_mode(a)
new_areas.append(a)
area.ops = other_ops
if len(other_ops) == 1:
self.set_default_mode(area)
if new_areas:
self.areas += new_areas
@ -533,14 +533,7 @@ class GraphSplitByPattern:
"""find recompute regions and copy them out to new Areas"""
def do_recompute_fuse():
"""split the unfusing pattern by add recompute area"""
recompute_suc = False
orig_areas = []
orig_areas.extend(self.areas)
for dom in orig_areas:
if dom not in self.areas or not dom.out_relations:
continue
cheap_regions = self.find_cheap_regions(dom)
dom_changed = False
def recompute_cheap_region(dom):
for cheap_region in cheap_regions:
user_areas = self.select_user_area(cheap_region[-1].output)
if not user_areas:
@ -550,12 +543,17 @@ class GraphSplitByPattern:
self.pattern_fuse(self.fuse_recom)
self.clear_recompute()
if self.recom_res:
recompute_suc = True
# Copy region at most once for this dom
dom_changed = True
break
if dom_changed:
break
return True
return False
recompute_suc = False
orig_areas = []
orig_areas.extend(self.areas)
for dom in orig_areas:
if dom not in self.areas or not dom.out_relations:
continue
cheap_regions = self.find_cheap_regions(dom)
if recompute_cheap_region(dom):
recompute_suc = True
return recompute_suc
if self.enable_recompute:
@ -563,9 +561,6 @@ class GraphSplitByPattern:
self.pattern_fuse()
use_poly_reduce = True
class GraphSplitGpu(GraphSplitByPattern):
"""Graph splitter"""
BORADCAST_FUSE_DEPTH = 20
@ -616,7 +611,7 @@ class GraphSplitGpu(GraphSplitByPattern):
return fused, True
def _broadcast_pat_exclude(dom, a, r):
if use_poly_reduce and a.pattern == PrimLib.REDUCE:
if a.pattern == PrimLib.REDUCE:
return dom.pattern > PrimLib.ELEMWISE or r > PrimLib.ELEMWISE
return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST
@ -641,34 +636,14 @@ class GraphSplitGpu(GraphSplitByPattern):
fused.append(a)
return fused, False
def _check_reduce_exclude(dom):
if use_poly_reduce:
return False
# exclude large all-reduce
if len(dom.ops[0].inputs[0].shape) == len(dom.ops[0].attrs["reduce_axis"]) and \
dom.ops[0].inputs[0].get_size() > 10000:
return True
# exclude multi output
for a in dom.in_relations.keys():
if len(a.out_relations) > 1:
return True
if any([op.output.para_type == Tensor.PARA_OUTPUT for op in a.ops]):
return True
return False
def _reduce_pat_exclude(_, a, r):
if len(a.ops) > self.REDUCE_FUSE_DEPTH:
return True
if use_poly_reduce:
return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST
return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE
return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST
def _reduce_depth(dom):
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
return None
if _check_reduce_exclude(dom):
return None
a, r = list(dom.in_relations.items())[0]
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
_is_atomic_add_available(dom):
@ -681,8 +656,6 @@ class GraphSplitGpu(GraphSplitByPattern):
def _reduce_width(dom):
if dom.pattern != PrimLib.REDUCE:
return None
if _check_reduce_exclude(dom):
return None
fused = []
for a, r in dom.in_relations.items():
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
@ -763,16 +736,16 @@ class GraphSplitGpu(GraphSplitByPattern):
def _may_stitch(dom, a, r):
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
if _reduce_nums(a.ops) < 2:
dom_outs = [op.output for op in dom.ops]
a_ins = [op_input for op in a.ops for op_input in op.inputs]
a_outs = [op.output for op in a.ops]
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
if _same_stitch_axis(stitch_tensors, a_final_outs):
for tensor in stitch_tensors:
if _tensor_size(tensor) >= 1024 * 1024:
return True
if _reduce_nums(a.ops) >= 2:
return False
dom_outs = [op.output for op in dom.ops]
a_ins = [op_input for op in a.ops for op_input in op.inputs]
a_outs = [op.output for op in a.ops]
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
if not _same_stitch_axis(stitch_tensors, a_final_outs):
return False
return any([_tensor_size(tensor) >= 1024 * 1024 for tensor in stitch_tensors])
return False
def _reduce_stitch(dom):
@ -785,14 +758,15 @@ class GraphSplitGpu(GraphSplitByPattern):
fused = []
for a, r in dom.out_relations.items():
if _may_stitch(dom, a, r):
if a.pattern == PrimLib.REDUCE:
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
fused.append(a)
elif a.pattern == PrimLib.BROADCAST:
if not _may_stitch(dom, a, r):
continue
if a.pattern == PrimLib.REDUCE:
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
fused.append(a)
elif a.pattern == PrimLib.BROADCAST:
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
fused.append(a)
return fused, False
def _transpose(dom):
@ -825,10 +799,9 @@ class GraphSplitGpu(GraphSplitByPattern):
changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed
changed = self.fuse(_strided_slice) or changed
if use_poly_reduce:
changed = self.fuse(_reduce_output) or changed
if enable_stitch_fusion:
changed = self.fuse(_reduce_stitch) or changed
changed = self.fuse(_reduce_output) or changed
if enable_stitch_fusion:
changed = self.fuse(_reduce_stitch) or changed
self.fuse(_transpose)
def _fuse_once(fuse_func):
@ -836,9 +809,8 @@ class GraphSplitGpu(GraphSplitByPattern):
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
fuse_func(_broadcast_width):
return
if use_poly_reduce:
if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)):
return
if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)):
return
fuse_func(_transpose)
return

View File

@ -422,14 +422,13 @@ class Graph:
for t in op.inputs:
if t not in inputs and t.op not in self.ops:
inputs.append(t)
if op.output not in outputs:
if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
outputs.append(op.output)
else:
for d in op.output.to_ops:
if d not in self.ops:
outputs.append(op.output)
break
if op.output in outputs:
continue
if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
outputs.append(op.output)
continue
if any([succ not in self.ops for succ in op.output.to_ops]):
outputs.append(op.output)
if self.inputs:
inputs = self.inputs

View File

@ -81,7 +81,6 @@ class GraphBuilder:
"""Create a new Value"""
if name in (None, ''):
name = self._alloc_tensor_name()
v = Value(name, dtype, value)
return v
@ -128,34 +127,14 @@ class CompositeGraph:
def load(self, desc):
"""Load Graph from json"""
def _attr_of(op, inputs, output):
def _get_axis_while_none(input_shape, output_shape):
red_axis = []
if len(output_shape) == len(input_shape):
for i, s in enumerate(output_shape):
if s == 1 and input_shape[i] > 1:
red_axis.append(i)
else:
red_axis = list(range(len(output_shape)))
return red_axis
def _attr_of(op):
if not op['attr']:
return dict()
attr = {}
if op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
for a in op['attr']:
if a['name'] == 'axis':
red_axis, dim_size = [], len(inputs[0].shape)
if not a['value']:
red_axis = _get_axis_while_none(inputs[0].shape, output.shape)
else:
if isinstance(a['value'], int):
a['value'] = [a['value']]
for i in a['value']:
red_axis.append(i if i >= 0 else dim_size + i)
attr['reduce_axis'] = red_axis
if a['name'] == "reduce_output_fuse":
attr['reduce_output_fuse'] = a['value']
elif op['attr']:
for a in op['attr']:
for a in op['attr']:
if a['name'] == 'axis' and op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
attr['reduce_axis'] = a['value']
else:
attr[a['name']] = a['value']
return attr
@ -171,7 +150,6 @@ class CompositeGraph:
'shape'], out_desc['data_type'], out_desc['format']
self.tensors[name] = builder.tensor(
shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT)
cur_fusion = None
for op in desc['op_desc']:
inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d]
out_desc = op['output_desc']
@ -182,21 +160,12 @@ class CompositeGraph:
inputs[1].para_type = Tensor.PARA_OUTPUT
output = inputs[2]
self.tensors[name] = output
else:
output = self.tensors.get(name, None)
if not output:
output = builder.tensor(
shape, dtype, data_format, name=name)
self.tensors[name] = output
builder.op(op['name'], output, inputs,
attrs=_attr_of(op, inputs, output))
if 'fusion' in op:
if cur_fusion is None:
cur_fusion = output
else:
cur_fusion.add_buddy(output)
if op['fusion'].endswith('_end'):
cur_fusion = None
continue
output = self.tensors.get(name, None)
if not output:
output = builder.tensor(shape, dtype, data_format, name=name)
self.tensors[name] = output
builder.op(op['name'], output, inputs, attrs=_attr_of(op))
self.graph = builder.get()[0]
self.desc = desc
@ -234,43 +203,40 @@ class CompositeGraph:
inputs, outputs = subgraph.deduce_parameters()
graph_ops = set(subgraph.ops)
inplace_assign, inplace_assign_z = self._pre_dump(outputs)
for key in self.desc:
def dump_output(t):
if t.name in inplace_assign:
z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name]
return {'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign[t.name]}
return {'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}
def dump_op_desc(d):
if d['name'] == 'InplaceAssign':
y = d['input_desc'][1][0]['tensor_name']
if self.tensors[y].op in graph_ops:
z, fake = (inplace_assign_z, False) if inplace_assign_z is not None else (self.tensors[y], True)
inplace_desc = copy.deepcopy(d)
inplace_desc['attr'] = {'name': 'fake_output', 'value': fake}
z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0]
z_desc['shape'] = z.shape
z_desc['data_type'] = z.dtype
z_desc['tensor_name'] = z.name
out_desc['shape'] = z.shape
out_desc['data_type'] = z.dtype
return inplace_desc
op = self.tensors[d['output_desc'][0]['tensor_name']].op
if op in graph_ops or op in subgraph.recompute_ops:
return d
return None
for key in self.desc.keys():
if key == 'input_desc':
desc[key] = [
[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs]
desc[key] = [[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs]
elif key == 'output_desc':
out_desc = []
for t in outputs:
if t.name in inplace_assign:
z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name]
out_desc.append(
{'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign[t.name]})
else:
out_desc.append(
{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name})
desc[key] = out_desc
desc[key] = list(map(dump_output, outputs))
elif key == 'op_desc':
op_desc = []
for d in self.desc[key]:
if d['name'] == 'InplaceAssign':
y = d['input_desc'][1][0]['tensor_name']
if self.tensors[y].op in graph_ops:
z, fake = (inplace_assign_z, False) if inplace_assign_z is not None else (
self.tensors[y], True)
inplace_desc = copy.deepcopy(d)
inplace_desc['attr'] = {'name': 'fake_output', 'value': fake}
z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0]
z_desc['shape'] = z.shape
z_desc['data_type'] = z.dtype
z_desc['tensor_name'] = z.name
out_desc['shape'] = z.shape
out_desc['data_type'] = z.dtype
op_desc.append(inplace_desc)
else:
op = self.tensors[d['output_desc'][0]['tensor_name']].op
if op in graph_ops or op in subgraph.recompute_ops:
op_desc.append(d)
desc[key] = op_desc
op_desc = map(dump_op_desc, self.desc[key])
desc[key] = [d for d in op_desc if d is not None]
elif key == 'op':
desc[key] = subgraph.name
else:

View File

@ -109,11 +109,12 @@ class _Elemwise(OpInfer):
out_shape = [1] * dim_size
for i in range(dim_size):
for align_shape in align_shapes:
if align_shape[i] > 1:
if out_shape[i] == 1:
out_shape[i] = align_shape[i]
if out_shape[i] != align_shape[i]:
raise GKException("shape broadcast failed!")
if align_shape[i] == 1:
continue
if out_shape[i] == 1:
out_shape[i] = align_shape[i]
elif out_shape[i] != align_shape[i]:
raise GKException("shape broadcast failed!")
return out_shape
@staticmethod

View File

@ -57,11 +57,11 @@ def _dump_split_info(flags, graph_json, graph_desc, subgraphs, graph_mode):
return
utils.create_dir(utils.GRAPH_KERNEL_DUMP_PATH)
filename = os.path.join(utils.GRAPH_KERNEL_DUMP_PATH, "graph_kernel_split_mode.txt")
with open(filename, "a+") as f:
with os.fdopen(os.open(filename, os.O_WRONLY | os.O_CREAT), "a+") as f:
f.write("********** main graph: {} **********\n".format(graph_desc.name))
f.write("input json:\n{}\n".format(graph_json))
f.write("graph desc:\n{}\n".format(str(graph_desc)))
if len(subgraphs) > 1:
if len(subgraphs) > 1 or subgraphs[0].stitch_info.has_stitch_op():
for i, g in enumerate(subgraphs):
f.write("-------- subgraph {}, mode: {} --------\n".format(i, graph_mode[i]))
f.write("{}\n".format(str(g)))

View File

@ -26,3 +26,5 @@ def create_dir(pathname):
os.mkdir(pathname)
except OSError:
pass
finally:
pass

View File

@ -69,6 +69,7 @@ bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) const {
}
if (diff) {
changed = true;
std::sort(axis_vec.begin(), axis_vec.end());
SetNodeAttrSafely(kAttrAxis, MakeValue(axis_vec), node);
}
}

View File

@ -63,7 +63,7 @@ bool DoFuse(const FuncGraphPtr &func_graph) {
if (cnode->size() != 4) {
continue;
}
auto cast_node = cnode->input(3);
auto cast_node = cnode->inputs().back(); // bias node
if (!IsPrimitiveCNode(cast_node, prim::kPrimCast)) {
continue;
}

View File

@ -239,13 +239,7 @@ bool InsertPadUnpad(const FuncGraphPtr &func_graph) {
if (!AnfAlgo::CheckPrimitiveType(n, prim::kPrimMatMul)) continue;
auto mm_cnode = n->cast<CNodePtr>();
vec pad_shape_a, pad_shape_b, tail_shape_a, tail_shape_b, tail_shape_unpad, unpad_shape;
bool pad_K, pad_M, pad_N;
pad_shape_a.clear();
pad_shape_b.clear();
tail_shape_a.clear();
tail_shape_b.clear();
tail_shape_unpad.clear();
unpad_shape.clear();
bool pad_K{false}, pad_M{false}, pad_N{false};
std::tie(pad_K, pad_M, pad_N) =
NeedPad(mm_cnode, &pad_shape_a, &pad_shape_b, &unpad_shape, &tail_shape_a, &tail_shape_b, &tail_shape_unpad);
if (!pad_K && !pad_M && !pad_N) continue;

View File

@ -37,7 +37,7 @@ const BaseRef SplitAssign::DefinePattern() const {
bool CanSplit(const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimAssign); }
AnfNodePtr ProcessNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, int input_idx) {
AnfNodePtr ProcessNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(node);
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
@ -46,16 +46,14 @@ AnfNodePtr ProcessNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, i
AbstractBasePtr original_abstract = cnode->abstract()->Clone();
auto original_inputs = cnode->inputs();
int input_node_size = cnode->size() - 1;
// Create depend node
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[input_idx],
original_inputs[input_node_size]};
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[input_idx], original_inputs.back()};
auto depend_cnode = func_graph->NewCNode(depend_inputs);
depend_cnode->set_abstract(original_inputs[input_idx]->abstract());
depend_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
// Create new node, delete U from inputs.
AnfNodePtrList new_inputs = {cnode->input(0)};
for (int i = 1; i < input_node_size; i++) {
for (size_t i = 1; i + 1 < cnode->size(); i++) {
if (i == input_idx) {
new_inputs.push_back(depend_cnode);
} else {
@ -77,19 +75,11 @@ const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfN
AnfNodePtr OpUMonadExpander::Run(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
bool has_umonad = false;
for (unsigned int i = 1; i < cnode->size(); i++) {
if (HasAbstractUMonad(cnode->input(i))) {
has_umonad = true;
break;
}
}
if (has_umonad) {
// assume the UMonad node is the last input
if (cnode->size() > 1 && HasAbstractUMonad(cnode->inputs().back())) {
auto new_node = ProcessNode(node->func_graph(), node, input_idx_);
return DefaultExpander::Run(new_node);
}
return DefaultExpander::Run(node);
}
} // namespace opt

View File

@ -30,12 +30,12 @@ class SplitAssign : public PatternProcessPass {
class OpUMonadExpander : public DefaultExpander {
public:
explicit OpUMonadExpander(int input_idx) : input_idx_(input_idx) {}
explicit OpUMonadExpander(size_t input_idx) : input_idx_(input_idx) {}
~OpUMonadExpander() = default;
AnfNodePtr Run(const AnfNodePtr &node) override;
private:
int input_idx_;
size_t input_idx_;
};
} // namespace opt
} // namespace mindspore