forked from mindspore-Ecosystem/mindspore
!40531 add vmaprule for MinimumGrad
Merge pull request !40531 from 吕浩宇/minimumgrad-vmap
This commit is contained in:
commit
9fc588394f
|
@ -1811,7 +1811,7 @@ def get_stridedslice_vmap_rule(prim, axis_size):
|
|||
_raise_value_error("vmap of `StridedSlice` not support `begin`, `end` or `strided` has batch dimension, "
|
||||
"but got {}, {}, {}".format(begin_dim, end_dim, strided_dim))
|
||||
|
||||
# x_dim is not None, and other is None
|
||||
# x_dim is not None, and the others are None
|
||||
x = _bdim_at_front(x, x_dim, axis_size)
|
||||
new_begin, new_end, new_strided = get_new_begin_end_strided(begin, end, strided)
|
||||
result = batch_stridedslice(x, new_begin, new_end, new_strided)
|
||||
|
|
|
@ -85,6 +85,41 @@ def _handle_broadcasting(x, x_shape, y_shape):
|
|||
return F.reshape(x, broadcast_shape)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_broadcasting_with_front_axis_additional_axis(x_shape, y_shape):
|
||||
""" Get the axes that are inserted after broadcasting.
|
||||
Args:
|
||||
x_shape (Tuple): The shape of x.
|
||||
y_shape (Tuple): The shape of y using to broadcasting, whose batch axis is also in the front.
|
||||
|
||||
Returns:
|
||||
If there isn't any additional axis, return empty tuple, else, return tuple with additional axis.
|
||||
"""
|
||||
x_len = len(x_shape)
|
||||
y_len = len(y_shape)
|
||||
|
||||
additional_axis = ()
|
||||
|
||||
if x_len >= y_len:
|
||||
return additional_axis
|
||||
|
||||
# scalar case
|
||||
if not x_len:
|
||||
return tuple(range(0, y_len))
|
||||
|
||||
x_index = x_len - 1
|
||||
|
||||
for i in range(y_len - 1, 0, -1):
|
||||
if x_index == 0:
|
||||
additional_axis = (i,) + additional_axis
|
||||
elif x_shape[x_index] == y_shape[i] or y_shape[i] == 1 or x_shape[x_index] == 1:
|
||||
x_index -= 1
|
||||
else:
|
||||
additional_axis = (i,) + additional_axis
|
||||
|
||||
return additional_axis
|
||||
|
||||
|
||||
@constexpr
|
||||
def _raise_value_error(info, param=None):
|
||||
"""Constexpr for raise_value_error."""
|
||||
|
|
|
@ -557,8 +557,6 @@ def get_layernormgrad_vmap_rule(prim, axis_size):
|
|||
x_rank = len(x_shape)
|
||||
begin_params_axis += x_rank
|
||||
batch_params_reduce_axes = tuple(range(1, begin_params_axis))
|
||||
if not batch_params_reduce_axes:
|
||||
return None
|
||||
return batch_params_reduce_axes
|
||||
|
||||
@constexpr
|
||||
|
@ -590,19 +588,18 @@ def get_layernormgrad_vmap_rule(prim, axis_size):
|
|||
dy_shape = F.shape(dy)
|
||||
batch_params_reduce_axes = get_batch_params_reduce_axes(params_axis, dy_shape)
|
||||
|
||||
if batch_params_reduce_axes is None:
|
||||
if not batch_params_reduce_axes:
|
||||
d_beta = dy
|
||||
else:
|
||||
d_beta = F.reduce_sum(dy, batch_params_reduce_axes)
|
||||
|
||||
d_gamma_tmp = dy * (x - mean) / F.sqrt(var + eps)
|
||||
if batch_params_reduce_axes is None:
|
||||
if not batch_params_reduce_axes:
|
||||
d_gamma = d_gamma_tmp
|
||||
else:
|
||||
d_gamma = F.reduce_sum(d_gamma_tmp, batch_params_reduce_axes)
|
||||
|
||||
gamma_shape = F.shape(gamma)
|
||||
dy_shape = F.shape(dy)
|
||||
gamma = _handle_broadcasting(gamma, gamma_shape, dy_shape)
|
||||
dy = dy * gamma
|
||||
gamma_logical_shape = get_logical_shape(gamma_shape)
|
||||
|
|
|
@ -25,13 +25,13 @@ from mindspore.ops.operations import math_ops
|
|||
from mindspore.ops.operations import linalg_ops
|
||||
from mindspore.ops.operations import _inner_ops
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from ..primitive import Primitive
|
||||
from ..composite import _VmapGeneralRule
|
||||
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_assign_vmap_rule, \
|
||||
from mindspore.ops.primitive import Primitive
|
||||
from mindspore.ops.composite import _VmapGeneralRule
|
||||
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_assign_vmap_rule, \
|
||||
get_unop_vmap_rule, _raise_value_error, _bdim_at_front, _broadcast_by_axis, _handle_broadcasting, \
|
||||
get_unary_grad_vmap_rule, _vmap_clone_prim, _bdim_at_any
|
||||
from ..operations.math_ops import (Bernoulli, BesselJ0, BesselJ1, BesselK0, BesselK0e, BesselY0, BesselY1, BesselK1,
|
||||
BesselK1e, Median)
|
||||
get_unary_grad_vmap_rule, _vmap_clone_prim, _bdim_at_any, _get_broadcasting_with_front_axis_additional_axis
|
||||
from mindspore.ops.operations.math_ops import Bernoulli, BesselJ0, BesselJ1, BesselK0, BesselK0e, BesselY0, \
|
||||
BesselY1, BesselK1, BesselK1e, Median
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -885,6 +885,76 @@ def get_square_sum_all_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register('MaximumGrad')
|
||||
@vmap_rules_getters.register('MinimumGrad')
|
||||
def get_broadcast_binary_op_grad_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for grad of binary operations with broadcasting"""
|
||||
broadcast_binary_op_grad_map = {
|
||||
"MinimumGrad": G.MinimumGrad,
|
||||
"MaximumGrad": G.MaximumGrad
|
||||
}
|
||||
|
||||
if isinstance(prim, str):
|
||||
prim = broadcast_binary_op_grad_map.get(prim)()
|
||||
|
||||
@constexpr
|
||||
def get_longest_shape(x_shape, y_shape, g_shape):
|
||||
x_rank = len(x_shape)
|
||||
y_rank = len(y_shape)
|
||||
g_rank = len(g_shape)
|
||||
if x_rank > y_rank:
|
||||
if x_rank > g_rank:
|
||||
return x_shape
|
||||
else:
|
||||
if y_rank > g_rank:
|
||||
return y_shape
|
||||
return g_shape
|
||||
|
||||
def vmap_rule(x_bdim, y_bdim, grad_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, x_bdim, y_bdim, grad_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
|
||||
x, x_dim = x_bdim
|
||||
y, y_dim = y_bdim
|
||||
g, g_dim = grad_bdim
|
||||
|
||||
x_shape = F.shape(x)
|
||||
y_shape = F.shape(y)
|
||||
g_shape = F.shape(g)
|
||||
|
||||
if x_dim == y_dim and x_dim == g_dim and \
|
||||
x_shape == y_shape and x_shape == g_shape:
|
||||
dx, dy = prim(x, y, g)
|
||||
return (dx, x_dim), (dy, y_dim)
|
||||
|
||||
x = _bdim_at_front(x, x_dim, axis_size)
|
||||
y = _bdim_at_front(y, y_dim, axis_size)
|
||||
g = _bdim_at_front(g, g_dim, axis_size)
|
||||
|
||||
x_shape = F.shape(x)
|
||||
y_shape = F.shape(y)
|
||||
g_shape = F.shape(g)
|
||||
|
||||
longest_shape = get_longest_shape(x_shape, y_shape, g_shape)
|
||||
x = _handle_broadcasting(x, x_shape, longest_shape)
|
||||
y = _handle_broadcasting(y, y_shape, longest_shape)
|
||||
g = _handle_broadcasting(g, g_shape, longest_shape)
|
||||
|
||||
x_axis_for_reduce = _get_broadcasting_with_front_axis_additional_axis(x_shape, longest_shape)
|
||||
y_axis_for_reduce = _get_broadcasting_with_front_axis_additional_axis(y_shape, longest_shape)
|
||||
|
||||
dx, dy = prim(x, y, g)
|
||||
if x_axis_for_reduce:
|
||||
dx = F.reduce_sum(dx, x_axis_for_reduce)
|
||||
|
||||
if y_axis_for_reduce:
|
||||
dy = F.reduce_sum(dy, y_axis_for_reduce)
|
||||
|
||||
return (dx, 0), (dy, 0)
|
||||
return vmap_rule
|
||||
|
||||
|
||||
get_assign_vmap_rule = vmap_rules_getters.register(P.AssignAdd)(get_assign_vmap_rule)
|
||||
get_assign_vmap_rule = vmap_rules_getters.register(P.AssignSub)(get_assign_vmap_rule)
|
||||
# Unary vmap
|
||||
|
@ -942,6 +1012,8 @@ get_unop_vmap_rule = vmap_rules_getters.register(BesselK1)(get_unop_vmap_rule)
|
|||
get_unop_vmap_rule = vmap_rules_getters.register(BesselK1e)(get_unop_vmap_rule)
|
||||
get_unop_vmap_rule = vmap_rules_getters.register(P.Trunc)(get_unop_vmap_rule)
|
||||
get_unop_vmap_rule = vmap_rules_getters.register(P.PopulationCount)(get_unop_vmap_rule)
|
||||
get_unop_vmap_rule = vmap_rules_getters.register(P.Square)(get_unop_vmap_rule)
|
||||
|
||||
# UnaryGrad vmap
|
||||
get_unary_grad_vmap_rule = vmap_rules_getters.register(G.InvGrad)(get_unary_grad_vmap_rule)
|
||||
get_unary_grad_vmap_rule = vmap_rules_getters.register(G.LogitGrad)(get_unary_grad_vmap_rule)
|
||||
|
|
|
@ -1675,7 +1675,7 @@ def get_layernorm_vmap_rule(prim, axis_size):
|
|||
|
||||
norm_axis = process_attr_axis(prim.begin_norm_axis)
|
||||
params_axis = process_attr_axis(prim.begin_params_axis)
|
||||
batch_prim = P.LayerNorm(norm_axis, params_axis)
|
||||
batch_prim = P.LayerNorm(norm_axis, params_axis, prim.epsilon)
|
||||
|
||||
@constexpr
|
||||
def get_logical_shape(var_shape):
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Vmap operations."""
|
||||
from ..._c_expression import VmapGeneralPreprocess_, VmapGeneralRulePyAdapter_
|
||||
from mindspore._c_expression import VmapGeneralPreprocess_, VmapGeneralRulePyAdapter_
|
||||
|
||||
|
||||
class _VmapGeneralPreprocess(VmapGeneralPreprocess_):
|
||||
|
|
|
@ -20,6 +20,7 @@ from mindspore import Tensor
|
|||
from mindspore.nn import Cell
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops.operations import Minimum
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
grad = C.GradOperation(get_all=True, sens_param=True)
|
||||
|
@ -45,21 +46,21 @@ class GradWrap(Cell):
|
|||
return gout
|
||||
|
||||
|
||||
def gen_data(inputA_np, inputB_np, grad_=None):
|
||||
inputA_me = inputA_np
|
||||
if isinstance(inputA_np, np.ndarray):
|
||||
inputA_me = Tensor(inputA_me)
|
||||
def gen_data(x_np, y_np, grad_=None):
|
||||
x_me = x_np
|
||||
if isinstance(x_np, np.ndarray):
|
||||
x_me = Tensor(x_me)
|
||||
|
||||
inputB_me = inputB_np
|
||||
if isinstance(inputB_np, np.ndarray):
|
||||
inputB_me = Tensor(inputB_np)
|
||||
y_me = y_np
|
||||
if isinstance(y_np, np.ndarray):
|
||||
y_me = Tensor(y_np)
|
||||
|
||||
if grad_ is None:
|
||||
grad_ = Tensor(grad_)
|
||||
|
||||
net_me = GradWrap(MinNetMe())
|
||||
net_me.set_train()
|
||||
output = net_me(inputA_me, inputB_me, Tensor(grad_))
|
||||
output = net_me(x_me, y_me, Tensor(grad_))
|
||||
return output
|
||||
|
||||
|
||||
|
@ -67,10 +68,10 @@ def gen_data(inputA_np, inputB_np, grad_=None):
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_min_tensor_grad_4d():
|
||||
inputA_np = np.random.randn(1, 3, 2, 2).astype(np.float32)
|
||||
inputB_np = np.random.randn(1, 3, 2, 2).astype(np.float32)
|
||||
x_np = np.random.randn(1, 3, 2, 2).astype(np.float32)
|
||||
y_np = np.random.randn(1, 3, 2, 2).astype(np.float32)
|
||||
grad_ = np.random.randn(1, 3, 2, 2).astype(np.float32)
|
||||
output = gen_data(inputA_np, inputB_np, grad_)
|
||||
output = gen_data(x_np, y_np, grad_)
|
||||
print(output[0].asnumpy())
|
||||
print(output[1].asnumpy())
|
||||
|
||||
|
@ -169,3 +170,105 @@ def test_min_tensor_grad_result():
|
|||
error1 = np.ones(shape=expect1.shape) * 1.0e-5
|
||||
assert np.all(np.abs(output[0].asnumpy() - expect0) < error0)
|
||||
assert np.all(np.abs(output[1].asnumpy() - expect1) < error1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_min_grad_vmap():
|
||||
"""
|
||||
Feature: minimumgrad vmap
|
||||
Description: test the minimumgrad vmap when in_axes=(0, 0, 0).
|
||||
Expectation: match to np benchmark.
|
||||
"""
|
||||
net = GradWrap(MinNetMe())
|
||||
vmap_net = vmap(net, in_axes=(0, 0, 0))
|
||||
np.random.seed(20)
|
||||
x = Tensor(np.random.rand(2, 4))
|
||||
y = Tensor(np.random.rand(2, 4))
|
||||
sens = Tensor(np.random.rand(2, 4))
|
||||
dx, dy = vmap_net(x, y, sens)
|
||||
|
||||
expect0 = np.array([[0.11669374, 0., 0., 0.],
|
||||
[0.85762554, 0.94977903, 0.56168687, 0.]])
|
||||
|
||||
expect1 = np.array([[0., 0.75128073, 0.23921822, 0.25480601],
|
||||
[0., 0., 0., 0.17878053]])
|
||||
|
||||
assert np.allclose(dx.asnumpy(), expect0, rtol=1e-6, atol=1e-4)
|
||||
assert np.allclose(dy.asnumpy(), expect1, rtol=1e-6, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_min_grad_vmap_none():
|
||||
"""
|
||||
Feature: minimumgrad vmap
|
||||
Description: test the minimumgrad vmap when in_axes=(0, None, None) .
|
||||
Expectation: match to np benchmark.
|
||||
"""
|
||||
net = GradWrap(MinNetMe())
|
||||
vmap_net = vmap(net, in_axes=(0, None, None))
|
||||
np.random.seed(20)
|
||||
x = Tensor(np.random.rand(2, 4))
|
||||
y = Tensor(np.random.rand(4))
|
||||
sens = Tensor(np.random.rand(4))
|
||||
dx, dy = vmap_net(x, y, sens)
|
||||
|
||||
|
||||
expect0 = np.array([[0.78300363, 0., 0., 0.],
|
||||
[0.78300363, 0., 0., 0.03666431]])
|
||||
|
||||
expect1 = np.array([[0., 0.85032761, 0.77524489, 0.03666431],
|
||||
[0., 0.85032761, 0.77524489, 0.]])
|
||||
|
||||
assert np.allclose(dx.asnumpy(), expect0, rtol=1e-6, atol=1e-4)
|
||||
assert np.allclose(dy.asnumpy(), expect1, rtol=1e-6, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_min_grad_vmap_scalar():
|
||||
"""
|
||||
Feature: minimumgrad vmap
|
||||
Description: test the minimumgrad vmap when x, y is scalar .
|
||||
Expectation: match to np benchmark.
|
||||
"""
|
||||
net = GradWrap(MinNetMe())
|
||||
vmap_net = vmap(net, in_axes=(None, None, 0))
|
||||
np.random.seed(20)
|
||||
x = Tensor(1.0)
|
||||
y = Tensor(2.0)
|
||||
sens = Tensor(np.random.rand(2).astype(np.float32))
|
||||
dx, dy = vmap_net(x, y, sens)
|
||||
|
||||
|
||||
expect0 = np.array([0.5881308, 0.8977137])
|
||||
expect1 = np.array([0., 0.])
|
||||
assert np.allclose(dx.asnumpy(), expect0, rtol=1e-6, atol=1e-4)
|
||||
assert np.allclose(dy.asnumpy(), expect1, rtol=1e-6, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_min_grad_vmap_special():
|
||||
"""
|
||||
Feature: minimumgrad vmap
|
||||
Description: test the minimumgrad vmap when x_rank smaller than y_rank and y_rank smaller than g_rank.
|
||||
Expectation: match to np benchmark.
|
||||
"""
|
||||
net = GradWrap(MinNetMe())
|
||||
vmap_net = vmap(net, in_axes=(None, None, 0))
|
||||
np.random.seed(20)
|
||||
x = Tensor(np.random.rand(1).astype(np.float32))
|
||||
y = Tensor(np.random.rand(1, 2).astype(np.float32))
|
||||
sens = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
|
||||
dx, dy = vmap_net(x, y, sens)
|
||||
|
||||
expect0 = np.array([[0.85172707], [1.0704385]])
|
||||
expect1 = np.array([[[0., 0.]], [[0., 0.]]])
|
||||
assert np.allclose(dx.asnumpy(), expect0, rtol=1e-6, atol=1e-4)
|
||||
assert np.allclose(dy.asnumpy(), expect1, rtol=1e-6, atol=1e-4)
|
||||
|
|
Loading…
Reference in New Issue