forked from mindspore-Ecosystem/mindspore
!19330 Fix bug of TensorCopySlices bprop
Merge pull request !19330 from caifubi/master-tensor-copy-slices
This commit is contained in:
commit
191b1a4fba
|
@ -25,7 +25,7 @@ def get_bprop_tensor_copy_slices(self):
|
|||
tensor_copy_slices = inner.TensorCopySlices()
|
||||
|
||||
def bprop(x, update, begin, end, stride, out, dout):
|
||||
x_grad = tensor_copy_slices(dout, F.zeros_like(update))
|
||||
x_grad = tensor_copy_slices(dout, F.zeros_like(update), begin, end, stride)
|
||||
update_grad = F.strided_slice(dout, begin, end, stride)
|
||||
return x_grad, update_grad, F.zeros_like(begin), F.zeros_like(end), F.zeros_like(stride)
|
||||
|
||||
|
|
Loading…
Reference in New Issue