From dfb958de1e575f6fc81ec69b62fdcd421d40c995 Mon Sep 17 00:00:00 2001 From: wilfChen Date: Wed, 15 Jul 2020 16:44:31 +0800 Subject: [PATCH] Gpu support BroadcastTo kernel --- .../gpu/arrays/broadcast_to_gpu_kernel.cc | 26 ++++++ .../gpu/arrays/broadcast_to_gpu_kernel.h | 83 +++++++++++++++++++ .../gpu/cuda_impl/broadcast_impl.cu | 40 +++++++-- .../gpu/cuda_impl/broadcast_impl.cuh | 4 + tests/st/ops/gpu/test_broadcast_to_ops.py | 40 +++++++++ 5 files changed, 187 insertions(+), 6 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_broadcast_to_ops.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.cc new file mode 100644 index 00000000000..96e82bc5f3d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastToGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastToGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h new file mode 100644 index 00000000000..459471ed763 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_TO_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_TO_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class BroadcastToGpuKernel : public GpuKernel { + public: + BroadcastToGpuKernel() {} + ~BroadcastToGpuKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + + BroadcastTo(input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], output_shape_[0], output_shape_[1], + output_shape_[2], output_shape_[3], input_addr, output_addr, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); + if (input_shapes.size() > 4 || output_shapes.size() > 4) { + MS_LOG(EXCEPTION) << "BroadcastTo operation not support dim greater than 4"; + } + + for (int i = input_shapes.size() - 1; i >= 0; i--) { + input_shape_[i] = input_shapes[i]; + } + + for (int j = output_shapes.size() - 1; j >= 0; j--) { + output_shape_[j] = output_shapes[j]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T)); + output_size_list_.push_back(output_shape_[0] * output_shape_[1] * output_shape_[2] * output_shape_[3] * sizeof(T)); + } + + private: + int input_shape_[4] = {1, 1, 1, 1}; + int output_shape_[4] = {1, 1, 1, 1}; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_TO_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu index a72daa42346..f5c88e7ebfc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -116,16 +116,16 @@ __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const output); case BROADCAST_TYPE_REALDIV: return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + output); case BROADCAST_TYPE_MUL: return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + output); case BROADCAST_TYPE_SUB: return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + output); case BROADCAST_TYPE_ADD: return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + output); } } @@ -176,6 +176,28 @@ void NoBroadcast(const int &nums, enum BroadcastOpType op, const T *input0, cons NoBroadcastKernel<<>>(nums, op, input0, input1, output); } +template +__global__ void BroadcastToKernel(const int i0, const int i1, const int i2, const int i3, const int o0, + const int o1, const int o2, const int o3, const T *input_addr, T *output_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < o0 * o1 * o2 * o3; pos += blockDim.x * gridDim.x) { + int i = pos / (o1 * o2 * o3) % o0; + int j = pos / (o2 * o3) % o1; + int k = pos / o3 % o2; + int l = pos % o3; + + int input_idx = Index(i, i0) * i1 * i2 * i3 + Index(j, i1) * i2 * i3 + Index(k, i2) * i3 + Index(l, i3); + output_addr[pos] = input_addr[input_idx]; + } +} + +template +void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, + const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream) { + int nums = o0 * o1 * o2 * o3; + BroadcastToKernel<<>>(i0, i1, i2, i3, o0, o1, o2, o3, input_addr, + output_addr); +} + template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, const float *input0, const float *input1, bool *output, @@ -204,5 +226,11 @@ template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half * bool *output, cudaStream_t stream); template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, half *output, cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, - int *output, cudaStream_t stream); +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, int *output, + cudaStream_t stream); + +template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, + const int &o2, const int &o3, const float *input_addr, float *output_addr, + cudaStream_t stream); +template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, + const int &o2, const int &o3, const half *input_addr, half *output_addr, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh index dfc4c75c932..62a3baad0e6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh @@ -41,4 +41,8 @@ template void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output, cudaStream_t stream); +template +void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, + const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream); + #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ diff --git a/tests/st/ops/gpu/test_broadcast_to_ops.py b/tests/st/ops/gpu/test_broadcast_to_ops.py new file mode 100644 index 00000000000..828e72c4d00 --- /dev/null +++ b/tests/st/ops/gpu/test_broadcast_to_ops.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + + x_np = np.random.rand(3, 1, 5, 1).astype(np.float32) + shape = (3, 4, 5, 6) + + output = P.BroadcastTo(shape)(Tensor(x_np)) + expect = np.broadcast_to(x_np, shape) + assert np.allclose(output.asnumpy(), expect) + + x1_np = np.random.rand(3, 1, 5, 1).astype(np.float16) + output = P.BroadcastTo(shape)(Tensor(x1_np)) + expect = np.broadcast_to(x1_np, shape) + assert np.allclose(output.asnumpy(), expect)