diff --git a/mindspore/core/ops/clip_by_norm_no_div_dum.cc b/mindspore/core/ops/clip_by_norm_no_div_sum.cc similarity index 100% rename from mindspore/core/ops/clip_by_norm_no_div_dum.cc rename to mindspore/core/ops/clip_by_norm_no_div_sum.cc diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_grad_nn_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_grad_nn_ops.py index 29911c22779..891a90dd948 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_grad_nn_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_grad_nn_ops.py @@ -494,3 +494,48 @@ def get_instance_norm_grad_rule(prim, axis_size): return (output_x, 0), (updated_moving_mean, 0), (updated_moving_variance, 0) return vmap_rule + + +@vmap_rules_getters.register(G.MirrorPadGrad) +def get_mirror_pad_grad_grad_vmap_rule(prim, axis_size): + """VmapRule for `MirrorPadGrad` operation.""" + input_max_dim = 4 + + def vmap_rule(*params_bdim): + is_all_none, result = vmap_general_preprocess(prim, params_bdim) + if is_all_none: + return result + if len(params_bdim) < 2: + _raise_value_error("The input params in `{}` must >= 2, but got {}.".format(prim.name, len(params_bdim))) + input_x, input_x_dim = params_bdim[0] + paddings, paddings_dim = params_bdim[1] + + out = None + x = _bdim_at_front(input_x, input_x_dim, axis_size) + if paddings_dim is not None: + _raise_value_error( + "The source axis of `paddings` in `{}` must be None, but got {}.".format(prim.name, paddings_dim)) + pad_dim = F.shape(paddings)[0] + x_ndim = F.rank(x) + + if pad_dim == x_ndim and x_ndim <= input_max_dim: + out = prim(x, paddings) + elif x_ndim > input_max_dim: + # reshape to 4 dims + x_shape = F.shape(x) + diff_dim = x_ndim - input_max_dim + first_shape = 1 + for i in range(diff_dim + 1): + first_shape *= x_shape[i] + input_shape = (first_shape,) + x_shape[(-input_max_dim + 1):] + x = F.reshape(x, input_shape) + out = prim(x, paddings) + out_shape = F.shape(out) + real_out_shape = x_shape[:diff_dim + 1] + out_shape[1:] + out = F.reshape(out, real_out_shape) + else: + _raise_value_error("The dim of `input_x` in `{}` must be bigger than {}, " + "but got {}.".format(prim.name, pad_dim, x_ndim)) + return (out, 0) + + return vmap_rule