From 1867baa88e9e6e29fc55e1596609b6d65f9f2026 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Tue, 16 Aug 2022 17:22:16 +0800 Subject: [PATCH] add vmaprule for minimum grad --- .../mindspore/ops/_vmap/vmap_array_ops.py | 2 +- .../python/mindspore/ops/_vmap/vmap_base.py | 35 +++++ .../mindspore/ops/_vmap/vmap_grad_nn_ops.py | 7 +- .../mindspore/ops/_vmap/vmap_math_ops.py | 84 +++++++++++- .../python/mindspore/ops/_vmap/vmap_nn_ops.py | 2 +- .../mindspore/ops/composite/vmap_ops.py | 2 +- tests/st/ops/cpu/test_minimum_grad_op.py | 125 ++++++++++++++++-- 7 files changed, 232 insertions(+), 25 deletions(-) diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_array_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_array_ops.py index 3ba01f378c6..83c6173115b 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_array_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_array_ops.py @@ -1804,7 +1804,7 @@ def get_stridedslice_vmap_rule(prim, axis_size): _raise_value_error("vmap of `StridedSlice` not support `begin`, `end` or `strided` has batch dimension, " "but got {}, {}, {}".format(begin_dim, end_dim, strided_dim)) - # x_dim is not None, and other is None + # x_dim is not None, and the others are None x = _bdim_at_front(x, x_dim, axis_size) new_begin, new_end, new_strided = get_new_begin_end_strided(begin, end, strided) result = batch_stridedslice(x, new_begin, new_end, new_strided) diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_base.py b/mindspore/python/mindspore/ops/_vmap/vmap_base.py index 75756694e84..5440ab52bea 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_base.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_base.py @@ -85,6 +85,41 @@ def _handle_broadcasting(x, x_shape, y_shape): return F.reshape(x, broadcast_shape) +@constexpr +def _get_broadcasting_with_front_axis_additional_axis(x_shape, y_shape): + """ Get the axes that are inserted after broadcasting. + Args: + x_shape (Tuple): The shape of x. + y_shape (Tuple): The shape of y using to broadcasting, whose batch axis is also in the front. + + Returns: + If there isn't any additional axis, return empty tuple, else, return tuple with additional axis. + """ + x_len = len(x_shape) + y_len = len(y_shape) + + additional_axis = () + + if x_len >= y_len: + return additional_axis + + # scalar case + if not x_len: + return tuple(range(0, y_len)) + + x_index = x_len - 1 + + for i in range(y_len - 1, 0, -1): + if x_index == 0: + additional_axis = (i,) + additional_axis + elif x_shape[x_index] == y_shape[i] or y_shape[i] == 1 or x_shape[x_index] == 1: + x_index -= 1 + else: + additional_axis = (i,) + additional_axis + + return additional_axis + + @constexpr def _raise_value_error(info, param=None): """Constexpr for raise_value_error.""" diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_grad_nn_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_grad_nn_ops.py index a974c8e7718..49aaa300b05 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_grad_nn_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_grad_nn_ops.py @@ -557,8 +557,6 @@ def get_layernormgrad_vmap_rule(prim, axis_size): x_rank = len(x_shape) begin_params_axis += x_rank batch_params_reduce_axes = tuple(range(1, begin_params_axis)) - if not batch_params_reduce_axes: - return None return batch_params_reduce_axes @constexpr @@ -590,19 +588,18 @@ def get_layernormgrad_vmap_rule(prim, axis_size): dy_shape = F.shape(dy) batch_params_reduce_axes = get_batch_params_reduce_axes(params_axis, dy_shape) - if batch_params_reduce_axes is None: + if not batch_params_reduce_axes: d_beta = dy else: d_beta = F.reduce_sum(dy, batch_params_reduce_axes) d_gamma_tmp = dy * (x - mean) / F.sqrt(var + eps) - if batch_params_reduce_axes is None: + if not batch_params_reduce_axes: d_gamma = d_gamma_tmp else: d_gamma = F.reduce_sum(d_gamma_tmp, batch_params_reduce_axes) gamma_shape = F.shape(gamma) - dy_shape = F.shape(dy) gamma = _handle_broadcasting(gamma, gamma_shape, dy_shape) dy = dy * gamma gamma_logical_shape = get_logical_shape(gamma_shape) diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_math_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_math_ops.py index 0115762a4ad..9d8d23d216e 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_math_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_math_ops.py @@ -25,13 +25,13 @@ from mindspore.ops.operations import math_ops from mindspore.ops.operations import linalg_ops from mindspore.ops.operations import _inner_ops from mindspore.ops.operations import _grad_ops as G -from ..primitive import Primitive -from ..composite import _VmapGeneralRule -from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_assign_vmap_rule, \ +from mindspore.ops.primitive import Primitive +from mindspore.ops.composite import _VmapGeneralRule +from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_assign_vmap_rule, \ get_unop_vmap_rule, _raise_value_error, _bdim_at_front, _broadcast_by_axis, _handle_broadcasting, \ - get_unary_grad_vmap_rule, _vmap_clone_prim, _bdim_at_any -from ..operations.math_ops import (Bernoulli, BesselJ0, BesselJ1, BesselK0, BesselK0e, BesselY0, BesselY1, BesselK1, - BesselK1e, Median) + get_unary_grad_vmap_rule, _vmap_clone_prim, _bdim_at_any, _get_broadcasting_with_front_axis_additional_axis +from mindspore.ops.operations.math_ops import Bernoulli, BesselJ0, BesselJ1, BesselK0, BesselK0e, BesselY0, \ + BesselY1, BesselK1, BesselK1e, Median @constexpr @@ -885,6 +885,76 @@ def get_square_sum_all_vmap_rule(prim, axis_size): return vmap_rule +@vmap_rules_getters.register('MaximumGrad') +@vmap_rules_getters.register('MinimumGrad') +def get_broadcast_binary_op_grad_vmap_rule(prim, axis_size): + """VmapRule for grad of binary operations with broadcasting""" + broadcast_binary_op_grad_map = { + "MinimumGrad": G.MinimumGrad, + "MaximumGrad": G.MaximumGrad + } + + if isinstance(prim, str): + prim = broadcast_binary_op_grad_map.get(prim)() + + @constexpr + def get_longest_shape(x_shape, y_shape, g_shape): + x_rank = len(x_shape) + y_rank = len(y_shape) + g_rank = len(g_shape) + if x_rank > y_rank: + if x_rank > g_rank: + return x_shape + else: + if y_rank > g_rank: + return y_shape + return g_shape + + def vmap_rule(x_bdim, y_bdim, grad_bdim): + is_all_none, result = vmap_general_preprocess(prim, x_bdim, y_bdim, grad_bdim) + if is_all_none: + return result + + x, x_dim = x_bdim + y, y_dim = y_bdim + g, g_dim = grad_bdim + + x_shape = F.shape(x) + y_shape = F.shape(y) + g_shape = F.shape(g) + + if x_dim == y_dim and x_dim == g_dim and \ + x_shape == y_shape and x_shape == g_shape: + dx, dy = prim(x, y, g) + return (dx, x_dim), (dy, y_dim) + + x = _bdim_at_front(x, x_dim, axis_size) + y = _bdim_at_front(y, y_dim, axis_size) + g = _bdim_at_front(g, g_dim, axis_size) + + x_shape = F.shape(x) + y_shape = F.shape(y) + g_shape = F.shape(g) + + longest_shape = get_longest_shape(x_shape, y_shape, g_shape) + x = _handle_broadcasting(x, x_shape, longest_shape) + y = _handle_broadcasting(y, y_shape, longest_shape) + g = _handle_broadcasting(g, g_shape, longest_shape) + + x_axis_for_reduce = _get_broadcasting_with_front_axis_additional_axis(x_shape, longest_shape) + y_axis_for_reduce = _get_broadcasting_with_front_axis_additional_axis(y_shape, longest_shape) + + dx, dy = prim(x, y, g) + if x_axis_for_reduce: + dx = F.reduce_sum(dx, x_axis_for_reduce) + + if y_axis_for_reduce: + dy = F.reduce_sum(dy, y_axis_for_reduce) + + return (dx, 0), (dy, 0) + return vmap_rule + + get_assign_vmap_rule = vmap_rules_getters.register(P.AssignAdd)(get_assign_vmap_rule) get_assign_vmap_rule = vmap_rules_getters.register(P.AssignSub)(get_assign_vmap_rule) # Unary vmap @@ -942,6 +1012,8 @@ get_unop_vmap_rule = vmap_rules_getters.register(BesselK1)(get_unop_vmap_rule) get_unop_vmap_rule = vmap_rules_getters.register(BesselK1e)(get_unop_vmap_rule) get_unop_vmap_rule = vmap_rules_getters.register(P.Trunc)(get_unop_vmap_rule) get_unop_vmap_rule = vmap_rules_getters.register(P.PopulationCount)(get_unop_vmap_rule) +get_unop_vmap_rule = vmap_rules_getters.register(P.Square)(get_unop_vmap_rule) + # UnaryGrad vmap get_unary_grad_vmap_rule = vmap_rules_getters.register(G.InvGrad)(get_unary_grad_vmap_rule) get_unary_grad_vmap_rule = vmap_rules_getters.register(G.LogitGrad)(get_unary_grad_vmap_rule) diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py index 7a55f58f937..b514b1f0445 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py @@ -1595,7 +1595,7 @@ def get_layernorm_vmap_rule(prim, axis_size): norm_axis = process_attr_axis(prim.begin_norm_axis) params_axis = process_attr_axis(prim.begin_params_axis) - batch_prim = P.LayerNorm(norm_axis, params_axis) + batch_prim = P.LayerNorm(norm_axis, params_axis, prim.epsilon) @constexpr def get_logical_shape(var_shape): diff --git a/mindspore/python/mindspore/ops/composite/vmap_ops.py b/mindspore/python/mindspore/ops/composite/vmap_ops.py index 1fb80d79d89..176e4afc06f 100644 --- a/mindspore/python/mindspore/ops/composite/vmap_ops.py +++ b/mindspore/python/mindspore/ops/composite/vmap_ops.py @@ -14,7 +14,7 @@ # ============================================================================ """Vmap operations.""" -from ..._c_expression import VmapGeneralPreprocess_, VmapGeneralRulePyAdapter_ +from mindspore._c_expression import VmapGeneralPreprocess_, VmapGeneralRulePyAdapter_ class _VmapGeneralPreprocess(VmapGeneralPreprocess_): diff --git a/tests/st/ops/cpu/test_minimum_grad_op.py b/tests/st/ops/cpu/test_minimum_grad_op.py index d4731046f9f..0565740eb3b 100644 --- a/tests/st/ops/cpu/test_minimum_grad_op.py +++ b/tests/st/ops/cpu/test_minimum_grad_op.py @@ -20,6 +20,7 @@ from mindspore import Tensor from mindspore.nn import Cell from mindspore.ops import composite as C from mindspore.ops.operations import Minimum +from mindspore.ops.functional import vmap context.set_context(mode=context.GRAPH_MODE, device_target="CPU") grad = C.GradOperation(get_all=True, sens_param=True) @@ -45,21 +46,21 @@ class GradWrap(Cell): return gout -def gen_data(inputA_np, inputB_np, grad_=None): - inputA_me = inputA_np - if isinstance(inputA_np, np.ndarray): - inputA_me = Tensor(inputA_me) +def gen_data(x_np, y_np, grad_=None): + x_me = x_np + if isinstance(x_np, np.ndarray): + x_me = Tensor(x_me) - inputB_me = inputB_np - if isinstance(inputB_np, np.ndarray): - inputB_me = Tensor(inputB_np) + y_me = y_np + if isinstance(y_np, np.ndarray): + y_me = Tensor(y_np) if grad_ is None: grad_ = Tensor(grad_) net_me = GradWrap(MinNetMe()) net_me.set_train() - output = net_me(inputA_me, inputB_me, Tensor(grad_)) + output = net_me(x_me, y_me, Tensor(grad_)) return output @@ -67,10 +68,10 @@ def gen_data(inputA_np, inputB_np, grad_=None): @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_min_tensor_grad_4d(): - inputA_np = np.random.randn(1, 3, 2, 2).astype(np.float32) - inputB_np = np.random.randn(1, 3, 2, 2).astype(np.float32) + x_np = np.random.randn(1, 3, 2, 2).astype(np.float32) + y_np = np.random.randn(1, 3, 2, 2).astype(np.float32) grad_ = np.random.randn(1, 3, 2, 2).astype(np.float32) - output = gen_data(inputA_np, inputB_np, grad_) + output = gen_data(x_np, y_np, grad_) print(output[0].asnumpy()) print(output[1].asnumpy()) @@ -169,3 +170,105 @@ def test_min_tensor_grad_result(): error1 = np.ones(shape=expect1.shape) * 1.0e-5 assert np.all(np.abs(output[0].asnumpy() - expect0) < error0) assert np.all(np.abs(output[1].asnumpy() - expect1) < error1) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_min_grad_vmap(): + """ + Feature: minimumgrad vmap + Description: test the minimumgrad vmap when in_axes=(0, 0, 0). + Expectation: match to np benchmark. + """ + net = GradWrap(MinNetMe()) + vmap_net = vmap(net, in_axes=(0, 0, 0)) + np.random.seed(20) + x = Tensor(np.random.rand(2, 4)) + y = Tensor(np.random.rand(2, 4)) + sens = Tensor(np.random.rand(2, 4)) + dx, dy = vmap_net(x, y, sens) + + expect0 = np.array([[0.11669374, 0., 0., 0.], + [0.85762554, 0.94977903, 0.56168687, 0.]]) + + expect1 = np.array([[0., 0.75128073, 0.23921822, 0.25480601], + [0., 0., 0., 0.17878053]]) + + assert np.allclose(dx.asnumpy(), expect0, rtol=1e-6, atol=1e-4) + assert np.allclose(dy.asnumpy(), expect1, rtol=1e-6, atol=1e-4) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_min_grad_vmap_none(): + """ + Feature: minimumgrad vmap + Description: test the minimumgrad vmap when in_axes=(0, None, None) . + Expectation: match to np benchmark. + """ + net = GradWrap(MinNetMe()) + vmap_net = vmap(net, in_axes=(0, None, None)) + np.random.seed(20) + x = Tensor(np.random.rand(2, 4)) + y = Tensor(np.random.rand(4)) + sens = Tensor(np.random.rand(4)) + dx, dy = vmap_net(x, y, sens) + + + expect0 = np.array([[0.78300363, 0., 0., 0.], + [0.78300363, 0., 0., 0.03666431]]) + + expect1 = np.array([[0., 0.85032761, 0.77524489, 0.03666431], + [0., 0.85032761, 0.77524489, 0.]]) + + assert np.allclose(dx.asnumpy(), expect0, rtol=1e-6, atol=1e-4) + assert np.allclose(dy.asnumpy(), expect1, rtol=1e-6, atol=1e-4) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_min_grad_vmap_scalar(): + """ + Feature: minimumgrad vmap + Description: test the minimumgrad vmap when x, y is scalar . + Expectation: match to np benchmark. + """ + net = GradWrap(MinNetMe()) + vmap_net = vmap(net, in_axes=(None, None, 0)) + np.random.seed(20) + x = Tensor(1.0) + y = Tensor(2.0) + sens = Tensor(np.random.rand(2).astype(np.float32)) + dx, dy = vmap_net(x, y, sens) + + + expect0 = np.array([0.5881308, 0.8977137]) + expect1 = np.array([0., 0.]) + assert np.allclose(dx.asnumpy(), expect0, rtol=1e-6, atol=1e-4) + assert np.allclose(dy.asnumpy(), expect1, rtol=1e-6, atol=1e-4) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_min_grad_vmap_special(): + """ + Feature: minimumgrad vmap + Description: test the minimumgrad vmap when x_rank smaller than y_rank and y_rank smaller than g_rank. + Expectation: match to np benchmark. + """ + net = GradWrap(MinNetMe()) + vmap_net = vmap(net, in_axes=(None, None, 0)) + np.random.seed(20) + x = Tensor(np.random.rand(1).astype(np.float32)) + y = Tensor(np.random.rand(1, 2).astype(np.float32)) + sens = Tensor(np.random.rand(2, 1, 2).astype(np.float32)) + dx, dy = vmap_net(x, y, sens) + + expect0 = np.array([[0.85172707], [1.0704385]]) + expect1 = np.array([[[0., 0.]], [[0., 0.]]]) + assert np.allclose(dx.asnumpy(), expect0, rtol=1e-6, atol=1e-4) + assert np.allclose(dy.asnumpy(), expect1, rtol=1e-6, atol=1e-4)