forked from mindspore-Ecosystem/mindspore
!10009 Add infer shape function for tile in graph kernel
From: @looop5 Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_dou
This commit is contained in:
commit
c6f3b1826c
|
@ -16,45 +16,6 @@
|
|||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
|
||||
|
||||
def _get_tile_output_shape(shape, multiples):
|
||||
"""compute output shape of tile"""
|
||||
|
||||
if multiples is None:
|
||||
return shape
|
||||
if not isinstance(shape, (list, tuple)):
|
||||
raise TypeError("Input shape of Tile must be of type list or tuple")
|
||||
if not isinstance(multiples, (list, tuple)):
|
||||
raise TypeError("multiples of Tile must be of type list or tuple")
|
||||
|
||||
shape = list(shape)
|
||||
multiples = list(multiples)
|
||||
diff_len = len(multiples) - len(shape)
|
||||
if diff_len < 0:
|
||||
raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
|
||||
if diff_len > 0:
|
||||
for _ in range(diff_len):
|
||||
shape.insert(0, 1)
|
||||
|
||||
shape_compatible = True
|
||||
output_shape = []
|
||||
input_reshape = []
|
||||
output_reshape = []
|
||||
for sh, mul in list(zip(shape, multiples)):
|
||||
dim = sh * mul
|
||||
output_shape.append(dim)
|
||||
if sh == 1 or mul == 1:
|
||||
input_reshape.append(sh)
|
||||
output_reshape.append(dim)
|
||||
else:
|
||||
shape_compatible = False
|
||||
input_reshape.append(1)
|
||||
input_reshape.append(sh)
|
||||
output_reshape.append(mul)
|
||||
output_reshape.append(sh)
|
||||
|
||||
return output_shape, input_reshape, output_reshape, shape_compatible
|
||||
|
||||
|
||||
def expand_tile(expand_info):
|
||||
"""Tile expander"""
|
||||
|
||||
|
@ -65,7 +26,7 @@ def expand_tile(expand_info):
|
|||
for item in attrs:
|
||||
if 'multiples' in item:
|
||||
multiples = item['multiples']
|
||||
output_shape, _, _, shape_compatible = _get_tile_output_shape(input_desc['shape'], multiples)
|
||||
output_shape, _, _, shape_compatible = builder.get_tile_output_shape(input_desc['shape'], multiples)
|
||||
graph_builder = builder.GraphBuilder()
|
||||
|
||||
# generate a graph.
|
||||
|
|
|
@ -18,6 +18,45 @@ import copy
|
|||
from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy
|
||||
|
||||
|
||||
def get_tile_output_shape(shape, multiples):
|
||||
"""compute output shape of tile"""
|
||||
|
||||
if multiples is None:
|
||||
return shape
|
||||
if not isinstance(shape, (list, tuple)):
|
||||
raise TypeError("Input shape of Tile must be of type list or tuple")
|
||||
if not isinstance(multiples, (list, tuple)):
|
||||
raise TypeError("multiples of Tile must be of type list or tuple")
|
||||
|
||||
shape = list(shape)
|
||||
multiples = list(multiples)
|
||||
diff_len = len(multiples) - len(shape)
|
||||
if diff_len < 0:
|
||||
raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
|
||||
if diff_len > 0:
|
||||
for _ in range(diff_len):
|
||||
shape.insert(0, 1)
|
||||
|
||||
shape_compatible = True
|
||||
output_shape = []
|
||||
input_reshape = []
|
||||
output_reshape = []
|
||||
for sh, mul in list(zip(shape, multiples)):
|
||||
dim = sh * mul
|
||||
output_shape.append(dim)
|
||||
if sh == 1 or mul == 1:
|
||||
input_reshape.append(sh)
|
||||
output_reshape.append(dim)
|
||||
else:
|
||||
shape_compatible = False
|
||||
input_reshape.append(1)
|
||||
input_reshape.append(sh)
|
||||
output_reshape.append(mul)
|
||||
output_reshape.append(sh)
|
||||
|
||||
return output_shape, input_reshape, output_reshape, shape_compatible
|
||||
|
||||
|
||||
class OpInfer:
|
||||
"""Op infer"""
|
||||
@staticmethod
|
||||
|
@ -74,6 +113,7 @@ class OpInfer:
|
|||
'InplaceAssign': lambda inputs, attrs: inputs[2].shape,
|
||||
'Reshape': lambda inputs, attrs: attrs["shape"],
|
||||
'BroadcastTo': lambda inputs, attrs: attrs["shape"],
|
||||
'Tile': lambda inputs, attrs: get_tile_output_shape(inputs[0].shape, attrs["multiples"])[0],
|
||||
}
|
||||
infer_dtype_func = {
|
||||
# add special infer func here
|
||||
|
|
|
@ -47,7 +47,13 @@ def test_sqrt_grad(shape_x, shape_dout, dtype):
|
|||
expect_np = expect.asnumpy().copy()
|
||||
output_np = output.asnumpy().copy()
|
||||
|
||||
assert np.allclose(expect_np, output_np, 0.0001, 0.0001)
|
||||
rtol = 0.0001
|
||||
atol = 0.0001
|
||||
if dtype == np.float16:
|
||||
rtol = 0.001
|
||||
atol = 0.001
|
||||
|
||||
assert np.allclose(expect_np, output_np, rtol, atol)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
Loading…
Reference in New Issue