From 3f1c71e04a0d5de6c9cf28b854825be34e5be17c Mon Sep 17 00:00:00 2001 From: liqiliang Date: Fri, 27 May 2022 17:51:32 +0800 Subject: [PATCH] Add vmap and it's tests. --- .../gpu/kernel/arrays/transpose_gpu_kernel.cc | 16 +++- .../cuda_impl/cuda_ops/transpose_impl.cu | 25 ++++- .../cuda_impl/cuda_ops/transpose_impl_opt.cu | 76 +++++++++++++-- mindspore/core/ops/grad/sqrt_grad.cc | 4 + mindspore/core/ops/sqrt.cc | 1 + mindspore/python/mindspore/common/tensor.py | 1 - .../mindspore/ops/_vmap/vmap_array_ops.py | 53 +++++++++- .../python/mindspore/ops/_vmap/vmap_base.py | 34 ++++++- .../mindspore/ops/_vmap/vmap_math_ops.py | 9 +- .../python/mindspore/ops/_vmap/vmap_nn_ops.py | 38 +++++--- .../mindspore/ops/function/array_func.py | 1 - .../mindspore/ops/operations/array_ops.py | 1 - tests/st/ops/cpu/test_arithmetic_self_op.py | 96 +++++++++++++++++++ tests/st/ops/cpu/test_inv_grad_op.py | 58 +++++++++-- tests/st/ops/cpu/test_matrix_band_part.py | 51 +++++++++- tests/st/ops/cpu/test_mish_grad_op.py | 65 +++++++++++++ tests/st/ops/cpu/test_padding.py | 57 +++++++++++ tests/st/ops/gpu/test_inv_grad_op.py | 58 +++++++++-- tests/st/ops/gpu/test_inv_op.py | 31 ++++++ tests/st/ops/gpu/test_invert_op.py | 29 ++++++ tests/st/ops/gpu/test_matrix_band_part.py | 51 +++++++++- tests/st/ops/gpu/test_mish_grad_op.py | 65 +++++++++++++ tests/st/ops/gpu/test_padding.py | 57 +++++++++++ 23 files changed, 820 insertions(+), 57 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.cc index 99d148bcbaa..2d77ed5357b 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.cc @@ -25,9 +25,21 @@ MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).A TransposeFwdGpuKernelMod, float) MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), TransposeFwdGpuKernelMod, half) -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - TransposeFwdGpuKernelMod, int) MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), TransposeFwdGpuKernelMod, int64_t) +MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + TransposeFwdGpuKernelMod, int) +MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + TransposeFwdGpuKernelMod, int16_t) +MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + TransposeFwdGpuKernelMod, int8_t) +MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + TransposeFwdGpuKernelMod, uint64_t) +MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + TransposeFwdGpuKernelMod, uint32_t) +MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + TransposeFwdGpuKernelMod, uint16_t) +MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + TransposeFwdGpuKernelMod, uint8_t) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cu index b7770b53598..9200b82c4be 100755 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cu @@ -74,11 +74,32 @@ template CUDA_LIB_EXPORT void CalTranspose(const size_t size, const float template CUDA_LIB_EXPORT void CalTranspose(const size_t size, const half *input, const size_t *input_shape, const size_t *input_axis, const size_t shape_size, half *output, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTranspose(const size_t size, const int64_t *input, const size_t *input_shape, + const size_t *input_axis, const size_t shape_size, int64_t *output, + cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void CalTranspose(const size_t size, const int *input, const size_t *input_shape, const size_t *input_axis, const size_t shape_size, int *output, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CalTranspose(const size_t size, const int64_t *input, const size_t *input_shape, - const size_t *input_axis, const size_t shape_size, int64_t *output, +template CUDA_LIB_EXPORT void CalTranspose(const size_t size, const int16_t *input, const size_t *input_shape, + const size_t *input_axis, const size_t shape_size, int16_t *output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTranspose(const size_t size, const int8_t *input, const size_t *input_shape, + const size_t *input_axis, const size_t shape_size, int8_t *output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTranspose(const size_t size, const uint64_t *input, + const size_t *input_shape, const size_t *input_axis, + const size_t shape_size, uint64_t *output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTranspose(const size_t size, const uint32_t *input, + const size_t *input_shape, const size_t *input_axis, + const size_t shape_size, uint32_t *output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTranspose(const size_t size, const uint16_t *input, + const size_t *input_shape, const size_t *input_axis, + const size_t shape_size, uint16_t *output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalTranspose(const size_t size, const uint8_t *input, const size_t *input_shape, + const size_t *input_axis, const size_t shape_size, uint8_t *output, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void CalTranspose>(const size_t size, const Complex *input, const size_t *input_shape, const size_t *input_axis, diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl_opt.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl_opt.cu index 9cec09b022f..851df327e19 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl_opt.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl_opt.cu @@ -275,15 +275,45 @@ template CUDA_LIB_EXPORT void CalNHWC2NCHWInterface(const size_t size, con const size_t *input_axis, const size_t *d_input_shape, const size_t *d_input_axis, half *d_output, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, const int *d_input, - const size_t *input_shape, const size_t *input_axis, - const size_t *d_input_shape, const size_t *d_input_axis, - int *d_output, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, const int64_t *d_input, const size_t *input_shape, const size_t *input_axis, const size_t *d_input_shape, const size_t *d_input_axis, int64_t *d_output, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, const int *d_input, + const size_t *input_shape, const size_t *input_axis, + const size_t *d_input_shape, const size_t *d_input_axis, + int *d_output, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, + const int16_t *d_input, const size_t *input_shape, + const size_t *input_axis, const size_t *d_input_shape, + const size_t *d_input_axis, int16_t *d_output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, + const int8_t *d_input, const size_t *input_shape, + const size_t *input_axis, const size_t *d_input_shape, + const size_t *d_input_axis, int8_t *d_output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, + const uint64_t *d_input, const size_t *input_shape, + const size_t *input_axis, const size_t *d_input_shape, + const size_t *d_input_axis, uint64_t *d_output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, + const uint32_t *d_input, const size_t *input_shape, + const size_t *input_axis, const size_t *d_input_shape, + const size_t *d_input_axis, uint32_t *d_output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, + const uint16_t *d_input, const size_t *input_shape, + const size_t *input_axis, const size_t *d_input_shape, + const size_t *d_input_axis, uint16_t *d_output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, + const uint8_t *d_input, const size_t *input_shape, + const size_t *input_axis, const size_t *d_input_shape, + const size_t *d_input_axis, uint8_t *d_output, + cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, const bool *d_input, const size_t *input_shape, @@ -305,12 +335,42 @@ template CUDA_LIB_EXPORT void CalNCHW2NHWCInterface(const size_t size, con const size_t *input_axis, const size_t *d_input_shape, const size_t *d_input_axis, half *d_output, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, const int *d_input, - const size_t *input_shape, const size_t *input_axis, - const size_t *d_input_shape, const size_t *d_input_axis, - int *d_output, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, const int64_t *d_input, const size_t *input_shape, const size_t *input_axis, const size_t *d_input_shape, const size_t *d_input_axis, int64_t *d_output, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, const int *d_input, + const size_t *input_shape, const size_t *input_axis, + const size_t *d_input_shape, const size_t *d_input_axis, + int *d_output, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, + const int16_t *d_input, const size_t *input_shape, + const size_t *input_axis, const size_t *d_input_shape, + const size_t *d_input_axis, int16_t *d_output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, + const int8_t *d_input, const size_t *input_shape, + const size_t *input_axis, const size_t *d_input_shape, + const size_t *d_input_axis, int8_t *d_output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, + const uint64_t *d_input, const size_t *input_shape, + const size_t *input_axis, const size_t *d_input_shape, + const size_t *d_input_axis, uint64_t *d_output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, + const uint32_t *d_input, const size_t *input_shape, + const size_t *input_axis, const size_t *d_input_shape, + const size_t *d_input_axis, uint32_t *d_output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, + const uint16_t *d_input, const size_t *input_shape, + const size_t *input_axis, const size_t *d_input_shape, + const size_t *d_input_axis, uint16_t *d_output, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, + const uint8_t *d_input, const size_t *input_shape, + const size_t *input_axis, const size_t *d_input_shape, + const size_t *d_input_axis, uint8_t *d_output, + cudaStream_t cuda_stream); diff --git a/mindspore/core/ops/grad/sqrt_grad.cc b/mindspore/core/ops/grad/sqrt_grad.cc index f515f6a1643..53f69e24653 100644 --- a/mindspore/core/ops/grad/sqrt_grad.cc +++ b/mindspore/core/ops/grad/sqrt_grad.cc @@ -15,6 +15,10 @@ */ #include "ops/grad/sqrt_grad.h" +#include +#include +#include +#include #include "ops/op_utils.h" #include "utils/check_convert_utils.h" #include "abstract/ops/primitive_infer_map.h" diff --git a/mindspore/core/ops/sqrt.cc b/mindspore/core/ops/sqrt.cc index d0b23cc6ef1..94851c5ec5c 100644 --- a/mindspore/core/ops/sqrt.cc +++ b/mindspore/core/ops/sqrt.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include "ops/op_utils.h" #include "utils/check_convert_utils.h" #include "abstract/ops/primitive_infer_map.h" diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index 0288bb0a184..e75bfa835ed 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -1244,7 +1244,6 @@ class Tensor(Tensor_): [1. 1. 1. 0.] [1. 1. 1. 1.] [0. 1. 1. 1.]] - [[1. 1. 0. 0.] [1. 1. 1. 0.] [1. 1. 1. 1.] diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_array_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_array_ops.py index 314affcba57..4b4bca5a016 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_array_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_array_ops.py @@ -447,8 +447,8 @@ def _get_one_hot_vmap_axis(orig_axis, ndim, indices_dim): """Find vmap axis for OneHot.""" if orig_axis >= 0 and indices_dim <= orig_axis: return orig_axis + 1 - if indices_dim == (ndim-1) and orig_axis in (-1, (ndim-1)): - return ndim-1 + if indices_dim == (ndim - 1) and orig_axis in (-1, (ndim - 1)): + return ndim - 1 return orig_axis @@ -522,6 +522,55 @@ def get_masked_select_vmap_rule(prim, axis_size): return vmap_rule +@vmap_rules_getters.register(P.array_ops.MatrixBandPart) +def get_matrix_band_part_vmap_rule(prim, axis_size): + """VmapRule for `MatrixBandPart` operation.""" + if isinstance(prim, str): + prim = Primitive(prim) + + def vmap_rule(x_bdim, lower_bdim, upper_bdim): + is_all_none, result = vmap_general_preprocess(prim, x_bdim, lower_bdim, upper_bdim) + if is_all_none: + return result + + x, x_dim = x_bdim + lower, lower_dim = lower_bdim + upper, upper_dim = upper_bdim + if lower_dim is not None: + _raise_value_error("The source axis of `lower` in `P.array_ops.MatrixBandPart` currently does not support" + "setting to None, but got {}.".format(lower_dim)) + if upper_dim is not None: + _raise_value_error("The source axis of `upper` in `P.array_ops.MatrixBandPart` currently does not support" + "setting to None, but got {}.".format(upper_dim)) + x = _bdim_at_front(x, x_dim, axis_size) + out = prim(x, lower, upper) + return (out, 0) + + return vmap_rule + + +@vmap_rules_getters.register(P.Padding) +def get_padding_vmap_rule(prim, axis_size): + """VmapRule for `Padding` operation.""" + if isinstance(prim, str): + prim = Primitive(prim) + + def vmap_rule(x_bdim): + is_all_none, result = vmap_general_preprocess(prim, x_bdim) + if is_all_none: + return result + + x, x_dim = x_bdim + if F.rank(x) and x_dim in (-1, F.rank(x) - 1): + x = _bdim_at_front(x, x_dim, axis_size) + output = prim(x) + return (output, 0) + output = prim(x) + return (output, x_dim) + + return vmap_rule + + @vmap_rules_getters.register(P.Ger) def get_ger_vmap_rule(prim, axis_size): """VmapRule for `Ger`.""" diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_base.py b/mindspore/python/mindspore/ops/_vmap/vmap_base.py index 3cffa86e70e..977aae9cebc 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_base.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_base.py @@ -24,7 +24,6 @@ from ..composite import _VmapGeneralPreprocess from ..primitive import Primitive from ...common import Tensor - vmap_rules_getters = Registry() vmap_rules = Registry() @@ -185,6 +184,7 @@ def vmap_monad_rule(prim, axis_size): vals = vals + (val,) out = prim(*vals) return (out, None) + return vmap_rule @@ -279,3 +279,35 @@ def get_unsupported_dynamic_vmap_rule(prim, axis_size): return result return vmap_rule + + +def get_unary_grad_vmap_rule(prim, axis_size): + """VmapRule for `UnaryGrad`.""" + if isinstance(prim, str): + prim = Primitive(prim) + + def vmap_rule(x_bdim, dout_bdim): + x, x_dim = x_bdim + dout, dout_dim = dout_bdim + x_shape = F.shape(x) + dout_shape = F.shape(dout) + if x_dim == dout_dim and x_shape == dout_shape: + out = prim(x, dout) + return (out, x_dim) + + # This branch means (x_dim is None) and (dout_dim is not None). + if x_dim is None: + x = _broadcast_by_axis(x, dout_dim, axis_size) + out_dim = dout_dim + # This branch means (x_dim is not None) and (dout_dim is None). + elif dout_dim is None: + dout = _broadcast_by_axis(dout, x_dim, axis_size) + out_dim = x_dim + # This branch means (x_dim is not None) and (dout_dim is not None). + else: + dout = mnp.moveaxis(dout, dout_dim, x_dim) + out_dim = x_dim + out = prim(x, dout) + return (out, out_dim) + + return vmap_rule diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_math_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_math_ops.py index 1caee125ba0..5e6a37682ec 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_math_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_math_ops.py @@ -23,9 +23,11 @@ from mindspore.common import Tensor from mindspore.ops.operations.math_ops import Lerp from mindspore.ops.operations.math_ops import LpNorm from mindspore.ops.operations import linalg_ops +from mindspore.ops.operations import _grad_ops as G from ..primitive import Primitive from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_assign_vmap_rule, \ - get_unop_vmap_rule, _raise_value_error, _bdim_at_front, _broadcast_by_axis, _handle_broadcasting + get_unop_vmap_rule, _raise_value_error, _bdim_at_front, _broadcast_by_axis, _handle_broadcasting, \ + get_unary_grad_vmap_rule from ..operations.math_ops import (Bernoulli, BesselJ0, BesselJ1, BesselK0, BesselK0e, BesselY0, BesselY1, BesselK1, BesselK1e) @@ -406,6 +408,7 @@ def get_lp_norm_vmap_rule(prim, axis_size): get_assign_vmap_rule = vmap_rules_getters.register(P.AssignAdd)(get_assign_vmap_rule) get_assign_vmap_rule = vmap_rules_getters.register(P.AssignSub)(get_assign_vmap_rule) +# Unary vmap get_unop_vmap_rule = vmap_rules_getters.register(P.Abs)(get_unop_vmap_rule) get_unop_vmap_rule = vmap_rules_getters.register(P.ACos)(get_unop_vmap_rule) get_unop_vmap_rule = vmap_rules_getters.register(P.Acosh)(get_unop_vmap_rule) @@ -431,6 +434,8 @@ get_unop_vmap_rule = vmap_rules_getters.register(P.Log1p)(get_unop_vmap_rule) get_unop_vmap_rule = vmap_rules_getters.register(P.LogicalNot)(get_unop_vmap_rule) get_unop_vmap_rule = vmap_rules_getters.register(P.Neg)(get_unop_vmap_rule) get_unop_vmap_rule = vmap_rules_getters.register(P.Reciprocal)(get_unop_vmap_rule) +get_unop_vmap_rule = vmap_rules_getters.register(P.Inv)(get_unop_vmap_rule) +get_unop_vmap_rule = vmap_rules_getters.register(P.Invert)(get_unop_vmap_rule) get_unop_vmap_rule = vmap_rules_getters.register(P.Rint)(get_unop_vmap_rule) get_unop_vmap_rule = vmap_rules_getters.register(P.Round)(get_unop_vmap_rule) get_unop_vmap_rule = vmap_rules_getters.register(P.Rsqrt)(get_unop_vmap_rule) @@ -454,3 +459,5 @@ get_unop_vmap_rule = vmap_rules_getters.register(P.BesselI1)(get_unop_vmap_rule) get_unop_vmap_rule = vmap_rules_getters.register(P.BesselI1e)(get_unop_vmap_rule) get_unop_vmap_rule = vmap_rules_getters.register(BesselK1)(get_unop_vmap_rule) get_unop_vmap_rule = vmap_rules_getters.register(BesselK1e)(get_unop_vmap_rule) +# UnaryGrad vmap +get_unary_grad_vmap_rule = vmap_rules_getters.register(G.InvGrad)(get_unary_grad_vmap_rule) diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py index 6fe60c5079c..909302165f9 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_nn_ops.py @@ -21,7 +21,8 @@ from mindspore.ops.operations import nn_ops as NN from mindspore.ops import functional as F from mindspore.ops import constexpr from ..primitive import Primitive -from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_unop_vmap_rule, _bdim_at_front +from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_unop_vmap_rule, _bdim_at_front, \ + get_unary_grad_vmap_rule @vmap_rules_getters.register(P.BiasAdd) @@ -163,6 +164,7 @@ def get_pdist_vmap_rule(prim, axis_size): if isinstance(prim, str): prim = Primitive(prim) prim.add_prim_attr('p', 2.0) + def vmap_rule(x_bdim): is_all_none, result = vmap_general_preprocess(prim, x_bdim) if is_all_none: @@ -171,22 +173,10 @@ def get_pdist_vmap_rule(prim, axis_size): x = _bdim_at_front(x, x_dim, axis_size) out = prim(x) return out, 0 + return vmap_rule -get_unop_vmap_rule = vmap_rules_getters.register(P.Elu)(get_unop_vmap_rule) -get_unop_vmap_rule = vmap_rules_getters.register(P.ReLU)(get_unop_vmap_rule) -get_unop_vmap_rule = vmap_rules_getters.register(P.ReLU6)(get_unop_vmap_rule) -get_unop_vmap_rule = vmap_rules_getters.register(P.SeLU)(get_unop_vmap_rule) -get_unop_vmap_rule = vmap_rules_getters.register(P.HSigmoid)(get_unop_vmap_rule) -get_unop_vmap_rule = vmap_rules_getters.register(P.Softplus)(get_unop_vmap_rule) -get_unop_vmap_rule = vmap_rules_getters.register(P.SoftShrink)(get_unop_vmap_rule) -get_unop_vmap_rule = vmap_rules_getters.register(P.HShrink)(get_unop_vmap_rule) -get_unop_vmap_rule = vmap_rules_getters.register(P.GeLU)(get_unop_vmap_rule) -get_unop_vmap_rule = vmap_rules_getters.register(P.FastGeLU)(get_unop_vmap_rule) -get_unop_vmap_rule = vmap_rules_getters.register(P.HSwish)(get_unop_vmap_rule) - - @vmap_rules_getters.register(P.KLDivLoss) def get_kl_div_loss_vmap_rule(prim, axis_size): """VmapRule for `KLDivLoss` operation.""" @@ -240,4 +230,24 @@ def get_kl_div_loss_vmap_rule(prim, axis_size): raise RuntimeError("For KLDivLoss vmap, reduction should be one of " "['none', 'mean', 'batchmean', 'sum'], but got '{}'".format(prim_reduction)) return (out, 0) + return vmap_rule + + +# Unary vmap +get_unop_vmap_rule = vmap_rules_getters.register(P.Elu)(get_unop_vmap_rule) +get_unop_vmap_rule = vmap_rules_getters.register(P.ReLU)(get_unop_vmap_rule) +get_unop_vmap_rule = vmap_rules_getters.register(P.ReLU6)(get_unop_vmap_rule) +get_unop_vmap_rule = vmap_rules_getters.register(P.SeLU)(get_unop_vmap_rule) +get_unop_vmap_rule = vmap_rules_getters.register(P.HSigmoid)(get_unop_vmap_rule) +get_unop_vmap_rule = vmap_rules_getters.register(P.Softplus)(get_unop_vmap_rule) +get_unop_vmap_rule = vmap_rules_getters.register(P.Softsign)(get_unop_vmap_rule) +get_unop_vmap_rule = vmap_rules_getters.register(P.SoftShrink)(get_unop_vmap_rule) +get_unop_vmap_rule = vmap_rules_getters.register(P.HShrink)(get_unop_vmap_rule) +get_unop_vmap_rule = vmap_rules_getters.register(P.GeLU)(get_unop_vmap_rule) +get_unop_vmap_rule = vmap_rules_getters.register(P.FastGeLU)(get_unop_vmap_rule) +get_unop_vmap_rule = vmap_rules_getters.register(P.HSwish)(get_unop_vmap_rule) +get_unop_vmap_rule = vmap_rules_getters.register(P.Tanh)(get_unop_vmap_rule) +# UnaryGrad vmap +get_unary_grad_vmap_rule = vmap_rules_getters.register(G.TanhGrad)(get_unary_grad_vmap_rule) +get_unary_grad_vmap_rule = vmap_rules_getters.register(G.SoftplusGrad)(get_unary_grad_vmap_rule) diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index c3393e2a6c2..3801bb491af 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -154,7 +154,6 @@ def matrix_band_part(x, lower, upper): [1. 1. 1. 0.] [1. 1. 1. 1.] [0. 1. 1. 1.]] - [[1. 1. 0. 0.] [1. 1. 1. 0.] [1. 1. 1. 1.] diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 24db5d64a4d..fd085709aa9 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -1420,7 +1420,6 @@ class MatrixBandPart(Primitive): [1. 1. 1. 0.] [1. 1. 1. 1.] [0. 1. 1. 1.]] - [[1. 1. 0. 0.] [1. 1. 1. 0.] [1. 1. 1. 1.] diff --git a/tests/st/ops/cpu/test_arithmetic_self_op.py b/tests/st/ops/cpu/test_arithmetic_self_op.py index 2a0dbd2cc2c..6759eeb0b02 100644 --- a/tests/st/ops/cpu/test_arithmetic_self_op.py +++ b/tests/st/ops/cpu/test_arithmetic_self_op.py @@ -19,6 +19,7 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops import functional as F context.set_context(mode=context.GRAPH_MODE, device_target='CPU') @@ -240,6 +241,36 @@ def test_inv(shape, dtype, tol): assert np.all(np.abs(diff) < error) +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_inv_vmap(mode): + """ + Feature: test inv vmap feature. + Description: test inv vmap feature. + Expectation: Success. + """ + context.set_context(mode=mode, device_target="CPU") + x = Tensor(np.array([[0.25, 0.4, 0.31, 0.52], [0.5, 0.12, 0.31, 0.58]], dtype=np.float32)) + # Case 1 + output = F.vmap(F.inv, 0, 0)(x) + expect_output = np.array([[4., 2.5, 3.2258065, 1.923077], [2., 8.333334, 3.2258065, 1.724138]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # Case 2 + output = F.vmap(F.inv, 1, 0)(x) + expect_output = np.array([[4., 2.], [2.5, 8.333334], [3.2258065, 3.2258065], [1.923077, 1.724138]], + dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # Case 3 + output = F.vmap(F.inv, 0, 1)(x) + expect_output = np.array([[4., 2.], [2.5, 8.333334], [3.2258065, 3.2258065], [1.923077, 1.724138]], + dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @@ -259,6 +290,34 @@ def test_invert(shape, dtype): np.testing.assert_almost_equal(output.asnumpy(), expect_output) +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_invert_vmap(mode): + """ + Feature: test invert vmap feature. + Description: test invert vmap feature. + Expectation: Success. + """ + context.set_context(mode=mode, device_target="CPU") + x = Tensor(np.array([[25, 4, 13, 9], [2, -1, 0, -5]], dtype=np.int16)) + # Case 1 + output = F.vmap(F.invert, 0, 0)(x) + expect_output = np.array([[-26, -5, -14, -10], [-3, 0, -1, 4]], dtype=np.int16) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # Case 2 + output = F.vmap(F.invert, 1, 0)(x) + expect_output = np.array([[-26, -3], [-5, 0], [-14, -1], [-10, 4]], dtype=np.int16) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # Case 3 + output = F.vmap(F.invert, 0, 1)(x) + expect_output = np.array([[-26, -3], [-5, 0], [-14, -1], [-10, 4]], dtype=np.int16) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @@ -280,6 +339,43 @@ def test_softsign(shape, dtype, tol): assert np.all(np.abs(diff) < error) +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_softsign_vmap(mode): + """ + Feature: test softsign vmap feature. + Description: test softsign vmap feature. + Expectation: Success. + """ + context.set_context(mode=mode, device_target="CPU") + x = Tensor(np.array([[0, -1, 2, 30, -30], [2, -1, 0, -5, 50]], dtype=np.float32)) + # Case 1 + output = F.vmap(F.softsign, 0, 0)(x) + expect_output = np.array([[0., -0.5, 0.6666667, 0.9677419, -0.9677419], + [0.6666667, -0.5, 0., -0.8333333, 0.98039216]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # Case 2 + output = F.vmap(F.softsign, 1, 0)(x) + expect_output = np.array([[0., 0.6666667], + [-0.5, -0.5], + [0.6666667, 0.], + [0.9677419, -0.8333333], + [-0.9677419, 0.98039216]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # Case 3 + output = F.vmap(F.softsign, 0, 1)(x) + expect_output = np.array([[0., 0.6666667], + [-0.5, -0.5], + [0.6666667, 0.], + [0.9677419, -0.8333333], + [-0.9677419, 0.98039216]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard diff --git a/tests/st/ops/cpu/test_inv_grad_op.py b/tests/st/ops/cpu/test_inv_grad_op.py index 0bd1c9c72f2..a1702be7f29 100644 --- a/tests/st/ops/cpu/test_inv_grad_op.py +++ b/tests/st/ops/cpu/test_inv_grad_op.py @@ -20,6 +20,7 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops.operations import _grad_ops as G +from mindspore.ops import functional as F class NetInvGrad(nn.Cell): @@ -90,15 +91,54 @@ def test_inv_grad_int(mode, dtype): Expectation: the result match to numpy """ context.set_context(mode=mode, device_target="CPU") - y = Tensor(np.array([[[[-1, 1, 5], - [5, 3, 6], - [3, 2, -1]]]]).astype(dtype)) - dy = Tensor(np.array([[[[29, 1, -2], - [2, -1, 2], - [3, 1, 12]]]]).astype(dtype)) - expect = np.array([[[[-29, -1, 50], - [-50, 9, -72], - [-27, -4, -12]]]]).astype(dtype) + y = Tensor(np.array([[-1, 1, 5], + [5, 3, 6], + [3, 2, -1]]).astype(dtype)) + dy = Tensor(np.array([[29, 1, -2], + [2, -1, 2], + [3, 1, 12]]).astype(dtype)) + expect = np.array([[-29, -1, 50], + [-50, 9, -72], + [-27, -4, -12]]).astype(dtype) net = NetInvGrad() output = net(y, dy) np.testing.assert_array_almost_equal(output.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_inv_grad_vmap(mode): + """ + Feature: test inv_grad vmap feature. + Description: test inv_grad vmap feature. + Expectation: Success. + """ + context.set_context(mode=mode, device_target="CPU") + y = Tensor(np.array([[-1, 1, 12], + [5, 34, 6], + [10, 2, -1]]).astype(np.float32)) + dout = Tensor(np.array([[29, 1, 55], + [2.2, 63, 2], + [3, 3, 12]]).astype(np.float32)) + # Case 1 + output = F.vmap(NetInvGrad(), (0, 0), 0)(y, dout) + expect_output = np.array([[-29, -1, -7920], + [-55, -72828, -72], + [-300, -12, -12]]).astype(np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # Case 2 + output = F.vmap(NetInvGrad(), (0, 1), 0)(y, dout) + expect_output = np.array([[-29, -2.2, -432], + [-25, -72828, -108], + [-5500, -8, -12]]).astype(np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # Case 3 + output = F.vmap(NetInvGrad(), (0, 0), 1)(y, dout) + expect_output = np.array([[-29, -55, -300], + [-1, -72828, -12], + [-7920, -72, -12]]).astype(np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) diff --git a/tests/st/ops/cpu/test_matrix_band_part.py b/tests/st/ops/cpu/test_matrix_band_part.py index b3202a7553d..99b23327499 100644 --- a/tests/st/ops/cpu/test_matrix_band_part.py +++ b/tests/st/ops/cpu/test_matrix_band_part.py @@ -43,7 +43,52 @@ def test_matrix_band_part(mode, dtype, batch_shape, rows, cols): np_output = np.triu(np_output, -lower) if upper >= 0: np_output = np.tril(np_output, upper) - if batch_shape: - np_output = np.tile(np_output, batch_shape + [1, 1]) - ms_output = F.matrix_band_part(Tensor(np_output), lower, upper) + ms_output = F.matrix_band_part(Tensor(input_x), lower, upper) np.testing.assert_array_almost_equal(ms_output.asnumpy(), np_output) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_matrix_band_part_vmap(mode): + """ + Feature: test inv vmap feature. + Description: test inv vmap feature. + Expectation: Success. + """ + context.set_context(mode=mode, device_target="CPU") + x = Tensor(np.ones((2, 2, 3, 5)).astype(np.float32)) + lower = 1 + upper = 1 + # Case 1 + output = F.vmap(F.matrix_band_part, (0, None, None), 0)(x, lower, upper) + expect_output = np.array([[[[1., 1., 0., 0., 0.], + [1., 1., 1., 0., 0.], + [0., 1., 1., 1., 0.]], + [[1., 1., 0., 0., 0.], + [1., 1., 1., 0., 0.], + [0., 1., 1., 1., 0.]]], + [[[1., 1., 0., 0., 0.], + [1., 1., 1., 0., 0.], + [0., 1., 1., 1., 0.]], + [[1., 1., 0., 0., 0.], + [1., 1., 1., 0., 0.], + [0., 1., 1., 1., 0.]]]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # # Case 2 + output = F.vmap(F.matrix_band_part, (-1, None, None), -1)(x, lower, upper) + expect_output = np.array([[[[1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [0., 0., 0., 0., 0.]], + [[1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.]]], + [[[1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [0., 0., 0., 0., 0.]], + [[1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.]]]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) diff --git a/tests/st/ops/cpu/test_mish_grad_op.py b/tests/st/ops/cpu/test_mish_grad_op.py index becb97d03c1..38b7202204b 100644 --- a/tests/st/ops/cpu/test_mish_grad_op.py +++ b/tests/st/ops/cpu/test_mish_grad_op.py @@ -21,6 +21,7 @@ import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import composite as C from mindspore.ops import operations as P +from mindspore.ops import functional as F class MishNet(nn.Cell): @@ -83,3 +84,67 @@ def test_mish_grad(mode, dtype, tol): grad = MishGradNet(net) output = grad(x, dy) assert np.allclose(output[0].asnumpy(), expect, atol=tol, rtol=tol, equal_nan=True) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_mish_grad_vmap(mode): + """ + Feature: test mish_grad vmap feature. + Description: test mish_grad vmap feature. + Expectation: Success. + """ + context.set_context(mode=mode, device_target="CPU") + x = Tensor(np.array([[[[1.7641, 0.4002, 0.9787], + [2.2409, 1.8676, -0.9773]], + [[0.9501, -0.1514, -0.1032], + [0.4106, 0.1440, 1.4543]]], + [[[0.7610, 0.1217, 0.4439], + [0.3337, 1.4941, -0.2052]], + [[0.3131, -0.8541, -2.5530], + [0.6536, 0.8644, -0.7422]]]]).astype(np.float32)) + dout = Tensor(np.array([[[[2.2698, -1.4544, 0.0458], + [-0.1872, 1.5328, 1.4694]], + [[0.1549, 0.3782, -0.8878], + [-1.9808, -0.3479, 0.1563]]], + [[[1.2303, 1.2024, -0.3873], + [-0.3023, -1.0486, -1.4200]], + [[-1.7063, 1.9508, -0.5097], + [-0.4381, -1.2528, 0.7775]]]]).astype(np.float32)) + # Case 1 + output = F.vmap(MishGradNet(MishNet()), (0, 0), 0)(x, dout) + expect_output = np.array([[[[2.4551497, -1.2175097, 0.0478603], + [-0.1975334, 1.6502883, 0.09884691]], + [[0.16096735, 0.19009684, -0.47376704], + [-1.6688112, -0.24026634, 0.17010784]]], + [[[1.2171272, 0.81384104, -0.33282074], + [-0.24231759, -1.1413976, -0.6648671]], + [[-1.3482722, 0.22441024, 0.05531986], + [-0.41696107, -1.2767013, 0.1277946]]]]).astype(np.float32) + assert np.allclose(output[0].asnumpy(), expect_output, atol=1e-4, rtol=1e-4, equal_nan=True) + + # # Case 2 + output = F.vmap(MishGradNet(MishNet()), (0, 1), 0)(x, dout) + expect_output = np.array([[[[2.4551497, -1.2175097, 0.0478603], + [-0.1975334, 1.6502883, 0.09884691]], + [[1.2784901, 0.6043692, -0.20667942], + [-0.2546858, -0.724183, -1.5454454]]], + [[[0.15324152, 0.2559836, -0.7629183], + [-1.5877694, -0.378688, 0.0731822]], + [[-1.3482722, 0.22441024, 0.05531986], + [-0.41696107, -1.2767013, 0.1277946]]]]).astype(np.float32) + assert np.allclose(output[0].asnumpy(), expect_output, atol=1e-4, rtol=1e-4, equal_nan=True) + + # # Case 3 + output = F.vmap(MishGradNet(MishNet()), (0, 0), 1)(x, dout) + expect_output = np.array([[[[2.4551497, -1.2175097, 0.0478603], + [-0.1975334, 1.6502883, 0.09884691]], + [[1.2171272, 0.81384104, -0.33282074], + [-0.24231759, -1.1413976, -0.6648671]]], + [[[0.16096735, 0.19009684, -0.47376704], + [-1.6688112, -0.24026634, 0.17010784]], + [[-1.3482722, 0.22441024, 0.05531986], + [-0.41696107, -1.2767013, 0.1277946]]]]).astype(np.float32) + assert np.allclose(output[0].asnumpy(), expect_output, atol=1e-4, rtol=1e-4, equal_nan=True) diff --git a/tests/st/ops/cpu/test_padding.py b/tests/st/ops/cpu/test_padding.py index 006869f46a1..03d25697831 100644 --- a/tests/st/ops/cpu/test_padding.py +++ b/tests/st/ops/cpu/test_padding.py @@ -19,6 +19,7 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops import functional as F class Net(nn.Cell): @@ -52,3 +53,59 @@ def test_padding(mode, shape, dtype, pad_dim_size): pad_width.append((0, pad_dim_size - 1)) expect = np.pad(x, tuple(pad_width), 'constant', constant_values=0) np.testing.assert_array_almost_equal(output.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_padding_vmap(mode): + """ + Feature: test padding vmap feature. + Description: test padding vmap feature. + Expectation: Success. + """ + context.set_context(mode=mode, device_target="CPU") + x = Tensor(np.array([[[-270.0144], + [19.09283], + [43.96024], + [257.01694]], + [[-104.56876], + [42.85809], + [-123.558815], + [54.194077]]], dtype=np.float32)) + # Case 1 + output = F.vmap(Net(4), 0, 0)(x) + expect_output = np.array([[[-270.0144, 0, 0, 0], + [19.09283, 0, 0, 0], + [43.96024, 0, 0, 0], + [257.01694, 0, 0, 0]], + [[-104.56876, 0, 0, 0], + [42.85809, 0, 0, 0], + [-123.558815, 0, 0, 0], + [54.194077, 0, 0, 0]]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # Case 2 + output = F.vmap(Net(4), 0, 1)(x) + expect_output = np.array([[[-270.0144, 0., 0., 0.], + [-104.56876, 0., 0., 0.]], + [[19.09283, 0., 0., 0.], + [42.85809, 0., 0., 0.]], + [[43.96024, 0., 0., 0.], + [-123.558815, 0., 0., 0.]], + [[257.01694, 0., 0., 0.], + [54.194077, 0., 0., 0.]]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # # Case 3 + output = F.vmap(Net(4), 1, 0)(x) + expect_output = np.array([[[-270.0144, 0., 0., 0.], + [-104.56876, 0., 0., 0.]], + [[19.09283, 0., 0., 0.], + [42.85809, 0., 0., 0.]], + [[43.96024, 0., 0., 0.], + [-123.558815, 0., 0., 0.]], + [[257.01694, 0., 0., 0.], + [54.194077, 0., 0., 0.]]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) diff --git a/tests/st/ops/gpu/test_inv_grad_op.py b/tests/st/ops/gpu/test_inv_grad_op.py index a911a9d5aab..7f8d6ed8028 100644 --- a/tests/st/ops/gpu/test_inv_grad_op.py +++ b/tests/st/ops/gpu/test_inv_grad_op.py @@ -20,6 +20,7 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops.operations import _grad_ops as G +from mindspore.ops import functional as F class NetInvGrad(nn.Cell): @@ -90,15 +91,54 @@ def test_inv_grad_int(mode, dtype): Expectation: the result match to numpy """ context.set_context(mode=mode, device_target="GPU") - y = Tensor(np.array([[[[-1, 1, 5], - [5, 3, 6], - [3, 2, -1]]]]).astype(dtype)) - dy = Tensor(np.array([[[[29, 1, -2], - [2, -1, 2], - [3, 1, 12]]]]).astype(dtype)) - expect = np.array([[[[-29, -1, 50], - [-50, 9, -72], - [-27, -4, -12]]]]).astype(dtype) + y = Tensor(np.array([[-1, 1, 5], + [5, 3, 6], + [3, 2, -1]]).astype(dtype)) + dy = Tensor(np.array([[29, 1, -2], + [2, -1, 2], + [3, 1, 12]]).astype(dtype)) + expect = np.array([[-29, -1, 50], + [-50, 9, -72], + [-27, -4, -12]]).astype(dtype) net = NetInvGrad() output = net(y, dy) np.testing.assert_array_almost_equal(output.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_inv_grad_vmap(mode): + """ + Feature: test inv_grad vmap feature. + Description: test inv_grad vmap feature. + Expectation: Success. + """ + context.set_context(mode=mode, device_target="GPU") + y = Tensor(np.array([[-1, 1, 12], + [5, 34, 6], + [10, 2, -1]]).astype(np.float32)) + dout = Tensor(np.array([[29, 1, 55], + [2.2, 63, 2], + [3, 3, 12]]).astype(np.float32)) + # Case 1 + output = F.vmap(NetInvGrad(), (0, 0), 0)(y, dout) + expect_output = np.array([[-29, -1, -7920], + [-55, -72828, -72], + [-300, -12, -12]]).astype(np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # Case 2 + output = F.vmap(NetInvGrad(), (0, 1), 0)(y, dout) + expect_output = np.array([[-29, -2.2, -432], + [-25, -72828, -108], + [-5500, -8, -12]]).astype(np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # Case 3 + output = F.vmap(NetInvGrad(), (0, 0), 1)(y, dout) + expect_output = np.array([[-29, -55, -300], + [-1, -72828, -12], + [-7920, -72, -12]]).astype(np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) diff --git a/tests/st/ops/gpu/test_inv_op.py b/tests/st/ops/gpu/test_inv_op.py index 0163aee22b9..08456384207 100644 --- a/tests/st/ops/gpu/test_inv_op.py +++ b/tests/st/ops/gpu/test_inv_op.py @@ -20,6 +20,7 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops import functional as F class NetInv(nn.Cell): @@ -53,3 +54,33 @@ def test_inv(mode, shape, dtype, tol): diff = output.asnumpy() - expect_output error = np.ones(shape=expect_output.shape) * tol assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_inv_vmap(mode): + """ + Feature: test inv vmap feature. + Description: test inv vmap feature. + Expectation: Success. + """ + context.set_context(mode=mode, device_target="GPU") + x = Tensor(np.array([[0.25, 0.4, 0.31, 0.52], [0.5, 0.12, 0.31, 0.58]], dtype=np.float32)) + # Case 1 + output = F.vmap(F.inv, 0, 0)(x) + expect_output = np.array([[4., 2.5, 3.2258065, 1.923077], [2., 8.333334, 3.2258065, 1.724138]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # Case 2 + output = F.vmap(F.inv, 1, 0)(x) + expect_output = np.array([[4., 2.], [2.5, 8.333334], [3.2258065, 3.2258065], [1.923077, 1.724138]], + dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # Case 3 + output = F.vmap(F.inv, 0, 1)(x) + expect_output = np.array([[4., 2.], [2.5, 8.333334], [3.2258065, 3.2258065], [1.923077, 1.724138]], + dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) diff --git a/tests/st/ops/gpu/test_invert_op.py b/tests/st/ops/gpu/test_invert_op.py index 17ee3262c46..f3685d838a1 100644 --- a/tests/st/ops/gpu/test_invert_op.py +++ b/tests/st/ops/gpu/test_invert_op.py @@ -19,6 +19,7 @@ import pytest import mindspore.context as context from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops import functional as F @pytest.mark.level0 @@ -41,3 +42,31 @@ def test_invert(mode, shape, dtype): output = invert(Tensor(input_x)) expect_output = np.invert(input_x) np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_invert_vmap(mode): + """ + Feature: test invert vmap feature. + Description: test invert vmap feature. + Expectation: Success. + """ + context.set_context(mode=mode, device_target="GPU") + x = Tensor(np.array([[25, 4, 13, 9], [2, -1, 0, -5]], dtype=np.int16)) + # Case 1 + output = F.vmap(F.invert, 0, 0)(x) + expect_output = np.array([[-26, -5, -14, -10], [-3, 0, -1, 4]], dtype=np.int16) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # Case 2 + output = F.vmap(F.invert, 1, 0)(x) + expect_output = np.array([[-26, -3], [-5, 0], [-14, -1], [-10, 4]], dtype=np.int16) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # Case 3 + output = F.vmap(F.invert, 0, 1)(x) + expect_output = np.array([[-26, -3], [-5, 0], [-14, -1], [-10, 4]], dtype=np.int16) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) diff --git a/tests/st/ops/gpu/test_matrix_band_part.py b/tests/st/ops/gpu/test_matrix_band_part.py index ecca7db5072..fb26afd0b85 100644 --- a/tests/st/ops/gpu/test_matrix_band_part.py +++ b/tests/st/ops/gpu/test_matrix_band_part.py @@ -43,7 +43,52 @@ def test_matrix_band_part(mode, dtype, batch_shape, rows, cols): np_output = np.triu(np_output, -lower) if upper >= 0: np_output = np.tril(np_output, upper) - if batch_shape: - np_output = np.tile(np_output, batch_shape + [1, 1]) - ms_output = F.matrix_band_part(Tensor(np_output), lower, upper) + ms_output = F.matrix_band_part(Tensor(input_x), lower, upper) np.testing.assert_array_almost_equal(ms_output.asnumpy(), np_output) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) +def test_matrix_band_part_vmap(mode): + """ + Feature: test inv vmap feature. + Description: test inv vmap feature. + Expectation: Success. + """ + context.set_context(mode=mode, device_target="GPU") + x = Tensor(np.ones((2, 2, 3, 5)).astype(np.float32)) + lower = 1 + upper = 1 + # Case 1 + output = F.vmap(F.matrix_band_part, (0, None, None), 0)(x, lower, upper) + expect_output = np.array([[[[1., 1., 0., 0., 0.], + [1., 1., 1., 0., 0.], + [0., 1., 1., 1., 0.]], + [[1., 1., 0., 0., 0.], + [1., 1., 1., 0., 0.], + [0., 1., 1., 1., 0.]]], + [[[1., 1., 0., 0., 0.], + [1., 1., 1., 0., 0.], + [0., 1., 1., 1., 0.]], + [[1., 1., 0., 0., 0.], + [1., 1., 1., 0., 0.], + [0., 1., 1., 1., 0.]]]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # # Case 2 + output = F.vmap(F.matrix_band_part, (-1, None, None), -1)(x, lower, upper) + expect_output = np.array([[[[1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [0., 0., 0., 0., 0.]], + [[1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.]]], + [[[1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [0., 0., 0., 0., 0.]], + [[1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.]]]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) diff --git a/tests/st/ops/gpu/test_mish_grad_op.py b/tests/st/ops/gpu/test_mish_grad_op.py index c700322adc9..6ad0283484e 100644 --- a/tests/st/ops/gpu/test_mish_grad_op.py +++ b/tests/st/ops/gpu/test_mish_grad_op.py @@ -21,6 +21,7 @@ import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import composite as C from mindspore.ops import operations as P +from mindspore.ops import functional as F class MishNet(nn.Cell): @@ -83,3 +84,67 @@ def test_mish_grad(mode, dtype, tol): grad = MishGradNet(net) output = grad(x, dy) assert np.allclose(output[0].asnumpy(), expect, atol=tol, rtol=tol, equal_nan=True) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_mish_grad_vmap(mode): + """ + Feature: test mish_grad vmap feature. + Description: test mish_grad vmap feature. + Expectation: Success. + """ + context.set_context(mode=mode, device_target="GPU") + x = Tensor(np.array([[[[1.7641, 0.4002, 0.9787], + [2.2409, 1.8676, -0.9773]], + [[0.9501, -0.1514, -0.1032], + [0.4106, 0.1440, 1.4543]]], + [[[0.7610, 0.1217, 0.4439], + [0.3337, 1.4941, -0.2052]], + [[0.3131, -0.8541, -2.5530], + [0.6536, 0.8644, -0.7422]]]]).astype(np.float32)) + dout = Tensor(np.array([[[[2.2698, -1.4544, 0.0458], + [-0.1872, 1.5328, 1.4694]], + [[0.1549, 0.3782, -0.8878], + [-1.9808, -0.3479, 0.1563]]], + [[[1.2303, 1.2024, -0.3873], + [-0.3023, -1.0486, -1.4200]], + [[-1.7063, 1.9508, -0.5097], + [-0.4381, -1.2528, 0.7775]]]]).astype(np.float32)) + # Case 1 + output = F.vmap(MishGradNet(MishNet()), (0, 0), 0)(x, dout) + expect_output = np.array([[[[2.4551494, -1.2175093, 0.04786031], + [-0.1975334, 1.6502876, 0.098847]], + [[0.16096734, 0.19009684, -0.4737671], + [-1.6688104, -0.24026635, 0.17010784]]], + [[[1.2171272, 0.8138411, -0.33282048], + [-0.24231756, -1.1413976, -0.6648672]], + [[-1.3482721, 0.22441003, 0.05531899], + [-0.41695648, -1.2767013, 0.12779452]]]]).astype(np.float32) + assert np.allclose(output[0].asnumpy(), expect_output, atol=1e-4, rtol=1e-4, equal_nan=True) + + # # Case 2 + output = F.vmap(MishGradNet(MishNet()), (0, 1), 0)(x, dout) + expect_output = np.array([[[[2.4551494, -1.2175093, 0.04786031], + [-0.1975334, 1.6502876, 0.098847]], + [[1.2784901, 0.6043692, -0.20667945], + [-0.25468567, -0.7241831, -1.5454454]]], + [[[0.1532415, 0.25598362, -0.76291764], + [-1.5877693, -0.378688, 0.07318222]], + [[-1.3482721, 0.22441003, 0.05531899], + [-0.41695648, -1.2767013, 0.12779452]]]]).astype(np.float32) + assert np.allclose(output[0].asnumpy(), expect_output, atol=1e-4, rtol=1e-4, equal_nan=True) + + # # Case 3 + output = F.vmap(MishGradNet(MishNet()), (0, 0), 1)(x, dout) + expect_output = np.array([[[[2.4551494, -1.2175093, 0.04786031], + [-0.1975334, 1.6502876, 0.098847]], + [[1.2171272, 0.8138411, -0.33282048], + [-0.24231756, -1.1413976, -0.6648672]]], + [[[0.16096734, 0.19009684, -0.4737671], + [-1.6688104, -0.24026635, 0.17010784]], + [[-1.3482721, 0.22441003, 0.05531899], + [-0.41695648, -1.2767013, 0.12779452]]]]).astype(np.float32) + assert np.allclose(output[0].asnumpy(), expect_output, atol=1e-4, rtol=1e-4, equal_nan=True) diff --git a/tests/st/ops/gpu/test_padding.py b/tests/st/ops/gpu/test_padding.py index 8971bf68c69..c2d45b83ed3 100644 --- a/tests/st/ops/gpu/test_padding.py +++ b/tests/st/ops/gpu/test_padding.py @@ -19,6 +19,7 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops import functional as F class Net(nn.Cell): @@ -52,3 +53,59 @@ def test_padding(mode, shape, dtype, pad_dim_size): pad_width.append((0, pad_dim_size - 1)) expect = np.pad(x, tuple(pad_width), 'constant', constant_values=0) np.testing.assert_array_almost_equal(output.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_padding_vmap(mode): + """ + Feature: test padding vmap feature. + Description: test padding vmap feature. + Expectation: Success. + """ + context.set_context(mode=mode, device_target="GPU") + x = Tensor(np.array([[[-270.0144], + [19.09283], + [43.96024], + [257.01694]], + [[-104.56876], + [42.85809], + [-123.558815], + [54.194077]]], dtype=np.float32)) + # Case 1 + output = F.vmap(Net(4), 0, 0)(x) + expect_output = np.array([[[-270.0144, 0, 0, 0], + [19.09283, 0, 0, 0], + [43.96024, 0, 0, 0], + [257.01694, 0, 0, 0]], + [[-104.56876, 0, 0, 0], + [42.85809, 0, 0, 0], + [-123.558815, 0, 0, 0], + [54.194077, 0, 0, 0]]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # Case 2 + output = F.vmap(Net(4), 0, 1)(x) + expect_output = np.array([[[-270.0144, 0., 0., 0.], + [-104.56876, 0., 0., 0.]], + [[19.09283, 0., 0., 0.], + [42.85809, 0., 0., 0.]], + [[43.96024, 0., 0., 0.], + [-123.558815, 0., 0., 0.]], + [[257.01694, 0., 0., 0.], + [54.194077, 0., 0., 0.]]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output) + + # # Case 3 + output = F.vmap(Net(4), 1, 0)(x) + expect_output = np.array([[[-270.0144, 0., 0., 0.], + [-104.56876, 0., 0., 0.]], + [[19.09283, 0., 0., 0.], + [42.85809, 0., 0., 0.]], + [[43.96024, 0., 0., 0.], + [-123.558815, 0., 0., 0.]], + [[257.01694, 0., 0., 0.], + [54.194077, 0., 0., 0.]]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect_output)