!49419 Fixed SparseReorder bprop with graph mode

Merge pull request !49419 from Bokai Li/master
This commit is contained in:
i-robot 2023-03-08 07:44:19 +00:00 committed by Gitee
commit f4678db84d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 2 additions and 2 deletions

View File

@ -133,7 +133,7 @@ def get_bprop_sparse_softmax(self):
default_values = _create_tensor(0, values.dtype) default_values = _create_tensor(0, values.dtype)
out_dout = mul(out, dout) out_dout = mul(out, dout)
sp_product = sparse_to_dense(indices, shape, out_dout, default_values) sp_product = sparse_to_dense(indices, shape, out_dout, default_values)
sum_reduced = -reduce_sum(sp_product, -1) sum_reduced = -1 * reduce_sum(sp_product, -1)
sp_sum = sparse_dense_cwise_add(indices, dout, shape, sum_reduced) sp_sum = sparse_dense_cwise_add(indices, dout, shape, sum_reduced)
grad_x = mul(sp_sum, out) grad_x = mul(sp_sum, out)
return zeros_like(indices), grad_x, zeros_like(shape) return zeros_like(indices), grad_x, zeros_like(shape)
@ -387,7 +387,7 @@ def get_bprop_sparse_reorder(self):
def bprop(indices, values, shape, out, dout): def bprop(indices, values, shape, out, dout):
num_entries = F.shape(indices)[0] num_entries = F.shape(indices)[0]
start = Tensor(0, dtype=mstype.int32) start = Tensor(0, dtype=mstype.int32)
limit = Tensor(num_entries, dtype=mstype.int32) limit = P.Cast()(num_entries, mstype.int32)
delta = Tensor(1, dtype=mstype.int32) delta = Tensor(1, dtype=mstype.int32)
entry_indices = range_op(start, limit, delta) entry_indices = range_op(start, limit, delta)
output = sparse_reorder_op(indices, entry_indices, shape) output = sparse_reorder_op(indices, entry_indices, shape)