ops.addn supports more types
This commit is contained in:
parent
8388e614f8
commit
11ef11c2f6
|
@ -8,7 +8,7 @@ mindspore.ops.addn
|
|||
所有输入Tensor必须具有相同的shape。
|
||||
|
||||
参数:
|
||||
- **x** (Union(tuple[Tensor], list[Tensor])) - Tensor组成的tuble或list,类型为 `bool_ <https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.dtype.html#mindspore.dtype>`_ 或 `number <https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.dtype.html#mindspore.dtype>`_ 。
|
||||
- **x** (Union(tuple[Tensor], list[Tensor])) - Tensor组成的tuble或list。
|
||||
|
||||
返回:
|
||||
Tensor,与 `x` 的每个Tensor具有相同的shape和数据类型。
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "plugin/device/cpu/kernel/mkldnn/addn_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <complex>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "plugin/device/cpu/kernel/nnacl/fp32/add_fp32.h"
|
||||
#include "plugin/device/cpu/kernel/nnacl/errorcode.h"
|
||||
|
@ -26,6 +27,9 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
using complex64 = std::complex<float>;
|
||||
using complex128 = std::complex<double>;
|
||||
|
||||
constexpr size_t kAddNInputsMinNum = 2;
|
||||
constexpr size_t kAddNOutputsNum = 1;
|
||||
|
||||
|
@ -164,7 +168,11 @@ std::vector<std::pair<KernelAttr, AddNCpuKernelMod::AddNFunc>> AddNCpuKernelMod:
|
|||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&AddNCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&AddNCpuKernelMod::LaunchKernel<double>}};
|
||||
&AddNCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
&AddNCpuKernelMod::LaunchKernel<complex64>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
&AddNCpuKernelMod::LaunchKernel<complex128>}};
|
||||
|
||||
std::vector<KernelAttr> AddNCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
|
|
|
@ -27,11 +27,35 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
MS_REG_GPU_KERNEL_ONE(
|
||||
AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
AddNFwdGpuKernelMod, half)
|
||||
MS_REG_GPU_KERNEL_ONE(AddN,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
AddNFwdGpuKernelMod, int8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(AddN,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
AddNFwdGpuKernelMod, int16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(AddN,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
AddNFwdGpuKernelMod, int)
|
||||
AddNFwdGpuKernelMod, int32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(AddN,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
AddNFwdGpuKernelMod, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(AddN,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
AddNFwdGpuKernelMod, uint8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
AddNFwdGpuKernelMod, uint16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
AddNFwdGpuKernelMod, uint32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
AddNFwdGpuKernelMod, uint64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
AddNFwdGpuKernelMod, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
AddNFwdGpuKernelMod, Complex<double>)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,9 +24,13 @@
|
|||
#include "plugin/device/gpu/kernel/math/broadcast_gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/slice_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/kernel_constants.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
template <typename T>
|
||||
class AddNFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
public:
|
||||
|
@ -50,8 +54,13 @@ class AddNFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
FillDeviceArray(outputs[0]->size / sizeof(T), work_addr, 0.0f, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
for (size_t i = 0; i < num_input_; i++) {
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, i);
|
||||
ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, work_addr, work_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
if constexpr (std::is_same<T, Complex<float>>::value || std::is_same<T, Complex<double>>::value) {
|
||||
ElewiseComplexArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, work_addr, work_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, work_addr, work_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
}
|
||||
if (work_addr != output_addr) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
|
@ -61,6 +70,7 @@ class AddNFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
|
|
|
@ -48,7 +48,7 @@ bool AddNDynShapeJoin(ShapeVector *shape1, const ShapeVector *shape2) {
|
|||
if ((*shape1)[i] == (*shape2)[i]) {
|
||||
continue;
|
||||
}
|
||||
// If shape1 is dynamic, use shape of shape2.If shape2 is dynamic, keep shape1.
|
||||
// If shape1 is dynamic, use shape of shape2. If shape2 is dynamic, keep shape1.
|
||||
if ((*shape1)[i] == abstract::Shape::kShapeDimAny) {
|
||||
(*shape1)[i] = (*shape2)[i];
|
||||
continue;
|
||||
|
@ -115,8 +115,7 @@ TypePtr AddNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
|
|||
std::string element_i = "element_" + std::to_string(i);
|
||||
(void)types.emplace(element_i, elements[i]->BuildType());
|
||||
}
|
||||
std::set<TypePtr> valid_types = common_valid_types;
|
||||
valid_types.insert(kBool);
|
||||
std::set<TypePtr> valid_types = common_valid_types_with_complex_and_bool;
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
|
||||
return elements[0]->BuildType();
|
||||
}
|
||||
|
|
|
@ -176,9 +176,7 @@ def addn(x):
|
|||
All input tensors must have the same shape.
|
||||
|
||||
Args:
|
||||
x (Union(tuple[Tensor], list[Tensor])): A tuple or list composed of Tensor, the data type is
|
||||
`bool_ <https://www.mindspore.cn/docs/api/en/master/api_python/mindspore.html#mindspore.dtype>`_ or
|
||||
`number <https://www.mindspore.cn/docs/api/en/master/api_python/mindspore.html#mindspore.dtype>`_ .
|
||||
x (Union(tuple[Tensor], list[Tensor])): A tuple or list composed of Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same shape and dtype as each Tensor of `x`.
|
||||
|
|
|
@ -18,7 +18,7 @@ import pytest
|
|||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import Tensor, ops
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
|
@ -118,3 +118,43 @@ def test_four_tensors_add():
|
|||
expect_result = (x + y + m + n).astype(dtype)
|
||||
assert output.asnumpy().dtype == expect_result.dtype
|
||||
assert np.array_equal(output.asnumpy(), expect_result)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_addn_support_type():
|
||||
"""
|
||||
Feature: test ops.addn.
|
||||
Description: test ops.addn with different types.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
out_fp16 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float16), Tensor([4.5, 5.5, 6.5], mstype.float16)])
|
||||
out_fp32 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float32), Tensor([4.5, 5.5, 6.5], mstype.float32)])
|
||||
out_fp64 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float64), Tensor([4.5, 5.5, 6.5], mstype.float64)])
|
||||
out_int8 = ops.addn([Tensor([1, 2, 3], mstype.int8), Tensor([4, 5, 6], mstype.int8)])
|
||||
out_int16 = ops.addn([Tensor([1, 2, 3], mstype.int16), Tensor([4, 5, 6], mstype.int16)])
|
||||
out_int32 = ops.addn([Tensor([1, 2, 3], mstype.int32), Tensor([4, 5, 6], mstype.int32)])
|
||||
out_int64 = ops.addn([Tensor([1, 2, 3], mstype.int64), Tensor([4, 5, 6], mstype.int64)])
|
||||
out_uint8 = ops.addn([Tensor([1, 2, 3], mstype.uint8), Tensor([4, 5, 6], mstype.uint8)])
|
||||
out_uint16 = ops.addn([Tensor([1, 2, 3], mstype.uint16), Tensor([4, 5, 6], mstype.uint16)])
|
||||
out_uint32 = ops.addn([Tensor([1, 2, 3], mstype.uint32), Tensor([4, 5, 6], mstype.uint32)])
|
||||
out_uint64 = ops.addn([Tensor([1, 2, 3], mstype.uint64), Tensor([4, 5, 6], mstype.uint64)])
|
||||
out_complex64 = ops.addn([Tensor(np.asarray(np.complex(1.5 + 0.4j)), mstype.complex64),
|
||||
Tensor(np.asarray(np.complex(2.5 + 0.4j)), mstype.complex64)])
|
||||
out_complex128 = ops.addn([Tensor(np.asarray(np.complex(1.5 + 0.4j)), mstype.complex128),
|
||||
Tensor(np.asarray(np.complex(2.5 + 0.4j)), mstype.complex128)])
|
||||
|
||||
assert np.allclose(out_fp16.asnumpy(), Tensor([6., 8., 10.], mstype.float16).asnumpy(), rtol=1e-5, atol=1e-5)
|
||||
assert np.allclose(out_fp32.asnumpy(), Tensor([6., 8., 10.], mstype.float32).asnumpy(), rtol=1e-5, atol=1e-5)
|
||||
assert np.allclose(out_fp64.asnumpy(), Tensor([6., 8., 10.], mstype.float64).asnumpy(), rtol=1e-5, atol=1e-5)
|
||||
assert np.all(out_int8.asnumpy() == Tensor([5, 7, 9], mstype.int8).asnumpy())
|
||||
assert np.all(out_int16.asnumpy() == Tensor([5, 7, 9], mstype.int16).asnumpy())
|
||||
assert np.all(out_int32.asnumpy() == Tensor([5, 7, 9], mstype.int32).asnumpy())
|
||||
assert np.all(out_int64.asnumpy() == Tensor([5, 7, 9], mstype.int64).asnumpy())
|
||||
assert np.all(out_uint8.asnumpy() == Tensor([5, 7, 9], mstype.uint8).asnumpy())
|
||||
assert np.all(out_uint16.asnumpy() == Tensor([5, 7, 9], mstype.uint16).asnumpy())
|
||||
assert np.all(out_uint32.asnumpy() == Tensor([5, 7, 9], mstype.uint32).asnumpy())
|
||||
assert np.all(out_uint64.asnumpy() == Tensor([5, 7, 9], mstype.uint64).asnumpy())
|
||||
assert np.all(out_complex64.asnumpy() == Tensor(np.asarray(np.complex(4 + 0.8j)), mstype.complex64).asnumpy())
|
||||
assert np.all(out_complex128.asnumpy() == Tensor(np.asarray(np.complex(4 + 0.8j)), mstype.complex128).asnumpy())
|
||||
|
|
|
@ -18,7 +18,7 @@ import pytest
|
|||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import Tensor, ops
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import dtype as mstype
|
||||
|
@ -169,3 +169,43 @@ def test_net_int64():
|
|||
[84., 87., 90., 93.],
|
||||
[96., 99., 102., 105.]]]]).astype(np.int64)
|
||||
assert (output.asnumpy() == expect_result).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_addn_support_type():
|
||||
"""
|
||||
Feature: test ops.addn.
|
||||
Description: test ops.addn with different types.
|
||||
Expectation: the result match with expected result.
|
||||
"""
|
||||
out_fp16 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float16), Tensor([4.5, 5.5, 6.5], mstype.float16)])
|
||||
out_fp32 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float32), Tensor([4.5, 5.5, 6.5], mstype.float32)])
|
||||
out_fp64 = ops.addn([Tensor([1.5, 2.5, 3.5], mstype.float64), Tensor([4.5, 5.5, 6.5], mstype.float64)])
|
||||
out_int8 = ops.addn([Tensor([1, 2, 3], mstype.int8), Tensor([4, 5, 6], mstype.int8)])
|
||||
out_int16 = ops.addn([Tensor([1, 2, 3], mstype.int16), Tensor([4, 5, 6], mstype.int16)])
|
||||
out_int32 = ops.addn([Tensor([1, 2, 3], mstype.int32), Tensor([4, 5, 6], mstype.int32)])
|
||||
out_int64 = ops.addn([Tensor([1, 2, 3], mstype.int64), Tensor([4, 5, 6], mstype.int64)])
|
||||
out_uint8 = ops.addn([Tensor([1, 2, 3], mstype.uint8), Tensor([4, 5, 6], mstype.uint8)])
|
||||
out_uint16 = ops.addn([Tensor([1, 2, 3], mstype.uint16), Tensor([4, 5, 6], mstype.uint16)])
|
||||
out_uint32 = ops.addn([Tensor([1, 2, 3], mstype.uint32), Tensor([4, 5, 6], mstype.uint32)])
|
||||
out_uint64 = ops.addn([Tensor([1, 2, 3], mstype.uint64), Tensor([4, 5, 6], mstype.uint64)])
|
||||
out_complex64 = ops.addn([Tensor(np.asarray(np.complex(1.5 + 0.4j)), mstype.complex64),
|
||||
Tensor(np.asarray(np.complex(2.5 + 0.4j)), mstype.complex64)])
|
||||
out_complex128 = ops.addn([Tensor(np.asarray(np.complex(1.5 + 0.4j)), mstype.complex128),
|
||||
Tensor(np.asarray(np.complex(2.5 + 0.4j)), mstype.complex128)])
|
||||
|
||||
assert np.allclose(out_fp16.asnumpy(), Tensor([6., 8., 10.], mstype.float16).asnumpy(), rtol=1e-5, atol=1e-5)
|
||||
assert np.allclose(out_fp32.asnumpy(), Tensor([6., 8., 10.], mstype.float32).asnumpy(), rtol=1e-5, atol=1e-5)
|
||||
assert np.allclose(out_fp64.asnumpy(), Tensor([6., 8., 10.], mstype.float64).asnumpy(), rtol=1e-5, atol=1e-5)
|
||||
assert np.all(out_int8.asnumpy() == Tensor([5, 7, 9], mstype.int8).asnumpy())
|
||||
assert np.all(out_int16.asnumpy() == Tensor([5, 7, 9], mstype.int16).asnumpy())
|
||||
assert np.all(out_int32.asnumpy() == Tensor([5, 7, 9], mstype.int32).asnumpy())
|
||||
assert np.all(out_int64.asnumpy() == Tensor([5, 7, 9], mstype.int64).asnumpy())
|
||||
assert np.all(out_uint8.asnumpy() == Tensor([5, 7, 9], mstype.uint8).asnumpy())
|
||||
assert np.all(out_uint16.asnumpy() == Tensor([5, 7, 9], mstype.uint16).asnumpy())
|
||||
assert np.all(out_uint32.asnumpy() == Tensor([5, 7, 9], mstype.uint32).asnumpy())
|
||||
assert np.all(out_uint64.asnumpy() == Tensor([5, 7, 9], mstype.uint64).asnumpy())
|
||||
assert np.all(out_complex64.asnumpy() == Tensor(np.asarray(np.complex(4 + 0.8j)), mstype.complex64).asnumpy())
|
||||
assert np.all(out_complex128.asnumpy() == Tensor(np.asarray(np.complex(4 + 0.8j)), mstype.complex128).asnumpy())
|
||||
|
|
Loading…
Reference in New Issue