!47618 remove grad constexpr

Merge pull request !47618 from hujiahui8/constexpr_1
This commit is contained in:
i-robot 2023-01-13 03:03:39 +00:00 committed by Gitee
commit 3b73258f5a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 234 additions and 37 deletions

View File

@ -400,12 +400,6 @@ def get_bprop_concat(self):
return bprop return bprop
@constexpr
def _slice_grad_pad(begins, sizes, shapes):
pads = tuple((begin, shape - begin - size) for begin, size, shape in zip(begins, sizes, shapes))
return pads
@bprop_getters.register(P.Slice) @bprop_getters.register(P.Slice)
def get_bprop_slice(self): def get_bprop_slice(self):
"""Generate bprop for Slice""" """Generate bprop for Slice"""

View File

@ -15,7 +15,6 @@
"""Define the grad rules of math related operations.""" """Define the grad rules of math related operations."""
from functools import reduce
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
from mindspore import nn from mindspore import nn
@ -864,9 +863,19 @@ def _split_shape_index(input_shape, axis):
if isinstance(axis, int): if isinstance(axis, int):
axis = tuple([axis]) axis = tuple([axis])
reduction_indices = tuple([(i + rank) % rank for i in axis]) reduction_indices = tuple([(i + rank) % rank for i in axis])
other_indices = tuple(set(range(rank)) - set(reduction_indices)) other_indices_list = []
reduced_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in reduction_indices]) for i in range(rank):
other_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in other_indices]) if i not in reduction_indices and i not in other_indices_list:
other_indices_list.append(i)
other_indices = tuple(other_indices_list)
reduced_list = [1] + [input_shape[i] for i in reduction_indices]
other_list = [1] + [input_shape[i] for i in other_indices]
reduced_num = 1
for i in reduced_list:
reduced_num = reduced_num * i
other_num = 1
for i in other_list:
other_num = other_num * i
perm = reduction_indices + other_indices perm = reduction_indices + other_indices
return tuple([reduced_num, other_num]), perm return tuple([reduced_num, other_num]), perm

View File

@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
"""Define the grad rules of neural network related operations.""" """Define the grad rules of neural network related operations."""
import numpy as np
from mindspore import context from mindspore import context
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
@ -33,21 +32,23 @@ from mindspore.ops._utils.utils import range_op, get_1d_shape
@constexpr @constexpr
def bias_add_gradgrad_helper(shape, bias_shape, data_format): def bias_add_gradgrad_helper(shape, bias_shape, data_format):
"""Helper function of BiasGradGrad to calculate expanded shape.""" """Helper function of BiasGradGrad to calculate expanded shape."""
shape = np.array(shape).astype(np.int) new_shape = list(shape)
bias_shape = np.array(bias_shape).astype(np.int) new_bias_shape = list(bias_shape)
ones_1 = []
ones_2 = []
for _ in new_shape[2:]:
ones_1.append(1)
for _ in new_shape[:-1]:
ones_2.append(1)
if data_format == "NCHW": if data_format == "NCHW":
expanded_shape = np.concatenate([ expanded_shape = [1] + new_bias_shape + ones_1
np.ones_like(shape[:1]), tile_mults = [new_shape[0]] + [1] + new_shape[2:]
bias_shape,
np.ones_like(shape[2:])
], axis=0)
tile_mults = np.concatenate([shape[:1], [1], shape[2:]], axis=0)
else: else:
expanded_shape = np.concatenate([ expanded_shape = ones_2 + new_bias_shape
np.ones_like(shape[:-1]), tile_mults = new_shape[:-1] + [1]
bias_shape
], axis=0)
tile_mults = np.concatenate([shape[:-1], [1]], axis=0)
return tuple(expanded_shape), tuple(tile_mults) return tuple(expanded_shape), tuple(tile_mults)
@ -694,7 +695,7 @@ def get_bprop_softmax(self):
is_ascend = (device_target == "Ascend") is_ascend = (device_target == "Ascend")
def bprop(x, out, dout): def bprop(x, out, dout):
# dx = (dout - sum(dout * out)) * out # dx can be expressed as (dout - sum(dout * out)) * out
# This formula is correct only when the `axis` is the last dimension. # This formula is correct only when the `axis` is the last dimension.
# In order to support the scenario where the `axis` is other values, # In order to support the scenario where the `axis` is other values,
# we transpose the data of the `axis` dimension to the last dimension for calculation, # we transpose the data of the `axis` dimension to the last dimension for calculation,
@ -952,7 +953,7 @@ def get_bprop_top_kv2(self):
ind_2d = reshape_op(indices, (-1, ind_lastdim)) ind_2d = reshape_op(indices, (-1, ind_lastdim))
outerdim = shape_op(ind_2d)[0] outerdim = shape_op(ind_2d)[0]
# [0, outterdim, 2*outerdim, ..., (k-1)*outerdim] # range_flatten_index can be expressed as: [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
indices_dtype = dtype(indices) indices_dtype = dtype(indices)
range_flatten_index = range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype) range_flatten_index = range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype)
@ -979,7 +980,7 @@ def get_bprop_top_kv2(self):
ind_2d = reshape_op(indices, create_tensor_by_element((-1, ind_lastdim))) ind_2d = reshape_op(indices, create_tensor_by_element((-1, ind_lastdim)))
outerdim = dyn_shape(ind_2d)[0] outerdim = dyn_shape(ind_2d)[0]
# [0, outterdim, 2*outerdim, ..., (k-1)*outerdim] # range_flatten_index can be expressed as: [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
range_flatten_index = P.Range()(cast(0, mstype.int64), outerdim * in_lastdim, in_lastdim) range_flatten_index = P.Range()(cast(0, mstype.int64), outerdim * in_lastdim, in_lastdim)
# expand_dims to (k, 1), then broadcast # expand_dims to (k, 1), then broadcast

View File

@ -16,7 +16,6 @@
"""Define the grad rules of linalg related operations.""" """Define the grad rules of linalg related operations."""
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np
import mindspore import mindspore
from mindspore.ops import Tensor from mindspore.ops import Tensor
@ -83,11 +82,6 @@ def _make_tensor(value, dtype):
return Tensor(value, dtype) return Tensor(value, dtype)
@constexpr
def _make_zero_matrix(shape, dtype):
return Tensor(np.zeros(shape), dtype)
def _matrix_diag(diagonal): def _matrix_diag(diagonal):
"""Do matrix diagnoal""" """Do matrix diagnoal"""
diagonal_shape = _shape(diagonal) diagonal_shape = _shape(diagonal)

View File

@ -1856,7 +1856,10 @@ def get_bprop_inplace_index_add(self):
@constexpr @constexpr
def _fft_rank_offset(norm_shape, rank): def _fft_rank_offset(norm_shape, rank):
"""generate offset for fft with rank""" """generate offset for fft with rank"""
return np.product(norm_shape[-rank:]) norm_shape_product = 1
for i in norm_shape[-rank:]:
norm_shape_product *= i
return norm_shape_product
@constexpr @constexpr
@ -1908,19 +1911,28 @@ def _get_last_dim_slice_shape(tensor_shape, index):
@constexpr @constexpr
def _rfft_reshape(shape_a, shape_b): def _rfft_reshape(shape_a, shape_b):
"""generate rfft shape for reshape""" """generate rfft shape for reshape"""
return tuple([int(x) for x in np.concatenate((np.ones(len(shape_a) - 2), shape_b), 0)]) new_shape = list(shape_b)
for i in range(len(shape_a) - 2):
new_shape.insert(i, 1)
return tuple(new_shape)
@constexpr @constexpr
def _rfft_tile_reshape(shape_a): def _rfft_tile_reshape(shape_a):
"""generate rfft shape for tile""" """generate rfft shape for tile"""
return tuple(int(x) for x in np.concatenate((shape_a[:-2], [1, 1]), 0)) reshape_a = list(shape_a)
reshape_a[-2] = 1
reshape_a[-1] = 1
return tuple(reshape_a)
@constexpr @constexpr
def _rfft_last_term_shape(shape_a, shape_b): def _rfft_last_term_shape(shape_a, shape_b):
"""generate rfft shape for last term""" """generate rfft shape for last term"""
return tuple(int(x) for x in np.concatenate((np.ones(len(shape_a)-1), shape_b), 0)) new_shape = list(shape_b)
for i in range(len(shape_a) - 1):
new_shape.insert(i, 1)
return tuple(new_shape)
@constexpr @constexpr

View File

@ -875,3 +875,190 @@ def test_constantpad1d():
pad1d = ConstantPad1d(padding, value) pad1d = ConstantPad1d(padding, value)
output = pad1d(x) output = pad1d(x)
assert np.allclose(output.asnumpy(), expect) assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_generate_inverse_index():
"""
Feature: ms_len_with_iterable_check func
Description: Verify the result of ms_len_with_iterable_check
Expectation: success
"""
from mindspore.ops._grad.grad_array_ops import _generate_inverse_index
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = _generate_inverse_index
def construct(self, x_shape, axis):
return self.func(x_shape, axis)
x_shape = (3, 2, 3)
axis = 2
net = Net()
assert net(x_shape, axis) == (1, 2, 0)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_split_shape_index():
"""
Feature: ms_len_with_iterable_check func
Description: Verify the result of ms_len_with_iterable_check
Expectation: success
"""
from mindspore.ops._grad.grad_math_ops import _split_shape_index
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = _split_shape_index
def construct(self, input_shape, axis):
return self.func(input_shape, axis)
input_shape = (2, 3, 4, 4)
axis = 3
net = Net()
out1, out2 = net(input_shape, axis)
assert out1 == (4, 24)
assert out2 == (3, 0, 1, 2)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_bias_add_gradgrad_helper():
"""
Feature: ms_len_with_iterable_check func
Description: Verify the result of ms_len_with_iterable_check
Expectation: success
"""
from mindspore.ops._grad.grad_nn_ops import bias_add_gradgrad_helper
class Net(nn.Cell):
def __init__(self, data_format):
super(Net, self).__init__()
self.func = bias_add_gradgrad_helper
self.data_format = data_format
def construct(self, shape, bias_shape):
return self.func(shape, bias_shape, self.data_format)
shape = (2, 3, 4, 4)
bias_shape = (1, 3)
data_format = "NCHW"
net = Net(data_format)
out1, out2 = net(shape, bias_shape)
assert out1 == (1, 1, 3, 1, 1)
assert out2 == (2, 1, 4, 4)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fft_rank_offset():
"""
Feature: ms_len_with_iterable_check func
Description: Verify the result of ms_len_with_iterable_check
Expectation: success
"""
from mindspore.ops._grad_experimental.grad_math_ops import _fft_rank_offset
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = _fft_rank_offset
def construct(self, norm_shape, rank):
return self.func(norm_shape, rank)
norm_shape = (1, 2, 3, 4, 5, 10)
rank = 3
net = Net()
assert net(norm_shape, rank) == 200
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_rfft_reshape():
"""
Feature: ms_len_with_iterable_check func
Description: Verify the result of ms_len_with_iterable_check
Expectation: success
"""
from mindspore.ops._grad_experimental.grad_math_ops import _rfft_reshape
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = _rfft_reshape
def construct(self, shape_a, shape_b):
return self.func(shape_a, shape_b)
shape_a = (1, 2, 3, 4, 5)
shape_b = (2, 5, 7, 8)
net = Net()
assert net(shape_a, shape_b) == (1, 1, 1, 2, 5, 7, 8)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_rfft_tile_reshape():
"""
Feature: ms_len_with_iterable_check func
Description: Verify the result of ms_len_with_iterable_check
Expectation: success
"""
from mindspore.ops._grad_experimental.grad_math_ops import _rfft_tile_reshape
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = _rfft_tile_reshape
def construct(self, shape_a):
return self.func(shape_a)
shape_a = (1, 2, 3, 4, 5)
net = Net()
assert net(shape_a) == (1, 2, 3, 1, 1)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_rfft_last_term_shape():
"""
Feature: ms_len_with_iterable_check func
Description: Verify the result of ms_len_with_iterable_check
Expectation: success
"""
from mindspore.ops._grad_experimental.grad_math_ops import _rfft_last_term_shape
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = _rfft_last_term_shape
def construct(self, shape_a, shape_b):
return self.func(shape_a, shape_b)
shape_a = (1, 2, 3, 4, 5)
shape_b = (2, 5, 7, 8)
net = Net()
assert net(shape_a, shape_b) == (1, 1, 1, 1, 2, 5, 7, 8)