fix TensorCopySlices bprop

This commit is contained in:
caifubi 2021-07-06 15:52:43 +08:00
parent 3a195af6c0
commit bd70512806
2 changed files with 32 additions and 2 deletions

View File

@ -18,6 +18,7 @@
from .._grad.grad_base import bprop_getters
from ..operations import _inner_ops as inner
from .. import functional as F
from ..composite.multitype_ops.zeros_like_impl import zeros_like
@bprop_getters.register(inner.TensorCopySlices)
def get_bprop_tensor_copy_slices(self):
@ -25,8 +26,8 @@ def get_bprop_tensor_copy_slices(self):
tensor_copy_slices = inner.TensorCopySlices()
def bprop(x, update, begin, end, stride, out, dout):
x_grad = tensor_copy_slices(dout, F.zeros_like(update), begin, end, stride)
x_grad = tensor_copy_slices(dout, zeros_like(update), begin, end, stride)
update_grad = F.strided_slice(dout, begin, end, stride)
return x_grad, update_grad, F.zeros_like(begin), F.zeros_like(end), F.zeros_like(stride)
return x_grad, update_grad, zeros_like(begin), zeros_like(end), zeros_like(stride)
return bprop

View File

@ -18,7 +18,11 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import Parameter
import mindspore.common.dtype as mstype
from mindspore.ops.operations import _inner_ops
from mindspore import ops as P
from mindspore.ops import composite as C
class Net(nn.Cell):
def __init__(self):
@ -28,6 +32,16 @@ class Net(nn.Cell):
def construct(self, input_x, update, begin, end, strides):
return self.copy_slices(input_x, update, begin, end, strides)
class GradNet(nn.Cell):
def __init__(self):
super(GradNet, self).__init__()
self.input_x = Parameter(Tensor([[2, 3, 4, 5], [2, 3, 4, 5]], mstype.float32))
self.relu = P.ReLU()
def construct(self, input_y):
self.input_x[0:1] /= input_y
return self.relu(self.input_x)
def convert_begin_end_strides_to_slice(begin, end, strides):
result = []
for x, y, z in zip(begin, end, strides):
@ -60,6 +74,21 @@ def test_tensor_copy_slices():
test_tensor_copy_slices_net_many_dtype((10, 10, 10), (5, 10), (0, 5, 0), (1, 10, 10), (1, 1, 1,), support_dtype)
test_tensor_copy_slices_net_many_dtype((10, 10, 10), (5, 10), (9, 5,), (10, 10,), (1, 1,), support_dtype)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_tensor_copy_slices_bprop():
input_y = Tensor(2.0, mstype.float32)
net = GradNet()
net.set_grad()
grad_net = C.GradOperation(get_all=True, sens_param=True)(net)
output = net(input_y)
grad_output = grad_net(input_y, output)
assert np.allclose(grad_output[0].asnumpy(), np.array([-6.75]))
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training