add vmap rules for conv ops
This commit is contained in:
parent
3412fa4366
commit
361685bb72
|
@ -17,7 +17,7 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
from . import vmap_base, vmap_array_ops, vmap_grad_nn_ops, vmap_debug_ops, vmap_math_ops, vmap_nn_ops,\
|
||||
vmap_image_ops, vmap_other_ops, vmap_sparse_ops, vmap_random_ops
|
||||
vmap_image_ops, vmap_other_ops, vmap_sparse_ops, vmap_random_ops, vmap_convolution_ops
|
||||
from .vmap_base import get_vmap_rule, vmap_monad_rule, _broadcast_by_axis, vmap_bind_all_none,\
|
||||
vmap_unstack, vmap_general_output_process
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ from ..composite import _VmapGeneralPreprocess
|
|||
from ..primitive import Primitive
|
||||
from ..operations.random_ops import UniformCandidateSampler
|
||||
from ...common import Tensor
|
||||
from ..operations import nn_ops as nps
|
||||
|
||||
vmap_rules_getters = Registry()
|
||||
vmap_rules = Registry()
|
||||
|
@ -394,4 +395,12 @@ _ops_vmap_clone_prim_dict = {"ApplyAdaMax": P.ApplyAdaMax,
|
|||
"UniformCandidateSampler": UniformCandidateSampler,
|
||||
"CdistGrad": G.CdistGrad,
|
||||
"Cdist": P.Cdist,
|
||||
"STFT": math_ops.STFT}
|
||||
"STFT": math_ops.STFT,
|
||||
"Conv2D": P.Conv2D,
|
||||
"Conv3D": P.Conv3D,
|
||||
"Conv2DTranspose": P.Conv2DTranspose,
|
||||
"Conv2DBackpropInput": P.Conv2DBackpropInput,
|
||||
"Conv3DTranspose": P.Conv3DTranspose,
|
||||
"Conv3DBackpropInput": nps.Conv3DBackpropInput,
|
||||
"Conv2DBackpropFilter": G.Conv2DBackpropFilter,
|
||||
"Conv3DBackpropFilter": G.Conv3DBackpropFilter}
|
||||
|
|
|
@ -0,0 +1,442 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""convolution vmap impl"""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.numpy as mnp
|
||||
from mindspore.ops import constexpr
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from ..operations import nn_ops as nps
|
||||
from ..operations import _grad_ops as G
|
||||
from ..primitive import Primitive
|
||||
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, \
|
||||
_vmap_update_prim_attr, _vmap_clone_prim
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.Conv2D)
|
||||
@vmap_rules_getters.register(P.Conv3D)
|
||||
def get_conv_vmap_rule(prim, axis_size):
|
||||
"""Vmap rule for `Conv2D` and `Conv3D` operations."""
|
||||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
|
||||
attr_list = [prim.name, prim.group, prim.data_format]
|
||||
new_prim = _vmap_clone_prim(prim)
|
||||
|
||||
def vmap_rule(input_bdim, weight_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, input_bdim, weight_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
return _conv_vmap_rule(new_prim, axis_size, input_bdim, weight_bdim, attr_list)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.Conv2DTranspose)
|
||||
@vmap_rules_getters.register(P.Conv2DBackpropInput)
|
||||
def get_conv2d_transpose_vmap_rule(prim, axis_size):
|
||||
"""Vmap rule for `Conv2DTranspose` and `Conv2DBackpropInput` operations."""
|
||||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
|
||||
attr_list = [prim.name, prim.group, prim.data_format]
|
||||
new_prim = _vmap_clone_prim(prim)
|
||||
|
||||
def vmap_rule(dout_bdim, weight_bdim, input_size_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, dout_bdim, weight_bdim, input_size_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
return _conv_transpose_vmap_rule(new_prim, axis_size, dout_bdim, \
|
||||
weight_bdim, input_size_bdim, attr_list)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.Conv3DTranspose)
|
||||
def get_conv3d_transpose_vmap_rule(prim, axis_size):
|
||||
"""Vmap rule for `Conv3DTranspose` operation."""
|
||||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
|
||||
attr_list = [prim.name, prim.group, prim.data_format]
|
||||
new_prim = _vmap_clone_prim(prim)
|
||||
|
||||
def vmap_rule(dout_bdim, weight_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, dout_bdim, weight_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
return _conv_transpose_vmap_rule(new_prim, axis_size, dout_bdim, weight_bdim, None, attr_list)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(nps.Conv3DBackpropInput)
|
||||
def get_conv3d_backprop_input_vmap_rule(prim, axis_size):
|
||||
"""Vmap rule for `Conv3DBackpropInput` operation."""
|
||||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
|
||||
attr_list = [prim.name, prim.group, prim.data_format]
|
||||
new_prim = _vmap_clone_prim(prim)
|
||||
|
||||
def vmap_rule(weight_bdim, dout_bdim, input_size_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, weight_bdim, dout_bdim, input_size_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
return _conv_transpose_vmap_rule(new_prim, axis_size, dout_bdim, \
|
||||
weight_bdim, input_size_bdim, attr_list)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(G.Conv2DBackpropFilter)
|
||||
def get_conv2d_backprop_filter_vmap_rule(prim, axis_size):
|
||||
"""Vmap rule for `Conv2DBackpropFilter` operation."""
|
||||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
|
||||
attr_list = [prim.name, prim.group, prim.data_format]
|
||||
new_prim = _vmap_clone_prim(prim)
|
||||
|
||||
def vmap_rule(dout_bdim, input_x_bdim, weight_size_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, dout_bdim, input_x_bdim, weight_size_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
return _conv_backprop_filter_vmap_rule(new_prim, axis_size, dout_bdim, \
|
||||
input_x_bdim, weight_size_bdim, attr_list)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(G.Conv3DBackpropFilter)
|
||||
def get_conv3d_backprop_filter_vmap_rule(prim, axis_size):
|
||||
"""Vmap rule for `Conv3DBackpropFilter` operation."""
|
||||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
|
||||
attr_list = [prim.name, prim.group, prim.data_format]
|
||||
new_prim = _vmap_clone_prim(prim)
|
||||
|
||||
def vmap_rule(input_x_bdim, dout_bdim, weight_size_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, input_x_bdim, dout_bdim, weight_size_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
return _conv_backprop_filter_vmap_rule(new_prim, axis_size, dout_bdim, \
|
||||
input_x_bdim, weight_size_bdim, attr_list)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_reshape_src_dim(data_dim, cmp_dim):
|
||||
"""Get source dim for reshape"""
|
||||
if data_dim > cmp_dim:
|
||||
expand_dim = cmp_dim
|
||||
merge_dim = data_dim + 1
|
||||
else:
|
||||
expand_dim = cmp_dim + 1
|
||||
merge_dim = data_dim
|
||||
return expand_dim, merge_dim
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_merge_shape(src_dim, dst_dim, shape):
|
||||
"""Get new shape for merging the src_dim and dst_dim. The dst_dim is the value after removing src_dim."""
|
||||
new_shape = [shape[i] for i in range(len(shape)) if i != src_dim]
|
||||
new_shape[dst_dim] *= shape[src_dim]
|
||||
return tuple(new_shape)
|
||||
|
||||
|
||||
def _reshape_merge_dims(src_dim, dst_dim, target):
|
||||
"""Reshape target by merging the src_dim and dst_dim."""
|
||||
shape = F.shape(target)
|
||||
new_shape = _get_merge_shape(src_dim, dst_dim, shape)
|
||||
new_target = mnp.moveaxis(target, src_dim, dst_dim)
|
||||
output = F.reshape(new_target, new_shape)
|
||||
return output, new_shape
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_expand_shape(src_dim, dst_size, shape, prim_name):
|
||||
"""Get new shape for splitting src_dim into dst_size parts."""
|
||||
dst_size2, remainder = np.divmod(shape[src_dim], dst_size)
|
||||
if remainder != 0:
|
||||
_raise_value_error("The remainder of {} / {} should be 0, "
|
||||
"but got {} in {}.".format(shape[src_dim], dst_size, remainder, prim_name))
|
||||
new_shape = list(shape)
|
||||
new_shape[src_dim:(src_dim + 1)] = [dst_size, dst_size2]
|
||||
return tuple(new_shape)
|
||||
|
||||
|
||||
def _reshape_expand_dims(src_dim, dst_size, target, prim_name):
|
||||
"""Reshape target by splitting src_dim into dst_size parts."""
|
||||
shape = F.shape(target)
|
||||
new_shape = _get_expand_shape(src_dim, dst_size, shape, prim_name)
|
||||
return F.reshape(target, new_shape)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_new_size_by_index(input_size, batch_size, index):
|
||||
"""Get the new size of input_size by multiplying input_size[index] by batch_size."""
|
||||
if input_size is None:
|
||||
return None
|
||||
new_size = list(input_size)
|
||||
new_size[index] *= batch_size
|
||||
return tuple(new_size)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _update_group_attr(prim, groups, batch_size):
|
||||
"""Set new value for 'group' attribute of the convolution primitive."""
|
||||
group = groups * batch_size
|
||||
_vmap_update_prim_attr(prim, 'group', group)
|
||||
_vmap_update_prim_attr(prim, 'groups', group)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_channel_index(data_format, prim_name):
|
||||
"""Get channel index by data_format, only supports NHWC/NCHW/NCDHW now."""
|
||||
index = 0
|
||||
if data_format == "NHWC":
|
||||
index = 3
|
||||
elif data_format in ("NCHW", "NCDHW"):
|
||||
index = 1
|
||||
else:
|
||||
_raise_value_error("'data_format' in {} should be NHWC/NCHW/NCDHW, "
|
||||
"but got {}.".format(prim_name, data_format))
|
||||
return index
|
||||
|
||||
|
||||
def _conv_vmap_rule(prim, batch_size, input_bdim, weight_bdim, attr_list):
|
||||
"""Vmap rule for Convolution operations, such as `Conv2D` and `Conv3D`."""
|
||||
input_x, x_dim = input_bdim
|
||||
weight, w_dim = weight_bdim
|
||||
prim_name = attr_list[0]
|
||||
groups = attr_list[1]
|
||||
data_format = attr_list[2]
|
||||
c_axis = _get_channel_index(data_format, prim_name)
|
||||
|
||||
def _get_output_for_x_w_vmap():
|
||||
new_input, _ = _reshape_merge_dims(x_dim, c_axis, input_x)
|
||||
new_weight, new_w_shape = _reshape_merge_dims(w_dim, 0, weight)
|
||||
|
||||
_update_group_attr(prim, groups, batch_size)
|
||||
_vmap_update_prim_attr(prim, 'out_channel', new_w_shape[0])
|
||||
out = prim(new_input, new_weight)
|
||||
out = _reshape_expand_dims(c_axis, batch_size, out, prim_name)
|
||||
return (out, c_axis)
|
||||
|
||||
def _get_output_for_x_vmap():
|
||||
new_input, _ = _reshape_merge_dims(x_dim, 0, input_x)
|
||||
out = prim(new_input, weight)
|
||||
out = _reshape_expand_dims(0, batch_size, out, prim_name)
|
||||
return (out, 0)
|
||||
|
||||
def _get_output_for_w_vmap():
|
||||
if groups > 1:
|
||||
expand_dim, merge_dim = _get_reshape_src_dim(w_dim, 0)
|
||||
new_weight = _reshape_expand_dims(expand_dim, groups, weight, prim_name)
|
||||
new_weight, _ = _reshape_merge_dims(merge_dim, 1, new_weight)
|
||||
new_weight, new_w_shape = _reshape_merge_dims(0, 0, new_weight)
|
||||
|
||||
_vmap_update_prim_attr(prim, 'out_channel', new_w_shape[0])
|
||||
out = prim(input_x, new_weight)
|
||||
|
||||
out = _reshape_expand_dims(c_axis, groups, out, prim_name)
|
||||
out = _reshape_expand_dims(c_axis + 1, batch_size, out, prim_name)
|
||||
out, _ = _reshape_merge_dims(c_axis, c_axis + 1, out)
|
||||
return (out, c_axis)
|
||||
|
||||
new_weight, new_w_shape = _reshape_merge_dims(w_dim, 0, weight)
|
||||
_vmap_update_prim_attr(prim, 'out_channel', new_w_shape[0])
|
||||
out = prim(input_x, new_weight)
|
||||
out = _reshape_expand_dims(c_axis, batch_size, out, prim_name)
|
||||
return (out, c_axis)
|
||||
|
||||
if x_dim is not None and w_dim is not None:
|
||||
if prim_name == "Conv3D":
|
||||
_raise_value_error("vmap in_axes of 'x' and 'weight in `{}` cannot be non-None at the same time,"
|
||||
"but got {} and {}.".format(prim_name, x_dim, w_dim))
|
||||
output = _get_output_for_x_w_vmap()
|
||||
elif x_dim is not None:
|
||||
output = _get_output_for_x_vmap()
|
||||
else:
|
||||
output = _get_output_for_w_vmap()
|
||||
return output
|
||||
|
||||
|
||||
def _conv_transpose_vmap_rule(prim, batch_size, dout_bdim, weight_bdim, input_size_bdim, attr_list):
|
||||
"""
|
||||
Vmap rule for transposed convolution operations, such as `Conv2DTranspose`,
|
||||
`Conv2DBackpropInput`, `Conv3DTranspose` and `Conv3DBackpropInput`.
|
||||
"""
|
||||
prim_name = attr_list[0]
|
||||
input_size = None
|
||||
if input_size_bdim is not None:
|
||||
input_size, input_size_dim = input_size_bdim
|
||||
if input_size_dim is not None:
|
||||
_raise_value_error("Vmap in_axes of 'input_size' in `{}` must be None, "
|
||||
"but got {}.".format(prim_name, input_size_dim))
|
||||
if not isinstance(input_size, tuple):
|
||||
_raise_value_error("Unsupported vmap for dynamic shape of `{}` when "
|
||||
"'input_size' is a tensor.".format(prim_name))
|
||||
|
||||
dout, dout_dim = dout_bdim
|
||||
weight, w_dim = weight_bdim
|
||||
|
||||
groups = attr_list[1]
|
||||
data_format = attr_list[2]
|
||||
c_axis = _get_channel_index(data_format, prim_name)
|
||||
|
||||
def _get_conv_transpose_output(dout, weight, input_size):
|
||||
out = None
|
||||
if prim_name in ('Conv2DTranspose', 'Conv2DBackpropInput'):
|
||||
out = prim(dout, weight, input_size)
|
||||
elif prim_name == "Conv3DTranspose":
|
||||
out = prim(dout, weight)
|
||||
elif prim_name == "Conv3DBackpropInput":
|
||||
out = prim(weight, dout, input_size)
|
||||
else:
|
||||
_raise_value_error("Unsupported the operation: `{}`.".format(prim_name))
|
||||
return out
|
||||
|
||||
def _get_output_for_dout_weight_vmap():
|
||||
_update_group_attr(prim, groups, batch_size)
|
||||
new_dout, _ = _reshape_merge_dims(dout_dim, c_axis, dout)
|
||||
new_weight, _ = _reshape_merge_dims(w_dim, 0, weight)
|
||||
new_input_size = _get_new_size_by_index(input_size, batch_size, c_axis)
|
||||
|
||||
out = _get_conv_transpose_output(new_dout, new_weight, new_input_size)
|
||||
out = _reshape_expand_dims(c_axis, batch_size, out, prim_name)
|
||||
return (out, c_axis)
|
||||
|
||||
def _get_output_for_dout_vmap():
|
||||
new_dout, _ = _reshape_merge_dims(dout_dim, 0, dout)
|
||||
new_input_size = _get_new_size_by_index(input_size, batch_size, 0)
|
||||
|
||||
out = _get_conv_transpose_output(new_dout, weight, new_input_size)
|
||||
out = _reshape_expand_dims(0, batch_size, out, prim_name)
|
||||
return (out, 0)
|
||||
|
||||
def _get_output_for_weight_vmap():
|
||||
new_weight, _ = _reshape_merge_dims(w_dim, c_axis, weight)
|
||||
new_input_size = _get_new_size_by_index(input_size, batch_size, c_axis)
|
||||
|
||||
out = _get_conv_transpose_output(dout, new_weight, new_input_size)
|
||||
|
||||
if groups > 1:
|
||||
out = _reshape_expand_dims(c_axis, groups, out, prim_name)
|
||||
out = _reshape_expand_dims(c_axis + 1, batch_size, out, prim_name)
|
||||
out, _ = _reshape_merge_dims(c_axis, c_axis + 1, out)
|
||||
else:
|
||||
out = _reshape_expand_dims(c_axis, batch_size, out, prim_name)
|
||||
return (out, c_axis)
|
||||
|
||||
if dout_dim is not None and w_dim is not None:
|
||||
if prim_name in ("Conv3DTranspose", "Conv3DBackpropInput"):
|
||||
_raise_value_error("vmap in_axes of 'dout' and 'weight' in `{}` cannot be non-None at the same time,"
|
||||
"but got {} and {}.".format(prim_name, dout_dim, w_dim))
|
||||
output = _get_output_for_dout_weight_vmap()
|
||||
elif dout_dim is not None:
|
||||
output = _get_output_for_dout_vmap()
|
||||
else:
|
||||
output = _get_output_for_weight_vmap()
|
||||
return output
|
||||
|
||||
|
||||
def _conv_backprop_filter_vmap_rule(prim, batch_size, dout_bdim, input_bdim, weight_size_bdim, attr_list):
|
||||
"""Vmap rule for `Conv2DBackpropFilter` and `Conv3DBackpropFilter` operations"""
|
||||
dout, dout_dim = dout_bdim
|
||||
input_x, x_dim = input_bdim
|
||||
weight_size, w_size_dim = weight_size_bdim
|
||||
|
||||
prim_name = attr_list[0]
|
||||
groups = attr_list[1]
|
||||
data_format = attr_list[2]
|
||||
c_axis = _get_channel_index(data_format, prim_name)
|
||||
|
||||
if w_size_dim is not None:
|
||||
_raise_value_error("Vmap in_axes of 'weight_size' in `{}` must be None, "
|
||||
"but got {}.".format(prim_name, w_size_dim))
|
||||
if not isinstance(weight_size, tuple):
|
||||
_raise_value_error("Unsupported vmap for dynamic shape of `{}` when "
|
||||
"'weight_size' is a tensor.".format(prim_name))
|
||||
|
||||
def _get_conv_backprop_filter_output(dout, x, weight_size):
|
||||
out = None
|
||||
if prim_name == "Conv2DBackpropFilter":
|
||||
out = prim(dout, x, weight_size)
|
||||
elif prim_name == "Conv3DBackpropFilter":
|
||||
out = prim(x, dout, weight_size)
|
||||
else:
|
||||
_raise_value_error("Unsupported the operation: `{}`.".format(prim_name))
|
||||
return out
|
||||
|
||||
def _get_output_for_dout_x_vmap():
|
||||
_update_group_attr(prim, groups, batch_size)
|
||||
|
||||
new_dout, _ = _reshape_merge_dims(dout_dim, c_axis, dout)
|
||||
new_input, _ = _reshape_merge_dims(x_dim, c_axis, input_x)
|
||||
new_w_size = _get_new_size_by_index(weight_size, batch_size, 0)
|
||||
|
||||
out = _get_conv_backprop_filter_output(new_dout, new_input, new_w_size)
|
||||
out = _reshape_expand_dims(0, batch_size, out, prim_name)
|
||||
return (out, 0)
|
||||
|
||||
def _get_output_for_x_vmap():
|
||||
new_w_size = _get_new_size_by_index(weight_size, batch_size, c_axis)
|
||||
if groups > 1:
|
||||
expand_dim, merge_dim = _get_reshape_src_dim(x_dim, c_axis)
|
||||
new_input = _reshape_expand_dims(expand_dim, groups, input_x, prim_name)
|
||||
new_input, _ = _reshape_merge_dims(merge_dim, c_axis + 1, new_input)
|
||||
new_input, _ = _reshape_merge_dims(c_axis, c_axis, new_input)
|
||||
else:
|
||||
new_input, _ = _reshape_merge_dims(x_dim, c_axis, input_x)
|
||||
|
||||
out = _get_conv_backprop_filter_output(dout, new_input, new_w_size)
|
||||
out = _reshape_expand_dims(c_axis, batch_size, out, prim_name)
|
||||
return (out, c_axis)
|
||||
|
||||
def _get_output_for_dout_vmap():
|
||||
new_w_size = _get_new_size_by_index(weight_size, batch_size, 0)
|
||||
if groups > 1:
|
||||
expand_dim, merge_dim = _get_reshape_src_dim(dout_dim, c_axis)
|
||||
new_dout = _reshape_expand_dims(expand_dim, groups, dout, prim_name)
|
||||
new_dout, _ = _reshape_merge_dims(merge_dim, c_axis + 1, new_dout)
|
||||
new_dout, _ = _reshape_merge_dims(c_axis, c_axis, new_dout)
|
||||
|
||||
out = _get_conv_backprop_filter_output(new_dout, input_x, new_w_size)
|
||||
out = _reshape_expand_dims(0, groups, out, prim_name)
|
||||
out = _reshape_expand_dims(1, batch_size, out, prim_name)
|
||||
out, _ = _reshape_merge_dims(0, 1, out)
|
||||
return (out, 0)
|
||||
|
||||
new_dout, _ = _reshape_merge_dims(dout_dim, c_axis, dout)
|
||||
out = _get_conv_backprop_filter_output(new_dout, input_x, new_w_size)
|
||||
out = _reshape_expand_dims(0, batch_size, out, prim_name)
|
||||
return (out, 0)
|
||||
|
||||
if dout_dim is not None and x_dim is not None:
|
||||
if prim_name == "Conv3DBackpropFilter":
|
||||
_raise_value_error("vmap in_axes of 'dout' and 'x' in `{}` cannot be non-None at the same time,"
|
||||
"but got {} and {}.".format(prim_name, dout_dim, x_dim))
|
||||
output = _get_output_for_dout_x_vmap()
|
||||
elif x_dim is not None:
|
||||
output = _get_output_for_x_vmap()
|
||||
else:
|
||||
output = _get_output_for_dout_vmap()
|
||||
return output
|
|
@ -22,6 +22,7 @@ from mindspore import Tensor
|
|||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
|
||||
|
@ -70,3 +71,37 @@ def test_conv2d_backprop_filter():
|
|||
[-104, -211, -322],
|
||||
[-102, -144, -248]]]]).astype(np.float32)
|
||||
assert (abs(output.asnumpy() - expect) < np.ones(shape=[1, 1, 3, 3]) * 1.0e-4).all()
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_conv2d_backprop_filter_vmap():
|
||||
"""
|
||||
Feature: Conv2DBackpropFilter op
|
||||
Description: Test vmap rule for Conv2DBackpropFilter op
|
||||
Expectation: The dataset is processed as expected
|
||||
"""
|
||||
conv2d_filter = Conv2dFilter()
|
||||
batch_out = Tensor(np.arange(1 * 2 * 1 * 4 * 4).reshape(1, 2, 1, 4, 4).astype(np.float32))
|
||||
x = Tensor(np.arange(1 * 1 * 6 * 6).reshape(1, 1, 6, 6).astype(np.float32))
|
||||
w = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
||||
expected1 = np.array([[[[[1760., 1880., 2000.], [2480., 2600., 2720.], [3200., 3320., 3440.]]]],
|
||||
[[[[4448., 4824., 5200.], [6704., 7080., 7456.], [8960., 9336., 9712.]]]]]
|
||||
).astype(np.float32)
|
||||
output1 = vmap(conv2d_filter, (1, None, None))(batch_out, x, w)
|
||||
assert np.allclose(output1.asnumpy(), expected1, 0.0001, 0.0001)
|
||||
|
||||
dout = Tensor(np.arange(1 * 1 * 4 * 4).reshape(1, 1, 4, 4).astype(np.float32))
|
||||
batch_x = Tensor(np.arange(2 * 1 * 1 * 6 * 6).reshape(2, 1, 1, 6, 6).astype(np.float32))
|
||||
expected2 = np.array([[[[[1760., 1880., 2000.], [2480., 2600., 2720.], [3200., 3320., 3440.]]]],
|
||||
[[[[6080., 6200., 6320.], [6800., 6920., 7040.], [7520., 7640., 7760.]]]]]
|
||||
).astype(np.float32)
|
||||
output2 = vmap(conv2d_filter, (None, 0, None))(dout, batch_x, w)
|
||||
assert np.allclose(output2.asnumpy(), expected2, 0.0001, 0.0001)
|
||||
|
||||
expected3 = np.array([[[[[1760., 1880., 2000.], [2480., 2600., 2720.], [3200., 3320., 3440.]]]],
|
||||
[[[[17984., 18360., 18736.], [20240., 20616., 20992.], [22496., 22872., 23248.]]]]]
|
||||
).astype(np.float32)
|
||||
output3 = vmap(conv2d_filter, (1, 0, None))(batch_out, batch_x, w)
|
||||
assert np.allclose(output3.asnumpy(), expected3, 0.0001, 0.0001)
|
||||
|
|
|
@ -21,6 +21,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
|
||||
|
@ -73,3 +74,50 @@ def test_conv2d_backprop_input():
|
|||
[-3, -2, 0, -14, 3, 16]]]]).astype(np.float32)
|
||||
|
||||
assert (abs(output.asnumpy() - expect) < np.ones(shape=[1, 1, 6, 6]) * 1.0e-4).all()
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_conv2d_backprop_input_vmap():
|
||||
"""
|
||||
Feature: Conv2DBackpropInput op
|
||||
Description: Test vmap rule for Conv2DBackpropInput op
|
||||
Expectation: The dataset is processed as expected
|
||||
"""
|
||||
conv2d_input = Conv2dInput()
|
||||
|
||||
batch_dout = Tensor(np.arange(1 * 2 * 1 * 4 * 4).reshape(1, 2, 1, 4, 4).astype(np.float32))
|
||||
x = Tensor(np.arange(1 * 1 * 3 * 3).reshape(1, 1, 3, 3).astype(np.float32))
|
||||
w = Tensor(np.ones([1, 1, 6, 6]).astype(np.float32))
|
||||
expected1 = np.array([[[[[0., 0., 1., 4., 7., 6.], [0., 7., 23., 38., 41., 29.],
|
||||
[12., 45., 102., 138., 126., 81.], [48., 129., 246., 282., 234., 141.],
|
||||
[84., 197., 341., 374., 287., 163.], [72., 162., 271., 292., 217., 120.]]]],
|
||||
[[[[0., 16., 49., 52., 55., 38.], [48., 135., 263., 278., 233., 141.],
|
||||
[156., 381., 678., 714., 558., 321.], [192., 465., 822., 858., 666., 381.],
|
||||
[228., 517., 869., 902., 671., 371.], [168., 370., 607., 628., 457., 248.]]]]]
|
||||
).astype(np.float32)
|
||||
output1 = vmap(conv2d_input, (1, None, None))(batch_dout, x, w)
|
||||
assert np.allclose(output1.asnumpy(), expected1, 0.0001, 0.0001)
|
||||
|
||||
dout = Tensor(np.arange(1 * 1 * 4 * 4).reshape(1, 1, 4, 4).astype(np.float32))
|
||||
batch_x = Tensor(np.arange(2 * 1 * 1 * 3 * 3).reshape(2, 1, 1, 3, 3).astype(np.float32))
|
||||
expected2 = np.array([[[[[0., 0., 1., 4., 7., 6.], [0., 7., 23., 38., 41., 29.],
|
||||
[12., 45., 102., 138., 126., 81.], [48., 129., 246., 282., 234., 141.],
|
||||
[84., 197., 341., 374., 287., 163.], [72., 162., 271., 292., 217., 120.]]]],
|
||||
[[[[0., 9., 28., 58., 52., 33.], [36., 97., 185., 254., 203., 119.],
|
||||
[120., 288., 507., 624., 477., 270.], [264., 588., 975., 1092., 801., 438.],
|
||||
[264., 575., 935., 1022., 737., 397.], [180., 387., 622., 670., 478., 255.]]]]]
|
||||
).astype(np.float32)
|
||||
output2 = vmap(conv2d_input, (None, 0, None))(dout, batch_x, w)
|
||||
assert np.allclose(output2.asnumpy(), expected2, 0.0001, 0.0001)
|
||||
|
||||
expected3 = np.array([[[[[0., 0., 1., 4., 7., 6.], [0., 7., 23., 38., 41., 29.],
|
||||
[12., 45., 102., 138., 126., 81.], [48., 129., 246., 282., 234., 141.],
|
||||
[84., 197., 341., 374., 287., 163.], [72., 162., 271., 292., 217., 120.]]]],
|
||||
[[[[144., 313., 508., 538., 388., 209.], [372., 801., 1289., 1358., 971., 519.],
|
||||
[696., 1488., 2379., 2496., 1773., 942.], [840., 1788., 2847., 2964., 2097., 1110.],
|
||||
[696., 1471., 2327., 2414., 1697., 893.], [420., 883., 1390., 1438., 1006., 527.]]]]]
|
||||
).astype(np.float32)
|
||||
output3 = vmap(conv2d_input, (1, 0, None))(batch_dout, batch_x, w)
|
||||
assert np.allclose(output3.asnumpy(), expected3, 0.0001, 0.0001)
|
||||
|
|
|
@ -23,6 +23,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
|
||||
class NetConv2d(nn.Cell):
|
||||
|
@ -265,3 +266,37 @@ def test_conv_NHWC():
|
|||
conv2d = NetConvNHWC(w1, x1)
|
||||
output = conv2d()
|
||||
assert (output.asnumpy() == expected).all()
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_conv2d_vmap():
|
||||
"""
|
||||
Feature: Conv2D op
|
||||
Description: Test vmap rule for Conv2D op
|
||||
Expectation: The dataset is processed as expected
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
conv2d = NetConv2d()
|
||||
|
||||
batch_x = Tensor(np.arange(2 * 1 * 3 * 3 * 3).reshape(2, 1, 3, 3, 3).astype(np.float32))
|
||||
w = Tensor(np.ones([2, 3, 1, 1]).astype(np.float32))
|
||||
expected1 = np.array([[[[[27., 30., 33.], [36., 39., 42.], [45., 48., 51.]],
|
||||
[[27., 30., 33.], [36., 39., 42.], [45., 48., 51.]]]],
|
||||
[[[[108., 111., 114.], [117., 120., 123.], [126., 129., 132.]],
|
||||
[[108., 111., 114.], [117., 120., 123.], [126., 129., 132.]]]]]).astype(np.float32)
|
||||
output1 = vmap(conv2d, (0, None))(batch_x, w)
|
||||
assert np.allclose(output1.asnumpy(), expected1, 0.0001, 0.0001)
|
||||
|
||||
x = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
|
||||
batch_w = Tensor(np.ones([2, 2, 3, 1, 1]).astype(np.float32))
|
||||
expected2 = np.array([[[[[27., 30., 33.], [36., 39., 42.], [45., 48., 51.]],
|
||||
[[27., 30., 33.], [36., 39., 42.], [45., 48., 51.]]]],
|
||||
[[[[27., 30., 33.], [36., 39., 42.], [45., 48., 51.]],
|
||||
[[27., 30., 33.], [36., 39., 42.], [45., 48., 51.]]]]]).astype(np.float32)
|
||||
output2 = vmap(conv2d, (None, 0))(x, batch_w)
|
||||
assert np.allclose(output2.asnumpy(), expected2, 0.0001, 0.0001)
|
||||
|
||||
output3 = vmap(conv2d, (0, 0))(batch_x, batch_w)
|
||||
assert np.allclose(output3.asnumpy(), expected1, 0.0001, 0.0001)
|
||||
|
|
|
@ -22,6 +22,7 @@ from mindspore import Tensor
|
|||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
|
||||
class NetConv3d(nn.Cell):
|
||||
|
@ -154,3 +155,43 @@ def test_conv3d_grad():
|
|||
output = grad_net(x, dy)
|
||||
optimizer(output[1])
|
||||
assert np.allclose(net.cv1.weight.asnumpy(), w_exp, atol=1.0e-4)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_conv3d_vmap():
|
||||
"""
|
||||
Feature: Conv3D op
|
||||
Description: Test vmap rule for Conv3D op
|
||||
Expectation: The dataset is processed as expected
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
conv3d = NetConv3d()
|
||||
|
||||
batch_x = Tensor(np.arange(2 * 1 * 1 * 3 * 3 * 3).reshape(2, 1, 1, 3, 3, 3).astype(np.float32))
|
||||
w = Tensor(np.ones([4, 1, 2, 2, 2]).astype(np.float32))
|
||||
expected1 = np.array([[[[[[52., 60.], [76., 84.]], [[124., 132.], [148., 156.]]],
|
||||
[[[52., 60.], [76., 84.]], [[124., 132.], [148., 156.]]],
|
||||
[[[52., 60.], [76., 84.]], [[124., 132.], [148., 156.]]],
|
||||
[[[52., 60.], [76., 84.]], [[124., 132.], [148., 156.]]]]],
|
||||
[[[[[268., 276.], [292., 300.]], [[340., 348.], [364., 372.]]],
|
||||
[[[268., 276.], [292., 300.]], [[340., 348.], [364., 372.]]],
|
||||
[[[268., 276.], [292., 300.]], [[340., 348.], [364., 372.]]],
|
||||
[[[268., 276.], [292., 300.]], [[340., 348.], [364., 372.]]]]]]).astype(np.float32)
|
||||
output1 = vmap(conv3d, (0, None))(batch_x, w)
|
||||
assert np.allclose(output1.asnumpy(), expected1, 0.0001, 0.0001)
|
||||
|
||||
x = Tensor(np.arange(1 * 1 * 3 * 3 * 3).reshape(1, 1, 3, 3, 3).astype(np.float32))
|
||||
batch_w = Tensor(np.arange(2 * 4 * 1 * 2 * 2 * 2).reshape(2, 4, 1, 2, 2, 2).astype(np.float32))
|
||||
expected2 = np.array([[[[[[268., 296.], [352., 380.]], [[520., 548.], [604., 632.]]],
|
||||
[[[684., 776.], [960., 1052.]], [[1512., 1604.], [1788., 1880.]]],
|
||||
[[[1100., 1256.], [1568., 1724.]], [[2504., 2660.], [2972., 3128.]]],
|
||||
[[[1516., 1736.], [2176., 2396.]], [[3496., 3716.], [4156., 4376.]]]]],
|
||||
[[[[[1932., 2216.], [2784., 3068.]], [[4488., 4772.], [5340., 5624.]]],
|
||||
[[[2348., 2696.], [3392., 3740.]], [[5480., 5828.], [6524., 6872.]]],
|
||||
[[[2764., 3176.], [4000., 4412.]], [[6472., 6884.], [7708., 8120.]]],
|
||||
[[[3180., 3656.], [4608., 5084.]], [[7464., 7940.], [8892., 9368.]]]]]]
|
||||
).astype(np.float32)
|
||||
output2 = vmap(conv3d, (None, 0))(x, batch_w)
|
||||
assert np.allclose(output2.asnumpy(), expected2, 0.0001, 0.0001)
|
||||
|
|
|
@ -20,6 +20,7 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
|
||||
class NetConv3dTranspose(nn.Cell):
|
||||
|
@ -63,3 +64,34 @@ def test_conv3d_transpose():
|
|||
conv3dtranspose = NetConv3dTranspose()
|
||||
output = conv3dtranspose(x, w)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_conv3d_transpose_vmap():
|
||||
"""
|
||||
Feature: Conv3DTranspose op
|
||||
Description: Test vmap rule for Conv3DTranspose op
|
||||
Expectation: The dataset is processed as expected
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
conv3d_trans = NetConv3dTranspose()
|
||||
|
||||
batch_dout = Tensor(np.arange(2 * 1 * 2 * 3 * 3 * 3).reshape(2, 1, 2, 3, 3, 3).astype(np.float32))
|
||||
weight = Tensor(np.ones([2, 2, 2, 2, 2]).astype(np.float32))
|
||||
expected1 = np.array([[[[[[320., 336.], [368., 384.]], [[464., 480.], [512., 528.]]],
|
||||
[[[320., 336.], [368., 384.]], [[464., 480.], [512., 528.]]]]],
|
||||
[[[[[1184., 1200.], [1232., 1248.]], [[1328., 1344.], [1376., 1392.]]],
|
||||
[[[1184., 1200.], [1232., 1248.]], [[1328., 1344.], [1376., 1392.]]]]]]).astype(np.float32)
|
||||
output1 = vmap(conv3d_trans, (0, None))(batch_dout, weight)
|
||||
assert np.allclose(output1.asnumpy(), expected1, 0.0001, 0.0001)
|
||||
|
||||
dout = Tensor(np.arange(1 * 2 * 3 * 3 * 3).reshape(1, 2, 3, 3, 3).astype(np.float32))
|
||||
batch_weight = Tensor(np.ones([2, 2, 2, 2, 2, 2]).astype(np.float32))
|
||||
expected2 = np.array([[[[[[320., 336.], [368., 384.]], [[464., 480.], [512., 528.]]],
|
||||
[[[320., 336.], [368., 384.]], [[464., 480.], [512., 528.]]]]],
|
||||
[[[[[320., 336.], [368., 384.]], [[464., 480.], [512., 528.]]],
|
||||
[[[320., 336.], [368., 384.]], [[464., 480.], [512., 528.]]]]]]).astype(np.float32)
|
||||
output2 = vmap(conv3d_trans, (None, 0))(dout, batch_weight)
|
||||
assert np.allclose(output2.asnumpy(), expected2, 0.0001, 0.0001)
|
||||
|
|
Loading…
Reference in New Issue