forked from mindspore-Ecosystem/mindspore
!19330 Fix bug of TensorCopySlices bprop
Merge pull request !19330 from caifubi/master-tensor-copy-slices
This commit is contained in:
commit
191b1a4fba
|
@ -25,7 +25,7 @@ def get_bprop_tensor_copy_slices(self):
|
||||||
tensor_copy_slices = inner.TensorCopySlices()
|
tensor_copy_slices = inner.TensorCopySlices()
|
||||||
|
|
||||||
def bprop(x, update, begin, end, stride, out, dout):
|
def bprop(x, update, begin, end, stride, out, dout):
|
||||||
x_grad = tensor_copy_slices(dout, F.zeros_like(update))
|
x_grad = tensor_copy_slices(dout, F.zeros_like(update), begin, end, stride)
|
||||||
update_grad = F.strided_slice(dout, begin, end, stride)
|
update_grad = F.strided_slice(dout, begin, end, stride)
|
||||||
return x_grad, update_grad, F.zeros_like(begin), F.zeros_like(end), F.zeros_like(stride)
|
return x_grad, update_grad, F.zeros_like(begin), F.zeros_like(end), F.zeros_like(stride)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue