!40280 add reduce std operation vmap

Merge pull request !40280 from tan-wei-cheng-3260/develop-twc-master
This commit is contained in:
i-robot 2022-08-16 08:23:50 +00:00 committed by Gitee
commit 9f5361e2dd
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 23 additions and 1 deletions

View File

@ -446,7 +446,6 @@ def _get_reduce_out_dim(keep_dims, x_dim, x_ndim, batch_axis):
@vmap_rules_getters.register(P.ReduceMin)
@vmap_rules_getters.register(P.ReduceMean)
@vmap_rules_getters.register(P.ReduceProd)
@vmap_rules_getters.register(math_ops.ReduceStd)
def get_reducer_vmap_rule(prim, axis_size):
"""VmapRule for reduce operations, such as `ReduceSum`."""
keep_dims = prim.keep_dims
@ -622,6 +621,29 @@ def get_svd_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register(math_ops.ReduceStd)
def get_reducer_std_vmap_rule(prim, axis_size):
"""VmapRule for reduce operations, such as `ReduceStd`."""
axis = prim.axis
keep_dims = prim.keep_dims
unbiased = prim.unbiased
def vmap_rule(x_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
if is_all_none:
return result
x, x_dim = x_bdim
x_ndim = F.rank(x)
# LpNorm is a reduction class op, so just reuse the common function.
batch_axis = _get_reduce_batch_axis(axis, x_dim, x_ndim)
reduce_std = math_ops.ReduceStd(batch_axis, unbiased=unbiased, keep_dims=keep_dims)
out_std, out_mean = reduce_std(x)
out_dim = _get_reduce_out_dim(keep_dims, x_dim, x_ndim, batch_axis)
return (out_std, out_dim), (out_mean, out_dim)
return vmap_rule
@vmap_rules_getters.register(math_ops.LpNorm)
def get_lp_norm_vmap_rule(prim, axis_size):
"""VmapRule for 'LpNorm' operation."""