!10009 Add infer shape function for tile in graph kernel

From: @looop5
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
This commit is contained in:
mindspore-ci-bot 2020-12-22 09:50:32 +08:00 committed by Gitee
commit c6f3b1826c
3 changed files with 48 additions and 41 deletions

View File

@ -16,45 +16,6 @@
from mindspore._extends.graph_kernel.model import model_builder as builder
def _get_tile_output_shape(shape, multiples):
"""compute output shape of tile"""
if multiples is None:
return shape
if not isinstance(shape, (list, tuple)):
raise TypeError("Input shape of Tile must be of type list or tuple")
if not isinstance(multiples, (list, tuple)):
raise TypeError("multiples of Tile must be of type list or tuple")
shape = list(shape)
multiples = list(multiples)
diff_len = len(multiples) - len(shape)
if diff_len < 0:
raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
if diff_len > 0:
for _ in range(diff_len):
shape.insert(0, 1)
shape_compatible = True
output_shape = []
input_reshape = []
output_reshape = []
for sh, mul in list(zip(shape, multiples)):
dim = sh * mul
output_shape.append(dim)
if sh == 1 or mul == 1:
input_reshape.append(sh)
output_reshape.append(dim)
else:
shape_compatible = False
input_reshape.append(1)
input_reshape.append(sh)
output_reshape.append(mul)
output_reshape.append(sh)
return output_shape, input_reshape, output_reshape, shape_compatible
def expand_tile(expand_info):
"""Tile expander"""
@ -65,7 +26,7 @@ def expand_tile(expand_info):
for item in attrs:
if 'multiples' in item:
multiples = item['multiples']
output_shape, _, _, shape_compatible = _get_tile_output_shape(input_desc['shape'], multiples)
output_shape, _, _, shape_compatible = builder.get_tile_output_shape(input_desc['shape'], multiples)
graph_builder = builder.GraphBuilder()
# generate a graph.

View File

@ -18,6 +18,45 @@ import copy
from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy
def get_tile_output_shape(shape, multiples):
"""compute output shape of tile"""
if multiples is None:
return shape
if not isinstance(shape, (list, tuple)):
raise TypeError("Input shape of Tile must be of type list or tuple")
if not isinstance(multiples, (list, tuple)):
raise TypeError("multiples of Tile must be of type list or tuple")
shape = list(shape)
multiples = list(multiples)
diff_len = len(multiples) - len(shape)
if diff_len < 0:
raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
if diff_len > 0:
for _ in range(diff_len):
shape.insert(0, 1)
shape_compatible = True
output_shape = []
input_reshape = []
output_reshape = []
for sh, mul in list(zip(shape, multiples)):
dim = sh * mul
output_shape.append(dim)
if sh == 1 or mul == 1:
input_reshape.append(sh)
output_reshape.append(dim)
else:
shape_compatible = False
input_reshape.append(1)
input_reshape.append(sh)
output_reshape.append(mul)
output_reshape.append(sh)
return output_shape, input_reshape, output_reshape, shape_compatible
class OpInfer:
"""Op infer"""
@staticmethod
@ -74,6 +113,7 @@ class OpInfer:
'InplaceAssign': lambda inputs, attrs: inputs[2].shape,
'Reshape': lambda inputs, attrs: attrs["shape"],
'BroadcastTo': lambda inputs, attrs: attrs["shape"],
'Tile': lambda inputs, attrs: get_tile_output_shape(inputs[0].shape, attrs["multiples"])[0],
}
infer_dtype_func = {
# add special infer func here

View File

@ -47,7 +47,13 @@ def test_sqrt_grad(shape_x, shape_dout, dtype):
expect_np = expect.asnumpy().copy()
output_np = output.asnumpy().copy()
assert np.allclose(expect_np, output_np, 0.0001, 0.0001)
rtol = 0.0001
atol = 0.0001
if dtype == np.float16:
rtol = 0.001
atol = 0.001
assert np.allclose(expect_np, output_np, rtol, atol)
@pytest.mark.level0