forked from mindspore-Ecosystem/mindspore
!12352 Add float64 support to Gather GPU
From: @TFbunny Reviewed-by: @tom__chen,@robingrosman Signed-off-by: @robingrosman
This commit is contained in:
commit
8d936a6589
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,6 +18,17 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherV2GpuFwdKernel, double, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
GatherV2GpuFwdKernel, double, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Gather,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERV2_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERV2_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
@ -149,4 +149,4 @@ class GatherV2GpuFwdKernel : public GpuKernel {
|
|||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERV2_GPU_KERNEL_H_
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -55,3 +55,7 @@ template void GatherV2<half, int>(half *input, int *indices, half *output, size_
|
|||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template void GatherV2<half, int64_t>(half *input, int64_t *indices, half *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template void GatherV2<double, int>(double *input, int *indices, double *output, size_t output_dim0, size_t output_dim1,
|
||||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template void GatherV2<double, int64_t>(double *input, int64_t *indices, double *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,10 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GATHER_V2_CU_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GATHER_V2_CU_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERV2_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERV2_CUH_
|
||||
template <typename T, typename S>
|
||||
void GatherV2(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2,
|
||||
size_t input_dim1, cudaStream_t stream);
|
||||
|
||||
#endif
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERV2_CUH_
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1134,3 +1134,47 @@ def test_gatherV2_dyn_b():
|
|||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather1_float64():
|
||||
x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float64).reshape(2, 3, 4, 5))
|
||||
indices = Tensor(np.array([1, 3, 4], dtype='i4'))
|
||||
expect = np.array([[[[1., 3., 4.],
|
||||
[6., 8., 9.],
|
||||
[11., 13., 14.],
|
||||
[16., 18., 19.]],
|
||||
|
||||
[[21., 23., 24.],
|
||||
[26., 28., 29.],
|
||||
[31., 33., 34.],
|
||||
[36., 38., 39.]],
|
||||
|
||||
[[41., 43., 44.],
|
||||
[46., 48., 49.],
|
||||
[51., 53., 54.],
|
||||
[56., 58., 59.]]],
|
||||
|
||||
[[[61., 63., 64.],
|
||||
[66., 68., 69.],
|
||||
[71., 73., 74.],
|
||||
[76., 78., 79.]],
|
||||
|
||||
[[81., 83., 84.],
|
||||
[86., 88., 89.],
|
||||
[91., 93., 94.],
|
||||
[96., 98., 99.]],
|
||||
|
||||
[[101., 103., 104.],
|
||||
[106., 108., 109.],
|
||||
[111., 113., 114.],
|
||||
[116., 118., 119.]]]]).astype(np.float64)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
gather = GatherNet1()
|
||||
output = gather(x, indices)
|
||||
error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
|
Loading…
Reference in New Issue