From 6fad165e431bb7cf3098b018a52ee2602d96b635 Mon Sep 17 00:00:00 2001 From: TFBunny Date: Thu, 11 Feb 2021 12:01:04 -0500 Subject: [PATCH] add float64 support to gather gpu --- .../gpu/arrays/gatherv2_gpu_kernel.cc | 13 +++++- .../gpu/arrays/gatherv2_gpu_kernel.h | 8 ++-- .../kernel_compiler/gpu/cuda_impl/gatherv2.cu | 6 ++- .../gpu/cuda_impl/gatherv2.cuh | 8 ++-- tests/st/ops/gpu/test_gatherV2_op.py | 46 ++++++++++++++++++- 5 files changed, 70 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc index 7a17c34ced8..daad7939e63 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,17 @@ namespace mindspore { namespace kernel { + +MS_REG_GPU_KERNEL_TWO( + Gather, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), + GatherV2GpuFwdKernel, double, int) + +MS_REG_GPU_KERNEL_TWO( + Gather, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), + GatherV2GpuFwdKernel, double, int64_t) + MS_REG_GPU_KERNEL_TWO( Gather, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h index eb4e86d4247..8cd42aa1986 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERV2_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERV2_GPU_KERNEL_H_ #include #include @@ -149,4 +149,4 @@ class GatherV2GpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERV2_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu index 9ef45329194..a2469dd0a59 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -55,3 +55,7 @@ template void GatherV2(half *input, int *indices, half *output, size_ size_t output_dim2, size_t input_dim1, cudaStream_t stream); template void GatherV2(half *input, int64_t *indices, half *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template void GatherV2(double *input, int *indices, double *output, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template void GatherV2(double *input, int64_t *indices, double *output, size_t output_dim0, + size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh index 9af9fd1b71f..db6725a4523 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,10 +14,10 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GATHER_V2_CU_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GATHER_V2_CU_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERV2_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERV2_CUH_ template void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); -#endif +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERV2_CUH_ diff --git a/tests/st/ops/gpu/test_gatherV2_op.py b/tests/st/ops/gpu/test_gatherV2_op.py index 747586f2958..dc9a758bcb1 100644 --- a/tests/st/ops/gpu/test_gatherV2_op.py +++ b/tests/st/ops/gpu/test_gatherV2_op.py @@ -1,4 +1,4 @@ -# Copyright 2019-2020 Huawei Technologies Co., Ltd +# Copyright 2019-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1134,3 +1134,47 @@ def test_gatherV2_dyn_b(): diff = output.asnumpy() - expect assert np.all(diff < error) assert np.all(-diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather1_float64(): + x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float64).reshape(2, 3, 4, 5)) + indices = Tensor(np.array([1, 3, 4], dtype='i4')) + expect = np.array([[[[1., 3., 4.], + [6., 8., 9.], + [11., 13., 14.], + [16., 18., 19.]], + + [[21., 23., 24.], + [26., 28., 29.], + [31., 33., 34.], + [36., 38., 39.]], + + [[41., 43., 44.], + [46., 48., 49.], + [51., 53., 54.], + [56., 58., 59.]]], + + [[[61., 63., 64.], + [66., 68., 69.], + [71., 73., 74.], + [76., 78., 79.]], + + [[81., 83., 84.], + [86., 88., 89.], + [91., 93., 94.], + [96., 98., 99.]], + + [[101., 103., 104.], + [106., 108., 109.], + [111., 113., 114.], + [116., 118., 119.]]]]).astype(np.float64) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = GatherNet1() + output = gather(x, indices) + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error)