!40143 [MS][LITE][parallel predict]mirror pad grad vmap

Merge pull request !40143 from yefeng/op_vmap_mirror_pad_grad
This commit is contained in:
i-robot 2022-08-10 07:10:51 +00:00 committed by Gitee
commit 1f6eab5596
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 45 additions and 0 deletions

View File

@ -494,3 +494,48 @@ def get_instance_norm_grad_rule(prim, axis_size):
return (output_x, 0), (updated_moving_mean, 0), (updated_moving_variance, 0)
return vmap_rule
@vmap_rules_getters.register(G.MirrorPadGrad)
def get_mirror_pad_grad_grad_vmap_rule(prim, axis_size):
"""VmapRule for `MirrorPadGrad` operation."""
input_max_dim = 4
def vmap_rule(*params_bdim):
is_all_none, result = vmap_general_preprocess(prim, params_bdim)
if is_all_none:
return result
if len(params_bdim) < 2:
_raise_value_error("The input params in `{}` must >= 2, but got {}.".format(prim.name, len(params_bdim)))
input_x, input_x_dim = params_bdim[0]
paddings, paddings_dim = params_bdim[1]
out = None
x = _bdim_at_front(input_x, input_x_dim, axis_size)
if paddings_dim is not None:
_raise_value_error(
"The source axis of `paddings` in `{}` must be None, but got {}.".format(prim.name, paddings_dim))
pad_dim = F.shape(paddings)[0]
x_ndim = F.rank(x)
if pad_dim == x_ndim and x_ndim <= input_max_dim:
out = prim(x, paddings)
elif x_ndim > input_max_dim:
# reshape to 4 dims
x_shape = F.shape(x)
diff_dim = x_ndim - input_max_dim
first_shape = 1
for i in range(diff_dim + 1):
first_shape *= x_shape[i]
input_shape = (first_shape,) + x_shape[(-input_max_dim + 1):]
x = F.reshape(x, input_shape)
out = prim(x, paddings)
out_shape = F.shape(out)
real_out_shape = x_shape[:diff_dim + 1] + out_shape[1:]
out = F.reshape(out, real_out_shape)
else:
_raise_value_error("The dim of `input_x` in `{}` must be bigger than {}, "
"but got {}.".format(prim.name, pad_dim, x_ndim))
return (out, 0)
return vmap_rule