!12368 Add type support to Split gpu op
From: @peilin-wang Reviewed-by: @robingrosman,@tom__chen Signed-off-by: @robingrosman
This commit is contained in:
commit
191b3f0c8c
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,17 +18,26 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SplitGpuFwdKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SplitGpuFwdKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SplitGpuFwdKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(Split,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SplitGpuFwdKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SplitGpuFwdKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
SplitGpuFwdKernel, uint32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Split,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SplitGpuFwdKernel, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Split,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SplitGpuFwdKernel, bool)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -39,15 +39,24 @@ void SplitKernel(const size_t size, const int axis_step, const int all_size_befo
|
|||
return;
|
||||
}
|
||||
|
||||
template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const half* input, half** outputs,
|
||||
cudaStream_t cuda_stream);
|
||||
template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const float* input, float** outputs,
|
||||
cudaStream_t cuda_stream);
|
||||
template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const double* input, double** outputs,
|
||||
cudaStream_t cuda_stream);
|
||||
template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const int* input, int** outputs,
|
||||
cudaStream_t cuda_stream);
|
||||
template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const half* input, half** outputs,
|
||||
cudaStream_t cuda_stream);
|
||||
template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const uint32_t* input, uint32_t** outputs,
|
||||
cudaStream_t cuda_stream);
|
||||
template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const int64_t* input, int64_t** outputs,
|
||||
cudaStream_t cuda_stream);
|
||||
template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const bool* input, bool** outputs,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -46,13 +46,10 @@ class NetDynamic(nn.Cell):
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_split():
|
||||
def split_basic(nptype):
|
||||
x = np.array([[[1, -1, 1], [2, -2, 2]],
|
||||
[[3, -3, 3], [4, -4, 4]],
|
||||
[[5, -5, 5], [6, -6, 6]]]).astype(np.float32)
|
||||
[[5, -5, 5], [6, -6, 6]]]).astype(nptype)
|
||||
|
||||
split_op = Net(0, 3)
|
||||
outputs = split_op(Tensor(x))
|
||||
|
@ -60,6 +57,55 @@ def test_split():
|
|||
assert (out.asnumpy() == x[i]).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_split_basic_float16():
|
||||
split_basic(np.float16)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_split_basic_float32():
|
||||
split_basic(np.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_split_basic_float64():
|
||||
split_basic(np.float64)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_split_basic_int32():
|
||||
split_basic(np.int32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_split_basic_uint32():
|
||||
split_basic(np.uint32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_split_basic_int64():
|
||||
split_basic(np.int64)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_split_basic_bool():
|
||||
split_basic(np.bool)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue