forked from mindspore-Ecosystem/mindspore
!21603 Fix static-check warnings
Merge pull request !21603 from DeshiChen/0809_clean
This commit is contained in:
commit
ff07de80b4
|
@ -24,6 +24,7 @@ from .expand_dims import ExpandDims
|
|||
@VLD.check_attrs('is_training', 'momentum', 'epsilon')
|
||||
class BatchNorm(Expander):
|
||||
"""BatchNorm expander"""
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
# get op info
|
||||
input_x = self.inputs[0]
|
||||
|
@ -42,81 +43,8 @@ class BatchNorm(Expander):
|
|||
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})
|
||||
|
||||
if self.attrs['is_training']:
|
||||
reduce_axis = ()
|
||||
shape_x = input_x.shape
|
||||
if input_x.data_format == DF.NHWC:
|
||||
reduce_axis = (0, 1, 2)
|
||||
num = shape_x[0] * shape_x[1] * shape_x[2]
|
||||
else:
|
||||
reduce_axis = (0, 2, 3)
|
||||
num = shape_x[0] * shape_x[2] * shape_x[3]
|
||||
num_rec = 1.0 / num
|
||||
num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
|
||||
|
||||
# compute mean value of input_x
|
||||
mean_sum = graph_builder.emit(
|
||||
'ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
||||
mean_muls = graph_builder.emit('Mul', [mean_sum, num_rec_v])
|
||||
|
||||
# compute variance of input_x
|
||||
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
||||
mean_muls_expand = graph_builder.emit(
|
||||
'Reshape', [mean_muls], attrs={'shape': ExpandDims.infer_shape(mean_muls.shape, [-1, -1])})
|
||||
else:
|
||||
mean_muls_expand = mean_muls
|
||||
var_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
|
||||
var_mul = graph_builder.emit('Mul', [var_sub, var_sub])
|
||||
var_sum = graph_builder.emit('ReduceSum', [var_mul], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
||||
var_mul = graph_builder.emit('Mul', [var_sum, num_rec_v])
|
||||
|
||||
# y_sqrt_rec means 1 / sqrt(variance + epsilon), which is calculated in backward pass
|
||||
scalar_one = 1.0
|
||||
scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
|
||||
y_add = graph_builder.emit('Add', [var_mul, epsilon_v])
|
||||
y_sqrt = graph_builder.emit('Sqrt', [y_add])
|
||||
y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt])
|
||||
|
||||
# compute res_y
|
||||
tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
|
||||
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
||||
y_sqrt_rec_expand = graph_builder.emit(
|
||||
'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])})
|
||||
else:
|
||||
y_sqrt_rec_expand = y_sqrt_rec
|
||||
y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand])
|
||||
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
||||
input_scale_expand = graph_builder.emit(
|
||||
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
|
||||
else:
|
||||
input_scale_expand = input_scale
|
||||
res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm])
|
||||
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
||||
input_offset_expand = graph_builder.emit(
|
||||
'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
|
||||
else:
|
||||
input_offset_expand = input_offset
|
||||
res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand])
|
||||
|
||||
# compute mean_res
|
||||
momentum_sub = scalar_one - self.attrs['momentum']
|
||||
momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub)
|
||||
new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean])
|
||||
momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'])
|
||||
current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls])
|
||||
updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp])
|
||||
mean_res = graph_builder.emit(
|
||||
'InplaceAssign', [input_mean, updated_moving_mean, updated_moving_mean], attrs={'fake_output': True})
|
||||
|
||||
# variance_res is calculated by sample variance, and need to multiply by num / (num - 1)
|
||||
var_num = float(num) / (num - 1)
|
||||
var_num_v = graph_builder.value(input_scale.dtype, var_num)
|
||||
var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul])
|
||||
new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance])
|
||||
current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update])
|
||||
updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp])
|
||||
variance_res = graph_builder.emit(
|
||||
'InplaceAssign', [input_variance, updated_moving_variance, updated_moving_variance],
|
||||
attrs={'fake_output': True})
|
||||
self.inputs[0] = input_x
|
||||
res_y, mean_res, variance_res, mean_muls, y_sqrt_rec = self._bn_train(graph_builder)
|
||||
if input_x_new_type != input_x_ori_type:
|
||||
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
|
||||
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
|
||||
|
@ -140,3 +68,88 @@ class BatchNorm(Expander):
|
|||
if input_x_new_type != input_x_ori_type:
|
||||
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
|
||||
return res_y, var_add, var_add, var_add, var_add
|
||||
|
||||
def _bn_train(self, graph_builder):
|
||||
"""expand BatchNorm for training mode"""
|
||||
input_x = self.inputs[0]
|
||||
input_scale = self.inputs[1]
|
||||
input_offset = self.inputs[2]
|
||||
input_mean = self.inputs[3]
|
||||
input_variance = self.inputs[4]
|
||||
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
|
||||
reduce_axis = ()
|
||||
shape_x = input_x.shape
|
||||
if input_x.data_format == DF.NHWC:
|
||||
reduce_axis = (0, 1, 2)
|
||||
num = shape_x[0] * shape_x[1] * shape_x[2]
|
||||
else:
|
||||
reduce_axis = (0, 2, 3)
|
||||
num = shape_x[0] * shape_x[2] * shape_x[3]
|
||||
num_rec = 1.0 / num
|
||||
num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
|
||||
|
||||
# compute mean value of input_x
|
||||
mean_sum = graph_builder.emit(
|
||||
'ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
||||
mean_muls = graph_builder.emit('Mul', [mean_sum, num_rec_v])
|
||||
|
||||
# compute variance of input_x
|
||||
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
||||
mean_muls_expand = graph_builder.emit(
|
||||
'Reshape', [mean_muls], attrs={'shape': ExpandDims.infer_shape(mean_muls.shape, [-1, -1])})
|
||||
else:
|
||||
mean_muls_expand = mean_muls
|
||||
var_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
|
||||
var_mul = graph_builder.emit('Mul', [var_sub, var_sub])
|
||||
var_sum = graph_builder.emit('ReduceSum', [var_mul], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
||||
var_mul = graph_builder.emit('Mul', [var_sum, num_rec_v])
|
||||
|
||||
# y_sqrt_rec means 1 / sqrt(variance + epsilon), which is calculated in backward pass
|
||||
scalar_one = 1.0
|
||||
scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
|
||||
y_add = graph_builder.emit('Add', [var_mul, epsilon_v])
|
||||
y_sqrt = graph_builder.emit('Sqrt', [y_add])
|
||||
y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt])
|
||||
|
||||
# compute res_y
|
||||
tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
|
||||
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
||||
y_sqrt_rec_expand = graph_builder.emit(
|
||||
'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])})
|
||||
else:
|
||||
y_sqrt_rec_expand = y_sqrt_rec
|
||||
y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand])
|
||||
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
||||
input_scale_expand = graph_builder.emit(
|
||||
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
|
||||
else:
|
||||
input_scale_expand = input_scale
|
||||
res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm])
|
||||
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
||||
input_offset_expand = graph_builder.emit(
|
||||
'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
|
||||
else:
|
||||
input_offset_expand = input_offset
|
||||
res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand])
|
||||
|
||||
# compute mean_res
|
||||
momentum_sub = scalar_one - self.attrs['momentum']
|
||||
momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub)
|
||||
new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean])
|
||||
momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'])
|
||||
current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls])
|
||||
updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp])
|
||||
mean_res = graph_builder.emit(
|
||||
'InplaceAssign', [input_mean, updated_moving_mean, updated_moving_mean], attrs={'fake_output': True})
|
||||
|
||||
# variance_res is calculated by sample variance, and need to multiply by num / (num - 1)
|
||||
var_num = float(num) / (num - 1)
|
||||
var_num_v = graph_builder.value(input_scale.dtype, var_num)
|
||||
var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul])
|
||||
new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance])
|
||||
current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update])
|
||||
updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp])
|
||||
variance_res = graph_builder.emit(
|
||||
'InplaceAssign', [input_variance, updated_moving_variance, updated_moving_variance],
|
||||
attrs={'fake_output': True})
|
||||
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
|
||||
|
|
|
@ -48,24 +48,22 @@ class MatMul(Expander):
|
|||
if input_num < 2:
|
||||
raise GKException("matul inputs number should bigger than 1, but got {}.".format(input_num))
|
||||
|
||||
@staticmethod
|
||||
def _trans_shape(shape):
|
||||
trans_shape = list(shape)
|
||||
trans_shape[-2] = shape[-1]
|
||||
trans_shape[-1] = shape[-2]
|
||||
return trans_shape
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
def transpose(shape):
|
||||
trans_shape = list(shape)
|
||||
trans_shape[-2] = shape[-1]
|
||||
trans_shape[-1] = shape[-2]
|
||||
return trans_shape
|
||||
if not self._optimize_to_mul():
|
||||
raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul")
|
||||
# Matmul is replaced by Mul([b m k], [b k n]) when k==1
|
||||
input_a = self.inputs[0]
|
||||
input_b = self.inputs[1]
|
||||
if self.transpose_a:
|
||||
shape_a_trans = self._trans_shape(self.shape_a)
|
||||
shape_a_trans = transpose(self.shape_a)
|
||||
input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans})
|
||||
if self.transpose_b:
|
||||
shape_b_trans = self._trans_shape(self.shape_b)
|
||||
shape_b_trans = transpose(self.shape_b)
|
||||
input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans})
|
||||
result = graph_builder.emit('Mul', [input_a, input_b])
|
||||
if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']:
|
||||
|
|
|
@ -24,7 +24,7 @@ class MinimumGrad(Expander):
|
|||
def _check(self):
|
||||
if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
|
||||
raise GKException("both grad_x and grad_y are False.")
|
||||
return super()._check()
|
||||
return super(MinimumGrad, self)._check()
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x, input_y, input_dout = self.inputs
|
||||
|
@ -34,7 +34,8 @@ class MinimumGrad(Expander):
|
|||
dx = graph_builder.emit('Mul', [le_result, input_dout])
|
||||
dy = graph_builder.emit('Sub', [input_dout, dx])
|
||||
|
||||
# for minimumgrad op, output_shape should be equal to input_shape, but some elementwise operating may broadcast input_shape
|
||||
# for minimumgrad op, output_shape should be equal to input_shape,
|
||||
# but some elementwise operating may broadcast input_shape
|
||||
# then output_shape not equal to original input_shape, so need to reduce output to let them equal
|
||||
reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape)
|
||||
reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape)
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""Cost model splitter"""
|
||||
import os
|
||||
from functools import reduce as prod_reduce
|
||||
from mindspore import log as logger
|
||||
from .model import PrimLib, Graph, Tensor, Operator
|
||||
|
@ -39,20 +38,24 @@ class GraphSplitByPattern:
|
|||
def sync(self, x, y):
|
||||
"""sync from y to x"""
|
||||
for i in self.alive:
|
||||
if self.map[y][i] and not self.map[x][i]:
|
||||
self.map[x][i] = True
|
||||
self._link(self.map[y][i], x, i)
|
||||
|
||||
def _link(self, cond, f, t):
|
||||
"""link from `f` to `t`"""
|
||||
if cond:
|
||||
self.map[f][t] = True
|
||||
|
||||
def fuse(self, x, y):
|
||||
"""fuse y to x"""
|
||||
for i in self.alive:
|
||||
# i is the succeeding node of y, links the x's previous nodes to i
|
||||
if self.map[y][i] and not self.map[x][i]:
|
||||
for pre in self.alive:
|
||||
if self.map[pre][x] and not self.map[pre][i]:
|
||||
self.map[pre][i] = True
|
||||
self._link(self.map[pre][x], pre, i)
|
||||
# i is the previous node of y, link i to x's succeeding nodes
|
||||
if self.map[i][y] and not self.map[i][x]:
|
||||
for suc in self.alive:
|
||||
if self.map[x][suc] and not self.map[i][suc]:
|
||||
self.map[i][suc] = True
|
||||
self._link(self.map[x][suc], i, suc)
|
||||
self.alive.remove(y)
|
||||
|
||||
class Area:
|
||||
|
@ -67,6 +70,10 @@ class GraphSplitByPattern:
|
|||
self.stitch_ops = set()
|
||||
self.stitch_atomic_ops = set()
|
||||
|
||||
def has_stitch_op(self):
|
||||
"""check stitch_op exists"""
|
||||
return self.stitch_ops or self.stitch_atomic_ops
|
||||
|
||||
def __init__(self, init_op, is_output, unique_id, reach_tab, recompute_ops=None):
|
||||
self.pattern = PrimLib.iter_type(init_op) if init_op is not None else PrimLib.UNKNOWN
|
||||
self.ops = [] if init_op is None else [init_op]
|
||||
|
@ -286,31 +293,35 @@ class GraphSplitByPattern:
|
|||
|
||||
def fuse(self, selector):
|
||||
"""Fuse areas"""
|
||||
changed = False
|
||||
while True:
|
||||
def _fuse_area():
|
||||
for dominant in self.areas:
|
||||
result = selector(dominant)
|
||||
if result is not None and result[0]:
|
||||
fuse_areas, is_forward = result
|
||||
fuse_areas = self.limit_area_size(dominant, fuse_areas)
|
||||
if not fuse_areas:
|
||||
continue
|
||||
if is_forward:
|
||||
for area in fuse_areas:
|
||||
dominant.fuse(area)
|
||||
self.set_area_map(area.ops, dominant)
|
||||
self.areas.remove(area)
|
||||
else:
|
||||
forward_area = dominant
|
||||
for area in fuse_areas:
|
||||
area.fuse(forward_area)
|
||||
self.set_area_map(forward_area.ops, area)
|
||||
self.areas.remove(forward_area)
|
||||
forward_area = area
|
||||
changed = True
|
||||
break
|
||||
else:
|
||||
return changed
|
||||
if result is None or not result[0]:
|
||||
continue
|
||||
fuse_areas, is_forward = result
|
||||
fuse_areas = self.limit_area_size(dominant, fuse_areas)
|
||||
if not fuse_areas:
|
||||
continue
|
||||
if is_forward:
|
||||
for area in fuse_areas:
|
||||
dominant.fuse(area)
|
||||
self.set_area_map(area.ops, dominant)
|
||||
self.areas.remove(area)
|
||||
else:
|
||||
forward_area = dominant
|
||||
for area in fuse_areas:
|
||||
area.fuse(forward_area)
|
||||
self.set_area_map(forward_area.ops, area)
|
||||
self.areas.remove(forward_area)
|
||||
forward_area = area
|
||||
return True
|
||||
return False
|
||||
|
||||
changed, do_again = False, True
|
||||
while do_again:
|
||||
do_again = _fuse_area()
|
||||
changed = changed or do_again
|
||||
return changed
|
||||
|
||||
def fuse_recom(self, selector):
|
||||
"""Fuse recompute area to its user"""
|
||||
|
@ -348,21 +359,6 @@ class GraphSplitByPattern:
|
|||
graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite")
|
||||
return subgraphs, graphmodes
|
||||
|
||||
def dump_subgraphs(self, subgraphs):
|
||||
"""Dump subgraphs"""
|
||||
if os.environ.get("ENABLE_SUBGRAPHS", "off") == "on":
|
||||
subgraphs_str = "subgraphs:\nlen: " + str(len(subgraphs)) + "\n"
|
||||
for i, sub in enumerate(subgraphs):
|
||||
subgraphs_str += str("============") + str(i) + "\n"
|
||||
subgraphs_str += str(sub)
|
||||
dirname = 'subgraphs'
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
graphname = self.graph.name
|
||||
filename = dirname + '/' + graphname + '.log'
|
||||
with os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), 'w+') as f:
|
||||
f.write(subgraphs_str)
|
||||
|
||||
def pattern_fuse(self, fuse_func=None):
|
||||
"""fuse Areas by pattern repeatedly"""
|
||||
del fuse_func
|
||||
|
@ -376,34 +372,38 @@ class GraphSplitByPattern:
|
|||
# Note: after this function, the input output relation is not maintained.
|
||||
self.split_output_reshapes()
|
||||
subgraphs, graphmodes = self.to_subgraphs()
|
||||
self.dump_subgraphs(subgraphs)
|
||||
return subgraphs, graphmodes
|
||||
|
||||
def split_output_reshapes(self):
|
||||
"""Force split the output reshapes into other new """
|
||||
"""Force split the output Reshapes into other new area"""
|
||||
def _remove_output_reshape(reshape_ops, other_ops):
|
||||
def _run():
|
||||
for op in reshape_ops:
|
||||
if any([to_op in other_ops for to_op in op.output.to_ops]):
|
||||
reshape_ops.remove(op)
|
||||
other_ops.append(op)
|
||||
return True
|
||||
return False
|
||||
while _run():
|
||||
pass
|
||||
|
||||
new_areas = []
|
||||
for area in self.areas:
|
||||
out_reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE]
|
||||
remain_ops = [op for op in area.ops if op not in out_reshape_ops]
|
||||
if not remain_ops or not out_reshape_ops:
|
||||
reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE]
|
||||
other_ops = [op for op in area.ops if op not in reshape_ops]
|
||||
if not other_ops or not reshape_ops:
|
||||
continue
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for op in out_reshape_ops:
|
||||
if any([to_op in remain_ops for to_op in op.output.to_ops]):
|
||||
out_reshape_ops.remove(op)
|
||||
remain_ops.append(op)
|
||||
changed = True
|
||||
break
|
||||
if out_reshape_ops:
|
||||
for op in out_reshape_ops:
|
||||
a = self.Area(op, False, 0, self.reach_tab)
|
||||
self.set_default_mode(a)
|
||||
new_areas.append(a)
|
||||
area.ops = remain_ops
|
||||
if len(remain_ops) == 1:
|
||||
self.set_default_mode(area)
|
||||
# remove the output reshape from "reshape_ops" and add it into "other_ops"
|
||||
_remove_output_reshape(reshape_ops, other_ops)
|
||||
if not reshape_ops:
|
||||
continue
|
||||
for op in reshape_ops:
|
||||
a = self.Area(op, False, 0, self.reach_tab)
|
||||
self.set_default_mode(a)
|
||||
new_areas.append(a)
|
||||
area.ops = other_ops
|
||||
if len(other_ops) == 1:
|
||||
self.set_default_mode(area)
|
||||
if new_areas:
|
||||
self.areas += new_areas
|
||||
|
||||
|
@ -533,14 +533,7 @@ class GraphSplitByPattern:
|
|||
"""find recompute regions and copy them out to new Areas"""
|
||||
def do_recompute_fuse():
|
||||
"""split the unfusing pattern by add recompute area"""
|
||||
recompute_suc = False
|
||||
orig_areas = []
|
||||
orig_areas.extend(self.areas)
|
||||
for dom in orig_areas:
|
||||
if dom not in self.areas or not dom.out_relations:
|
||||
continue
|
||||
cheap_regions = self.find_cheap_regions(dom)
|
||||
dom_changed = False
|
||||
def recompute_cheap_region(dom):
|
||||
for cheap_region in cheap_regions:
|
||||
user_areas = self.select_user_area(cheap_region[-1].output)
|
||||
if not user_areas:
|
||||
|
@ -550,12 +543,17 @@ class GraphSplitByPattern:
|
|||
self.pattern_fuse(self.fuse_recom)
|
||||
self.clear_recompute()
|
||||
if self.recom_res:
|
||||
recompute_suc = True
|
||||
# Copy region at most once for this dom
|
||||
dom_changed = True
|
||||
break
|
||||
if dom_changed:
|
||||
break
|
||||
return True
|
||||
return False
|
||||
recompute_suc = False
|
||||
orig_areas = []
|
||||
orig_areas.extend(self.areas)
|
||||
for dom in orig_areas:
|
||||
if dom not in self.areas or not dom.out_relations:
|
||||
continue
|
||||
cheap_regions = self.find_cheap_regions(dom)
|
||||
if recompute_cheap_region(dom):
|
||||
recompute_suc = True
|
||||
return recompute_suc
|
||||
|
||||
if self.enable_recompute:
|
||||
|
@ -563,9 +561,6 @@ class GraphSplitByPattern:
|
|||
self.pattern_fuse()
|
||||
|
||||
|
||||
use_poly_reduce = True
|
||||
|
||||
|
||||
class GraphSplitGpu(GraphSplitByPattern):
|
||||
"""Graph splitter"""
|
||||
BORADCAST_FUSE_DEPTH = 20
|
||||
|
@ -616,7 +611,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
return fused, True
|
||||
|
||||
def _broadcast_pat_exclude(dom, a, r):
|
||||
if use_poly_reduce and a.pattern == PrimLib.REDUCE:
|
||||
if a.pattern == PrimLib.REDUCE:
|
||||
return dom.pattern > PrimLib.ELEMWISE or r > PrimLib.ELEMWISE
|
||||
return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST
|
||||
|
||||
|
@ -641,34 +636,14 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
def _check_reduce_exclude(dom):
|
||||
if use_poly_reduce:
|
||||
return False
|
||||
# exclude large all-reduce
|
||||
if len(dom.ops[0].inputs[0].shape) == len(dom.ops[0].attrs["reduce_axis"]) and \
|
||||
dom.ops[0].inputs[0].get_size() > 10000:
|
||||
return True
|
||||
|
||||
# exclude multi output
|
||||
for a in dom.in_relations.keys():
|
||||
if len(a.out_relations) > 1:
|
||||
return True
|
||||
if any([op.output.para_type == Tensor.PARA_OUTPUT for op in a.ops]):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _reduce_pat_exclude(_, a, r):
|
||||
if len(a.ops) > self.REDUCE_FUSE_DEPTH:
|
||||
return True
|
||||
if use_poly_reduce:
|
||||
return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST
|
||||
return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE
|
||||
return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST
|
||||
|
||||
def _reduce_depth(dom):
|
||||
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
|
||||
return None
|
||||
if _check_reduce_exclude(dom):
|
||||
return None
|
||||
a, r = list(dom.in_relations.items())[0]
|
||||
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
|
||||
_is_atomic_add_available(dom):
|
||||
|
@ -681,8 +656,6 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
def _reduce_width(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
if _check_reduce_exclude(dom):
|
||||
return None
|
||||
fused = []
|
||||
for a, r in dom.in_relations.items():
|
||||
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
|
||||
|
@ -763,16 +736,16 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
|
||||
def _may_stitch(dom, a, r):
|
||||
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
|
||||
if _reduce_nums(a.ops) < 2:
|
||||
dom_outs = [op.output for op in dom.ops]
|
||||
a_ins = [op_input for op in a.ops for op_input in op.inputs]
|
||||
a_outs = [op.output for op in a.ops]
|
||||
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
|
||||
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
|
||||
if _same_stitch_axis(stitch_tensors, a_final_outs):
|
||||
for tensor in stitch_tensors:
|
||||
if _tensor_size(tensor) >= 1024 * 1024:
|
||||
return True
|
||||
if _reduce_nums(a.ops) >= 2:
|
||||
return False
|
||||
dom_outs = [op.output for op in dom.ops]
|
||||
a_ins = [op_input for op in a.ops for op_input in op.inputs]
|
||||
a_outs = [op.output for op in a.ops]
|
||||
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
|
||||
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
|
||||
if not _same_stitch_axis(stitch_tensors, a_final_outs):
|
||||
return False
|
||||
return any([_tensor_size(tensor) >= 1024 * 1024 for tensor in stitch_tensors])
|
||||
return False
|
||||
|
||||
def _reduce_stitch(dom):
|
||||
|
@ -785,14 +758,15 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
|
||||
fused = []
|
||||
for a, r in dom.out_relations.items():
|
||||
if _may_stitch(dom, a, r):
|
||||
if a.pattern == PrimLib.REDUCE:
|
||||
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
|
||||
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
|
||||
fused.append(a)
|
||||
elif a.pattern == PrimLib.BROADCAST:
|
||||
if not _may_stitch(dom, a, r):
|
||||
continue
|
||||
if a.pattern == PrimLib.REDUCE:
|
||||
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
|
||||
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
|
||||
fused.append(a)
|
||||
elif a.pattern == PrimLib.BROADCAST:
|
||||
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
|
||||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
def _transpose(dom):
|
||||
|
@ -825,10 +799,9 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
changed = self.fuse(_broadcast_depth) or changed
|
||||
changed = self.fuse(_broadcast_width) or changed
|
||||
changed = self.fuse(_strided_slice) or changed
|
||||
if use_poly_reduce:
|
||||
changed = self.fuse(_reduce_output) or changed
|
||||
if enable_stitch_fusion:
|
||||
changed = self.fuse(_reduce_stitch) or changed
|
||||
changed = self.fuse(_reduce_output) or changed
|
||||
if enable_stitch_fusion:
|
||||
changed = self.fuse(_reduce_stitch) or changed
|
||||
self.fuse(_transpose)
|
||||
|
||||
def _fuse_once(fuse_func):
|
||||
|
@ -836,9 +809,8 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
|
||||
fuse_func(_broadcast_width):
|
||||
return
|
||||
if use_poly_reduce:
|
||||
if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)):
|
||||
return
|
||||
if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)):
|
||||
return
|
||||
fuse_func(_transpose)
|
||||
return
|
||||
|
||||
|
|
|
@ -422,14 +422,13 @@ class Graph:
|
|||
for t in op.inputs:
|
||||
if t not in inputs and t.op not in self.ops:
|
||||
inputs.append(t)
|
||||
if op.output not in outputs:
|
||||
if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
|
||||
outputs.append(op.output)
|
||||
else:
|
||||
for d in op.output.to_ops:
|
||||
if d not in self.ops:
|
||||
outputs.append(op.output)
|
||||
break
|
||||
if op.output in outputs:
|
||||
continue
|
||||
if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
|
||||
outputs.append(op.output)
|
||||
continue
|
||||
if any([succ not in self.ops for succ in op.output.to_ops]):
|
||||
outputs.append(op.output)
|
||||
if self.inputs:
|
||||
inputs = self.inputs
|
||||
|
||||
|
|
|
@ -81,7 +81,6 @@ class GraphBuilder:
|
|||
"""Create a new Value"""
|
||||
if name in (None, ''):
|
||||
name = self._alloc_tensor_name()
|
||||
|
||||
v = Value(name, dtype, value)
|
||||
return v
|
||||
|
||||
|
@ -128,34 +127,14 @@ class CompositeGraph:
|
|||
|
||||
def load(self, desc):
|
||||
"""Load Graph from json"""
|
||||
def _attr_of(op, inputs, output):
|
||||
def _get_axis_while_none(input_shape, output_shape):
|
||||
red_axis = []
|
||||
if len(output_shape) == len(input_shape):
|
||||
for i, s in enumerate(output_shape):
|
||||
if s == 1 and input_shape[i] > 1:
|
||||
red_axis.append(i)
|
||||
else:
|
||||
red_axis = list(range(len(output_shape)))
|
||||
return red_axis
|
||||
|
||||
def _attr_of(op):
|
||||
if not op['attr']:
|
||||
return dict()
|
||||
attr = {}
|
||||
if op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
|
||||
for a in op['attr']:
|
||||
if a['name'] == 'axis':
|
||||
red_axis, dim_size = [], len(inputs[0].shape)
|
||||
if not a['value']:
|
||||
red_axis = _get_axis_while_none(inputs[0].shape, output.shape)
|
||||
else:
|
||||
if isinstance(a['value'], int):
|
||||
a['value'] = [a['value']]
|
||||
for i in a['value']:
|
||||
red_axis.append(i if i >= 0 else dim_size + i)
|
||||
attr['reduce_axis'] = red_axis
|
||||
if a['name'] == "reduce_output_fuse":
|
||||
attr['reduce_output_fuse'] = a['value']
|
||||
elif op['attr']:
|
||||
for a in op['attr']:
|
||||
for a in op['attr']:
|
||||
if a['name'] == 'axis' and op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
|
||||
attr['reduce_axis'] = a['value']
|
||||
else:
|
||||
attr[a['name']] = a['value']
|
||||
return attr
|
||||
|
||||
|
@ -171,7 +150,6 @@ class CompositeGraph:
|
|||
'shape'], out_desc['data_type'], out_desc['format']
|
||||
self.tensors[name] = builder.tensor(
|
||||
shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT)
|
||||
cur_fusion = None
|
||||
for op in desc['op_desc']:
|
||||
inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d]
|
||||
out_desc = op['output_desc']
|
||||
|
@ -182,21 +160,12 @@ class CompositeGraph:
|
|||
inputs[1].para_type = Tensor.PARA_OUTPUT
|
||||
output = inputs[2]
|
||||
self.tensors[name] = output
|
||||
else:
|
||||
output = self.tensors.get(name, None)
|
||||
if not output:
|
||||
output = builder.tensor(
|
||||
shape, dtype, data_format, name=name)
|
||||
self.tensors[name] = output
|
||||
builder.op(op['name'], output, inputs,
|
||||
attrs=_attr_of(op, inputs, output))
|
||||
if 'fusion' in op:
|
||||
if cur_fusion is None:
|
||||
cur_fusion = output
|
||||
else:
|
||||
cur_fusion.add_buddy(output)
|
||||
if op['fusion'].endswith('_end'):
|
||||
cur_fusion = None
|
||||
continue
|
||||
output = self.tensors.get(name, None)
|
||||
if not output:
|
||||
output = builder.tensor(shape, dtype, data_format, name=name)
|
||||
self.tensors[name] = output
|
||||
builder.op(op['name'], output, inputs, attrs=_attr_of(op))
|
||||
self.graph = builder.get()[0]
|
||||
self.desc = desc
|
||||
|
||||
|
@ -234,43 +203,40 @@ class CompositeGraph:
|
|||
inputs, outputs = subgraph.deduce_parameters()
|
||||
graph_ops = set(subgraph.ops)
|
||||
inplace_assign, inplace_assign_z = self._pre_dump(outputs)
|
||||
for key in self.desc:
|
||||
|
||||
def dump_output(t):
|
||||
if t.name in inplace_assign:
|
||||
z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name]
|
||||
return {'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign[t.name]}
|
||||
return {'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}
|
||||
|
||||
def dump_op_desc(d):
|
||||
if d['name'] == 'InplaceAssign':
|
||||
y = d['input_desc'][1][0]['tensor_name']
|
||||
if self.tensors[y].op in graph_ops:
|
||||
z, fake = (inplace_assign_z, False) if inplace_assign_z is not None else (self.tensors[y], True)
|
||||
inplace_desc = copy.deepcopy(d)
|
||||
inplace_desc['attr'] = {'name': 'fake_output', 'value': fake}
|
||||
z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0]
|
||||
z_desc['shape'] = z.shape
|
||||
z_desc['data_type'] = z.dtype
|
||||
z_desc['tensor_name'] = z.name
|
||||
out_desc['shape'] = z.shape
|
||||
out_desc['data_type'] = z.dtype
|
||||
return inplace_desc
|
||||
op = self.tensors[d['output_desc'][0]['tensor_name']].op
|
||||
if op in graph_ops or op in subgraph.recompute_ops:
|
||||
return d
|
||||
return None
|
||||
|
||||
for key in self.desc.keys():
|
||||
if key == 'input_desc':
|
||||
desc[key] = [
|
||||
[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs]
|
||||
desc[key] = [[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs]
|
||||
elif key == 'output_desc':
|
||||
out_desc = []
|
||||
for t in outputs:
|
||||
if t.name in inplace_assign:
|
||||
z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name]
|
||||
out_desc.append(
|
||||
{'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign[t.name]})
|
||||
else:
|
||||
out_desc.append(
|
||||
{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name})
|
||||
desc[key] = out_desc
|
||||
desc[key] = list(map(dump_output, outputs))
|
||||
elif key == 'op_desc':
|
||||
op_desc = []
|
||||
for d in self.desc[key]:
|
||||
if d['name'] == 'InplaceAssign':
|
||||
y = d['input_desc'][1][0]['tensor_name']
|
||||
if self.tensors[y].op in graph_ops:
|
||||
z, fake = (inplace_assign_z, False) if inplace_assign_z is not None else (
|
||||
self.tensors[y], True)
|
||||
inplace_desc = copy.deepcopy(d)
|
||||
inplace_desc['attr'] = {'name': 'fake_output', 'value': fake}
|
||||
z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0]
|
||||
z_desc['shape'] = z.shape
|
||||
z_desc['data_type'] = z.dtype
|
||||
z_desc['tensor_name'] = z.name
|
||||
out_desc['shape'] = z.shape
|
||||
out_desc['data_type'] = z.dtype
|
||||
op_desc.append(inplace_desc)
|
||||
else:
|
||||
op = self.tensors[d['output_desc'][0]['tensor_name']].op
|
||||
if op in graph_ops or op in subgraph.recompute_ops:
|
||||
op_desc.append(d)
|
||||
desc[key] = op_desc
|
||||
op_desc = map(dump_op_desc, self.desc[key])
|
||||
desc[key] = [d for d in op_desc if d is not None]
|
||||
elif key == 'op':
|
||||
desc[key] = subgraph.name
|
||||
else:
|
||||
|
|
|
@ -109,11 +109,12 @@ class _Elemwise(OpInfer):
|
|||
out_shape = [1] * dim_size
|
||||
for i in range(dim_size):
|
||||
for align_shape in align_shapes:
|
||||
if align_shape[i] > 1:
|
||||
if out_shape[i] == 1:
|
||||
out_shape[i] = align_shape[i]
|
||||
if out_shape[i] != align_shape[i]:
|
||||
raise GKException("shape broadcast failed!")
|
||||
if align_shape[i] == 1:
|
||||
continue
|
||||
if out_shape[i] == 1:
|
||||
out_shape[i] = align_shape[i]
|
||||
elif out_shape[i] != align_shape[i]:
|
||||
raise GKException("shape broadcast failed!")
|
||||
return out_shape
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -57,11 +57,11 @@ def _dump_split_info(flags, graph_json, graph_desc, subgraphs, graph_mode):
|
|||
return
|
||||
utils.create_dir(utils.GRAPH_KERNEL_DUMP_PATH)
|
||||
filename = os.path.join(utils.GRAPH_KERNEL_DUMP_PATH, "graph_kernel_split_mode.txt")
|
||||
with open(filename, "a+") as f:
|
||||
with os.fdopen(os.open(filename, os.O_WRONLY | os.O_CREAT), "a+") as f:
|
||||
f.write("********** main graph: {} **********\n".format(graph_desc.name))
|
||||
f.write("input json:\n{}\n".format(graph_json))
|
||||
f.write("graph desc:\n{}\n".format(str(graph_desc)))
|
||||
if len(subgraphs) > 1:
|
||||
if len(subgraphs) > 1 or subgraphs[0].stitch_info.has_stitch_op():
|
||||
for i, g in enumerate(subgraphs):
|
||||
f.write("-------- subgraph {}, mode: {} --------\n".format(i, graph_mode[i]))
|
||||
f.write("{}\n".format(str(g)))
|
||||
|
|
|
@ -26,3 +26,5 @@ def create_dir(pathname):
|
|||
os.mkdir(pathname)
|
||||
except OSError:
|
||||
pass
|
||||
finally:
|
||||
pass
|
||||
|
|
|
@ -69,6 +69,7 @@ bool AxisNormalizer::Process(const FuncGraphPtr &func_graph) const {
|
|||
}
|
||||
if (diff) {
|
||||
changed = true;
|
||||
std::sort(axis_vec.begin(), axis_vec.end());
|
||||
SetNodeAttrSafely(kAttrAxis, MakeValue(axis_vec), node);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,7 +63,7 @@ bool DoFuse(const FuncGraphPtr &func_graph) {
|
|||
if (cnode->size() != 4) {
|
||||
continue;
|
||||
}
|
||||
auto cast_node = cnode->input(3);
|
||||
auto cast_node = cnode->inputs().back(); // bias node
|
||||
if (!IsPrimitiveCNode(cast_node, prim::kPrimCast)) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -239,13 +239,7 @@ bool InsertPadUnpad(const FuncGraphPtr &func_graph) {
|
|||
if (!AnfAlgo::CheckPrimitiveType(n, prim::kPrimMatMul)) continue;
|
||||
auto mm_cnode = n->cast<CNodePtr>();
|
||||
vec pad_shape_a, pad_shape_b, tail_shape_a, tail_shape_b, tail_shape_unpad, unpad_shape;
|
||||
bool pad_K, pad_M, pad_N;
|
||||
pad_shape_a.clear();
|
||||
pad_shape_b.clear();
|
||||
tail_shape_a.clear();
|
||||
tail_shape_b.clear();
|
||||
tail_shape_unpad.clear();
|
||||
unpad_shape.clear();
|
||||
bool pad_K{false}, pad_M{false}, pad_N{false};
|
||||
std::tie(pad_K, pad_M, pad_N) =
|
||||
NeedPad(mm_cnode, &pad_shape_a, &pad_shape_b, &unpad_shape, &tail_shape_a, &tail_shape_b, &tail_shape_unpad);
|
||||
if (!pad_K && !pad_M && !pad_N) continue;
|
||||
|
|
|
@ -37,7 +37,7 @@ const BaseRef SplitAssign::DefinePattern() const {
|
|||
|
||||
bool CanSplit(const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimAssign); }
|
||||
|
||||
AnfNodePtr ProcessNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, int input_idx) {
|
||||
AnfNodePtr ProcessNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
@ -46,16 +46,14 @@ AnfNodePtr ProcessNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, i
|
|||
AbstractBasePtr original_abstract = cnode->abstract()->Clone();
|
||||
auto original_inputs = cnode->inputs();
|
||||
|
||||
int input_node_size = cnode->size() - 1;
|
||||
// Create depend node
|
||||
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[input_idx],
|
||||
original_inputs[input_node_size]};
|
||||
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[input_idx], original_inputs.back()};
|
||||
auto depend_cnode = func_graph->NewCNode(depend_inputs);
|
||||
depend_cnode->set_abstract(original_inputs[input_idx]->abstract());
|
||||
depend_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
// Create new node, delete U from inputs.
|
||||
AnfNodePtrList new_inputs = {cnode->input(0)};
|
||||
for (int i = 1; i < input_node_size; i++) {
|
||||
for (size_t i = 1; i + 1 < cnode->size(); i++) {
|
||||
if (i == input_idx) {
|
||||
new_inputs.push_back(depend_cnode);
|
||||
} else {
|
||||
|
@ -77,19 +75,11 @@ const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfN
|
|||
AnfNodePtr OpUMonadExpander::Run(const AnfNodePtr &node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
bool has_umonad = false;
|
||||
for (unsigned int i = 1; i < cnode->size(); i++) {
|
||||
if (HasAbstractUMonad(cnode->input(i))) {
|
||||
has_umonad = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (has_umonad) {
|
||||
// assume the UMonad node is the last input
|
||||
if (cnode->size() > 1 && HasAbstractUMonad(cnode->inputs().back())) {
|
||||
auto new_node = ProcessNode(node->func_graph(), node, input_idx_);
|
||||
return DefaultExpander::Run(new_node);
|
||||
}
|
||||
|
||||
return DefaultExpander::Run(node);
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -30,12 +30,12 @@ class SplitAssign : public PatternProcessPass {
|
|||
|
||||
class OpUMonadExpander : public DefaultExpander {
|
||||
public:
|
||||
explicit OpUMonadExpander(int input_idx) : input_idx_(input_idx) {}
|
||||
explicit OpUMonadExpander(size_t input_idx) : input_idx_(input_idx) {}
|
||||
~OpUMonadExpander() = default;
|
||||
AnfNodePtr Run(const AnfNodePtr &node) override;
|
||||
|
||||
private:
|
||||
int input_idx_;
|
||||
size_t input_idx_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue