!49636 fix maskedfill bprop in dynamic rank

Merge pull request !49636 from r1chardf1d0/b2
This commit is contained in:
i-robot 2023-03-06 02:52:09 +00:00 committed by Gitee
commit fc63b105ce
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 6 additions and 1 deletions

View File

@ -142,7 +142,12 @@ def get_bprop_masked_select(self):
dinput = mul_op(dout, (1 - mask))
dvalue = mul_op(dout, mask)
dinput, dvalue = binop_grad_common(input_data, mask, dinput, dvalue)
dvalue = sum_op(dvalue)
# for dynamic rank, reduce axis should be calc
if F.is_sequence_shape_unknown(P.Shape()(dvalue)):
axis = P.Range()(Tensor(0), dyn_rank(dvalue), Tensor(1))
dvalue = sum_op(dvalue, axis)
else:
dvalue = sum_op(dvalue)
dinput = F.cast(dinput, F.dtype(input_data))
if is_instance_op(value, mstype.number):
dvalue = 0