forked from mindspore-Ecosystem/mindspore
!47618 remove grad constexpr
Merge pull request !47618 from hujiahui8/constexpr_1
This commit is contained in:
commit
3b73258f5a
|
@ -400,12 +400,6 @@ def get_bprop_concat(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _slice_grad_pad(begins, sizes, shapes):
|
|
||||||
pads = tuple((begin, shape - begin - size) for begin, size, shape in zip(begins, sizes, shapes))
|
|
||||||
return pads
|
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.Slice)
|
@bprop_getters.register(P.Slice)
|
||||||
def get_bprop_slice(self):
|
def get_bprop_slice(self):
|
||||||
"""Generate bprop for Slice"""
|
"""Generate bprop for Slice"""
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
|
|
||||||
"""Define the grad rules of math related operations."""
|
"""Define the grad rules of math related operations."""
|
||||||
|
|
||||||
from functools import reduce
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
|
@ -864,9 +863,19 @@ def _split_shape_index(input_shape, axis):
|
||||||
if isinstance(axis, int):
|
if isinstance(axis, int):
|
||||||
axis = tuple([axis])
|
axis = tuple([axis])
|
||||||
reduction_indices = tuple([(i + rank) % rank for i in axis])
|
reduction_indices = tuple([(i + rank) % rank for i in axis])
|
||||||
other_indices = tuple(set(range(rank)) - set(reduction_indices))
|
other_indices_list = []
|
||||||
reduced_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in reduction_indices])
|
for i in range(rank):
|
||||||
other_num = reduce(lambda x, y: x * y, [1] + [input_shape[i] for i in other_indices])
|
if i not in reduction_indices and i not in other_indices_list:
|
||||||
|
other_indices_list.append(i)
|
||||||
|
other_indices = tuple(other_indices_list)
|
||||||
|
reduced_list = [1] + [input_shape[i] for i in reduction_indices]
|
||||||
|
other_list = [1] + [input_shape[i] for i in other_indices]
|
||||||
|
reduced_num = 1
|
||||||
|
for i in reduced_list:
|
||||||
|
reduced_num = reduced_num * i
|
||||||
|
other_num = 1
|
||||||
|
for i in other_list:
|
||||||
|
other_num = other_num * i
|
||||||
perm = reduction_indices + other_indices
|
perm = reduction_indices + other_indices
|
||||||
return tuple([reduced_num, other_num]), perm
|
return tuple([reduced_num, other_num]), perm
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""Define the grad rules of neural network related operations."""
|
"""Define the grad rules of neural network related operations."""
|
||||||
import numpy as np
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
|
@ -33,21 +32,23 @@ from mindspore.ops._utils.utils import range_op, get_1d_shape
|
||||||
@constexpr
|
@constexpr
|
||||||
def bias_add_gradgrad_helper(shape, bias_shape, data_format):
|
def bias_add_gradgrad_helper(shape, bias_shape, data_format):
|
||||||
"""Helper function of BiasGradGrad to calculate expanded shape."""
|
"""Helper function of BiasGradGrad to calculate expanded shape."""
|
||||||
shape = np.array(shape).astype(np.int)
|
new_shape = list(shape)
|
||||||
bias_shape = np.array(bias_shape).astype(np.int)
|
new_bias_shape = list(bias_shape)
|
||||||
|
|
||||||
|
ones_1 = []
|
||||||
|
ones_2 = []
|
||||||
|
for _ in new_shape[2:]:
|
||||||
|
ones_1.append(1)
|
||||||
|
|
||||||
|
for _ in new_shape[:-1]:
|
||||||
|
ones_2.append(1)
|
||||||
|
|
||||||
if data_format == "NCHW":
|
if data_format == "NCHW":
|
||||||
expanded_shape = np.concatenate([
|
expanded_shape = [1] + new_bias_shape + ones_1
|
||||||
np.ones_like(shape[:1]),
|
tile_mults = [new_shape[0]] + [1] + new_shape[2:]
|
||||||
bias_shape,
|
|
||||||
np.ones_like(shape[2:])
|
|
||||||
], axis=0)
|
|
||||||
tile_mults = np.concatenate([shape[:1], [1], shape[2:]], axis=0)
|
|
||||||
else:
|
else:
|
||||||
expanded_shape = np.concatenate([
|
expanded_shape = ones_2 + new_bias_shape
|
||||||
np.ones_like(shape[:-1]),
|
tile_mults = new_shape[:-1] + [1]
|
||||||
bias_shape
|
|
||||||
], axis=0)
|
|
||||||
tile_mults = np.concatenate([shape[:-1], [1]], axis=0)
|
|
||||||
return tuple(expanded_shape), tuple(tile_mults)
|
return tuple(expanded_shape), tuple(tile_mults)
|
||||||
|
|
||||||
|
|
||||||
|
@ -694,7 +695,7 @@ def get_bprop_softmax(self):
|
||||||
is_ascend = (device_target == "Ascend")
|
is_ascend = (device_target == "Ascend")
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
# dx = (dout - sum(dout * out)) * out
|
# dx can be expressed as (dout - sum(dout * out)) * out
|
||||||
# This formula is correct only when the `axis` is the last dimension.
|
# This formula is correct only when the `axis` is the last dimension.
|
||||||
# In order to support the scenario where the `axis` is other values,
|
# In order to support the scenario where the `axis` is other values,
|
||||||
# we transpose the data of the `axis` dimension to the last dimension for calculation,
|
# we transpose the data of the `axis` dimension to the last dimension for calculation,
|
||||||
|
@ -952,7 +953,7 @@ def get_bprop_top_kv2(self):
|
||||||
ind_2d = reshape_op(indices, (-1, ind_lastdim))
|
ind_2d = reshape_op(indices, (-1, ind_lastdim))
|
||||||
outerdim = shape_op(ind_2d)[0]
|
outerdim = shape_op(ind_2d)[0]
|
||||||
|
|
||||||
# [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
|
# range_flatten_index can be expressed as: [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
|
||||||
indices_dtype = dtype(indices)
|
indices_dtype = dtype(indices)
|
||||||
range_flatten_index = range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype)
|
range_flatten_index = range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype)
|
||||||
|
|
||||||
|
@ -979,7 +980,7 @@ def get_bprop_top_kv2(self):
|
||||||
ind_2d = reshape_op(indices, create_tensor_by_element((-1, ind_lastdim)))
|
ind_2d = reshape_op(indices, create_tensor_by_element((-1, ind_lastdim)))
|
||||||
outerdim = dyn_shape(ind_2d)[0]
|
outerdim = dyn_shape(ind_2d)[0]
|
||||||
|
|
||||||
# [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
|
# range_flatten_index can be expressed as: [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
|
||||||
range_flatten_index = P.Range()(cast(0, mstype.int64), outerdim * in_lastdim, in_lastdim)
|
range_flatten_index = P.Range()(cast(0, mstype.int64), outerdim * in_lastdim, in_lastdim)
|
||||||
|
|
||||||
# expand_dims to (k, 1), then broadcast
|
# expand_dims to (k, 1), then broadcast
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
"""Define the grad rules of linalg related operations."""
|
"""Define the grad rules of linalg related operations."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import mindspore
|
import mindspore
|
||||||
|
|
||||||
from mindspore.ops import Tensor
|
from mindspore.ops import Tensor
|
||||||
|
@ -83,11 +82,6 @@ def _make_tensor(value, dtype):
|
||||||
return Tensor(value, dtype)
|
return Tensor(value, dtype)
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _make_zero_matrix(shape, dtype):
|
|
||||||
return Tensor(np.zeros(shape), dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def _matrix_diag(diagonal):
|
def _matrix_diag(diagonal):
|
||||||
"""Do matrix diagnoal"""
|
"""Do matrix diagnoal"""
|
||||||
diagonal_shape = _shape(diagonal)
|
diagonal_shape = _shape(diagonal)
|
||||||
|
|
|
@ -1856,7 +1856,10 @@ def get_bprop_inplace_index_add(self):
|
||||||
@constexpr
|
@constexpr
|
||||||
def _fft_rank_offset(norm_shape, rank):
|
def _fft_rank_offset(norm_shape, rank):
|
||||||
"""generate offset for fft with rank"""
|
"""generate offset for fft with rank"""
|
||||||
return np.product(norm_shape[-rank:])
|
norm_shape_product = 1
|
||||||
|
for i in norm_shape[-rank:]:
|
||||||
|
norm_shape_product *= i
|
||||||
|
return norm_shape_product
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
@ -1908,19 +1911,28 @@ def _get_last_dim_slice_shape(tensor_shape, index):
|
||||||
@constexpr
|
@constexpr
|
||||||
def _rfft_reshape(shape_a, shape_b):
|
def _rfft_reshape(shape_a, shape_b):
|
||||||
"""generate rfft shape for reshape"""
|
"""generate rfft shape for reshape"""
|
||||||
return tuple([int(x) for x in np.concatenate((np.ones(len(shape_a) - 2), shape_b), 0)])
|
new_shape = list(shape_b)
|
||||||
|
for i in range(len(shape_a) - 2):
|
||||||
|
new_shape.insert(i, 1)
|
||||||
|
return tuple(new_shape)
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _rfft_tile_reshape(shape_a):
|
def _rfft_tile_reshape(shape_a):
|
||||||
"""generate rfft shape for tile"""
|
"""generate rfft shape for tile"""
|
||||||
return tuple(int(x) for x in np.concatenate((shape_a[:-2], [1, 1]), 0))
|
reshape_a = list(shape_a)
|
||||||
|
reshape_a[-2] = 1
|
||||||
|
reshape_a[-1] = 1
|
||||||
|
return tuple(reshape_a)
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _rfft_last_term_shape(shape_a, shape_b):
|
def _rfft_last_term_shape(shape_a, shape_b):
|
||||||
"""generate rfft shape for last term"""
|
"""generate rfft shape for last term"""
|
||||||
return tuple(int(x) for x in np.concatenate((np.ones(len(shape_a)-1), shape_b), 0))
|
new_shape = list(shape_b)
|
||||||
|
for i in range(len(shape_a) - 1):
|
||||||
|
new_shape.insert(i, 1)
|
||||||
|
return tuple(new_shape)
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
|
Binary file not shown.
|
@ -875,3 +875,190 @@ def test_constantpad1d():
|
||||||
pad1d = ConstantPad1d(padding, value)
|
pad1d = ConstantPad1d(padding, value)
|
||||||
output = pad1d(x)
|
output = pad1d(x)
|
||||||
assert np.allclose(output.asnumpy(), expect)
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_generate_inverse_index():
|
||||||
|
"""
|
||||||
|
Feature: ms_len_with_iterable_check func
|
||||||
|
Description: Verify the result of ms_len_with_iterable_check
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore.ops._grad.grad_array_ops import _generate_inverse_index
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = _generate_inverse_index
|
||||||
|
|
||||||
|
def construct(self, x_shape, axis):
|
||||||
|
return self.func(x_shape, axis)
|
||||||
|
x_shape = (3, 2, 3)
|
||||||
|
axis = 2
|
||||||
|
net = Net()
|
||||||
|
assert net(x_shape, axis) == (1, 2, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_split_shape_index():
|
||||||
|
"""
|
||||||
|
Feature: ms_len_with_iterable_check func
|
||||||
|
Description: Verify the result of ms_len_with_iterable_check
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore.ops._grad.grad_math_ops import _split_shape_index
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = _split_shape_index
|
||||||
|
|
||||||
|
def construct(self, input_shape, axis):
|
||||||
|
return self.func(input_shape, axis)
|
||||||
|
input_shape = (2, 3, 4, 4)
|
||||||
|
axis = 3
|
||||||
|
net = Net()
|
||||||
|
out1, out2 = net(input_shape, axis)
|
||||||
|
assert out1 == (4, 24)
|
||||||
|
assert out2 == (3, 0, 1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_bias_add_gradgrad_helper():
|
||||||
|
"""
|
||||||
|
Feature: ms_len_with_iterable_check func
|
||||||
|
Description: Verify the result of ms_len_with_iterable_check
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore.ops._grad.grad_nn_ops import bias_add_gradgrad_helper
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, data_format):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = bias_add_gradgrad_helper
|
||||||
|
self.data_format = data_format
|
||||||
|
|
||||||
|
def construct(self, shape, bias_shape):
|
||||||
|
return self.func(shape, bias_shape, self.data_format)
|
||||||
|
shape = (2, 3, 4, 4)
|
||||||
|
bias_shape = (1, 3)
|
||||||
|
data_format = "NCHW"
|
||||||
|
net = Net(data_format)
|
||||||
|
out1, out2 = net(shape, bias_shape)
|
||||||
|
assert out1 == (1, 1, 3, 1, 1)
|
||||||
|
assert out2 == (2, 1, 4, 4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_fft_rank_offset():
|
||||||
|
"""
|
||||||
|
Feature: ms_len_with_iterable_check func
|
||||||
|
Description: Verify the result of ms_len_with_iterable_check
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore.ops._grad_experimental.grad_math_ops import _fft_rank_offset
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = _fft_rank_offset
|
||||||
|
|
||||||
|
def construct(self, norm_shape, rank):
|
||||||
|
return self.func(norm_shape, rank)
|
||||||
|
norm_shape = (1, 2, 3, 4, 5, 10)
|
||||||
|
rank = 3
|
||||||
|
net = Net()
|
||||||
|
assert net(norm_shape, rank) == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_rfft_reshape():
|
||||||
|
"""
|
||||||
|
Feature: ms_len_with_iterable_check func
|
||||||
|
Description: Verify the result of ms_len_with_iterable_check
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore.ops._grad_experimental.grad_math_ops import _rfft_reshape
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = _rfft_reshape
|
||||||
|
|
||||||
|
def construct(self, shape_a, shape_b):
|
||||||
|
return self.func(shape_a, shape_b)
|
||||||
|
shape_a = (1, 2, 3, 4, 5)
|
||||||
|
shape_b = (2, 5, 7, 8)
|
||||||
|
net = Net()
|
||||||
|
assert net(shape_a, shape_b) == (1, 1, 1, 2, 5, 7, 8)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_rfft_tile_reshape():
|
||||||
|
"""
|
||||||
|
Feature: ms_len_with_iterable_check func
|
||||||
|
Description: Verify the result of ms_len_with_iterable_check
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore.ops._grad_experimental.grad_math_ops import _rfft_tile_reshape
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = _rfft_tile_reshape
|
||||||
|
|
||||||
|
def construct(self, shape_a):
|
||||||
|
return self.func(shape_a)
|
||||||
|
shape_a = (1, 2, 3, 4, 5)
|
||||||
|
net = Net()
|
||||||
|
assert net(shape_a) == (1, 2, 3, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_rfft_last_term_shape():
|
||||||
|
"""
|
||||||
|
Feature: ms_len_with_iterable_check func
|
||||||
|
Description: Verify the result of ms_len_with_iterable_check
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
from mindspore.ops._grad_experimental.grad_math_ops import _rfft_last_term_shape
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.func = _rfft_last_term_shape
|
||||||
|
|
||||||
|
def construct(self, shape_a, shape_b):
|
||||||
|
return self.func(shape_a, shape_b)
|
||||||
|
shape_a = (1, 2, 3, 4, 5)
|
||||||
|
shape_b = (2, 5, 7, 8)
|
||||||
|
net = Net()
|
||||||
|
assert net(shape_a, shape_b) == (1, 1, 1, 1, 2, 5, 7, 8)
|
||||||
|
|
Loading…
Reference in New Issue