fix maximum_grad and minimum_grad input_shape not equal to output_shape bug

This commit is contained in:
zengzitao 2021-06-16 10:54:46 +08:00
parent e38dc88d9c
commit 8064de7931
6 changed files with 150 additions and 48 deletions

View File

@ -113,20 +113,23 @@ class ExpanderInfoValidator:
@staticmethod
def check_all_formats_same(kls):
"""Check that all formats are the same"""
def _check_format(obj):
inp_formats = [inp['format'] for inp in obj.inputs]
if all([fmt == inp_formats[0] for fmt in inp_formats[1:]]):
return
raise GKException("[check_all_formats_same] unmatched formats ({}) for op {}".format(
','.join(inp_formats), obj.name))
# Ensure no args case can return a class
def _check(*args):
def _check_format(obj):
inp_formats = [inp['format'] for inp in obj.inputs]
if all([fmt == inp_formats[0] for fmt in inp_formats[1:]]):
return
raise GKException("[check_all_formats_same] unmatched formats ({}) for op {}".format(
','.join(inp_formats), obj.name))
def wrapper(*args, **kargs):
if not issubclass(kls, Expander):
raise Exception("{} should be subclass of Expander.".format(kls.__name__))
ExpanderInfoValidator._add_check_function(kls, _check_format)
return kls(*args, **kargs)
def wrapper(cls):
if not issubclass(cls, Expander):
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
ExpanderInfoValidator._add_check_function(cls, _check_format)
return cls
return wrapper
return wrapper
return _check()(kls)
@staticmethod
def check_attrs(*args):
@ -144,6 +147,7 @@ class ExpanderInfoValidator:
return wrapper
def to_frac_z_axis(ori_shape, ori_axis):
"""
judge the format is fractal NZ
@ -166,10 +170,10 @@ def to_frac_z_axis(ori_shape, ori_axis):
for i in range(axis_count):
axis_index = (frac_z_axis[i] + shape_len) % shape_len
if axis_index == axis_negative_1:
if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3]
if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3]
frac_z_axis[i] = axis_index - 1
frac_z_axis.append(axis_index + 2)
else: # no case cover this branch now
else: # no case cover this branch now
frac_z_axis[i] = axis_index - 1
frac_z_axis.append(axis_index + 2)
elif axis_index == axis_negative_2:
@ -179,6 +183,7 @@ def to_frac_z_axis(ori_shape, ori_axis):
frac_z_axis[i] = axis_index
return frac_z_axis
def infer_shape_from_fractalNz(fractal):
"get original shape from fractalNz shape"
shape = []
@ -192,6 +197,7 @@ def infer_shape_from_fractalNz(fractal):
shape.append(n)
return shape
def get_reduced_ori_shape(shape, axis):
"get shape after reduced which is based on original shape"
reduced_ori_shape = []

View File

@ -15,6 +15,7 @@
"""generate json desc for maximum_grad"""
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
from ._utils import Expander, ExpanderInfoValidator as VLD
from .minimum_grad import MinimumGrad
@VLD.check_all_formats_same
@ -28,11 +29,30 @@ class MaximumGrad(Expander):
def _expand(self, graph_builder):
input_x, input_y, input_dout = self.inputs
ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y])
ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
dx = graph_builder.emit('Mul', [ge_result, input_dout])
dy = graph_builder.emit('Sub', [input_dout, dx])
reduce_axis_x = MinimumGrad.get_reduce_axis(input_x.shape, dx.shape)
reduce_axis_y = MinimumGrad.get_reduce_axis(input_y.shape, dy.shape)
if reduce_axis_x:
dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False})
if dx_reduce.shape != input_x.shape:
dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape})
else:
dx_out = dx_reduce
else:
dx_out = dx
if reduce_axis_y:
dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False})
if dy_reduce.shape != input_y.shape:
dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape})
else:
dy_out = dy_reduce
else:
dy_out = dy
# output two results, regardless of grad_x and grad_y
return dx, dy
return dx_out, dy_out

View File

@ -34,5 +34,43 @@ class MinimumGrad(Expander):
dx = graph_builder.emit('Mul', [le_result, input_dout])
dy = graph_builder.emit('Sub', [input_dout, dx])
# for minimumgrad op, output_shape should be equal to input_shape, but some elementwise operating may broadcast input_shape
# then output_shape not equal to original input_shape, so need to reduce output to let them equal
reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape)
reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape)
if reduce_axis_x:
dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False})
if dx_reduce.shape != input_x.shape:
dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape})
else:
dx_out = dx_reduce
else:
dx_out = dx
if reduce_axis_y:
dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False})
if dy_reduce.shape != input_y.shape:
dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape})
else:
dy_out = dy_reduce
else:
dy_out = dy
# output two results, regardless of grad_x and grad_y
return dx, dy
return dx_out, dy_out
@staticmethod
def get_reduce_axis(original_shape, broadcast_shape):
"""compute reduce axis for final output_shape"""
if len(original_shape) > len(broadcast_shape):
raise ValueError("original_shape size need to less equal than broadcast_shape")
tmp_shape = [1] * (len(broadcast_shape) - len(original_shape)) + original_shape
reduce_axis = []
for idx, _ in enumerate(tmp_shape):
if tmp_shape[idx] != broadcast_shape[idx]:
if tmp_shape[idx] == 1:
reduce_axis.append(idx)
else:
raise ValueError("broadcast dismatch %s vs %s" % (tmp_shape[idx], broadcast_shape[idx]))
return reduce_axis

View File

@ -16,6 +16,7 @@
import copy
import sys
from functools import reduce
from .model import GraphKernelUnsupportedException as GKException
from .model import PrimLib, DataFormat as DF
@ -125,15 +126,15 @@ class _Elemwise(OpInfer):
if default_shape[-1] % 16 != 0:
raise GKException("should be multiplies of 16")
return shape
#(32, 1) -> (1, 2, 16, 1)
# (32, 1) -> (1, 2, 16, 1)
if len(default_shape) == 2 and default_shape[1] == 1:
shape = [1, default_shape[0] // 16, 16, 1]
if default_shape[0] % 16 != 0:
raise GKException("should be multiples of 16")
return shape
#(32, 48) -> (3, 2, 16, 16)
# (32, 48) -> (3, 2, 16, 16)
shape = [default_shape[1] // 16, default_shape[0] // 16, 16, 16]
if default_shape[0] % 16 != 0 or defautl_shape[1] % 16 != 0:
if default_shape[0] % 16 != 0 or default_shape[1] % 16 != 0:
raise GKException("should be multiples of 16")
return shape
@ -146,11 +147,11 @@ class _Elemwise(OpInfer):
return self._broadcast_shape([input.shape for input in self.inputs])
# in case formats are fractal_nz, default_fromat/NHWC/HCHW(optional)
is_default_frac_nz = [input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW, DF.FRAC_NZ) \
for input in self.inputs]
is_default_frac_nz = [input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW, DF.FRAC_NZ)
for input in self.inputs]
if all(is_default_frac_nz):
nz_shapes = [self._to_nz(input.shape) if input.data_format != DF.FRAC_NZ else input.shape \
for input in self.inputs]
nz_shapes = [self._to_nz(input.shape) if input.data_format != DF.FRAC_NZ else input.shape
for input in self.inputs]
return self._broadcast_shape(nz_shapes)
raise GKException("Only support default and fractal_nz")
@ -213,6 +214,12 @@ class _Reshape(OpInfer):
class Reshape(_Reshape):
def _check_shape(self):
size_before_reshape = reduce(lambda x, y: x * y, self.inputs[0].shape)
size_after_reshape = reduce(lambda x, y: x * y, self.attrs["shape"])
if size_before_reshape != size_after_reshape:
raise GKException("The shape product before and after reshaping should be equal")
def _infer_shape(self):
return self.attrs["shape"]
@ -305,6 +312,7 @@ def conv_had_pad(pad_list, pad_mode):
class Conv2D(OpInfer):
"""Conv2D infer"""
def _infer_type(self):
if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
return self.attrs["dst_type"]
@ -343,6 +351,7 @@ class Conv2D(OpInfer):
class MatMul(OpInfer):
"""MatMul infer"""
def _infer_type(self):
if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
return self.attrs["dst_type"]
@ -365,6 +374,7 @@ class MatMul(OpInfer):
class PadAkg(OpInfer):
"""PadAkg infer"""
def _infer_shape(self):
shape = list(self.inputs[0].shape)
n = len(shape)
@ -379,6 +389,7 @@ class PadAkg(OpInfer):
class UnPadAkg(OpInfer):
"""UnPadAkg infer"""
def _infer_shape(self):
shape = list(self.inputs[0].shape)
n = len(shape)

View File

@ -30,24 +30,38 @@ class MaxmumGradNet(Cell):
return self.maximum_grad(x, y, dy)
def test_maximum_grad():
def gen_data():
np.random.seed(0)
input_x = np.random.normal(0, 1, [2, 3]).astype(np.float32)
input_y = np.random.normal(0, 1, [2, 3]).astype(np.float32)
input_dout = np.maximum(input_x, input_y).astype(np.float32)
input_x_np = np.random.normal(0, 1, [2, 3]).astype(np.float32)
input_y_np = np.random.normal(0, 1, [1]).astype(np.float32)
input_dout_np = np.maximum(input_x_np, input_y_np).astype(np.float32)
input_x = Tensor(input_x_np)
input_y = Tensor(input_y_np)
input_dout = Tensor(input_dout_np)
return input_x, input_y, input_dout
def get_maximum_grad_output(input_x, input_y, input_dout, enable_graph_kernel=False):
context.set_context(enable_graph_kernel=enable_graph_kernel)
net = MaxmumGradNet()
result = net(Tensor(input_x), Tensor(input_y), Tensor(input_dout))
dx = input_dout * (input_x >= input_y)
dy = input_dout - dx
assert np.allclose(result[0].asnumpy(), dx, rtol=1.e-4, atol=1.e-8, equal_nan=True)
assert np.allclose(result[1].asnumpy(), dy, rtol=1.e-4, atol=1.e-8, equal_nan=True)
result = net(input_x, input_y, input_dout)
return result[0].asnumpy(), result[1].asnumpy()
def test_maximum_grad():
input_x, input_y, input_dout = gen_data()
result_off = get_maximum_grad_output(input_x, input_y, input_dout, False)
result_on = get_maximum_grad_output(input_x, input_y, input_dout, True)
assert np.allclose(result_on[0], result_off[0], rtol=1.e-4, atol=1.e-8, equal_nan=True)
assert np.allclose(result_on[1], result_off[1], rtol=1.e-4, atol=1.e-8, equal_nan=True)\
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_maximum_grad_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_maximum_grad()
@ -56,5 +70,5 @@ def test_maximum_grad_gpu():
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_maximum_grad_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_maximum_grad()

View File

@ -30,24 +30,37 @@ class MinmumGradNet(Cell):
return self.minimum_grad(x, y, dy)
def test_minimum_grad():
def gen_data():
np.random.seed(0)
input_x = np.random.normal(0, 1, [2, 3]).astype(np.float32)
input_y = np.random.normal(0, 1, [2, 3]).astype(np.float32)
input_dout = np.minimum(input_x, input_y).astype(np.float32)
input_x_np = np.random.normal(0, 1, [2, 3]).astype(np.float32)
input_y_np = np.random.normal(0, 1, [1]).astype(np.float32)
input_dout_np = np.minimum(input_x_np, input_y_np).astype(np.float32)
input_x = Tensor(input_x_np)
input_y = Tensor(input_y_np)
input_dout = Tensor(input_dout_np)
return input_x, input_y, input_dout
def get_minimum_grad_output(input_x, input_y, input_dout, enable_graph_kernel=False):
context.set_context(enable_graph_kernel=enable_graph_kernel)
net = MinmumGradNet()
result = net(Tensor(input_x), Tensor(input_y), Tensor(input_dout))
dx = input_dout * (input_x <= input_y)
dy = input_dout - dx
assert np.allclose(result[0].asnumpy(), dx, rtol=1.e-4, atol=1.e-8, equal_nan=True)
assert np.allclose(result[1].asnumpy(), dy, rtol=1.e-4, atol=1.e-8, equal_nan=True)
result = net(input_x, input_y, input_dout)
return result[0].asnumpy(), result[1].asnumpy()
def test_minimum_grad():
input_x, input_y, input_dout = gen_data()
result_off = get_minimum_grad_output(input_x, input_y, input_dout, False)
result_on = get_minimum_grad_output(input_x, input_y, input_dout, True)
assert np.allclose(result_on[0], result_off[0], rtol=1.e-4, atol=1.e-8, equal_nan=True)
assert np.allclose(result_on[1], result_off[1], rtol=1.e-4, atol=1.e-8, equal_nan=True)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
def test_minimum_grad_gpu():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_minimum_grad()
@ -55,6 +68,6 @@ def test_basic_gpu():
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_basic_ascend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
def test_minimum_grad_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_minimum_grad()