!10642 Fix-bug-of-gather-drop-negatives-without-default-parameter

From: @joylvliang
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2020-12-27 14:17:36 +08:00 committed by Gitee
commit 858f2a5c9c
2 changed files with 4 additions and 4 deletions

View File

@ -485,8 +485,7 @@ std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStru
{"opt_prepare", PrepareGroup},
{"cconv", CconvPass}};
std::vector<PassItem> kPynativePasses = {{"opt_grad_epilogue", OptPassGradEpilogueGroup},
{"opt_a", OptPassAGroup},
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
{"opt_b", OptPassBGroup},
{"cconv", CconvPass},
{"transform_top", TransformTopGraphPass},

View File

@ -779,7 +779,8 @@ def get_bprop_unsorted_segment_sum(self):
"""Generate bprop for UnsortedSegmentSum"""
def bprop(x, segment_ids, num_segments, out, dout):
return _gather_drop_negatives(dout, segment_ids)[0], zeros_like(segment_ids), zeros_like(num_segments)
return _gather_drop_negatives(dout, segment_ids, None, None)[0], zeros_like(segment_ids), \
zeros_like(num_segments)
return bprop
@ -827,7 +828,7 @@ def get_bprop_unsorted_segment_prod(self):
gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0)
prod_divided_by_x = gathered_prod / x
partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x)
gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices)
gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices, None)
dx = gathered_grad * partial_derivative
return dx, zeros_like(segment_ids), zeros_like(num_segments)