add vmap rules for conv ops

This commit is contained in:
zlq2020 2022-07-15 14:42:33 +08:00
parent 3412fa4366
commit 361685bb72
8 changed files with 644 additions and 2 deletions

View File

@ -17,7 +17,7 @@
from __future__ import absolute_import
from . import vmap_base, vmap_array_ops, vmap_grad_nn_ops, vmap_debug_ops, vmap_math_ops, vmap_nn_ops,\
vmap_image_ops, vmap_other_ops, vmap_sparse_ops, vmap_random_ops
vmap_image_ops, vmap_other_ops, vmap_sparse_ops, vmap_random_ops, vmap_convolution_ops
from .vmap_base import get_vmap_rule, vmap_monad_rule, _broadcast_by_axis, vmap_bind_all_none,\
vmap_unstack, vmap_general_output_process

View File

@ -26,6 +26,7 @@ from ..composite import _VmapGeneralPreprocess
from ..primitive import Primitive
from ..operations.random_ops import UniformCandidateSampler
from ...common import Tensor
from ..operations import nn_ops as nps
vmap_rules_getters = Registry()
vmap_rules = Registry()
@ -394,4 +395,12 @@ _ops_vmap_clone_prim_dict = {"ApplyAdaMax": P.ApplyAdaMax,
"UniformCandidateSampler": UniformCandidateSampler,
"CdistGrad": G.CdistGrad,
"Cdist": P.Cdist,
"STFT": math_ops.STFT}
"STFT": math_ops.STFT,
"Conv2D": P.Conv2D,
"Conv3D": P.Conv3D,
"Conv2DTranspose": P.Conv2DTranspose,
"Conv2DBackpropInput": P.Conv2DBackpropInput,
"Conv3DTranspose": P.Conv3DTranspose,
"Conv3DBackpropInput": nps.Conv3DBackpropInput,
"Conv2DBackpropFilter": G.Conv2DBackpropFilter,
"Conv3DBackpropFilter": G.Conv3DBackpropFilter}

View File

@ -0,0 +1,442 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""convolution vmap impl"""
import numpy as np
import mindspore.numpy as mnp
from mindspore.ops import constexpr
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from ..operations import nn_ops as nps
from ..operations import _grad_ops as G
from ..primitive import Primitive
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, \
_vmap_update_prim_attr, _vmap_clone_prim
@vmap_rules_getters.register(P.Conv2D)
@vmap_rules_getters.register(P.Conv3D)
def get_conv_vmap_rule(prim, axis_size):
"""Vmap rule for `Conv2D` and `Conv3D` operations."""
if isinstance(prim, str):
prim = Primitive(prim)
attr_list = [prim.name, prim.group, prim.data_format]
new_prim = _vmap_clone_prim(prim)
def vmap_rule(input_bdim, weight_bdim):
is_all_none, result = vmap_general_preprocess(prim, input_bdim, weight_bdim)
if is_all_none:
return result
return _conv_vmap_rule(new_prim, axis_size, input_bdim, weight_bdim, attr_list)
return vmap_rule
@vmap_rules_getters.register(P.Conv2DTranspose)
@vmap_rules_getters.register(P.Conv2DBackpropInput)
def get_conv2d_transpose_vmap_rule(prim, axis_size):
"""Vmap rule for `Conv2DTranspose` and `Conv2DBackpropInput` operations."""
if isinstance(prim, str):
prim = Primitive(prim)
attr_list = [prim.name, prim.group, prim.data_format]
new_prim = _vmap_clone_prim(prim)
def vmap_rule(dout_bdim, weight_bdim, input_size_bdim):
is_all_none, result = vmap_general_preprocess(prim, dout_bdim, weight_bdim, input_size_bdim)
if is_all_none:
return result
return _conv_transpose_vmap_rule(new_prim, axis_size, dout_bdim, \
weight_bdim, input_size_bdim, attr_list)
return vmap_rule
@vmap_rules_getters.register(P.Conv3DTranspose)
def get_conv3d_transpose_vmap_rule(prim, axis_size):
"""Vmap rule for `Conv3DTranspose` operation."""
if isinstance(prim, str):
prim = Primitive(prim)
attr_list = [prim.name, prim.group, prim.data_format]
new_prim = _vmap_clone_prim(prim)
def vmap_rule(dout_bdim, weight_bdim):
is_all_none, result = vmap_general_preprocess(prim, dout_bdim, weight_bdim)
if is_all_none:
return result
return _conv_transpose_vmap_rule(new_prim, axis_size, dout_bdim, weight_bdim, None, attr_list)
return vmap_rule
@vmap_rules_getters.register(nps.Conv3DBackpropInput)
def get_conv3d_backprop_input_vmap_rule(prim, axis_size):
"""Vmap rule for `Conv3DBackpropInput` operation."""
if isinstance(prim, str):
prim = Primitive(prim)
attr_list = [prim.name, prim.group, prim.data_format]
new_prim = _vmap_clone_prim(prim)
def vmap_rule(weight_bdim, dout_bdim, input_size_bdim):
is_all_none, result = vmap_general_preprocess(prim, weight_bdim, dout_bdim, input_size_bdim)
if is_all_none:
return result
return _conv_transpose_vmap_rule(new_prim, axis_size, dout_bdim, \
weight_bdim, input_size_bdim, attr_list)
return vmap_rule
@vmap_rules_getters.register(G.Conv2DBackpropFilter)
def get_conv2d_backprop_filter_vmap_rule(prim, axis_size):
"""Vmap rule for `Conv2DBackpropFilter` operation."""
if isinstance(prim, str):
prim = Primitive(prim)
attr_list = [prim.name, prim.group, prim.data_format]
new_prim = _vmap_clone_prim(prim)
def vmap_rule(dout_bdim, input_x_bdim, weight_size_bdim):
is_all_none, result = vmap_general_preprocess(prim, dout_bdim, input_x_bdim, weight_size_bdim)
if is_all_none:
return result
return _conv_backprop_filter_vmap_rule(new_prim, axis_size, dout_bdim, \
input_x_bdim, weight_size_bdim, attr_list)
return vmap_rule
@vmap_rules_getters.register(G.Conv3DBackpropFilter)
def get_conv3d_backprop_filter_vmap_rule(prim, axis_size):
"""Vmap rule for `Conv3DBackpropFilter` operation."""
if isinstance(prim, str):
prim = Primitive(prim)
attr_list = [prim.name, prim.group, prim.data_format]
new_prim = _vmap_clone_prim(prim)
def vmap_rule(input_x_bdim, dout_bdim, weight_size_bdim):
is_all_none, result = vmap_general_preprocess(prim, input_x_bdim, dout_bdim, weight_size_bdim)
if is_all_none:
return result
return _conv_backprop_filter_vmap_rule(new_prim, axis_size, dout_bdim, \
input_x_bdim, weight_size_bdim, attr_list)
return vmap_rule
@constexpr
def _get_reshape_src_dim(data_dim, cmp_dim):
"""Get source dim for reshape"""
if data_dim > cmp_dim:
expand_dim = cmp_dim
merge_dim = data_dim + 1
else:
expand_dim = cmp_dim + 1
merge_dim = data_dim
return expand_dim, merge_dim
@constexpr
def _get_merge_shape(src_dim, dst_dim, shape):
"""Get new shape for merging the src_dim and dst_dim. The dst_dim is the value after removing src_dim."""
new_shape = [shape[i] for i in range(len(shape)) if i != src_dim]
new_shape[dst_dim] *= shape[src_dim]
return tuple(new_shape)
def _reshape_merge_dims(src_dim, dst_dim, target):
"""Reshape target by merging the src_dim and dst_dim."""
shape = F.shape(target)
new_shape = _get_merge_shape(src_dim, dst_dim, shape)
new_target = mnp.moveaxis(target, src_dim, dst_dim)
output = F.reshape(new_target, new_shape)
return output, new_shape
@constexpr
def _get_expand_shape(src_dim, dst_size, shape, prim_name):
"""Get new shape for splitting src_dim into dst_size parts."""
dst_size2, remainder = np.divmod(shape[src_dim], dst_size)
if remainder != 0:
_raise_value_error("The remainder of {} / {} should be 0, "
"but got {} in {}.".format(shape[src_dim], dst_size, remainder, prim_name))
new_shape = list(shape)
new_shape[src_dim:(src_dim + 1)] = [dst_size, dst_size2]
return tuple(new_shape)
def _reshape_expand_dims(src_dim, dst_size, target, prim_name):
"""Reshape target by splitting src_dim into dst_size parts."""
shape = F.shape(target)
new_shape = _get_expand_shape(src_dim, dst_size, shape, prim_name)
return F.reshape(target, new_shape)
@constexpr
def _get_new_size_by_index(input_size, batch_size, index):
"""Get the new size of input_size by multiplying input_size[index] by batch_size."""
if input_size is None:
return None
new_size = list(input_size)
new_size[index] *= batch_size
return tuple(new_size)
@constexpr
def _update_group_attr(prim, groups, batch_size):
"""Set new value for 'group' attribute of the convolution primitive."""
group = groups * batch_size
_vmap_update_prim_attr(prim, 'group', group)
_vmap_update_prim_attr(prim, 'groups', group)
@constexpr
def _get_channel_index(data_format, prim_name):
"""Get channel index by data_format, only supports NHWC/NCHW/NCDHW now."""
index = 0
if data_format == "NHWC":
index = 3
elif data_format in ("NCHW", "NCDHW"):
index = 1
else:
_raise_value_error("'data_format' in {} should be NHWC/NCHW/NCDHW, "
"but got {}.".format(prim_name, data_format))
return index
def _conv_vmap_rule(prim, batch_size, input_bdim, weight_bdim, attr_list):
"""Vmap rule for Convolution operations, such as `Conv2D` and `Conv3D`."""
input_x, x_dim = input_bdim
weight, w_dim = weight_bdim
prim_name = attr_list[0]
groups = attr_list[1]
data_format = attr_list[2]
c_axis = _get_channel_index(data_format, prim_name)
def _get_output_for_x_w_vmap():
new_input, _ = _reshape_merge_dims(x_dim, c_axis, input_x)
new_weight, new_w_shape = _reshape_merge_dims(w_dim, 0, weight)
_update_group_attr(prim, groups, batch_size)
_vmap_update_prim_attr(prim, 'out_channel', new_w_shape[0])
out = prim(new_input, new_weight)
out = _reshape_expand_dims(c_axis, batch_size, out, prim_name)
return (out, c_axis)
def _get_output_for_x_vmap():
new_input, _ = _reshape_merge_dims(x_dim, 0, input_x)
out = prim(new_input, weight)
out = _reshape_expand_dims(0, batch_size, out, prim_name)
return (out, 0)
def _get_output_for_w_vmap():
if groups > 1:
expand_dim, merge_dim = _get_reshape_src_dim(w_dim, 0)
new_weight = _reshape_expand_dims(expand_dim, groups, weight, prim_name)
new_weight, _ = _reshape_merge_dims(merge_dim, 1, new_weight)
new_weight, new_w_shape = _reshape_merge_dims(0, 0, new_weight)
_vmap_update_prim_attr(prim, 'out_channel', new_w_shape[0])
out = prim(input_x, new_weight)
out = _reshape_expand_dims(c_axis, groups, out, prim_name)
out = _reshape_expand_dims(c_axis + 1, batch_size, out, prim_name)
out, _ = _reshape_merge_dims(c_axis, c_axis + 1, out)
return (out, c_axis)
new_weight, new_w_shape = _reshape_merge_dims(w_dim, 0, weight)
_vmap_update_prim_attr(prim, 'out_channel', new_w_shape[0])
out = prim(input_x, new_weight)
out = _reshape_expand_dims(c_axis, batch_size, out, prim_name)
return (out, c_axis)
if x_dim is not None and w_dim is not None:
if prim_name == "Conv3D":
_raise_value_error("vmap in_axes of 'x' and 'weight in `{}` cannot be non-None at the same time,"
"but got {} and {}.".format(prim_name, x_dim, w_dim))
output = _get_output_for_x_w_vmap()
elif x_dim is not None:
output = _get_output_for_x_vmap()
else:
output = _get_output_for_w_vmap()
return output
def _conv_transpose_vmap_rule(prim, batch_size, dout_bdim, weight_bdim, input_size_bdim, attr_list):
"""
Vmap rule for transposed convolution operations, such as `Conv2DTranspose`,
`Conv2DBackpropInput`, `Conv3DTranspose` and `Conv3DBackpropInput`.
"""
prim_name = attr_list[0]
input_size = None
if input_size_bdim is not None:
input_size, input_size_dim = input_size_bdim
if input_size_dim is not None:
_raise_value_error("Vmap in_axes of 'input_size' in `{}` must be None, "
"but got {}.".format(prim_name, input_size_dim))
if not isinstance(input_size, tuple):
_raise_value_error("Unsupported vmap for dynamic shape of `{}` when "
"'input_size' is a tensor.".format(prim_name))
dout, dout_dim = dout_bdim
weight, w_dim = weight_bdim
groups = attr_list[1]
data_format = attr_list[2]
c_axis = _get_channel_index(data_format, prim_name)
def _get_conv_transpose_output(dout, weight, input_size):
out = None
if prim_name in ('Conv2DTranspose', 'Conv2DBackpropInput'):
out = prim(dout, weight, input_size)
elif prim_name == "Conv3DTranspose":
out = prim(dout, weight)
elif prim_name == "Conv3DBackpropInput":
out = prim(weight, dout, input_size)
else:
_raise_value_error("Unsupported the operation: `{}`.".format(prim_name))
return out
def _get_output_for_dout_weight_vmap():
_update_group_attr(prim, groups, batch_size)
new_dout, _ = _reshape_merge_dims(dout_dim, c_axis, dout)
new_weight, _ = _reshape_merge_dims(w_dim, 0, weight)
new_input_size = _get_new_size_by_index(input_size, batch_size, c_axis)
out = _get_conv_transpose_output(new_dout, new_weight, new_input_size)
out = _reshape_expand_dims(c_axis, batch_size, out, prim_name)
return (out, c_axis)
def _get_output_for_dout_vmap():
new_dout, _ = _reshape_merge_dims(dout_dim, 0, dout)
new_input_size = _get_new_size_by_index(input_size, batch_size, 0)
out = _get_conv_transpose_output(new_dout, weight, new_input_size)
out = _reshape_expand_dims(0, batch_size, out, prim_name)
return (out, 0)
def _get_output_for_weight_vmap():
new_weight, _ = _reshape_merge_dims(w_dim, c_axis, weight)
new_input_size = _get_new_size_by_index(input_size, batch_size, c_axis)
out = _get_conv_transpose_output(dout, new_weight, new_input_size)
if groups > 1:
out = _reshape_expand_dims(c_axis, groups, out, prim_name)
out = _reshape_expand_dims(c_axis + 1, batch_size, out, prim_name)
out, _ = _reshape_merge_dims(c_axis, c_axis + 1, out)
else:
out = _reshape_expand_dims(c_axis, batch_size, out, prim_name)
return (out, c_axis)
if dout_dim is not None and w_dim is not None:
if prim_name in ("Conv3DTranspose", "Conv3DBackpropInput"):
_raise_value_error("vmap in_axes of 'dout' and 'weight' in `{}` cannot be non-None at the same time,"
"but got {} and {}.".format(prim_name, dout_dim, w_dim))
output = _get_output_for_dout_weight_vmap()
elif dout_dim is not None:
output = _get_output_for_dout_vmap()
else:
output = _get_output_for_weight_vmap()
return output
def _conv_backprop_filter_vmap_rule(prim, batch_size, dout_bdim, input_bdim, weight_size_bdim, attr_list):
"""Vmap rule for `Conv2DBackpropFilter` and `Conv3DBackpropFilter` operations"""
dout, dout_dim = dout_bdim
input_x, x_dim = input_bdim
weight_size, w_size_dim = weight_size_bdim
prim_name = attr_list[0]
groups = attr_list[1]
data_format = attr_list[2]
c_axis = _get_channel_index(data_format, prim_name)
if w_size_dim is not None:
_raise_value_error("Vmap in_axes of 'weight_size' in `{}` must be None, "
"but got {}.".format(prim_name, w_size_dim))
if not isinstance(weight_size, tuple):
_raise_value_error("Unsupported vmap for dynamic shape of `{}` when "
"'weight_size' is a tensor.".format(prim_name))
def _get_conv_backprop_filter_output(dout, x, weight_size):
out = None
if prim_name == "Conv2DBackpropFilter":
out = prim(dout, x, weight_size)
elif prim_name == "Conv3DBackpropFilter":
out = prim(x, dout, weight_size)
else:
_raise_value_error("Unsupported the operation: `{}`.".format(prim_name))
return out
def _get_output_for_dout_x_vmap():
_update_group_attr(prim, groups, batch_size)
new_dout, _ = _reshape_merge_dims(dout_dim, c_axis, dout)
new_input, _ = _reshape_merge_dims(x_dim, c_axis, input_x)
new_w_size = _get_new_size_by_index(weight_size, batch_size, 0)
out = _get_conv_backprop_filter_output(new_dout, new_input, new_w_size)
out = _reshape_expand_dims(0, batch_size, out, prim_name)
return (out, 0)
def _get_output_for_x_vmap():
new_w_size = _get_new_size_by_index(weight_size, batch_size, c_axis)
if groups > 1:
expand_dim, merge_dim = _get_reshape_src_dim(x_dim, c_axis)
new_input = _reshape_expand_dims(expand_dim, groups, input_x, prim_name)
new_input, _ = _reshape_merge_dims(merge_dim, c_axis + 1, new_input)
new_input, _ = _reshape_merge_dims(c_axis, c_axis, new_input)
else:
new_input, _ = _reshape_merge_dims(x_dim, c_axis, input_x)
out = _get_conv_backprop_filter_output(dout, new_input, new_w_size)
out = _reshape_expand_dims(c_axis, batch_size, out, prim_name)
return (out, c_axis)
def _get_output_for_dout_vmap():
new_w_size = _get_new_size_by_index(weight_size, batch_size, 0)
if groups > 1:
expand_dim, merge_dim = _get_reshape_src_dim(dout_dim, c_axis)
new_dout = _reshape_expand_dims(expand_dim, groups, dout, prim_name)
new_dout, _ = _reshape_merge_dims(merge_dim, c_axis + 1, new_dout)
new_dout, _ = _reshape_merge_dims(c_axis, c_axis, new_dout)
out = _get_conv_backprop_filter_output(new_dout, input_x, new_w_size)
out = _reshape_expand_dims(0, groups, out, prim_name)
out = _reshape_expand_dims(1, batch_size, out, prim_name)
out, _ = _reshape_merge_dims(0, 1, out)
return (out, 0)
new_dout, _ = _reshape_merge_dims(dout_dim, c_axis, dout)
out = _get_conv_backprop_filter_output(new_dout, input_x, new_w_size)
out = _reshape_expand_dims(0, batch_size, out, prim_name)
return (out, 0)
if dout_dim is not None and x_dim is not None:
if prim_name == "Conv3DBackpropFilter":
_raise_value_error("vmap in_axes of 'dout' and 'x' in `{}` cannot be non-None at the same time,"
"but got {} and {}.".format(prim_name, dout_dim, x_dim))
output = _get_output_for_dout_x_vmap()
elif x_dim is not None:
output = _get_output_for_x_vmap()
else:
output = _get_output_for_dout_vmap()
return output

View File

@ -22,6 +22,7 @@ from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.functional import vmap
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
@ -70,3 +71,37 @@ def test_conv2d_backprop_filter():
[-104, -211, -322],
[-102, -144, -248]]]]).astype(np.float32)
assert (abs(output.asnumpy() - expect) < np.ones(shape=[1, 1, 3, 3]) * 1.0e-4).all()
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_conv2d_backprop_filter_vmap():
"""
Feature: Conv2DBackpropFilter op
Description: Test vmap rule for Conv2DBackpropFilter op
Expectation: The dataset is processed as expected
"""
conv2d_filter = Conv2dFilter()
batch_out = Tensor(np.arange(1 * 2 * 1 * 4 * 4).reshape(1, 2, 1, 4, 4).astype(np.float32))
x = Tensor(np.arange(1 * 1 * 6 * 6).reshape(1, 1, 6, 6).astype(np.float32))
w = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
expected1 = np.array([[[[[1760., 1880., 2000.], [2480., 2600., 2720.], [3200., 3320., 3440.]]]],
[[[[4448., 4824., 5200.], [6704., 7080., 7456.], [8960., 9336., 9712.]]]]]
).astype(np.float32)
output1 = vmap(conv2d_filter, (1, None, None))(batch_out, x, w)
assert np.allclose(output1.asnumpy(), expected1, 0.0001, 0.0001)
dout = Tensor(np.arange(1 * 1 * 4 * 4).reshape(1, 1, 4, 4).astype(np.float32))
batch_x = Tensor(np.arange(2 * 1 * 1 * 6 * 6).reshape(2, 1, 1, 6, 6).astype(np.float32))
expected2 = np.array([[[[[1760., 1880., 2000.], [2480., 2600., 2720.], [3200., 3320., 3440.]]]],
[[[[6080., 6200., 6320.], [6800., 6920., 7040.], [7520., 7640., 7760.]]]]]
).astype(np.float32)
output2 = vmap(conv2d_filter, (None, 0, None))(dout, batch_x, w)
assert np.allclose(output2.asnumpy(), expected2, 0.0001, 0.0001)
expected3 = np.array([[[[[1760., 1880., 2000.], [2480., 2600., 2720.], [3200., 3320., 3440.]]]],
[[[[17984., 18360., 18736.], [20240., 20616., 20992.], [22496., 22872., 23248.]]]]]
).astype(np.float32)
output3 = vmap(conv2d_filter, (1, 0, None))(batch_out, batch_x, w)
assert np.allclose(output3.asnumpy(), expected3, 0.0001, 0.0001)

View File

@ -21,6 +21,7 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.ops import operations as P
from mindspore.ops.functional import vmap
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
@ -73,3 +74,50 @@ def test_conv2d_backprop_input():
[-3, -2, 0, -14, 3, 16]]]]).astype(np.float32)
assert (abs(output.asnumpy() - expect) < np.ones(shape=[1, 1, 6, 6]) * 1.0e-4).all()
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_conv2d_backprop_input_vmap():
"""
Feature: Conv2DBackpropInput op
Description: Test vmap rule for Conv2DBackpropInput op
Expectation: The dataset is processed as expected
"""
conv2d_input = Conv2dInput()
batch_dout = Tensor(np.arange(1 * 2 * 1 * 4 * 4).reshape(1, 2, 1, 4, 4).astype(np.float32))
x = Tensor(np.arange(1 * 1 * 3 * 3).reshape(1, 1, 3, 3).astype(np.float32))
w = Tensor(np.ones([1, 1, 6, 6]).astype(np.float32))
expected1 = np.array([[[[[0., 0., 1., 4., 7., 6.], [0., 7., 23., 38., 41., 29.],
[12., 45., 102., 138., 126., 81.], [48., 129., 246., 282., 234., 141.],
[84., 197., 341., 374., 287., 163.], [72., 162., 271., 292., 217., 120.]]]],
[[[[0., 16., 49., 52., 55., 38.], [48., 135., 263., 278., 233., 141.],
[156., 381., 678., 714., 558., 321.], [192., 465., 822., 858., 666., 381.],
[228., 517., 869., 902., 671., 371.], [168., 370., 607., 628., 457., 248.]]]]]
).astype(np.float32)
output1 = vmap(conv2d_input, (1, None, None))(batch_dout, x, w)
assert np.allclose(output1.asnumpy(), expected1, 0.0001, 0.0001)
dout = Tensor(np.arange(1 * 1 * 4 * 4).reshape(1, 1, 4, 4).astype(np.float32))
batch_x = Tensor(np.arange(2 * 1 * 1 * 3 * 3).reshape(2, 1, 1, 3, 3).astype(np.float32))
expected2 = np.array([[[[[0., 0., 1., 4., 7., 6.], [0., 7., 23., 38., 41., 29.],
[12., 45., 102., 138., 126., 81.], [48., 129., 246., 282., 234., 141.],
[84., 197., 341., 374., 287., 163.], [72., 162., 271., 292., 217., 120.]]]],
[[[[0., 9., 28., 58., 52., 33.], [36., 97., 185., 254., 203., 119.],
[120., 288., 507., 624., 477., 270.], [264., 588., 975., 1092., 801., 438.],
[264., 575., 935., 1022., 737., 397.], [180., 387., 622., 670., 478., 255.]]]]]
).astype(np.float32)
output2 = vmap(conv2d_input, (None, 0, None))(dout, batch_x, w)
assert np.allclose(output2.asnumpy(), expected2, 0.0001, 0.0001)
expected3 = np.array([[[[[0., 0., 1., 4., 7., 6.], [0., 7., 23., 38., 41., 29.],
[12., 45., 102., 138., 126., 81.], [48., 129., 246., 282., 234., 141.],
[84., 197., 341., 374., 287., 163.], [72., 162., 271., 292., 217., 120.]]]],
[[[[144., 313., 508., 538., 388., 209.], [372., 801., 1289., 1358., 971., 519.],
[696., 1488., 2379., 2496., 1773., 942.], [840., 1788., 2847., 2964., 2097., 1110.],
[696., 1471., 2327., 2414., 1697., 893.], [420., 883., 1390., 1438., 1006., 527.]]]]]
).astype(np.float32)
output3 = vmap(conv2d_input, (1, 0, None))(batch_dout, batch_x, w)
assert np.allclose(output3.asnumpy(), expected3, 0.0001, 0.0001)

View File

@ -23,6 +23,7 @@ from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.ops.functional import vmap
class NetConv2d(nn.Cell):
@ -265,3 +266,37 @@ def test_conv_NHWC():
conv2d = NetConvNHWC(w1, x1)
output = conv2d()
assert (output.asnumpy() == expected).all()
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_conv2d_vmap():
"""
Feature: Conv2D op
Description: Test vmap rule for Conv2D op
Expectation: The dataset is processed as expected
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
conv2d = NetConv2d()
batch_x = Tensor(np.arange(2 * 1 * 3 * 3 * 3).reshape(2, 1, 3, 3, 3).astype(np.float32))
w = Tensor(np.ones([2, 3, 1, 1]).astype(np.float32))
expected1 = np.array([[[[[27., 30., 33.], [36., 39., 42.], [45., 48., 51.]],
[[27., 30., 33.], [36., 39., 42.], [45., 48., 51.]]]],
[[[[108., 111., 114.], [117., 120., 123.], [126., 129., 132.]],
[[108., 111., 114.], [117., 120., 123.], [126., 129., 132.]]]]]).astype(np.float32)
output1 = vmap(conv2d, (0, None))(batch_x, w)
assert np.allclose(output1.asnumpy(), expected1, 0.0001, 0.0001)
x = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
batch_w = Tensor(np.ones([2, 2, 3, 1, 1]).astype(np.float32))
expected2 = np.array([[[[[27., 30., 33.], [36., 39., 42.], [45., 48., 51.]],
[[27., 30., 33.], [36., 39., 42.], [45., 48., 51.]]]],
[[[[27., 30., 33.], [36., 39., 42.], [45., 48., 51.]],
[[27., 30., 33.], [36., 39., 42.], [45., 48., 51.]]]]]).astype(np.float32)
output2 = vmap(conv2d, (None, 0))(x, batch_w)
assert np.allclose(output2.asnumpy(), expected2, 0.0001, 0.0001)
output3 = vmap(conv2d, (0, 0))(batch_x, batch_w)
assert np.allclose(output3.asnumpy(), expected1, 0.0001, 0.0001)

View File

@ -22,6 +22,7 @@ from mindspore import Tensor
from mindspore.common.parameter import ParameterTuple
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops.functional import vmap
class NetConv3d(nn.Cell):
@ -154,3 +155,43 @@ def test_conv3d_grad():
output = grad_net(x, dy)
optimizer(output[1])
assert np.allclose(net.cv1.weight.asnumpy(), w_exp, atol=1.0e-4)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_conv3d_vmap():
"""
Feature: Conv3D op
Description: Test vmap rule for Conv3D op
Expectation: The dataset is processed as expected
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
conv3d = NetConv3d()
batch_x = Tensor(np.arange(2 * 1 * 1 * 3 * 3 * 3).reshape(2, 1, 1, 3, 3, 3).astype(np.float32))
w = Tensor(np.ones([4, 1, 2, 2, 2]).astype(np.float32))
expected1 = np.array([[[[[[52., 60.], [76., 84.]], [[124., 132.], [148., 156.]]],
[[[52., 60.], [76., 84.]], [[124., 132.], [148., 156.]]],
[[[52., 60.], [76., 84.]], [[124., 132.], [148., 156.]]],
[[[52., 60.], [76., 84.]], [[124., 132.], [148., 156.]]]]],
[[[[[268., 276.], [292., 300.]], [[340., 348.], [364., 372.]]],
[[[268., 276.], [292., 300.]], [[340., 348.], [364., 372.]]],
[[[268., 276.], [292., 300.]], [[340., 348.], [364., 372.]]],
[[[268., 276.], [292., 300.]], [[340., 348.], [364., 372.]]]]]]).astype(np.float32)
output1 = vmap(conv3d, (0, None))(batch_x, w)
assert np.allclose(output1.asnumpy(), expected1, 0.0001, 0.0001)
x = Tensor(np.arange(1 * 1 * 3 * 3 * 3).reshape(1, 1, 3, 3, 3).astype(np.float32))
batch_w = Tensor(np.arange(2 * 4 * 1 * 2 * 2 * 2).reshape(2, 4, 1, 2, 2, 2).astype(np.float32))
expected2 = np.array([[[[[[268., 296.], [352., 380.]], [[520., 548.], [604., 632.]]],
[[[684., 776.], [960., 1052.]], [[1512., 1604.], [1788., 1880.]]],
[[[1100., 1256.], [1568., 1724.]], [[2504., 2660.], [2972., 3128.]]],
[[[1516., 1736.], [2176., 2396.]], [[3496., 3716.], [4156., 4376.]]]]],
[[[[[1932., 2216.], [2784., 3068.]], [[4488., 4772.], [5340., 5624.]]],
[[[2348., 2696.], [3392., 3740.]], [[5480., 5828.], [6524., 6872.]]],
[[[2764., 3176.], [4000., 4412.]], [[6472., 6884.], [7708., 8120.]]],
[[[3180., 3656.], [4608., 5084.]], [[7464., 7940.], [8892., 9368.]]]]]]
).astype(np.float32)
output2 = vmap(conv3d, (None, 0))(x, batch_w)
assert np.allclose(output2.asnumpy(), expected2, 0.0001, 0.0001)

View File

@ -20,6 +20,7 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.functional import vmap
class NetConv3dTranspose(nn.Cell):
@ -63,3 +64,34 @@ def test_conv3d_transpose():
conv3dtranspose = NetConv3dTranspose()
output = conv3dtranspose(x, w)
assert (output.asnumpy() == expect).all()
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_conv3d_transpose_vmap():
"""
Feature: Conv3DTranspose op
Description: Test vmap rule for Conv3DTranspose op
Expectation: The dataset is processed as expected
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
conv3d_trans = NetConv3dTranspose()
batch_dout = Tensor(np.arange(2 * 1 * 2 * 3 * 3 * 3).reshape(2, 1, 2, 3, 3, 3).astype(np.float32))
weight = Tensor(np.ones([2, 2, 2, 2, 2]).astype(np.float32))
expected1 = np.array([[[[[[320., 336.], [368., 384.]], [[464., 480.], [512., 528.]]],
[[[320., 336.], [368., 384.]], [[464., 480.], [512., 528.]]]]],
[[[[[1184., 1200.], [1232., 1248.]], [[1328., 1344.], [1376., 1392.]]],
[[[1184., 1200.], [1232., 1248.]], [[1328., 1344.], [1376., 1392.]]]]]]).astype(np.float32)
output1 = vmap(conv3d_trans, (0, None))(batch_dout, weight)
assert np.allclose(output1.asnumpy(), expected1, 0.0001, 0.0001)
dout = Tensor(np.arange(1 * 2 * 3 * 3 * 3).reshape(1, 2, 3, 3, 3).astype(np.float32))
batch_weight = Tensor(np.ones([2, 2, 2, 2, 2, 2]).astype(np.float32))
expected2 = np.array([[[[[[320., 336.], [368., 384.]], [[464., 480.], [512., 528.]]],
[[[320., 336.], [368., 384.]], [[464., 480.], [512., 528.]]]]],
[[[[[320., 336.], [368., 384.]], [[464., 480.], [512., 528.]]],
[[[320., 336.], [368., 384.]], [[464., 480.], [512., 528.]]]]]]).astype(np.float32)
output2 = vmap(conv3d_trans, (None, 0))(dout, batch_weight)
assert np.allclose(output2.asnumpy(), expected2, 0.0001, 0.0001)