forked from mindspore-Ecosystem/mindspore
!9747 add int64 support for indices of GPU gatherNd
From: @TFbunny Reviewed-by: @robingrosman,@tom__chen Signed-off-by: @tom__chen
This commit is contained in:
commit
bb241214f8
|
@ -38,5 +38,25 @@ MS_REG_GPU_KERNEL_TWO(
|
|||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherNd, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
GatherNdGpuFwdKernel, bool, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherNd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherNdGpuFwdKernel, float, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherNd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherNdGpuFwdKernel, half, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherNdGpuFwdKernel, int, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
GatherNdGpuFwdKernel, short, int64_t) // NOLINT
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherNd, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherNdGpuFwdKernel, uchar, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherNd, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
GatherNdGpuFwdKernel, bool, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_GATHERND_GPU_KERNEL_H
|
||||
#define MINDSPORE_GATHERND_GPU_KERNEL_H
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERND_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERND_GPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
|
@ -171,4 +171,4 @@ class GatherNdGpuFwdKernel : public GpuKernel {
|
|||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_GATHERND_GPU_KERNEL_H
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERND_GPU_KERNEL_H
|
||||
|
|
|
@ -73,3 +73,22 @@ template void GatherNd<unsigned char, int>(unsigned char *input, int *indices, u
|
|||
template void GatherNd<bool, int>(bool *input, int *indices, bool *output, const size_t &output_dim0,
|
||||
const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices,
|
||||
int *batch_strides, cudaStream_t stream);
|
||||
template void GatherNd<float, int64_t>(float *input, int64_t *indices, float *output, const size_t &output_dim0,
|
||||
const size_t &output_dim1, const size_t &indices_dim1, int64_t *batch_indices,
|
||||
int64_t *batch_strides, cudaStream_t stream);
|
||||
template void GatherNd<half, int64_t>(half *input, int64_t *indices, half *output, const size_t &output_dim0,
|
||||
const size_t &output_dim1, const size_t &indices_dim1, int64_t *batch_indices,
|
||||
int64_t *batch_strides, cudaStream_t stream);
|
||||
template void GatherNd<int, int64_t>(int *input, int64_t *indices, int *output, const size_t &output_dim0,
|
||||
const size_t &output_dim1, const size_t &indices_dim1, int64_t *batch_indices,
|
||||
int64_t *batch_strides, cudaStream_t stream);
|
||||
template void GatherNd<short, int64_t>(short *input, int64_t *indices, short *output, // NOLINT
|
||||
const size_t &output_dim0, const size_t &output_dim1, const size_t &indices_dim1,
|
||||
int64_t *batch_indices, int64_t *batch_strides, cudaStream_t stream);
|
||||
template void GatherNd<unsigned char, int64_t>(unsigned char *input, int64_t *indices, unsigned char *output,
|
||||
const size_t &output_dim0, const size_t &output_dim1,
|
||||
const size_t &indices_dim1, int64_t *batch_indices,
|
||||
int64_t *batch_strides, cudaStream_t stream);
|
||||
template void GatherNd<bool, int64_t>(bool *input, int64_t *indices, bool *output, const size_t &output_dim0,
|
||||
const size_t &output_dim1, const size_t &indices_dim1, int64_t *batch_indices,
|
||||
int64_t *batch_strides, cudaStream_t stream);
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_GATHERND_GPU_CU_H
|
||||
#define MINDSPORE_GATHERND_GPU_CU_H
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERND_GPU_CU_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERND_GPU_CU_H
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
|
@ -23,4 +23,4 @@ template <typename T, typename S>
|
|||
void GatherNd(T *input, S *indices, T *output, const size_t &output_dim0, const size_t &output_dim1,
|
||||
const size_t &indices_dim1, S *batch_indices, S *batch_strides, cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_GATHERND_GPU_CU_H
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERND_GPU_CU_H
|
||||
|
|
|
@ -206,3 +206,17 @@ def test_gathernd_bool():
|
|||
output = gathernd(x, indices)
|
||||
|
||||
assert np.array_equal(output.asnumpy(), expect)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gathernd_indices_int64():
|
||||
x = Tensor(np.array([[True, False], [False, False]]).astype(np.bool))
|
||||
indices = Tensor(np.array([[0, 0], [0, 1], [1, 0], [1, 1]]).astype(np.int64))
|
||||
expect = np.array([True, False, False, False]).astype(np.bool)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
gathernd = GatherNdNet()
|
||||
output = gathernd(x, indices)
|
||||
|
||||
assert np.array_equal(output.asnumpy(), expect)
|
||||
|
|
Loading…
Reference in New Issue