add avgpool grad vmap

This commit is contained in:
fan-jibin 2022-05-18 21:02:04 +08:00
parent 06b3764761
commit 0168650200
4 changed files with 114 additions and 11 deletions

View File

@ -21,7 +21,7 @@ from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import functional as F
from mindspore.ops import constexpr
from ..primitive import Primitive
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, _bdim_at_front
@vmap_rules_getters.register(G.NLLLossGrad)
@ -96,3 +96,32 @@ def get_nll_loss_grad_vmap_rule(prim, axis_size):
return (output, out_dim)
return vmap_rule
@vmap_rules_getters.register(G.AvgPoolGrad)
def get_avg_pool_grad_vmap_rule(prim, axis_size):
"""VmapRule for `AvgPoolGrad`."""
chw_reverse_index = -3
def vmap_rule(x_bdim, y_bdim, dy_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim, y_bdim, dy_bdim)
if is_all_none:
return result
x, x_dim = x_bdim
y, y_dim = y_bdim
dy, dy_dim = dy_bdim
x = _bdim_at_front(x, x_dim, axis_size)
y = _bdim_at_front(y, y_dim, axis_size)
dy = _bdim_at_front(dy, dy_dim, axis_size)
x_shape = F.shape(x)
y_shape = F.shape(y)
dy_shape = F.shape(dy)
x = F.reshape(x, (-1,) + x_shape[chw_reverse_index:])
y = F.reshape(y, (-1,) + y_shape[chw_reverse_index:])
dy = F.reshape(dy, (-1,) + dy_shape[chw_reverse_index:])
out = prim(x, y, dy)
out = F.reshape(out, x_shape)
return (out, 0)
return vmap_rule

View File

@ -21,8 +21,7 @@ from mindspore.ops import functional as F
from mindspore.ops import constexpr
from ..primitive import Primitive
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_unop_vmap_rule, \
_bdim_at_front, _bdim_at_back, \
get_unary_grad_vmap_rule, _raise_value_error
_bdim_at_front, _bdim_at_back, get_unary_grad_vmap_rule
@vmap_rules_getters.register(P.ApplyProximalAdagrad)
@ -291,21 +290,17 @@ def get_adaptive_avgpool2d_vmap_rule(prim, axis_size):
chw_reverse_index = -3
hw_reverse_index = -2
def vmap_rule(input_bdim, output_size_bdimm):
is_all_none, result = vmap_general_preprocess(prim, input_bdim, output_size_bdimm)
def vmap_rule(input_bdim):
is_all_none, result = vmap_general_preprocess(prim, input_bdim)
if is_all_none:
return result
input_x, x_dim = input_bdim
output_size, output_size_dim = output_size_bdimm
input_x = _bdim_at_front(input_x, x_dim, axis_size)
if output_size_dim is not None:
_raise_value_error("The source axis of `output_size` in `{}` must be None, "
"but got {}.".format(prim.name, output_size_dim))
x_shape = F.shape(input_x)
input_shape = (x_shape[0] * x_shape[1],) + x_shape[chw_reverse_index:]
input_shape = (-1,) + x_shape[chw_reverse_index:]
input_x = F.reshape(input_x, input_shape)
out = prim(output_size)(input_x)
out = prim(input_x)
out_shape = F.shape(out)
real_out_shape = x_shape[:hw_reverse_index] + out_shape[hw_reverse_index:]
out = F.reshape(out, real_out_shape)
@ -314,6 +309,30 @@ def get_adaptive_avgpool2d_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register(P.AvgPool)
def get_avgpool_vmap_rule(prim, axis_size):
"""VmapRule for `AvgPool`."""
chw_reverse_index = -3
def vmap_rule(x_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
if is_all_none:
return result
x, x_dim = x_bdim
x = _bdim_at_front(x, x_dim, axis_size)
x_shape = F.shape(x)
input_shape = (-1,) + x_shape[chw_reverse_index:]
x = F.reshape(x, input_shape)
out = prim(x)
out_shape = F.shape(out)
real_out_shape = x_shape[:chw_reverse_index] + out_shape[chw_reverse_index:]
out = F.reshape(out, real_out_shape)
return (out, 0)
return vmap_rule
@vmap_rules_getters.register(P.KLDivLoss)
def get_kl_div_loss_vmap_rule(prim, axis_size):
"""VmapRule for `KLDivLoss` operation."""

View File

@ -23,6 +23,7 @@ from mindspore import Tensor, ops
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
from mindspore.common.api import ms_function
from mindspore.ops.functional import vmap
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
@ -173,3 +174,20 @@ def test_net_graph_mode_fp64():
dx = grad_net(Tensor(x, mindspore.float64), output)
assert dx.asnumpy().shape == x.shape
assert (dx.asnumpy() == expect_dx).all
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_adaptive_avgpool_vmap():
"""
Feature: test vmap function.
Description: test adaptive avgpool op vmap.
Expectation: expect correct result.
"""
in_axes = -1
x = Tensor(np.random.randn(1, 2, 9, 9, 1, 2).astype(np.float32))
net = Net((3, 5))
nest_vmap = vmap(vmap(net, in_axes=in_axes, out_axes=0), in_axes=in_axes, out_axes=0)
out = nest_vmap(x)
assert out.shape == (2, 1, 1, 2, 3, 5)

View File

@ -22,6 +22,7 @@ import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore.ops import composite as C
from mindspore.ops.functional import vmap
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
@ -304,3 +305,39 @@ def test_avgpool3d_4():
[79.62501, 112.125, 113.625, 81.625],
[125.6875, 176.9375, 179.1875, 128.6875]]]]]).astype(np.float32)
assert np.allclose(actual_grad[0].asnumpy(), expect_grad)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_avgpool_vmap():
"""
Feature: test vmap function.
Description: test avgpool op vmap.
Expectation: expect correct result.
"""
in_axes = -1
x = Tensor(np.random.randn(1, 1, 6, 6, 3, 6).astype(np.float32))
net = AvgPool(dim=2, kernel_size=2, strides=2, pad_mode="VALID")
nest_vmap = vmap(vmap(net, in_axes=in_axes, out_axes=0), in_axes=in_axes, out_axes=0)
out = nest_vmap(x)
assert out.shape == (6, 3, 1, 1, 3, 3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_avgpool_grad_vmap():
"""
Feature: test vmap function.
Description: test avgpoolgrad op vmap.
Expectation: expect correct result.
"""
in_axes = -1
x = Tensor(np.random.randn(1, 1, 6, 6, 3, 6).astype(np.float32))
sens = Tensor(np.random.randn(1, 1, 3, 3, 3, 6).astype(np.float32))
avgpool = AvgPool(dim=2, kernel_size=2, strides=2, pad_mode="VALID")
net = AvgPoolGrad(avgpool)
nest_vmap = vmap(vmap(net, in_axes=in_axes, out_axes=0), in_axes=in_axes, out_axes=0)
out = nest_vmap(x, sens)
assert out[0].shape == (6, 3, 1, 1, 6, 6)