diff --git a/mindspore/_extends/graph_kernel/expanders/tile.py b/mindspore/_extends/graph_kernel/expanders/tile.py index 7f1ae7c6ca8..258f25246d2 100644 --- a/mindspore/_extends/graph_kernel/expanders/tile.py +++ b/mindspore/_extends/graph_kernel/expanders/tile.py @@ -16,45 +16,6 @@ from mindspore._extends.graph_kernel.model import model_builder as builder -def _get_tile_output_shape(shape, multiples): - """compute output shape of tile""" - - if multiples is None: - return shape - if not isinstance(shape, (list, tuple)): - raise TypeError("Input shape of Tile must be of type list or tuple") - if not isinstance(multiples, (list, tuple)): - raise TypeError("multiples of Tile must be of type list or tuple") - - shape = list(shape) - multiples = list(multiples) - diff_len = len(multiples) - len(shape) - if diff_len < 0: - raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape)) - if diff_len > 0: - for _ in range(diff_len): - shape.insert(0, 1) - - shape_compatible = True - output_shape = [] - input_reshape = [] - output_reshape = [] - for sh, mul in list(zip(shape, multiples)): - dim = sh * mul - output_shape.append(dim) - if sh == 1 or mul == 1: - input_reshape.append(sh) - output_reshape.append(dim) - else: - shape_compatible = False - input_reshape.append(1) - input_reshape.append(sh) - output_reshape.append(mul) - output_reshape.append(sh) - - return output_shape, input_reshape, output_reshape, shape_compatible - - def expand_tile(expand_info): """Tile expander""" @@ -65,7 +26,7 @@ def expand_tile(expand_info): for item in attrs: if 'multiples' in item: multiples = item['multiples'] - output_shape, _, _, shape_compatible = _get_tile_output_shape(input_desc['shape'], multiples) + output_shape, _, _, shape_compatible = builder.get_tile_output_shape(input_desc['shape'], multiples) graph_builder = builder.GraphBuilder() # generate a graph. diff --git a/mindspore/_extends/graph_kernel/model/model_builder.py b/mindspore/_extends/graph_kernel/model/model_builder.py index ea6d3564c6c..2dd2a93fe52 100644 --- a/mindspore/_extends/graph_kernel/model/model_builder.py +++ b/mindspore/_extends/graph_kernel/model/model_builder.py @@ -18,6 +18,45 @@ import copy from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy +def get_tile_output_shape(shape, multiples): + """compute output shape of tile""" + + if multiples is None: + return shape + if not isinstance(shape, (list, tuple)): + raise TypeError("Input shape of Tile must be of type list or tuple") + if not isinstance(multiples, (list, tuple)): + raise TypeError("multiples of Tile must be of type list or tuple") + + shape = list(shape) + multiples = list(multiples) + diff_len = len(multiples) - len(shape) + if diff_len < 0: + raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape)) + if diff_len > 0: + for _ in range(diff_len): + shape.insert(0, 1) + + shape_compatible = True + output_shape = [] + input_reshape = [] + output_reshape = [] + for sh, mul in list(zip(shape, multiples)): + dim = sh * mul + output_shape.append(dim) + if sh == 1 or mul == 1: + input_reshape.append(sh) + output_reshape.append(dim) + else: + shape_compatible = False + input_reshape.append(1) + input_reshape.append(sh) + output_reshape.append(mul) + output_reshape.append(sh) + + return output_shape, input_reshape, output_reshape, shape_compatible + + class OpInfer: """Op infer""" @staticmethod @@ -74,6 +113,7 @@ class OpInfer: 'InplaceAssign': lambda inputs, attrs: inputs[2].shape, 'Reshape': lambda inputs, attrs: attrs["shape"], 'BroadcastTo': lambda inputs, attrs: attrs["shape"], + 'Tile': lambda inputs, attrs: get_tile_output_shape(inputs[0].shape, attrs["multiples"])[0], } infer_dtype_func = { # add special infer func here diff --git a/tests/st/ops/graph_kernel/test_sqrt_grad.py b/tests/st/ops/graph_kernel/test_sqrt_grad.py index a200ae7d974..65342920fd3 100644 --- a/tests/st/ops/graph_kernel/test_sqrt_grad.py +++ b/tests/st/ops/graph_kernel/test_sqrt_grad.py @@ -47,7 +47,13 @@ def test_sqrt_grad(shape_x, shape_dout, dtype): expect_np = expect.asnumpy().copy() output_np = output.asnumpy().copy() - assert np.allclose(expect_np, output_np, 0.0001, 0.0001) + rtol = 0.0001 + atol = 0.0001 + if dtype == np.float16: + rtol = 0.001 + atol = 0.001 + + assert np.allclose(expect_np, output_np, rtol, atol) @pytest.mark.level0