forked from mindspore-Ecosystem/mindspore
!49636 fix maskedfill bprop in dynamic rank
Merge pull request !49636 from r1chardf1d0/b2
This commit is contained in:
commit
fc63b105ce
|
@ -142,7 +142,12 @@ def get_bprop_masked_select(self):
|
||||||
dinput = mul_op(dout, (1 - mask))
|
dinput = mul_op(dout, (1 - mask))
|
||||||
dvalue = mul_op(dout, mask)
|
dvalue = mul_op(dout, mask)
|
||||||
dinput, dvalue = binop_grad_common(input_data, mask, dinput, dvalue)
|
dinput, dvalue = binop_grad_common(input_data, mask, dinput, dvalue)
|
||||||
dvalue = sum_op(dvalue)
|
# for dynamic rank, reduce axis should be calc
|
||||||
|
if F.is_sequence_shape_unknown(P.Shape()(dvalue)):
|
||||||
|
axis = P.Range()(Tensor(0), dyn_rank(dvalue), Tensor(1))
|
||||||
|
dvalue = sum_op(dvalue, axis)
|
||||||
|
else:
|
||||||
|
dvalue = sum_op(dvalue)
|
||||||
dinput = F.cast(dinput, F.dtype(input_data))
|
dinput = F.cast(dinput, F.dtype(input_data))
|
||||||
if is_instance_op(value, mstype.number):
|
if is_instance_op(value, mstype.number):
|
||||||
dvalue = 0
|
dvalue = 0
|
||||||
|
|
Loading…
Reference in New Issue