ops.addn supports more types

This commit is contained in:
huangbingjian 2022-10-08 09:39:24 +08:00
parent 8388e614f8
commit 11ef11c2f6
8 changed files with 132 additions and 13 deletions

View File

@ -8,7 +8,7 @@ mindspore.ops.addn
所有输入Tensor必须具有相同的shape。
参数:
- **x** (Union(tuple[Tensor], list[Tensor])) - Tensor组成的tuble或list,类型为 `bool_ <https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.dtype.html#mindspore.dtype>`_`number <https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.dtype.html#mindspore.dtype>`_
- **x** (Union(tuple[Tensor], list[Tensor])) - Tensor组成的tuble或list。
返回:
Tensor`x` 的每个Tensor具有相同的shape和数据类型。

View File

@ -17,6 +17,7 @@
#include "plugin/device/cpu/kernel/mkldnn/addn_cpu_kernel.h"
#include <algorithm>
#include <utility>
#include <complex>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/device/cpu/kernel/nnacl/fp32/add_fp32.h"
#include "plugin/device/cpu/kernel/nnacl/errorcode.h"
@ -26,6 +27,9 @@
namespace mindspore {
namespace kernel {
namespace {
using complex64 = std::complex<float>;
using complex128 = std::complex<double>;
constexpr size_t kAddNInputsMinNum = 2;
constexpr size_t kAddNOutputsNum = 1;
@ -164,7 +168,11 @@ std::vector<std::pair<KernelAttr, AddNCpuKernelMod::AddNFunc>> AddNCpuKernelMod:
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&AddNCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&AddNCpuKernelMod::LaunchKernel<double>}};
&AddNCpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&AddNCpuKernelMod::LaunchKernel<complex64>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&AddNCpuKernelMod::LaunchKernel<complex128>}};
std::vector<KernelAttr> AddNCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;

View File

@ -27,11 +27,35 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
AddNFwdGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(AddN,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
AddNFwdGpuKernelMod, int8_t)
MS_REG_GPU_KERNEL_ONE(AddN,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
AddNFwdGpuKernelMod, int16_t)
MS_REG_GPU_KERNEL_ONE(AddN,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
AddNFwdGpuKernelMod, int)
AddNFwdGpuKernelMod, int32_t)
MS_REG_GPU_KERNEL_ONE(AddN,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
AddNFwdGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(AddN,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
AddNFwdGpuKernelMod, uint8_t)
MS_REG_GPU_KERNEL_ONE(
AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
AddNFwdGpuKernelMod, uint16_t)
MS_REG_GPU_KERNEL_ONE(
AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
AddNFwdGpuKernelMod, uint32_t)
MS_REG_GPU_KERNEL_ONE(
AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
AddNFwdGpuKernelMod, uint64_t)
MS_REG_GPU_KERNEL_ONE(
AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
AddNFwdGpuKernelMod, Complex<float>)
MS_REG_GPU_KERNEL_ONE(
AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
AddNFwdGpuKernelMod, Complex<double>)
} // namespace kernel
} // namespace mindspore

View File

@ -24,9 +24,13 @@
#include "plugin/device/gpu/kernel/math/broadcast_gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/slice_impl.cuh"
#include "plugin/device/gpu/kernel/kernel_constants.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
namespace mindspore {
namespace kernel {
template <typename T>
using Complex = mindspore::utils::Complex<T>;
template <typename T>
class AddNFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
public:
@ -50,8 +54,13 @@ class AddNFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
FillDeviceArray(outputs[0]->size / sizeof(T), work_addr, 0.0f, reinterpret_cast<cudaStream_t>(stream_ptr));
for (size_t i = 0; i < num_input_; i++) {
T *input_addr = GetDeviceAddress<T>(inputs, i);
ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, work_addr, work_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
if constexpr (std::is_same<T, Complex<float>>::value || std::is_same<T, Complex<double>>::value) {
ElewiseComplexArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, work_addr, work_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, work_addr, work_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
}
if (work_addr != output_addr) {
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
@ -61,6 +70,7 @@ class AddNFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;

View File

@ -48,7 +48,7 @@ bool AddNDynShapeJoin(ShapeVector *shape1, const ShapeVector *shape2) {
if ((*shape1)[i] == (*shape2)[i]) {
continue;
}
// If shape1 is dynamic, use shape of shape2.If shape2 is dynamic, keep shape1.
// If shape1 is dynamic, use shape of shape2. If shape2 is dynamic, keep shape1.
if ((*shape1)[i] == abstract::Shape::kShapeDimAny) {
(*shape1)[i] = (*shape2)[i];
continue;
@ -115,8 +115,7 @@ TypePtr AddNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
std::string element_i = "element_" + std::to_string(i);
(void)types.emplace(element_i, elements[i]->BuildType());
}
std::set<TypePtr> valid_types = common_valid_types;
valid_types.insert(kBool);
std::set<TypePtr> valid_types = common_valid_types_with_complex_and_bool;
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
return elements[0]->BuildType();
}

View File

@ -176,9 +176,7 @@ def addn(x):
All input tensors must have the same shape.
Args:
x (Union(tuple[Tensor], list[Tensor])): A tuple or list composed of Tensor, the data type is
`bool_ <https://www.mindspore.cn/docs/api/en/master/api_python/mindspore.html#mindspore.dtype>`_ or
`number <https://www.mindspore.cn/docs/api/en/master/api_python/mindspore.html#mindspore.dtype>`_ .
x (Union(tuple[Tensor], list[Tensor])): A tuple or list composed of Tensor.
Returns:
Tensor, has the same shape and dtype as each Tensor of `x`.

View File

@ -18,7 +18,7 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import Tensor, ops
from mindspore.ops import operations as P
from mindspore import dtype as mstype
@ -118,3 +118,43 @@ def test_four_tensors_add():
expect_result = (x + y + m + n).astype(dtype)
assert output.asnumpy().dtype == expect_result.dtype
assert np.array_equal(output.asnumpy(), expect_result)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_addn_support_type():
"""
Feature: test ops.addn.
Description: test ops.addn with different types.
Expectation: the result match with expected result.
"""
out_fp16 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float16), Tensor([4.5, 5.5, 6.5], mstype.float16)])
out_fp32 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float32), Tensor([4.5, 5.5, 6.5], mstype.float32)])
out_fp64 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float64), Tensor([4.5, 5.5, 6.5], mstype.float64)])
out_int8 = ops.addn([Tensor([1, 2, 3], mstype.int8), Tensor([4, 5, 6], mstype.int8)])
out_int16 = ops.addn([Tensor([1, 2, 3], mstype.int16), Tensor([4, 5, 6], mstype.int16)])
out_int32 = ops.addn([Tensor([1, 2, 3], mstype.int32), Tensor([4, 5, 6], mstype.int32)])
out_int64 = ops.addn([Tensor([1, 2, 3], mstype.int64), Tensor([4, 5, 6], mstype.int64)])
out_uint8 = ops.addn([Tensor([1, 2, 3], mstype.uint8), Tensor([4, 5, 6], mstype.uint8)])
out_uint16 = ops.addn([Tensor([1, 2, 3], mstype.uint16), Tensor([4, 5, 6], mstype.uint16)])
out_uint32 = ops.addn([Tensor([1, 2, 3], mstype.uint32), Tensor([4, 5, 6], mstype.uint32)])
out_uint64 = ops.addn([Tensor([1, 2, 3], mstype.uint64), Tensor([4, 5, 6], mstype.uint64)])
out_complex64 = ops.addn([Tensor(np.asarray(np.complex(1.5 + 0.4j)), mstype.complex64),
Tensor(np.asarray(np.complex(2.5 + 0.4j)), mstype.complex64)])
out_complex128 = ops.addn([Tensor(np.asarray(np.complex(1.5 + 0.4j)), mstype.complex128),
Tensor(np.asarray(np.complex(2.5 + 0.4j)), mstype.complex128)])
assert np.allclose(out_fp16.asnumpy(), Tensor([6., 8., 10.], mstype.float16).asnumpy(), rtol=1e-5, atol=1e-5)
assert np.allclose(out_fp32.asnumpy(), Tensor([6., 8., 10.], mstype.float32).asnumpy(), rtol=1e-5, atol=1e-5)
assert np.allclose(out_fp64.asnumpy(), Tensor([6., 8., 10.], mstype.float64).asnumpy(), rtol=1e-5, atol=1e-5)
assert np.all(out_int8.asnumpy() == Tensor([5, 7, 9], mstype.int8).asnumpy())
assert np.all(out_int16.asnumpy() == Tensor([5, 7, 9], mstype.int16).asnumpy())
assert np.all(out_int32.asnumpy() == Tensor([5, 7, 9], mstype.int32).asnumpy())
assert np.all(out_int64.asnumpy() == Tensor([5, 7, 9], mstype.int64).asnumpy())
assert np.all(out_uint8.asnumpy() == Tensor([5, 7, 9], mstype.uint8).asnumpy())
assert np.all(out_uint16.asnumpy() == Tensor([5, 7, 9], mstype.uint16).asnumpy())
assert np.all(out_uint32.asnumpy() == Tensor([5, 7, 9], mstype.uint32).asnumpy())
assert np.all(out_uint64.asnumpy() == Tensor([5, 7, 9], mstype.uint64).asnumpy())
assert np.all(out_complex64.asnumpy() == Tensor(np.asarray(np.complex(4 + 0.8j)), mstype.complex64).asnumpy())
assert np.all(out_complex128.asnumpy() == Tensor(np.asarray(np.complex(4 + 0.8j)), mstype.complex128).asnumpy())

View File

@ -18,7 +18,7 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import Tensor, ops
from mindspore.common.api import ms_function
from mindspore.ops import operations as P
from mindspore import dtype as mstype
@ -169,3 +169,43 @@ def test_net_int64():
[84., 87., 90., 93.],
[96., 99., 102., 105.]]]]).astype(np.int64)
assert (output.asnumpy() == expect_result).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_addn_support_type():
"""
Feature: test ops.addn.
Description: test ops.addn with different types.
Expectation: the result match with expected result.
"""
out_fp16 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float16), Tensor([4.5, 5.5, 6.5], mstype.float16)])
out_fp32 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float32), Tensor([4.5, 5.5, 6.5], mstype.float32)])
out_fp64 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float64), Tensor([4.5, 5.5, 6.5], mstype.float64)])
out_int8 = ops.addn([Tensor([1, 2, 3], mstype.int8), Tensor([4, 5, 6], mstype.int8)])
out_int16 = ops.addn([Tensor([1, 2, 3], mstype.int16), Tensor([4, 5, 6], mstype.int16)])
out_int32 = ops.addn([Tensor([1, 2, 3], mstype.int32), Tensor([4, 5, 6], mstype.int32)])
out_int64 = ops.addn([Tensor([1, 2, 3], mstype.int64), Tensor([4, 5, 6], mstype.int64)])
out_uint8 = ops.addn([Tensor([1, 2, 3], mstype.uint8), Tensor([4, 5, 6], mstype.uint8)])
out_uint16 = ops.addn([Tensor([1, 2, 3], mstype.uint16), Tensor([4, 5, 6], mstype.uint16)])
out_uint32 = ops.addn([Tensor([1, 2, 3], mstype.uint32), Tensor([4, 5, 6], mstype.uint32)])
out_uint64 = ops.addn([Tensor([1, 2, 3], mstype.uint64), Tensor([4, 5, 6], mstype.uint64)])
out_complex64 = ops.addn([Tensor(np.asarray(np.complex(1.5 + 0.4j)), mstype.complex64),
Tensor(np.asarray(np.complex(2.5 + 0.4j)), mstype.complex64)])
out_complex128 = ops.addn([Tensor(np.asarray(np.complex(1.5 + 0.4j)), mstype.complex128),
Tensor(np.asarray(np.complex(2.5 + 0.4j)), mstype.complex128)])
assert np.allclose(out_fp16.asnumpy(), Tensor([6., 8., 10.], mstype.float16).asnumpy(), rtol=1e-5, atol=1e-5)
assert np.allclose(out_fp32.asnumpy(), Tensor([6., 8., 10.], mstype.float32).asnumpy(), rtol=1e-5, atol=1e-5)
assert np.allclose(out_fp64.asnumpy(), Tensor([6., 8., 10.], mstype.float64).asnumpy(), rtol=1e-5, atol=1e-5)
assert np.all(out_int8.asnumpy() == Tensor([5, 7, 9], mstype.int8).asnumpy())
assert np.all(out_int16.asnumpy() == Tensor([5, 7, 9], mstype.int16).asnumpy())
assert np.all(out_int32.asnumpy() == Tensor([5, 7, 9], mstype.int32).asnumpy())
assert np.all(out_int64.asnumpy() == Tensor([5, 7, 9], mstype.int64).asnumpy())
assert np.all(out_uint8.asnumpy() == Tensor([5, 7, 9], mstype.uint8).asnumpy())
assert np.all(out_uint16.asnumpy() == Tensor([5, 7, 9], mstype.uint16).asnumpy())
assert np.all(out_uint32.asnumpy() == Tensor([5, 7, 9], mstype.uint32).asnumpy())
assert np.all(out_uint64.asnumpy() == Tensor([5, 7, 9], mstype.uint64).asnumpy())
assert np.all(out_complex64.asnumpy() == Tensor(np.asarray(np.complex(4 + 0.8j)), mstype.complex64).asnumpy())
assert np.all(out_complex128.asnumpy() == Tensor(np.asarray(np.complex(4 + 0.8j)), mstype.complex128).asnumpy())