!40531 add vmaprule for MinimumGrad

Merge pull request !40531 from 吕浩宇/minimumgrad-vmap
This commit is contained in:
i-robot 2022-08-26 01:47:31 +00:00 committed by Gitee
commit 9fc588394f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 232 additions and 25 deletions

View File

@ -1811,7 +1811,7 @@ def get_stridedslice_vmap_rule(prim, axis_size):
_raise_value_error("vmap of `StridedSlice` not support `begin`, `end` or `strided` has batch dimension, "
"but got {}, {}, {}".format(begin_dim, end_dim, strided_dim))
# x_dim is not None, and other is None
# x_dim is not None, and the others are None
x = _bdim_at_front(x, x_dim, axis_size)
new_begin, new_end, new_strided = get_new_begin_end_strided(begin, end, strided)
result = batch_stridedslice(x, new_begin, new_end, new_strided)

View File

@ -85,6 +85,41 @@ def _handle_broadcasting(x, x_shape, y_shape):
return F.reshape(x, broadcast_shape)
@constexpr
def _get_broadcasting_with_front_axis_additional_axis(x_shape, y_shape):
""" Get the axes that are inserted after broadcasting.
Args:
x_shape (Tuple): The shape of x.
y_shape (Tuple): The shape of y using to broadcasting, whose batch axis is also in the front.
Returns:
If there isn't any additional axis, return empty tuple, else, return tuple with additional axis.
"""
x_len = len(x_shape)
y_len = len(y_shape)
additional_axis = ()
if x_len >= y_len:
return additional_axis
# scalar case
if not x_len:
return tuple(range(0, y_len))
x_index = x_len - 1
for i in range(y_len - 1, 0, -1):
if x_index == 0:
additional_axis = (i,) + additional_axis
elif x_shape[x_index] == y_shape[i] or y_shape[i] == 1 or x_shape[x_index] == 1:
x_index -= 1
else:
additional_axis = (i,) + additional_axis
return additional_axis
@constexpr
def _raise_value_error(info, param=None):
"""Constexpr for raise_value_error."""

View File

@ -557,8 +557,6 @@ def get_layernormgrad_vmap_rule(prim, axis_size):
x_rank = len(x_shape)
begin_params_axis += x_rank
batch_params_reduce_axes = tuple(range(1, begin_params_axis))
if not batch_params_reduce_axes:
return None
return batch_params_reduce_axes
@constexpr
@ -590,19 +588,18 @@ def get_layernormgrad_vmap_rule(prim, axis_size):
dy_shape = F.shape(dy)
batch_params_reduce_axes = get_batch_params_reduce_axes(params_axis, dy_shape)
if batch_params_reduce_axes is None:
if not batch_params_reduce_axes:
d_beta = dy
else:
d_beta = F.reduce_sum(dy, batch_params_reduce_axes)
d_gamma_tmp = dy * (x - mean) / F.sqrt(var + eps)
if batch_params_reduce_axes is None:
if not batch_params_reduce_axes:
d_gamma = d_gamma_tmp
else:
d_gamma = F.reduce_sum(d_gamma_tmp, batch_params_reduce_axes)
gamma_shape = F.shape(gamma)
dy_shape = F.shape(dy)
gamma = _handle_broadcasting(gamma, gamma_shape, dy_shape)
dy = dy * gamma
gamma_logical_shape = get_logical_shape(gamma_shape)

View File

@ -25,13 +25,13 @@ from mindspore.ops.operations import math_ops
from mindspore.ops.operations import linalg_ops
from mindspore.ops.operations import _inner_ops
from mindspore.ops.operations import _grad_ops as G
from ..primitive import Primitive
from ..composite import _VmapGeneralRule
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_assign_vmap_rule, \
from mindspore.ops.primitive import Primitive
from mindspore.ops.composite import _VmapGeneralRule
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_assign_vmap_rule, \
get_unop_vmap_rule, _raise_value_error, _bdim_at_front, _broadcast_by_axis, _handle_broadcasting, \
get_unary_grad_vmap_rule, _vmap_clone_prim, _bdim_at_any
from ..operations.math_ops import (Bernoulli, BesselJ0, BesselJ1, BesselK0, BesselK0e, BesselY0, BesselY1, BesselK1,
BesselK1e, Median)
get_unary_grad_vmap_rule, _vmap_clone_prim, _bdim_at_any, _get_broadcasting_with_front_axis_additional_axis
from mindspore.ops.operations.math_ops import Bernoulli, BesselJ0, BesselJ1, BesselK0, BesselK0e, BesselY0, \
BesselY1, BesselK1, BesselK1e, Median
@constexpr
@ -885,6 +885,76 @@ def get_square_sum_all_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register('MaximumGrad')
@vmap_rules_getters.register('MinimumGrad')
def get_broadcast_binary_op_grad_vmap_rule(prim, axis_size):
"""VmapRule for grad of binary operations with broadcasting"""
broadcast_binary_op_grad_map = {
"MinimumGrad": G.MinimumGrad,
"MaximumGrad": G.MaximumGrad
}
if isinstance(prim, str):
prim = broadcast_binary_op_grad_map.get(prim)()
@constexpr
def get_longest_shape(x_shape, y_shape, g_shape):
x_rank = len(x_shape)
y_rank = len(y_shape)
g_rank = len(g_shape)
if x_rank > y_rank:
if x_rank > g_rank:
return x_shape
else:
if y_rank > g_rank:
return y_shape
return g_shape
def vmap_rule(x_bdim, y_bdim, grad_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim, y_bdim, grad_bdim)
if is_all_none:
return result
x, x_dim = x_bdim
y, y_dim = y_bdim
g, g_dim = grad_bdim
x_shape = F.shape(x)
y_shape = F.shape(y)
g_shape = F.shape(g)
if x_dim == y_dim and x_dim == g_dim and \
x_shape == y_shape and x_shape == g_shape:
dx, dy = prim(x, y, g)
return (dx, x_dim), (dy, y_dim)
x = _bdim_at_front(x, x_dim, axis_size)
y = _bdim_at_front(y, y_dim, axis_size)
g = _bdim_at_front(g, g_dim, axis_size)
x_shape = F.shape(x)
y_shape = F.shape(y)
g_shape = F.shape(g)
longest_shape = get_longest_shape(x_shape, y_shape, g_shape)
x = _handle_broadcasting(x, x_shape, longest_shape)
y = _handle_broadcasting(y, y_shape, longest_shape)
g = _handle_broadcasting(g, g_shape, longest_shape)
x_axis_for_reduce = _get_broadcasting_with_front_axis_additional_axis(x_shape, longest_shape)
y_axis_for_reduce = _get_broadcasting_with_front_axis_additional_axis(y_shape, longest_shape)
dx, dy = prim(x, y, g)
if x_axis_for_reduce:
dx = F.reduce_sum(dx, x_axis_for_reduce)
if y_axis_for_reduce:
dy = F.reduce_sum(dy, y_axis_for_reduce)
return (dx, 0), (dy, 0)
return vmap_rule
get_assign_vmap_rule = vmap_rules_getters.register(P.AssignAdd)(get_assign_vmap_rule)
get_assign_vmap_rule = vmap_rules_getters.register(P.AssignSub)(get_assign_vmap_rule)
# Unary vmap
@ -942,6 +1012,8 @@ get_unop_vmap_rule = vmap_rules_getters.register(BesselK1)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(BesselK1e)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.Trunc)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.PopulationCount)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.Square)(get_unop_vmap_rule)
# UnaryGrad vmap
get_unary_grad_vmap_rule = vmap_rules_getters.register(G.InvGrad)(get_unary_grad_vmap_rule)
get_unary_grad_vmap_rule = vmap_rules_getters.register(G.LogitGrad)(get_unary_grad_vmap_rule)

View File

@ -1675,7 +1675,7 @@ def get_layernorm_vmap_rule(prim, axis_size):
norm_axis = process_attr_axis(prim.begin_norm_axis)
params_axis = process_attr_axis(prim.begin_params_axis)
batch_prim = P.LayerNorm(norm_axis, params_axis)
batch_prim = P.LayerNorm(norm_axis, params_axis, prim.epsilon)
@constexpr
def get_logical_shape(var_shape):

View File

@ -14,7 +14,7 @@
# ============================================================================
"""Vmap operations."""
from ..._c_expression import VmapGeneralPreprocess_, VmapGeneralRulePyAdapter_
from mindspore._c_expression import VmapGeneralPreprocess_, VmapGeneralRulePyAdapter_
class _VmapGeneralPreprocess(VmapGeneralPreprocess_):

View File

@ -20,6 +20,7 @@ from mindspore import Tensor
from mindspore.nn import Cell
from mindspore.ops import composite as C
from mindspore.ops.operations import Minimum
from mindspore.ops.functional import vmap
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
grad = C.GradOperation(get_all=True, sens_param=True)
@ -45,21 +46,21 @@ class GradWrap(Cell):
return gout
def gen_data(inputA_np, inputB_np, grad_=None):
inputA_me = inputA_np
if isinstance(inputA_np, np.ndarray):
inputA_me = Tensor(inputA_me)
def gen_data(x_np, y_np, grad_=None):
x_me = x_np
if isinstance(x_np, np.ndarray):
x_me = Tensor(x_me)
inputB_me = inputB_np
if isinstance(inputB_np, np.ndarray):
inputB_me = Tensor(inputB_np)
y_me = y_np
if isinstance(y_np, np.ndarray):
y_me = Tensor(y_np)
if grad_ is None:
grad_ = Tensor(grad_)
net_me = GradWrap(MinNetMe())
net_me.set_train()
output = net_me(inputA_me, inputB_me, Tensor(grad_))
output = net_me(x_me, y_me, Tensor(grad_))
return output
@ -67,10 +68,10 @@ def gen_data(inputA_np, inputB_np, grad_=None):
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_min_tensor_grad_4d():
inputA_np = np.random.randn(1, 3, 2, 2).astype(np.float32)
inputB_np = np.random.randn(1, 3, 2, 2).astype(np.float32)
x_np = np.random.randn(1, 3, 2, 2).astype(np.float32)
y_np = np.random.randn(1, 3, 2, 2).astype(np.float32)
grad_ = np.random.randn(1, 3, 2, 2).astype(np.float32)
output = gen_data(inputA_np, inputB_np, grad_)
output = gen_data(x_np, y_np, grad_)
print(output[0].asnumpy())
print(output[1].asnumpy())
@ -169,3 +170,105 @@ def test_min_tensor_grad_result():
error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(np.abs(output[0].asnumpy() - expect0) < error0)
assert np.all(np.abs(output[1].asnumpy() - expect1) < error1)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_min_grad_vmap():
"""
Feature: minimumgrad vmap
Description: test the minimumgrad vmap when in_axes=(0, 0, 0).
Expectation: match to np benchmark.
"""
net = GradWrap(MinNetMe())
vmap_net = vmap(net, in_axes=(0, 0, 0))
np.random.seed(20)
x = Tensor(np.random.rand(2, 4))
y = Tensor(np.random.rand(2, 4))
sens = Tensor(np.random.rand(2, 4))
dx, dy = vmap_net(x, y, sens)
expect0 = np.array([[0.11669374, 0., 0., 0.],
[0.85762554, 0.94977903, 0.56168687, 0.]])
expect1 = np.array([[0., 0.75128073, 0.23921822, 0.25480601],
[0., 0., 0., 0.17878053]])
assert np.allclose(dx.asnumpy(), expect0, rtol=1e-6, atol=1e-4)
assert np.allclose(dy.asnumpy(), expect1, rtol=1e-6, atol=1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_min_grad_vmap_none():
"""
Feature: minimumgrad vmap
Description: test the minimumgrad vmap when in_axes=(0, None, None) .
Expectation: match to np benchmark.
"""
net = GradWrap(MinNetMe())
vmap_net = vmap(net, in_axes=(0, None, None))
np.random.seed(20)
x = Tensor(np.random.rand(2, 4))
y = Tensor(np.random.rand(4))
sens = Tensor(np.random.rand(4))
dx, dy = vmap_net(x, y, sens)
expect0 = np.array([[0.78300363, 0., 0., 0.],
[0.78300363, 0., 0., 0.03666431]])
expect1 = np.array([[0., 0.85032761, 0.77524489, 0.03666431],
[0., 0.85032761, 0.77524489, 0.]])
assert np.allclose(dx.asnumpy(), expect0, rtol=1e-6, atol=1e-4)
assert np.allclose(dy.asnumpy(), expect1, rtol=1e-6, atol=1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_min_grad_vmap_scalar():
"""
Feature: minimumgrad vmap
Description: test the minimumgrad vmap when x, y is scalar .
Expectation: match to np benchmark.
"""
net = GradWrap(MinNetMe())
vmap_net = vmap(net, in_axes=(None, None, 0))
np.random.seed(20)
x = Tensor(1.0)
y = Tensor(2.0)
sens = Tensor(np.random.rand(2).astype(np.float32))
dx, dy = vmap_net(x, y, sens)
expect0 = np.array([0.5881308, 0.8977137])
expect1 = np.array([0., 0.])
assert np.allclose(dx.asnumpy(), expect0, rtol=1e-6, atol=1e-4)
assert np.allclose(dy.asnumpy(), expect1, rtol=1e-6, atol=1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_min_grad_vmap_special():
"""
Feature: minimumgrad vmap
Description: test the minimumgrad vmap when x_rank smaller than y_rank and y_rank smaller than g_rank.
Expectation: match to np benchmark.
"""
net = GradWrap(MinNetMe())
vmap_net = vmap(net, in_axes=(None, None, 0))
np.random.seed(20)
x = Tensor(np.random.rand(1).astype(np.float32))
y = Tensor(np.random.rand(1, 2).astype(np.float32))
sens = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
dx, dy = vmap_net(x, y, sens)
expect0 = np.array([[0.85172707], [1.0704385]])
expect1 = np.array([[[0., 0.]], [[0., 0.]]])
assert np.allclose(dx.asnumpy(), expect0, rtol=1e-6, atol=1e-4)
assert np.allclose(dy.asnumpy(), expect1, rtol=1e-6, atol=1e-4)