forked from mindspore-Ecosystem/mindspore
!12438 Add backward op for IndexAdd GPU
From: @TFbunny Reviewed-by: @robingrosman Signed-off-by: @robingrosman
This commit is contained in:
commit
57f6c17933
|
@ -49,6 +49,38 @@ MS_REG_GPU_KERNEL_TWO(
|
|||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherV2GpuFwdKernel, half, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherV2GpuFwdKernel, int, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherV2GpuFwdKernel, int, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
GatherV2GpuFwdKernel, int16_t, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
GatherV2GpuFwdKernel, int16_t, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
GatherV2GpuFwdKernel, int8_t, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
GatherV2GpuFwdKernel, int8_t, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherV2GpuFwdKernel, uint8_t, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherV2GpuFwdKernel, uint8_t, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(Gather,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
|
|
@ -59,3 +59,21 @@ template void GatherV2<double, int>(double *input, int *indices, double *output,
|
|||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template void GatherV2<double, int64_t>(double *input, int64_t *indices, double *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template void GatherV2<int, int>(int *input, int *indices, int *output, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template void GatherV2<int, int64_t>(int *input, int64_t *indices, int *output, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template void GatherV2<int16_t, int>(int16_t *input, int *indices, int16_t *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template void GatherV2<int16_t, int64_t>(int16_t *input, int64_t *indices, int16_t *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
template void GatherV2<int8_t, int>(int8_t *input, int *indices, int8_t *output, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template void GatherV2<int8_t, int64_t>(int8_t *input, int64_t *indices, int8_t *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template void GatherV2<uint8_t, int>(uint8_t *input, int *indices, uint8_t *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template void GatherV2<uint8_t, int64_t>(uint8_t *input, int64_t *indices, uint8_t *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
|
|
|
@ -1300,3 +1300,15 @@ def get_bprop_lin_space(self):
|
|||
return zeros_like(start), zeros_like(stop), zeros_like(num)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.IndexAdd)
|
||||
def get_bprop_index_add(self):
|
||||
"""Generate bprop for IndexAdd"""
|
||||
gather = P.Gather()
|
||||
_axis = self.axis
|
||||
|
||||
def bprop(input_x, indices, input_y, out, dout):
|
||||
return dout, zeros_like(indices), gather(dout, indices, _axis)
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -1178,3 +1178,183 @@ def test_gather1_float64():
|
|||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather1_int32():
|
||||
x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.int32).reshape(2, 3, 4, 5))
|
||||
indices = Tensor(np.array([1, 3, 4], dtype='i4'))
|
||||
expect = np.array([[[[1., 3., 4.],
|
||||
[6., 8., 9.],
|
||||
[11., 13., 14.],
|
||||
[16., 18., 19.]],
|
||||
|
||||
[[21., 23., 24.],
|
||||
[26., 28., 29.],
|
||||
[31., 33., 34.],
|
||||
[36., 38., 39.]],
|
||||
|
||||
[[41., 43., 44.],
|
||||
[46., 48., 49.],
|
||||
[51., 53., 54.],
|
||||
[56., 58., 59.]]],
|
||||
|
||||
[[[61., 63., 64.],
|
||||
[66., 68., 69.],
|
||||
[71., 73., 74.],
|
||||
[76., 78., 79.]],
|
||||
|
||||
[[81., 83., 84.],
|
||||
[86., 88., 89.],
|
||||
[91., 93., 94.],
|
||||
[96., 98., 99.]],
|
||||
|
||||
[[101., 103., 104.],
|
||||
[106., 108., 109.],
|
||||
[111., 113., 114.],
|
||||
[116., 118., 119.]]]]).astype(np.int32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
gather = GatherNet1()
|
||||
output = gather(x, indices)
|
||||
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather1_int16():
|
||||
x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.int16).reshape(2, 3, 4, 5))
|
||||
indices = Tensor(np.array([1, 3, 4], dtype='i4'))
|
||||
expect = np.array([[[[1., 3., 4.],
|
||||
[6., 8., 9.],
|
||||
[11., 13., 14.],
|
||||
[16., 18., 19.]],
|
||||
|
||||
[[21., 23., 24.],
|
||||
[26., 28., 29.],
|
||||
[31., 33., 34.],
|
||||
[36., 38., 39.]],
|
||||
|
||||
[[41., 43., 44.],
|
||||
[46., 48., 49.],
|
||||
[51., 53., 54.],
|
||||
[56., 58., 59.]]],
|
||||
|
||||
[[[61., 63., 64.],
|
||||
[66., 68., 69.],
|
||||
[71., 73., 74.],
|
||||
[76., 78., 79.]],
|
||||
|
||||
[[81., 83., 84.],
|
||||
[86., 88., 89.],
|
||||
[91., 93., 94.],
|
||||
[96., 98., 99.]],
|
||||
|
||||
[[101., 103., 104.],
|
||||
[106., 108., 109.],
|
||||
[111., 113., 114.],
|
||||
[116., 118., 119.]]]]).astype(np.int16)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
gather = GatherNet1()
|
||||
output = gather(x, indices)
|
||||
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather1_int8():
|
||||
x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.int8).reshape(2, 3, 4, 5))
|
||||
indices = Tensor(np.array([1, 3, 4], dtype='i4'))
|
||||
expect = np.array([[[[1., 3., 4.],
|
||||
[6., 8., 9.],
|
||||
[11., 13., 14.],
|
||||
[16., 18., 19.]],
|
||||
|
||||
[[21., 23., 24.],
|
||||
[26., 28., 29.],
|
||||
[31., 33., 34.],
|
||||
[36., 38., 39.]],
|
||||
|
||||
[[41., 43., 44.],
|
||||
[46., 48., 49.],
|
||||
[51., 53., 54.],
|
||||
[56., 58., 59.]]],
|
||||
|
||||
[[[61., 63., 64.],
|
||||
[66., 68., 69.],
|
||||
[71., 73., 74.],
|
||||
[76., 78., 79.]],
|
||||
|
||||
[[81., 83., 84.],
|
||||
[86., 88., 89.],
|
||||
[91., 93., 94.],
|
||||
[96., 98., 99.]],
|
||||
|
||||
[[101., 103., 104.],
|
||||
[106., 108., 109.],
|
||||
[111., 113., 114.],
|
||||
[116., 118., 119.]]]]).astype(np.int8)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
gather = GatherNet1()
|
||||
output = gather(x, indices)
|
||||
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather1_uint8():
|
||||
x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.uint8).reshape(2, 3, 4, 5))
|
||||
indices = Tensor(np.array([1, 3, 4], dtype='i4'))
|
||||
expect = np.array([[[[1., 3., 4.],
|
||||
[6., 8., 9.],
|
||||
[11., 13., 14.],
|
||||
[16., 18., 19.]],
|
||||
|
||||
[[21., 23., 24.],
|
||||
[26., 28., 29.],
|
||||
[31., 33., 34.],
|
||||
[36., 38., 39.]],
|
||||
|
||||
[[41., 43., 44.],
|
||||
[46., 48., 49.],
|
||||
[51., 53., 54.],
|
||||
[56., 58., 59.]]],
|
||||
|
||||
[[[61., 63., 64.],
|
||||
[66., 68., 69.],
|
||||
[71., 73., 74.],
|
||||
[76., 78., 79.]],
|
||||
|
||||
[[81., 83., 84.],
|
||||
[86., 88., 89.],
|
||||
[91., 93., 94.],
|
||||
[96., 98., 99.]],
|
||||
|
||||
[[101., 103., 104.],
|
||||
[106., 108., 109.],
|
||||
[111., 113., 114.],
|
||||
[116., 118., 119.]]]]).astype(np.uint8)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
gather = GatherNet1()
|
||||
output = gather(x, indices)
|
||||
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
|
|
@ -16,10 +16,12 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
|
||||
class NetIndexAdd(nn.Cell):
|
||||
|
@ -257,3 +259,110 @@ def test_index_add_invalid_inputs():
|
|||
net = NetIndexAdd(1)
|
||||
_ = net(Tensor(x), Tensor(idx), Tensor(y))
|
||||
assert "out of range" in str(info.value)
|
||||
|
||||
|
||||
class IndexAddGradNet(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(IndexAddGradNet, self).__init__()
|
||||
self.grad = C.GradOperation(get_all=True, sens_param=True)
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, idx, y, dout):
|
||||
out = self.grad(self.network)(x, idx, y, dout)
|
||||
return out
|
||||
|
||||
|
||||
def index_add_grad_with_type(nptype):
|
||||
net = NetIndexAdd(1)
|
||||
grad_net = IndexAddGradNet(net)
|
||||
x = Tensor(np.arange(15).reshape(5, 3).astype(nptype))
|
||||
y = Tensor(np.arange(5).reshape(5, 1).astype(nptype))
|
||||
dout = Tensor(np.array([[63., 64., 65.],
|
||||
[66., 67., 68.],
|
||||
[69., 70., 71.],
|
||||
[72., 73., 74.],
|
||||
[75., 76., 77.]]).astype(nptype))
|
||||
index = Tensor(np.array([1]), dtype=mindspore.int32)
|
||||
xgrad, _, ygrad = grad_net(x, index, y, dout)
|
||||
expect_xgrad = np.array([[63., 64., 65.],
|
||||
[66., 67., 68.],
|
||||
[69., 70., 71.],
|
||||
[72., 73., 74.],
|
||||
[75., 76., 77.]]).astype(nptype)
|
||||
expect_ygrad = np.array([[64.],
|
||||
[67.],
|
||||
[70.],
|
||||
[73.],
|
||||
[76.]]).astype(nptype)
|
||||
np.testing.assert_array_equal(xgrad.asnumpy(), expect_xgrad)
|
||||
np.testing.assert_array_equal(ygrad.asnumpy(), expect_ygrad)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_index_add_grad_float64():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
index_add_grad_with_type(np.float64)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
index_add_grad_with_type(np.float64)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_index_add_grad_float32():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
index_add_grad_with_type(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
index_add_grad_with_type(np.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_index_add_grad_float16():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
index_add_grad_with_type(np.float16)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
index_add_grad_with_type(np.float16)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_index_add_grad_int32():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
index_add_grad_with_type(np.int32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
index_add_grad_with_type(np.int32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_index_add_grad_int16():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
index_add_grad_with_type(np.int16)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
index_add_grad_with_type(np.int16)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_index_add_grad_int8():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
index_add_grad_with_type(np.int8)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
index_add_grad_with_type(np.int8)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_index_add_grad_uint8():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
index_add_grad_with_type(np.uint8)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
index_add_grad_with_type(np.uint8)
|
||||
|
|
Loading…
Reference in New Issue