forked from mindspore-Ecosystem/mindspore
add avgpool grad vmap
This commit is contained in:
parent
06b3764761
commit
0168650200
|
@ -21,7 +21,7 @@ from mindspore.ops.operations import _grad_ops as G
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import constexpr
|
||||
from ..primitive import Primitive
|
||||
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error
|
||||
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, _bdim_at_front
|
||||
|
||||
|
||||
@vmap_rules_getters.register(G.NLLLossGrad)
|
||||
|
@ -96,3 +96,32 @@ def get_nll_loss_grad_vmap_rule(prim, axis_size):
|
|||
return (output, out_dim)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(G.AvgPoolGrad)
|
||||
def get_avg_pool_grad_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `AvgPoolGrad`."""
|
||||
chw_reverse_index = -3
|
||||
|
||||
def vmap_rule(x_bdim, y_bdim, dy_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, x_bdim, y_bdim, dy_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
|
||||
x, x_dim = x_bdim
|
||||
y, y_dim = y_bdim
|
||||
dy, dy_dim = dy_bdim
|
||||
x = _bdim_at_front(x, x_dim, axis_size)
|
||||
y = _bdim_at_front(y, y_dim, axis_size)
|
||||
dy = _bdim_at_front(dy, dy_dim, axis_size)
|
||||
x_shape = F.shape(x)
|
||||
y_shape = F.shape(y)
|
||||
dy_shape = F.shape(dy)
|
||||
x = F.reshape(x, (-1,) + x_shape[chw_reverse_index:])
|
||||
y = F.reshape(y, (-1,) + y_shape[chw_reverse_index:])
|
||||
dy = F.reshape(dy, (-1,) + dy_shape[chw_reverse_index:])
|
||||
out = prim(x, y, dy)
|
||||
out = F.reshape(out, x_shape)
|
||||
return (out, 0)
|
||||
|
||||
return vmap_rule
|
||||
|
|
|
@ -21,8 +21,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore.ops import constexpr
|
||||
from ..primitive import Primitive
|
||||
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_unop_vmap_rule, \
|
||||
_bdim_at_front, _bdim_at_back, \
|
||||
get_unary_grad_vmap_rule, _raise_value_error
|
||||
_bdim_at_front, _bdim_at_back, get_unary_grad_vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.ApplyProximalAdagrad)
|
||||
|
@ -291,21 +290,17 @@ def get_adaptive_avgpool2d_vmap_rule(prim, axis_size):
|
|||
chw_reverse_index = -3
|
||||
hw_reverse_index = -2
|
||||
|
||||
def vmap_rule(input_bdim, output_size_bdimm):
|
||||
is_all_none, result = vmap_general_preprocess(prim, input_bdim, output_size_bdimm)
|
||||
def vmap_rule(input_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, input_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
|
||||
input_x, x_dim = input_bdim
|
||||
output_size, output_size_dim = output_size_bdimm
|
||||
input_x = _bdim_at_front(input_x, x_dim, axis_size)
|
||||
if output_size_dim is not None:
|
||||
_raise_value_error("The source axis of `output_size` in `{}` must be None, "
|
||||
"but got {}.".format(prim.name, output_size_dim))
|
||||
x_shape = F.shape(input_x)
|
||||
input_shape = (x_shape[0] * x_shape[1],) + x_shape[chw_reverse_index:]
|
||||
input_shape = (-1,) + x_shape[chw_reverse_index:]
|
||||
input_x = F.reshape(input_x, input_shape)
|
||||
out = prim(output_size)(input_x)
|
||||
out = prim(input_x)
|
||||
out_shape = F.shape(out)
|
||||
real_out_shape = x_shape[:hw_reverse_index] + out_shape[hw_reverse_index:]
|
||||
out = F.reshape(out, real_out_shape)
|
||||
|
@ -314,6 +309,30 @@ def get_adaptive_avgpool2d_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.AvgPool)
|
||||
def get_avgpool_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `AvgPool`."""
|
||||
chw_reverse_index = -3
|
||||
|
||||
def vmap_rule(x_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, x_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
|
||||
x, x_dim = x_bdim
|
||||
x = _bdim_at_front(x, x_dim, axis_size)
|
||||
x_shape = F.shape(x)
|
||||
input_shape = (-1,) + x_shape[chw_reverse_index:]
|
||||
x = F.reshape(x, input_shape)
|
||||
out = prim(x)
|
||||
out_shape = F.shape(out)
|
||||
real_out_shape = x_shape[:chw_reverse_index] + out_shape[chw_reverse_index:]
|
||||
out = F.reshape(out, real_out_shape)
|
||||
return (out, 0)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.KLDivLoss)
|
||||
def get_kl_div_loss_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `KLDivLoss` operation."""
|
||||
|
|
|
@ -23,6 +23,7 @@ from mindspore import Tensor, ops
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
|
||||
|
@ -173,3 +174,20 @@ def test_net_graph_mode_fp64():
|
|||
dx = grad_net(Tensor(x, mindspore.float64), output)
|
||||
assert dx.asnumpy().shape == x.shape
|
||||
assert (dx.asnumpy() == expect_dx).all
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_adaptive_avgpool_vmap():
|
||||
"""
|
||||
Feature: test vmap function.
|
||||
Description: test adaptive avgpool op vmap.
|
||||
Expectation: expect correct result.
|
||||
"""
|
||||
in_axes = -1
|
||||
x = Tensor(np.random.randn(1, 2, 9, 9, 1, 2).astype(np.float32))
|
||||
net = Net((3, 5))
|
||||
nest_vmap = vmap(vmap(net, in_axes=in_axes, out_axes=0), in_axes=in_axes, out_axes=0)
|
||||
out = nest_vmap(x)
|
||||
assert out.shape == (2, 1, 1, 2, 3, 5)
|
||||
|
|
|
@ -22,6 +22,7 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
@ -304,3 +305,39 @@ def test_avgpool3d_4():
|
|||
[79.62501, 112.125, 113.625, 81.625],
|
||||
[125.6875, 176.9375, 179.1875, 128.6875]]]]]).astype(np.float32)
|
||||
assert np.allclose(actual_grad[0].asnumpy(), expect_grad)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_avgpool_vmap():
|
||||
"""
|
||||
Feature: test vmap function.
|
||||
Description: test avgpool op vmap.
|
||||
Expectation: expect correct result.
|
||||
"""
|
||||
in_axes = -1
|
||||
x = Tensor(np.random.randn(1, 1, 6, 6, 3, 6).astype(np.float32))
|
||||
net = AvgPool(dim=2, kernel_size=2, strides=2, pad_mode="VALID")
|
||||
nest_vmap = vmap(vmap(net, in_axes=in_axes, out_axes=0), in_axes=in_axes, out_axes=0)
|
||||
out = nest_vmap(x)
|
||||
assert out.shape == (6, 3, 1, 1, 3, 3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_avgpool_grad_vmap():
|
||||
"""
|
||||
Feature: test vmap function.
|
||||
Description: test avgpoolgrad op vmap.
|
||||
Expectation: expect correct result.
|
||||
"""
|
||||
in_axes = -1
|
||||
x = Tensor(np.random.randn(1, 1, 6, 6, 3, 6).astype(np.float32))
|
||||
sens = Tensor(np.random.randn(1, 1, 3, 3, 3, 6).astype(np.float32))
|
||||
avgpool = AvgPool(dim=2, kernel_size=2, strides=2, pad_mode="VALID")
|
||||
net = AvgPoolGrad(avgpool)
|
||||
nest_vmap = vmap(vmap(net, in_axes=in_axes, out_axes=0), in_axes=in_axes, out_axes=0)
|
||||
out = nest_vmap(x, sens)
|
||||
assert out[0].shape == (6, 3, 1, 1, 6, 6)
|
||||
|
|
Loading…
Reference in New Issue