!40280 add reduce std operation vmap
Merge pull request !40280 from tan-wei-cheng-3260/develop-twc-master
This commit is contained in:
commit
9f5361e2dd
|
@ -446,7 +446,6 @@ def _get_reduce_out_dim(keep_dims, x_dim, x_ndim, batch_axis):
|
|||
@vmap_rules_getters.register(P.ReduceMin)
|
||||
@vmap_rules_getters.register(P.ReduceMean)
|
||||
@vmap_rules_getters.register(P.ReduceProd)
|
||||
@vmap_rules_getters.register(math_ops.ReduceStd)
|
||||
def get_reducer_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for reduce operations, such as `ReduceSum`."""
|
||||
keep_dims = prim.keep_dims
|
||||
|
@ -622,6 +621,29 @@ def get_svd_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(math_ops.ReduceStd)
|
||||
def get_reducer_std_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for reduce operations, such as `ReduceStd`."""
|
||||
axis = prim.axis
|
||||
keep_dims = prim.keep_dims
|
||||
unbiased = prim.unbiased
|
||||
|
||||
def vmap_rule(x_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
x, x_dim = x_bdim
|
||||
x_ndim = F.rank(x)
|
||||
# LpNorm is a reduction class op, so just reuse the common function.
|
||||
batch_axis = _get_reduce_batch_axis(axis, x_dim, x_ndim)
|
||||
reduce_std = math_ops.ReduceStd(batch_axis, unbiased=unbiased, keep_dims=keep_dims)
|
||||
out_std, out_mean = reduce_std(x)
|
||||
out_dim = _get_reduce_out_dim(keep_dims, x_dim, x_ndim, batch_axis)
|
||||
return (out_std, out_dim), (out_mean, out_dim)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(math_ops.LpNorm)
|
||||
def get_lp_norm_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for 'LpNorm' operation."""
|
||||
|
|
Loading…
Reference in New Issue