From b61e963b9f108cd929ca265ceed26bcdc1e341de Mon Sep 17 00:00:00 2001 From: TFBunny Date: Wed, 14 Apr 2021 11:49:01 -0400 Subject: [PATCH] add bool support to gpu select --- .../gpu/arrays/select_gpu_kernel.cc | 7 ++++++ .../gpu/cuda_impl/select_impl.cu | 5 ++-- .../gpu/cuda_impl/select_impl.cuh | 8 +++--- tests/st/ops/gpu/test_select_op.py | 25 +++++++++++++------ 4 files changed, 32 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc index d0c9ba19438..53365066374 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc @@ -53,5 +53,12 @@ MS_REG_GPU_KERNEL_ONE(Select, .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeInt64), SelectGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE(Select, + KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeBool) + .AddOutputAttr(kNumberTypeBool), + SelectGpuKernel, bool) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu index 393dd683bf6..e8ebe7e30fb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,4 +44,5 @@ template void CalSelect(const size_t size, const bool* cond, const half* i half* output, cudaStream_t cuda_stream); template void CalSelect(const size_t size, const bool* cond, const int64_t* input_X, const int64_t* input_y, int64_t* output, cudaStream_t cuda_stream); - +template void CalSelect(const size_t size, const bool *cond, const bool *input_X, const bool *input_y, + bool *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cuh index e201ab352ce..300ed281f70 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,12 +14,12 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SELECT_IMPL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SELECT_IMPL_H_ #include "runtime/device/gpu/cuda_common.h" template void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SELECT_IMPL_H_ diff --git a/tests/st/ops/gpu/test_select_op.py b/tests/st/ops/gpu/test_select_op.py index 1b1ccb7ef5c..0bc35116b45 100644 --- a/tests/st/ops/gpu/test_select_op.py +++ b/tests/st/ops/gpu/test_select_op.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P - class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -31,20 +30,32 @@ class Net(nn.Cell): return self.select(cond_op, input_x, input_y) -cond = np.array([[True, False], [True, False]]).astype(np.bool) -x = np.array([[1.2, 1], [1, 0]]).astype(np.float32) -y = np.array([[1, 2], [3, 4.0]]).astype(np.float32) - - @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_select(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") select = Net() + cond = np.array([[True, False], [True, False]]).astype(np.bool) + x = np.array([[1.2, 1], [1, 0]]).astype(np.float32) + y = np.array([[1, 2], [3, 4.0]]).astype(np.float32) output = select(Tensor(cond), Tensor(x), Tensor(y)) expect = [[1.2, 2], [1, 4.0]] error = np.ones(shape=[2, 2]) * 1.0e-6 diff = output.asnumpy() - expect assert np.all(diff < error) assert np.all(-diff < error) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x = np.array([[1, 0], [1, 0]]).astype(np.bool) + y = np.array([[0, 0], [1, 1]]).astype(np.bool) + output = select(Tensor(cond), Tensor(x), Tensor(y)) + expect = np.array([[1, 0], [1, 1]]).astype(np.bool) + assert np.all(output.asnumpy() == expect) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + x = np.array([[1, 0], [1, 0]]).astype(np.bool) + y = np.array([[0, 0], [1, 1]]).astype(np.bool) + output = select(Tensor(cond), Tensor(x), Tensor(y)) + expect = np.array([[1, 0], [1, 1]]).astype(np.bool) + assert np.all(output.asnumpy() == expect)