!47618 remove grad constexpr

Merge pull request !47618 from hujiahui8/constexpr_1
This commit is contained in:
i-robot 2023-01-13 03:03:39 +00:00 committed by Gitee
commit 3b73258f5a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 234 additions and 37 deletions

View File

@ -400,12 +400,6 @@ def get_bprop_concat(self):
return bprop
@constexpr
def _slice_grad_pad(begins, sizes, shapes):
pads = tuple((begin, shape - begin - size) for begin, size, shape in zip(begins, sizes, shapes))
return pads
@bprop_getters.register(P.Slice)
def get_bprop_slice(self):
"""Generate bprop for Slice"""

View File

@ -15,7 +15,6 @@
"""Define the grad rules of math related operations."""
from functools import reduce
import numpy as np
import mindspore as ms
from mindspore import nn
@ -864,9 +863,19 @@ def _split_shape_index(input_shape, axis):
if isinstance(axis, int):
axis = tuple([axis])
reduction_indices = tuple([(i + rank) % rank for i in axis])
other_indices = tuple(set(range(rank)) - set(reduction_indices))
reduced_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in reduction_indices])
other_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in other_indices])
other_indices_list = []
for i in range(rank):
if i not in reduction_indices and i not in other_indices_list:
other_indices_list.append(i)
other_indices = tuple(other_indices_list)
reduced_list = [1] + [input_shape[i] for i in reduction_indices]
other_list = [1] + [input_shape[i] for i in other_indices]
reduced_num = 1
for i in reduced_list:
reduced_num = reduced_num * i
other_num = 1
for i in other_list:
other_num = other_num * i
perm = reduction_indices + other_indices
return tuple([reduced_num, other_num]), perm

View File

@ -14,7 +14,6 @@
# ============================================================================
"""Define the grad rules of neural network related operations."""
import numpy as np
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.common.tensor import Tensor
@ -33,21 +32,23 @@ from mindspore.ops._utils.utils import range_op, get_1d_shape
@constexpr
def bias_add_gradgrad_helper(shape, bias_shape, data_format):
"""Helper function of BiasGradGrad to calculate expanded shape."""
shape = np.array(shape).astype(np.int)
bias_shape = np.array(bias_shape).astype(np.int)
new_shape = list(shape)
new_bias_shape = list(bias_shape)
ones_1 = []
ones_2 = []
for _ in new_shape[2:]:
ones_1.append(1)
for _ in new_shape[:-1]:
ones_2.append(1)
if data_format == "NCHW":
expanded_shape = np.concatenate([
np.ones_like(shape[:1]),
bias_shape,
np.ones_like(shape[2:])
], axis=0)
tile_mults = np.concatenate([shape[:1], [1], shape[2:]], axis=0)
expanded_shape = [1] + new_bias_shape + ones_1
tile_mults = [new_shape[0]] + [1] + new_shape[2:]
else:
expanded_shape = np.concatenate([
np.ones_like(shape[:-1]),
bias_shape
], axis=0)
tile_mults = np.concatenate([shape[:-1], [1]], axis=0)
expanded_shape = ones_2 + new_bias_shape
tile_mults = new_shape[:-1] + [1]
return tuple(expanded_shape), tuple(tile_mults)
@ -694,7 +695,7 @@ def get_bprop_softmax(self):
is_ascend = (device_target == "Ascend")
def bprop(x, out, dout):
# dx = (dout - sum(dout * out)) * out
# dx can be expressed as (dout - sum(dout * out)) * out
# This formula is correct only when the `axis` is the last dimension.
# In order to support the scenario where the `axis` is other values,
# we transpose the data of the `axis` dimension to the last dimension for calculation,
@ -952,7 +953,7 @@ def get_bprop_top_kv2(self):
ind_2d = reshape_op(indices, (-1, ind_lastdim))
outerdim = shape_op(ind_2d)[0]
# [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
# range_flatten_index can be expressed as: [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
indices_dtype = dtype(indices)
range_flatten_index = range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype)
@ -979,7 +980,7 @@ def get_bprop_top_kv2(self):
ind_2d = reshape_op(indices, create_tensor_by_element((-1, ind_lastdim)))
outerdim = dyn_shape(ind_2d)[0]
# [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
# range_flatten_index can be expressed as: [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
range_flatten_index = P.Range()(cast(0, mstype.int64), outerdim * in_lastdim, in_lastdim)
# expand_dims to (k, 1), then broadcast

View File

@ -16,7 +16,6 @@
"""Define the grad rules of linalg related operations."""
from __future__ import absolute_import
import numpy as np
import mindspore
from mindspore.ops import Tensor
@ -83,11 +82,6 @@ def _make_tensor(value, dtype):
return Tensor(value, dtype)
@constexpr
def _make_zero_matrix(shape, dtype):
return Tensor(np.zeros(shape), dtype)
def _matrix_diag(diagonal):
"""Do matrix diagnoal"""
diagonal_shape = _shape(diagonal)

View File

@ -1856,7 +1856,10 @@ def get_bprop_inplace_index_add(self):
@constexpr
def _fft_rank_offset(norm_shape, rank):
"""generate offset for fft with rank"""
return np.product(norm_shape[-rank:])
norm_shape_product = 1
for i in norm_shape[-rank:]:
norm_shape_product *= i
return norm_shape_product
@constexpr
@ -1908,19 +1911,28 @@ def _get_last_dim_slice_shape(tensor_shape, index):
@constexpr
def _rfft_reshape(shape_a, shape_b):
"""generate rfft shape for reshape"""
return tuple([int(x) for x in np.concatenate((np.ones(len(shape_a) - 2), shape_b), 0)])
new_shape = list(shape_b)
for i in range(len(shape_a) - 2):
new_shape.insert(i, 1)
return tuple(new_shape)
@constexpr
def _rfft_tile_reshape(shape_a):
"""generate rfft shape for tile"""
return tuple(int(x) for x in np.concatenate((shape_a[:-2], [1, 1]), 0))
reshape_a = list(shape_a)
reshape_a[-2] = 1
reshape_a[-1] = 1
return tuple(reshape_a)
@constexpr
def _rfft_last_term_shape(shape_a, shape_b):
"""generate rfft shape for last term"""
return tuple(int(x) for x in np.concatenate((np.ones(len(shape_a)-1), shape_b), 0))
new_shape = list(shape_b)
for i in range(len(shape_a) - 1):
new_shape.insert(i, 1)
return tuple(new_shape)
@constexpr

View File

@ -875,3 +875,190 @@ def test_constantpad1d():
pad1d = ConstantPad1d(padding, value)
output = pad1d(x)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_generate_inverse_index():
"""
Feature: ms_len_with_iterable_check func
Description: Verify the result of ms_len_with_iterable_check
Expectation: success
"""
from mindspore.ops._grad.grad_array_ops import _generate_inverse_index
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = _generate_inverse_index
def construct(self, x_shape, axis):
return self.func(x_shape, axis)
x_shape = (3, 2, 3)
axis = 2
net = Net()
assert net(x_shape, axis) == (1, 2, 0)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_split_shape_index():
"""
Feature: ms_len_with_iterable_check func
Description: Verify the result of ms_len_with_iterable_check
Expectation: success
"""
from mindspore.ops._grad.grad_math_ops import _split_shape_index
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = _split_shape_index
def construct(self, input_shape, axis):
return self.func(input_shape, axis)
input_shape = (2, 3, 4, 4)
axis = 3
net = Net()
out1, out2 = net(input_shape, axis)
assert out1 == (4, 24)
assert out2 == (3, 0, 1, 2)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_bias_add_gradgrad_helper():
"""
Feature: ms_len_with_iterable_check func
Description: Verify the result of ms_len_with_iterable_check
Expectation: success
"""
from mindspore.ops._grad.grad_nn_ops import bias_add_gradgrad_helper
class Net(nn.Cell):
def __init__(self, data_format):
super(Net, self).__init__()
self.func = bias_add_gradgrad_helper
self.data_format = data_format
def construct(self, shape, bias_shape):
return self.func(shape, bias_shape, self.data_format)
shape = (2, 3, 4, 4)
bias_shape = (1, 3)
data_format = "NCHW"
net = Net(data_format)
out1, out2 = net(shape, bias_shape)
assert out1 == (1, 1, 3, 1, 1)
assert out2 == (2, 1, 4, 4)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fft_rank_offset():
"""
Feature: ms_len_with_iterable_check func
Description: Verify the result of ms_len_with_iterable_check
Expectation: success
"""
from mindspore.ops._grad_experimental.grad_math_ops import _fft_rank_offset
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = _fft_rank_offset
def construct(self, norm_shape, rank):
return self.func(norm_shape, rank)
norm_shape = (1, 2, 3, 4, 5, 10)
rank = 3
net = Net()
assert net(norm_shape, rank) == 200
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_rfft_reshape():
"""
Feature: ms_len_with_iterable_check func
Description: Verify the result of ms_len_with_iterable_check
Expectation: success
"""
from mindspore.ops._grad_experimental.grad_math_ops import _rfft_reshape
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = _rfft_reshape
def construct(self, shape_a, shape_b):
return self.func(shape_a, shape_b)
shape_a = (1, 2, 3, 4, 5)
shape_b = (2, 5, 7, 8)
net = Net()
assert net(shape_a, shape_b) == (1, 1, 1, 2, 5, 7, 8)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_rfft_tile_reshape():
"""
Feature: ms_len_with_iterable_check func
Description: Verify the result of ms_len_with_iterable_check
Expectation: success
"""
from mindspore.ops._grad_experimental.grad_math_ops import _rfft_tile_reshape
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = _rfft_tile_reshape
def construct(self, shape_a):
return self.func(shape_a)
shape_a = (1, 2, 3, 4, 5)
net = Net()
assert net(shape_a) == (1, 2, 3, 1, 1)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_rfft_last_term_shape():
"""
Feature: ms_len_with_iterable_check func
Description: Verify the result of ms_len_with_iterable_check
Expectation: success
"""
from mindspore.ops._grad_experimental.grad_math_ops import _rfft_last_term_shape
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = _rfft_last_term_shape
def construct(self, shape_a, shape_b):
return self.func(shape_a, shape_b)
shape_a = (1, 2, 3, 4, 5)
shape_b = (2, 5, 7, 8)
net = Net()
assert net(shape_a, shape_b) == (1, 1, 1, 1, 2, 5, 7, 8)