!40143 [MS][LITE][parallel predict]mirror pad grad vmap
Merge pull request !40143 from yefeng/op_vmap_mirror_pad_grad
This commit is contained in:
commit
1f6eab5596
|
@ -494,3 +494,48 @@ def get_instance_norm_grad_rule(prim, axis_size):
|
|||
return (output_x, 0), (updated_moving_mean, 0), (updated_moving_variance, 0)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(G.MirrorPadGrad)
|
||||
def get_mirror_pad_grad_grad_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `MirrorPadGrad` operation."""
|
||||
input_max_dim = 4
|
||||
|
||||
def vmap_rule(*params_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, params_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
if len(params_bdim) < 2:
|
||||
_raise_value_error("The input params in `{}` must >= 2, but got {}.".format(prim.name, len(params_bdim)))
|
||||
input_x, input_x_dim = params_bdim[0]
|
||||
paddings, paddings_dim = params_bdim[1]
|
||||
|
||||
out = None
|
||||
x = _bdim_at_front(input_x, input_x_dim, axis_size)
|
||||
if paddings_dim is not None:
|
||||
_raise_value_error(
|
||||
"The source axis of `paddings` in `{}` must be None, but got {}.".format(prim.name, paddings_dim))
|
||||
pad_dim = F.shape(paddings)[0]
|
||||
x_ndim = F.rank(x)
|
||||
|
||||
if pad_dim == x_ndim and x_ndim <= input_max_dim:
|
||||
out = prim(x, paddings)
|
||||
elif x_ndim > input_max_dim:
|
||||
# reshape to 4 dims
|
||||
x_shape = F.shape(x)
|
||||
diff_dim = x_ndim - input_max_dim
|
||||
first_shape = 1
|
||||
for i in range(diff_dim + 1):
|
||||
first_shape *= x_shape[i]
|
||||
input_shape = (first_shape,) + x_shape[(-input_max_dim + 1):]
|
||||
x = F.reshape(x, input_shape)
|
||||
out = prim(x, paddings)
|
||||
out_shape = F.shape(out)
|
||||
real_out_shape = x_shape[:diff_dim + 1] + out_shape[1:]
|
||||
out = F.reshape(out, real_out_shape)
|
||||
else:
|
||||
_raise_value_error("The dim of `input_x` in `{}` must be bigger than {}, "
|
||||
"but got {}.".format(prim.name, pad_dim, x_ndim))
|
||||
return (out, 0)
|
||||
|
||||
return vmap_rule
|
||||
|
|
Loading…
Reference in New Issue