forked from mindspore-Ecosystem/mindspore
fix TensorCopySlices bprop
This commit is contained in:
parent
3a195af6c0
commit
bd70512806
|
@ -18,6 +18,7 @@
|
|||
from .._grad.grad_base import bprop_getters
|
||||
from ..operations import _inner_ops as inner
|
||||
from .. import functional as F
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
|
||||
@bprop_getters.register(inner.TensorCopySlices)
|
||||
def get_bprop_tensor_copy_slices(self):
|
||||
|
@ -25,8 +26,8 @@ def get_bprop_tensor_copy_slices(self):
|
|||
tensor_copy_slices = inner.TensorCopySlices()
|
||||
|
||||
def bprop(x, update, begin, end, stride, out, dout):
|
||||
x_grad = tensor_copy_slices(dout, F.zeros_like(update), begin, end, stride)
|
||||
x_grad = tensor_copy_slices(dout, zeros_like(update), begin, end, stride)
|
||||
update_grad = F.strided_slice(dout, begin, end, stride)
|
||||
return x_grad, update_grad, F.zeros_like(begin), F.zeros_like(end), F.zeros_like(stride)
|
||||
return x_grad, update_grad, zeros_like(begin), zeros_like(end), zeros_like(stride)
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -18,7 +18,11 @@ import pytest
|
|||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import Parameter
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops.operations import _inner_ops
|
||||
from mindspore import ops as P
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -28,6 +32,16 @@ class Net(nn.Cell):
|
|||
def construct(self, input_x, update, begin, end, strides):
|
||||
return self.copy_slices(input_x, update, begin, end, strides)
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GradNet, self).__init__()
|
||||
self.input_x = Parameter(Tensor([[2, 3, 4, 5], [2, 3, 4, 5]], mstype.float32))
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, input_y):
|
||||
self.input_x[0:1] /= input_y
|
||||
return self.relu(self.input_x)
|
||||
|
||||
def convert_begin_end_strides_to_slice(begin, end, strides):
|
||||
result = []
|
||||
for x, y, z in zip(begin, end, strides):
|
||||
|
@ -60,6 +74,21 @@ def test_tensor_copy_slices():
|
|||
test_tensor_copy_slices_net_many_dtype((10, 10, 10), (5, 10), (0, 5, 0), (1, 10, 10), (1, 1, 1,), support_dtype)
|
||||
test_tensor_copy_slices_net_many_dtype((10, 10, 10), (5, 10), (9, 5,), (10, 10,), (1, 1,), support_dtype)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_tensor_copy_slices_bprop():
|
||||
input_y = Tensor(2.0, mstype.float32)
|
||||
net = GradNet()
|
||||
net.set_grad()
|
||||
grad_net = C.GradOperation(get_all=True, sens_param=True)(net)
|
||||
|
||||
output = net(input_y)
|
||||
grad_output = grad_net(input_y, output)
|
||||
assert np.allclose(grad_output[0].asnumpy(), np.array([-6.75]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
|
Loading…
Reference in New Issue