diff --git a/docs/api/api_python/ops/mindspore.ops.func_addn.rst b/docs/api/api_python/ops/mindspore.ops.func_addn.rst index fd666197de1..2c658acb9e4 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_addn.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_addn.rst @@ -8,7 +8,7 @@ mindspore.ops.addn 所有输入Tensor必须具有相同的shape。 参数: - - **x** (Union(tuple[Tensor], list[Tensor])) - Tensor组成的tuble或list,类型为 `bool_ `_ 或 `number `_ 。 + - **x** (Union(tuple[Tensor], list[Tensor])) - Tensor组成的tuble或list。 返回: Tensor,与 `x` 的每个Tensor具有相同的shape和数据类型。 diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/addn_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/addn_cpu_kernel.cc index 0d1ad473e55..7a2d7502d8d 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/addn_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/addn_cpu_kernel.cc @@ -17,6 +17,7 @@ #include "plugin/device/cpu/kernel/mkldnn/addn_cpu_kernel.h" #include #include +#include #include "plugin/device/cpu/hal/device/cpu_device_address.h" #include "plugin/device/cpu/kernel/nnacl/fp32/add_fp32.h" #include "plugin/device/cpu/kernel/nnacl/errorcode.h" @@ -26,6 +27,9 @@ namespace mindspore { namespace kernel { namespace { +using complex64 = std::complex; +using complex128 = std::complex; + constexpr size_t kAddNInputsMinNum = 2; constexpr size_t kAddNOutputsNum = 1; @@ -164,7 +168,11 @@ std::vector> AddNCpuKernelMod: {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), &AddNCpuKernelMod::LaunchKernel}, {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - &AddNCpuKernelMod::LaunchKernel}}; + &AddNCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), + &AddNCpuKernelMod::LaunchKernel}, + {KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), + &AddNCpuKernelMod::LaunchKernel}}; std::vector AddNCpuKernelMod::GetOpSupport() { std::vector support_list; diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/addn_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/addn_gpu_kernel.cc index f8356aaa325..f552529b44a 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/addn_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/addn_gpu_kernel.cc @@ -27,11 +27,35 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), AddNFwdGpuKernelMod, half) +MS_REG_GPU_KERNEL_ONE(AddN, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + AddNFwdGpuKernelMod, int8_t) +MS_REG_GPU_KERNEL_ONE(AddN, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + AddNFwdGpuKernelMod, int16_t) MS_REG_GPU_KERNEL_ONE(AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - AddNFwdGpuKernelMod, int) + AddNFwdGpuKernelMod, int32_t) MS_REG_GPU_KERNEL_ONE(AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), AddNFwdGpuKernelMod, int64_t) +MS_REG_GPU_KERNEL_ONE(AddN, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + AddNFwdGpuKernelMod, uint8_t) +MS_REG_GPU_KERNEL_ONE( + AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + AddNFwdGpuKernelMod, uint16_t) +MS_REG_GPU_KERNEL_ONE( + AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + AddNFwdGpuKernelMod, uint32_t) +MS_REG_GPU_KERNEL_ONE( + AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + AddNFwdGpuKernelMod, uint64_t) +MS_REG_GPU_KERNEL_ONE( + AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), + AddNFwdGpuKernelMod, Complex) +MS_REG_GPU_KERNEL_ONE( + AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), + AddNFwdGpuKernelMod, Complex) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/addn_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/addn_gpu_kernel.h index a1cf71ace3c..e8e68660865 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/addn_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/addn_gpu_kernel.h @@ -24,9 +24,13 @@ #include "plugin/device/gpu/kernel/math/broadcast_gpu_kernel.h" #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/slice_impl.cuh" #include "plugin/device/gpu/kernel/kernel_constants.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" namespace mindspore { namespace kernel { +template +using Complex = mindspore::utils::Complex; + template class AddNFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod { public: @@ -50,8 +54,13 @@ class AddNFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod { FillDeviceArray(outputs[0]->size / sizeof(T), work_addr, 0.0f, reinterpret_cast(stream_ptr)); for (size_t i = 0; i < num_input_; i++) { T *input_addr = GetDeviceAddress(inputs, i); - ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, work_addr, work_addr, - reinterpret_cast(stream_ptr)); + if constexpr (std::is_same>::value || std::is_same>::value) { + ElewiseComplexArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, work_addr, work_addr, + reinterpret_cast(stream_ptr)); + } else { + ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, work_addr, work_addr, + reinterpret_cast(stream_ptr)); + } } if (work_addr != output_addr) { CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, @@ -61,6 +70,7 @@ class AddNFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod { } return true; } + bool Init(const CNodePtr &kernel_node) override { auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node); kernel_node_ = kernel_node; diff --git a/mindspore/core/ops/addn.cc b/mindspore/core/ops/addn.cc index 500cf938583..e16fcba1ab7 100644 --- a/mindspore/core/ops/addn.cc +++ b/mindspore/core/ops/addn.cc @@ -48,7 +48,7 @@ bool AddNDynShapeJoin(ShapeVector *shape1, const ShapeVector *shape2) { if ((*shape1)[i] == (*shape2)[i]) { continue; } - // If shape1 is dynamic, use shape of shape2.If shape2 is dynamic, keep shape1. + // If shape1 is dynamic, use shape of shape2. If shape2 is dynamic, keep shape1. if ((*shape1)[i] == abstract::Shape::kShapeDimAny) { (*shape1)[i] = (*shape2)[i]; continue; @@ -115,8 +115,7 @@ TypePtr AddNInferType(const PrimitivePtr &prim, const std::vectorBuildType()); } - std::set valid_types = common_valid_types; - valid_types.insert(kBool); + std::set valid_types = common_valid_types_with_complex_and_bool; (void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name); return elements[0]->BuildType(); } diff --git a/mindspore/python/mindspore/ops/function/math_func.py b/mindspore/python/mindspore/ops/function/math_func.py index e9e37fe1aa1..3b8797967da 100644 --- a/mindspore/python/mindspore/ops/function/math_func.py +++ b/mindspore/python/mindspore/ops/function/math_func.py @@ -176,9 +176,7 @@ def addn(x): All input tensors must have the same shape. Args: - x (Union(tuple[Tensor], list[Tensor])): A tuple or list composed of Tensor, the data type is - `bool_ `_ or - `number `_ . + x (Union(tuple[Tensor], list[Tensor])): A tuple or list composed of Tensor. Returns: Tensor, has the same shape and dtype as each Tensor of `x`. diff --git a/tests/st/ops/cpu/test_addn_op.py b/tests/st/ops/cpu/test_addn_op.py index 06eddc33bc8..d3c82adca3f 100644 --- a/tests/st/ops/cpu/test_addn_op.py +++ b/tests/st/ops/cpu/test_addn_op.py @@ -18,7 +18,7 @@ import pytest import mindspore.context as context import mindspore.nn as nn -from mindspore import Tensor +from mindspore import Tensor, ops from mindspore.ops import operations as P from mindspore import dtype as mstype @@ -118,3 +118,43 @@ def test_four_tensors_add(): expect_result = (x + y + m + n).astype(dtype) assert output.asnumpy().dtype == expect_result.dtype assert np.array_equal(output.asnumpy(), expect_result) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_addn_support_type(): + """ + Feature: test ops.addn. + Description: test ops.addn with different types. + Expectation: the result match with expected result. + """ + out_fp16 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float16), Tensor([4.5, 5.5, 6.5], mstype.float16)]) + out_fp32 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float32), Tensor([4.5, 5.5, 6.5], mstype.float32)]) + out_fp64 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float64), Tensor([4.5, 5.5, 6.5], mstype.float64)]) + out_int8 = ops.addn([Tensor([1, 2, 3], mstype.int8), Tensor([4, 5, 6], mstype.int8)]) + out_int16 = ops.addn([Tensor([1, 2, 3], mstype.int16), Tensor([4, 5, 6], mstype.int16)]) + out_int32 = ops.addn([Tensor([1, 2, 3], mstype.int32), Tensor([4, 5, 6], mstype.int32)]) + out_int64 = ops.addn([Tensor([1, 2, 3], mstype.int64), Tensor([4, 5, 6], mstype.int64)]) + out_uint8 = ops.addn([Tensor([1, 2, 3], mstype.uint8), Tensor([4, 5, 6], mstype.uint8)]) + out_uint16 = ops.addn([Tensor([1, 2, 3], mstype.uint16), Tensor([4, 5, 6], mstype.uint16)]) + out_uint32 = ops.addn([Tensor([1, 2, 3], mstype.uint32), Tensor([4, 5, 6], mstype.uint32)]) + out_uint64 = ops.addn([Tensor([1, 2, 3], mstype.uint64), Tensor([4, 5, 6], mstype.uint64)]) + out_complex64 = ops.addn([Tensor(np.asarray(np.complex(1.5 + 0.4j)), mstype.complex64), + Tensor(np.asarray(np.complex(2.5 + 0.4j)), mstype.complex64)]) + out_complex128 = ops.addn([Tensor(np.asarray(np.complex(1.5 + 0.4j)), mstype.complex128), + Tensor(np.asarray(np.complex(2.5 + 0.4j)), mstype.complex128)]) + + assert np.allclose(out_fp16.asnumpy(), Tensor([6., 8., 10.], mstype.float16).asnumpy(), rtol=1e-5, atol=1e-5) + assert np.allclose(out_fp32.asnumpy(), Tensor([6., 8., 10.], mstype.float32).asnumpy(), rtol=1e-5, atol=1e-5) + assert np.allclose(out_fp64.asnumpy(), Tensor([6., 8., 10.], mstype.float64).asnumpy(), rtol=1e-5, atol=1e-5) + assert np.all(out_int8.asnumpy() == Tensor([5, 7, 9], mstype.int8).asnumpy()) + assert np.all(out_int16.asnumpy() == Tensor([5, 7, 9], mstype.int16).asnumpy()) + assert np.all(out_int32.asnumpy() == Tensor([5, 7, 9], mstype.int32).asnumpy()) + assert np.all(out_int64.asnumpy() == Tensor([5, 7, 9], mstype.int64).asnumpy()) + assert np.all(out_uint8.asnumpy() == Tensor([5, 7, 9], mstype.uint8).asnumpy()) + assert np.all(out_uint16.asnumpy() == Tensor([5, 7, 9], mstype.uint16).asnumpy()) + assert np.all(out_uint32.asnumpy() == Tensor([5, 7, 9], mstype.uint32).asnumpy()) + assert np.all(out_uint64.asnumpy() == Tensor([5, 7, 9], mstype.uint64).asnumpy()) + assert np.all(out_complex64.asnumpy() == Tensor(np.asarray(np.complex(4 + 0.8j)), mstype.complex64).asnumpy()) + assert np.all(out_complex128.asnumpy() == Tensor(np.asarray(np.complex(4 + 0.8j)), mstype.complex128).asnumpy()) diff --git a/tests/st/ops/gpu/test_addn_op.py b/tests/st/ops/gpu/test_addn_op.py index e952f020c2e..564510f4b5c 100644 --- a/tests/st/ops/gpu/test_addn_op.py +++ b/tests/st/ops/gpu/test_addn_op.py @@ -18,7 +18,7 @@ import pytest import mindspore.context as context import mindspore.nn as nn -from mindspore import Tensor +from mindspore import Tensor, ops from mindspore.common.api import ms_function from mindspore.ops import operations as P from mindspore import dtype as mstype @@ -169,3 +169,43 @@ def test_net_int64(): [84., 87., 90., 93.], [96., 99., 102., 105.]]]]).astype(np.int64) assert (output.asnumpy() == expect_result).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_addn_support_type(): + """ + Feature: test ops.addn. + Description: test ops.addn with different types. + Expectation: the result match with expected result. + """ + out_fp16 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float16), Tensor([4.5, 5.5, 6.5], mstype.float16)]) + out_fp32 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float32), Tensor([4.5, 5.5, 6.5], mstype.float32)]) + out_fp64 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float64), Tensor([4.5, 5.5, 6.5], mstype.float64)]) + out_int8 = ops.addn([Tensor([1, 2, 3], mstype.int8), Tensor([4, 5, 6], mstype.int8)]) + out_int16 = ops.addn([Tensor([1, 2, 3], mstype.int16), Tensor([4, 5, 6], mstype.int16)]) + out_int32 = ops.addn([Tensor([1, 2, 3], mstype.int32), Tensor([4, 5, 6], mstype.int32)]) + out_int64 = ops.addn([Tensor([1, 2, 3], mstype.int64), Tensor([4, 5, 6], mstype.int64)]) + out_uint8 = ops.addn([Tensor([1, 2, 3], mstype.uint8), Tensor([4, 5, 6], mstype.uint8)]) + out_uint16 = ops.addn([Tensor([1, 2, 3], mstype.uint16), Tensor([4, 5, 6], mstype.uint16)]) + out_uint32 = ops.addn([Tensor([1, 2, 3], mstype.uint32), Tensor([4, 5, 6], mstype.uint32)]) + out_uint64 = ops.addn([Tensor([1, 2, 3], mstype.uint64), Tensor([4, 5, 6], mstype.uint64)]) + out_complex64 = ops.addn([Tensor(np.asarray(np.complex(1.5 + 0.4j)), mstype.complex64), + Tensor(np.asarray(np.complex(2.5 + 0.4j)), mstype.complex64)]) + out_complex128 = ops.addn([Tensor(np.asarray(np.complex(1.5 + 0.4j)), mstype.complex128), + Tensor(np.asarray(np.complex(2.5 + 0.4j)), mstype.complex128)]) + + assert np.allclose(out_fp16.asnumpy(), Tensor([6., 8., 10.], mstype.float16).asnumpy(), rtol=1e-5, atol=1e-5) + assert np.allclose(out_fp32.asnumpy(), Tensor([6., 8., 10.], mstype.float32).asnumpy(), rtol=1e-5, atol=1e-5) + assert np.allclose(out_fp64.asnumpy(), Tensor([6., 8., 10.], mstype.float64).asnumpy(), rtol=1e-5, atol=1e-5) + assert np.all(out_int8.asnumpy() == Tensor([5, 7, 9], mstype.int8).asnumpy()) + assert np.all(out_int16.asnumpy() == Tensor([5, 7, 9], mstype.int16).asnumpy()) + assert np.all(out_int32.asnumpy() == Tensor([5, 7, 9], mstype.int32).asnumpy()) + assert np.all(out_int64.asnumpy() == Tensor([5, 7, 9], mstype.int64).asnumpy()) + assert np.all(out_uint8.asnumpy() == Tensor([5, 7, 9], mstype.uint8).asnumpy()) + assert np.all(out_uint16.asnumpy() == Tensor([5, 7, 9], mstype.uint16).asnumpy()) + assert np.all(out_uint32.asnumpy() == Tensor([5, 7, 9], mstype.uint32).asnumpy()) + assert np.all(out_uint64.asnumpy() == Tensor([5, 7, 9], mstype.uint64).asnumpy()) + assert np.all(out_complex64.asnumpy() == Tensor(np.asarray(np.complex(4 + 0.8j)), mstype.complex64).asnumpy()) + assert np.all(out_complex128.asnumpy() == Tensor(np.asarray(np.complex(4 + 0.8j)), mstype.complex128).asnumpy())