!48583 [MS][LITE][master] support reduce prod axis0

Merge pull request !48583 from Greatpan/reduce_prod_axi0_bugfix
This commit is contained in:
i-robot 2023-02-09 06:07:40 +00:00 committed by Gitee
commit f85bbf7d63
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 2 additions and 0 deletions

View File

@ -920,6 +920,8 @@ def get_bprop_reduceprod(self):
"""Grad definition for `Product` operation."""
# Expand dout to full input shape
input_shape = shape_op(x)
if input_shape == ():
return Tensor(1, x.dtype), zeros_like(axis)
if F.is_sequence_value_unknown(input_shape):
input_shape = dyn_shape_op(x)
input_shape = P.Cast()(input_shape, ms.int64)