!12352 Add float64 support to Gather GPU

From: @TFbunny
Reviewed-by: @tom__chen,@robingrosman
Signed-off-by: @robingrosman
This commit is contained in:
mindspore-ci-bot 2021-02-13 00:03:14 +08:00 committed by Gitee
commit 8d936a6589
5 changed files with 70 additions and 11 deletions

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2020 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,6 +18,17 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_TWO(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
GatherV2GpuFwdKernel, double, int)
MS_REG_GPU_KERNEL_TWO(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
GatherV2GpuFwdKernel, double, int64_t)
MS_REG_GPU_KERNEL_TWO( MS_REG_GPU_KERNEL_TWO(
Gather, Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2020 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERV2_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERV2_GPU_KERNEL_H_
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
@ -149,4 +149,4 @@ class GatherV2GpuFwdKernel : public GpuKernel {
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERV2_GPU_KERNEL_H_

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2020 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -55,3 +55,7 @@ template void GatherV2<half, int>(half *input, int *indices, half *output, size_
size_t output_dim2, size_t input_dim1, cudaStream_t stream); size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template void GatherV2<half, int64_t>(half *input, int64_t *indices, half *output, size_t output_dim0, template void GatherV2<half, int64_t>(half *input, int64_t *indices, half *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template void GatherV2<double, int>(double *input, int *indices, double *output, size_t output_dim0, size_t output_dim1,
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template void GatherV2<double, int64_t>(double *input, int64_t *indices, double *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2020 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,10 +14,10 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GATHER_V2_CU_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERV2_CUH_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GATHER_V2_CU_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERV2_CUH_
template <typename T, typename S> template <typename T, typename S>
void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2,
size_t input_dim1, cudaStream_t stream); size_t input_dim1, cudaStream_t stream);
#endif #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERV2_CUH_

View File

@ -1,4 +1,4 @@
# Copyright 2019-2020 Huawei Technologies Co., Ltd # Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -1134,3 +1134,47 @@ def test_gatherV2_dyn_b():
diff = output.asnumpy() - expect diff = output.asnumpy() - expect
assert np.all(diff < error) assert np.all(diff < error)
assert np.all(-diff < error) assert np.all(-diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_gather1_float64():
x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float64).reshape(2, 3, 4, 5))
indices = Tensor(np.array([1, 3, 4], dtype='i4'))
expect = np.array([[[[1., 3., 4.],
[6., 8., 9.],
[11., 13., 14.],
[16., 18., 19.]],
[[21., 23., 24.],
[26., 28., 29.],
[31., 33., 34.],
[36., 38., 39.]],
[[41., 43., 44.],
[46., 48., 49.],
[51., 53., 54.],
[56., 58., 59.]]],
[[[61., 63., 64.],
[66., 68., 69.],
[71., 73., 74.],
[76., 78., 79.]],
[[81., 83., 84.],
[86., 88., 89.],
[91., 93., 94.],
[96., 98., 99.]],
[[101., 103., 104.],
[106., 108., 109.],
[111., 113., 114.],
[116., 118., 119.]]]]).astype(np.float64)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
gather = GatherNet1()
output = gather(x, indices)
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)