forked from mindspore-Ecosystem/mindspore
fix maximum_grad and minimum_grad input_shape not equal to output_shape bug
This commit is contained in:
parent
e38dc88d9c
commit
8064de7931
|
@ -113,20 +113,23 @@ class ExpanderInfoValidator:
|
|||
@staticmethod
|
||||
def check_all_formats_same(kls):
|
||||
"""Check that all formats are the same"""
|
||||
def _check_format(obj):
|
||||
inp_formats = [inp['format'] for inp in obj.inputs]
|
||||
if all([fmt == inp_formats[0] for fmt in inp_formats[1:]]):
|
||||
return
|
||||
raise GKException("[check_all_formats_same] unmatched formats ({}) for op {}".format(
|
||||
','.join(inp_formats), obj.name))
|
||||
# Ensure no args case can return a class
|
||||
def _check(*args):
|
||||
def _check_format(obj):
|
||||
inp_formats = [inp['format'] for inp in obj.inputs]
|
||||
if all([fmt == inp_formats[0] for fmt in inp_formats[1:]]):
|
||||
return
|
||||
raise GKException("[check_all_formats_same] unmatched formats ({}) for op {}".format(
|
||||
','.join(inp_formats), obj.name))
|
||||
|
||||
def wrapper(*args, **kargs):
|
||||
if not issubclass(kls, Expander):
|
||||
raise Exception("{} should be subclass of Expander.".format(kls.__name__))
|
||||
ExpanderInfoValidator._add_check_function(kls, _check_format)
|
||||
return kls(*args, **kargs)
|
||||
def wrapper(cls):
|
||||
if not issubclass(cls, Expander):
|
||||
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
|
||||
ExpanderInfoValidator._add_check_function(cls, _check_format)
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
return wrapper
|
||||
return _check()(kls)
|
||||
|
||||
@staticmethod
|
||||
def check_attrs(*args):
|
||||
|
@ -144,6 +147,7 @@ class ExpanderInfoValidator:
|
|||
|
||||
return wrapper
|
||||
|
||||
|
||||
def to_frac_z_axis(ori_shape, ori_axis):
|
||||
"""
|
||||
judge the format is fractal NZ
|
||||
|
@ -166,10 +170,10 @@ def to_frac_z_axis(ori_shape, ori_axis):
|
|||
for i in range(axis_count):
|
||||
axis_index = (frac_z_axis[i] + shape_len) % shape_len
|
||||
if axis_index == axis_negative_1:
|
||||
if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3]
|
||||
if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3]
|
||||
frac_z_axis[i] = axis_index - 1
|
||||
frac_z_axis.append(axis_index + 2)
|
||||
else: # no case cover this branch now
|
||||
else: # no case cover this branch now
|
||||
frac_z_axis[i] = axis_index - 1
|
||||
frac_z_axis.append(axis_index + 2)
|
||||
elif axis_index == axis_negative_2:
|
||||
|
@ -179,6 +183,7 @@ def to_frac_z_axis(ori_shape, ori_axis):
|
|||
frac_z_axis[i] = axis_index
|
||||
return frac_z_axis
|
||||
|
||||
|
||||
def infer_shape_from_fractalNz(fractal):
|
||||
"get original shape from fractalNz shape"
|
||||
shape = []
|
||||
|
@ -192,6 +197,7 @@ def infer_shape_from_fractalNz(fractal):
|
|||
shape.append(n)
|
||||
return shape
|
||||
|
||||
|
||||
def get_reduced_ori_shape(shape, axis):
|
||||
"get shape after reduced which is based on original shape"
|
||||
reduced_ori_shape = []
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""generate json desc for maximum_grad"""
|
||||
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
from .minimum_grad import MinimumGrad
|
||||
|
||||
|
||||
@VLD.check_all_formats_same
|
||||
|
@ -28,11 +29,30 @@ class MaximumGrad(Expander):
|
|||
|
||||
def _expand(self, graph_builder):
|
||||
input_x, input_y, input_dout = self.inputs
|
||||
|
||||
ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y])
|
||||
ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
|
||||
dx = graph_builder.emit('Mul', [ge_result, input_dout])
|
||||
dy = graph_builder.emit('Sub', [input_dout, dx])
|
||||
|
||||
reduce_axis_x = MinimumGrad.get_reduce_axis(input_x.shape, dx.shape)
|
||||
reduce_axis_y = MinimumGrad.get_reduce_axis(input_y.shape, dy.shape)
|
||||
if reduce_axis_x:
|
||||
dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False})
|
||||
if dx_reduce.shape != input_x.shape:
|
||||
dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape})
|
||||
else:
|
||||
dx_out = dx_reduce
|
||||
else:
|
||||
dx_out = dx
|
||||
|
||||
if reduce_axis_y:
|
||||
dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False})
|
||||
if dy_reduce.shape != input_y.shape:
|
||||
dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape})
|
||||
else:
|
||||
dy_out = dy_reduce
|
||||
else:
|
||||
dy_out = dy
|
||||
|
||||
# output two results, regardless of grad_x and grad_y
|
||||
return dx, dy
|
||||
return dx_out, dy_out
|
||||
|
|
|
@ -34,5 +34,43 @@ class MinimumGrad(Expander):
|
|||
dx = graph_builder.emit('Mul', [le_result, input_dout])
|
||||
dy = graph_builder.emit('Sub', [input_dout, dx])
|
||||
|
||||
# for minimumgrad op, output_shape should be equal to input_shape, but some elementwise operating may broadcast input_shape
|
||||
# then output_shape not equal to original input_shape, so need to reduce output to let them equal
|
||||
reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape)
|
||||
reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape)
|
||||
if reduce_axis_x:
|
||||
dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False})
|
||||
if dx_reduce.shape != input_x.shape:
|
||||
dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape})
|
||||
else:
|
||||
dx_out = dx_reduce
|
||||
else:
|
||||
dx_out = dx
|
||||
|
||||
if reduce_axis_y:
|
||||
dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False})
|
||||
if dy_reduce.shape != input_y.shape:
|
||||
dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape})
|
||||
else:
|
||||
dy_out = dy_reduce
|
||||
else:
|
||||
dy_out = dy
|
||||
|
||||
# output two results, regardless of grad_x and grad_y
|
||||
return dx, dy
|
||||
return dx_out, dy_out
|
||||
|
||||
@staticmethod
|
||||
def get_reduce_axis(original_shape, broadcast_shape):
|
||||
"""compute reduce axis for final output_shape"""
|
||||
if len(original_shape) > len(broadcast_shape):
|
||||
raise ValueError("original_shape size need to less equal than broadcast_shape")
|
||||
|
||||
tmp_shape = [1] * (len(broadcast_shape) - len(original_shape)) + original_shape
|
||||
reduce_axis = []
|
||||
for idx, _ in enumerate(tmp_shape):
|
||||
if tmp_shape[idx] != broadcast_shape[idx]:
|
||||
if tmp_shape[idx] == 1:
|
||||
reduce_axis.append(idx)
|
||||
else:
|
||||
raise ValueError("broadcast dismatch %s vs %s" % (tmp_shape[idx], broadcast_shape[idx]))
|
||||
return reduce_axis
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
import copy
|
||||
import sys
|
||||
from functools import reduce
|
||||
from .model import GraphKernelUnsupportedException as GKException
|
||||
from .model import PrimLib, DataFormat as DF
|
||||
|
||||
|
@ -125,15 +126,15 @@ class _Elemwise(OpInfer):
|
|||
if default_shape[-1] % 16 != 0:
|
||||
raise GKException("should be multiplies of 16")
|
||||
return shape
|
||||
#(32, 1) -> (1, 2, 16, 1)
|
||||
# (32, 1) -> (1, 2, 16, 1)
|
||||
if len(default_shape) == 2 and default_shape[1] == 1:
|
||||
shape = [1, default_shape[0] // 16, 16, 1]
|
||||
if default_shape[0] % 16 != 0:
|
||||
raise GKException("should be multiples of 16")
|
||||
return shape
|
||||
#(32, 48) -> (3, 2, 16, 16)
|
||||
# (32, 48) -> (3, 2, 16, 16)
|
||||
shape = [default_shape[1] // 16, default_shape[0] // 16, 16, 16]
|
||||
if default_shape[0] % 16 != 0 or defautl_shape[1] % 16 != 0:
|
||||
if default_shape[0] % 16 != 0 or default_shape[1] % 16 != 0:
|
||||
raise GKException("should be multiples of 16")
|
||||
return shape
|
||||
|
||||
|
@ -146,11 +147,11 @@ class _Elemwise(OpInfer):
|
|||
return self._broadcast_shape([input.shape for input in self.inputs])
|
||||
|
||||
# in case formats are fractal_nz, default_fromat/NHWC/HCHW(optional)
|
||||
is_default_frac_nz = [input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW, DF.FRAC_NZ) \
|
||||
for input in self.inputs]
|
||||
is_default_frac_nz = [input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW, DF.FRAC_NZ)
|
||||
for input in self.inputs]
|
||||
if all(is_default_frac_nz):
|
||||
nz_shapes = [self._to_nz(input.shape) if input.data_format != DF.FRAC_NZ else input.shape \
|
||||
for input in self.inputs]
|
||||
nz_shapes = [self._to_nz(input.shape) if input.data_format != DF.FRAC_NZ else input.shape
|
||||
for input in self.inputs]
|
||||
return self._broadcast_shape(nz_shapes)
|
||||
|
||||
raise GKException("Only support default and fractal_nz")
|
||||
|
@ -213,6 +214,12 @@ class _Reshape(OpInfer):
|
|||
|
||||
|
||||
class Reshape(_Reshape):
|
||||
def _check_shape(self):
|
||||
size_before_reshape = reduce(lambda x, y: x * y, self.inputs[0].shape)
|
||||
size_after_reshape = reduce(lambda x, y: x * y, self.attrs["shape"])
|
||||
if size_before_reshape != size_after_reshape:
|
||||
raise GKException("The shape product before and after reshaping should be equal")
|
||||
|
||||
def _infer_shape(self):
|
||||
return self.attrs["shape"]
|
||||
|
||||
|
@ -305,6 +312,7 @@ def conv_had_pad(pad_list, pad_mode):
|
|||
|
||||
class Conv2D(OpInfer):
|
||||
"""Conv2D infer"""
|
||||
|
||||
def _infer_type(self):
|
||||
if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
|
||||
return self.attrs["dst_type"]
|
||||
|
@ -343,6 +351,7 @@ class Conv2D(OpInfer):
|
|||
|
||||
class MatMul(OpInfer):
|
||||
"""MatMul infer"""
|
||||
|
||||
def _infer_type(self):
|
||||
if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
|
||||
return self.attrs["dst_type"]
|
||||
|
@ -365,6 +374,7 @@ class MatMul(OpInfer):
|
|||
|
||||
class PadAkg(OpInfer):
|
||||
"""PadAkg infer"""
|
||||
|
||||
def _infer_shape(self):
|
||||
shape = list(self.inputs[0].shape)
|
||||
n = len(shape)
|
||||
|
@ -379,6 +389,7 @@ class PadAkg(OpInfer):
|
|||
|
||||
class UnPadAkg(OpInfer):
|
||||
"""UnPadAkg infer"""
|
||||
|
||||
def _infer_shape(self):
|
||||
shape = list(self.inputs[0].shape)
|
||||
n = len(shape)
|
||||
|
|
|
@ -30,24 +30,38 @@ class MaxmumGradNet(Cell):
|
|||
return self.maximum_grad(x, y, dy)
|
||||
|
||||
|
||||
def test_maximum_grad():
|
||||
def gen_data():
|
||||
np.random.seed(0)
|
||||
input_x = np.random.normal(0, 1, [2, 3]).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, [2, 3]).astype(np.float32)
|
||||
input_dout = np.maximum(input_x, input_y).astype(np.float32)
|
||||
input_x_np = np.random.normal(0, 1, [2, 3]).astype(np.float32)
|
||||
input_y_np = np.random.normal(0, 1, [1]).astype(np.float32)
|
||||
input_dout_np = np.maximum(input_x_np, input_y_np).astype(np.float32)
|
||||
input_x = Tensor(input_x_np)
|
||||
input_y = Tensor(input_y_np)
|
||||
input_dout = Tensor(input_dout_np)
|
||||
return input_x, input_y, input_dout
|
||||
|
||||
|
||||
def get_maximum_grad_output(input_x, input_y, input_dout, enable_graph_kernel=False):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
net = MaxmumGradNet()
|
||||
result = net(Tensor(input_x), Tensor(input_y), Tensor(input_dout))
|
||||
dx = input_dout * (input_x >= input_y)
|
||||
dy = input_dout - dx
|
||||
assert np.allclose(result[0].asnumpy(), dx, rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
||||
assert np.allclose(result[1].asnumpy(), dy, rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
||||
result = net(input_x, input_y, input_dout)
|
||||
return result[0].asnumpy(), result[1].asnumpy()
|
||||
|
||||
|
||||
def test_maximum_grad():
|
||||
input_x, input_y, input_dout = gen_data()
|
||||
result_off = get_maximum_grad_output(input_x, input_y, input_dout, False)
|
||||
result_on = get_maximum_grad_output(input_x, input_y, input_dout, True)
|
||||
assert np.allclose(result_on[0], result_off[0], rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
||||
assert np.allclose(result_on[1], result_off[1], rtol=1.e-4, atol=1.e-8, equal_nan=True)\
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_maximum_grad_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_maximum_grad()
|
||||
|
||||
|
||||
|
@ -56,5 +70,5 @@ def test_maximum_grad_gpu():
|
|||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_maximum_grad_ascend():
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_maximum_grad()
|
||||
|
|
|
@ -30,24 +30,37 @@ class MinmumGradNet(Cell):
|
|||
return self.minimum_grad(x, y, dy)
|
||||
|
||||
|
||||
def test_minimum_grad():
|
||||
def gen_data():
|
||||
np.random.seed(0)
|
||||
input_x = np.random.normal(0, 1, [2, 3]).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, [2, 3]).astype(np.float32)
|
||||
input_dout = np.minimum(input_x, input_y).astype(np.float32)
|
||||
input_x_np = np.random.normal(0, 1, [2, 3]).astype(np.float32)
|
||||
input_y_np = np.random.normal(0, 1, [1]).astype(np.float32)
|
||||
input_dout_np = np.minimum(input_x_np, input_y_np).astype(np.float32)
|
||||
input_x = Tensor(input_x_np)
|
||||
input_y = Tensor(input_y_np)
|
||||
input_dout = Tensor(input_dout_np)
|
||||
return input_x, input_y, input_dout
|
||||
|
||||
|
||||
def get_minimum_grad_output(input_x, input_y, input_dout, enable_graph_kernel=False):
|
||||
context.set_context(enable_graph_kernel=enable_graph_kernel)
|
||||
net = MinmumGradNet()
|
||||
result = net(Tensor(input_x), Tensor(input_y), Tensor(input_dout))
|
||||
dx = input_dout * (input_x <= input_y)
|
||||
dy = input_dout - dx
|
||||
assert np.allclose(result[0].asnumpy(), dx, rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
||||
assert np.allclose(result[1].asnumpy(), dy, rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
||||
result = net(input_x, input_y, input_dout)
|
||||
return result[0].asnumpy(), result[1].asnumpy()
|
||||
|
||||
|
||||
def test_minimum_grad():
|
||||
input_x, input_y, input_dout = gen_data()
|
||||
result_off = get_minimum_grad_output(input_x, input_y, input_dout, False)
|
||||
result_on = get_minimum_grad_output(input_x, input_y, input_dout, True)
|
||||
assert np.allclose(result_on[0], result_off[0], rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
||||
assert np.allclose(result_on[1], result_off[1], rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_basic_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
||||
def test_minimum_grad_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_minimum_grad()
|
||||
|
||||
|
||||
|
@ -55,6 +68,6 @@ def test_basic_gpu():
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_basic_ascend():
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
|
||||
def test_minimum_grad_ascend():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_minimum_grad()
|
||||
|
|
Loading…
Reference in New Issue