diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/batchtospace_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/batchtospace_gpu_kernel.cc new file mode 100644 index 00000000000..5d88f7a0401 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/batchtospace_gpu_kernel.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/batchtospace_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(BatchToSpace, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BatchToSpaceGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(BatchToSpace, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BatchToSpaceGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(BatchToSpace, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BatchToSpaceGpuKernel, int) +MS_REG_GPU_KERNEL_ONE(BatchToSpace, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + BatchToSpaceGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE(BatchToSpace, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + BatchToSpaceGpuKernel, int16_t) +MS_REG_GPU_KERNEL_ONE(BatchToSpace, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + BatchToSpaceGpuKernel, int8_t) +MS_REG_GPU_KERNEL_ONE(BatchToSpace, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + BatchToSpaceGpuKernel, uint8_t) +MS_REG_GPU_KERNEL_ONE(BatchToSpace, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + BatchToSpaceGpuKernel, uint16_t) +MS_REG_GPU_KERNEL_ONE(BatchToSpace, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + BatchToSpaceGpuKernel, uint32_t) +MS_REG_GPU_KERNEL_ONE(BatchToSpace, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + BatchToSpaceGpuKernel, uint64_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/batchtospace_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/batchtospace_gpu_kernel.h new file mode 100644 index 00000000000..e43aa0d6746 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/batchtospace_gpu_kernel.h @@ -0,0 +1,178 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BATCHOSPACE_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BATCHOSPACE_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/batchtospace_impl.cuh" + +namespace mindspore { +namespace kernel { +constexpr size_t SHAPE_SIZE = 4; +constexpr size_t CROPS_SHAPE_0 = 2; +constexpr size_t CROPS_SHAPE_1 = 2; +template +class BatchToSpaceGpuKernel : public GpuKernel { + public: + BatchToSpaceGpuKernel() { ResetResource(); } + ~BatchToSpaceGpuKernel() = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + + size_t size = output_size_ / sizeof(T); + + CalBatchToSpace(size, input, in_, ih_, iw_, ic_, on_, oh_, ow_, oc_, crops_[0][0], crops_[0][1], crops_[1][0], + crops_[1][1], block_size_, output, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + input_size_ = sizeof(T); + for (size_t idx = 0; idx < input_shape_.size(); ++idx) { + input_size_ *= input_shape_[idx]; + } + + in_ = input_shape_[0]; + ic_ = input_shape_[1]; + ih_ = input_shape_[2]; + iw_ = input_shape_[3]; + + on_ = in_ / (block_size_ * block_size_); + oc_ = ic_; + oh_ = ih_ * block_size_ - crops_[0][0] - crops_[0][1]; + ow_ = iw_ * block_size_ - crops_[1][0] - crops_[1][1]; + output_size_ = on_ * oc_ * oh_ * ow_ * sizeof(T); + InitSizeLists(); + return true; + } + void ResetResource() noexcept override { + in_ = 0; + ic_ = 0; + ih_ = 0; + iw_ = 0; + on_ = 0; + oc_ = 0; + oh_ = 0; + ow_ = 0; + input_size_list_.clear(); + output_size_list_.clear(); + crops_.clear(); + input_shape_.clear(); + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + bool CheckParam(const CNodePtr &kernel_node) { + block_size_ = GetAttr(kernel_node, "block_size"); + if (block_size_ < 1) { + MS_LOG(ERROR) << "block_size can not be less than 1."; + return false; + } + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "input_num is " << input_num << ", but BatchToSpace needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "output_num is " << output_num << ", but BatchToSpace needs 1 output."; + return false; + } + + // check input_shape + auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); + if (input_shape.size() != SHAPE_SIZE) { + MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but BatchToSpace supports 4-D tensor."; + return false; + } + if ((input_shape[0] % (block_size_ * block_size_)) != 0) { + MS_LOG(ERROR) << "input_shape[0] must be divisible by product of block_shape"; + return false; + } + for (size_t idx = 0; idx < SHAPE_SIZE; ++idx) { + if (input_shape[idx] < 1) { + MS_LOG(ERROR) << "input_shape[" << idx << "] can not less than 1"; + return false; + } + } + input_shape_.assign(input_shape.begin(), input_shape.end()); + + // check crops + crops_ = (GetAttr>>(kernel_node, "crops")); + + if (crops_.size() != CROPS_SHAPE_0) { + MS_LOG(ERROR) << "crops.size() in BatchToSpace needs 2."; + return false; + } + if (crops_[0].size() != CROPS_SHAPE_1 || crops_[1].size() != CROPS_SHAPE_1) { + MS_LOG(ERROR) << "crops[i].size() in BatchToSpace needs 2."; + return false; + } else { + for (size_t idx_i = 0; idx_i < CROPS_SHAPE_0; ++idx_i) { + for (size_t idx_j = 0; idx_j < CROPS_SHAPE_1; ++idx_j) { + if (crops_[idx_i][idx_j] < 0) { + MS_LOG(ERROR) << "the number in crops can not be less than 0."; + return false; + } + } + auto tmp_shape = input_shape[idx_i + CROPS_SHAPE_1] * block_size_ - crops_[idx_i][0] - crops_[idx_i][1]; + if (tmp_shape <= 0) { + MS_LOG(ERROR) << "out_shape can not be less 1."; + return false; + } + } + } + return true; + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + std::vector> crops_; + std::vector input_shape_; + size_t block_size_; + size_t input_size_; + size_t output_size_; + size_t in_; + size_t ic_; + size_t ih_; + size_t iw_; + size_t on_; + size_t oc_; + size_t oh_; + size_t ow_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BATCHOSPACE_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/spacetobatch_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/spacetobatch_gpu_kernel.cc new file mode 100644 index 00000000000..e1f54f7bec5 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/spacetobatch_gpu_kernel.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/spacetobatch_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(SpaceToBatch, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SpaceToBatchGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(SpaceToBatch, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SpaceToBatchGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(SpaceToBatch, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SpaceToBatchGpuKernel, int) +MS_REG_GPU_KERNEL_ONE(SpaceToBatch, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + SpaceToBatchGpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE(SpaceToBatch, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + SpaceToBatchGpuKernel, int16_t) +MS_REG_GPU_KERNEL_ONE(SpaceToBatch, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + SpaceToBatchGpuKernel, int8_t) +MS_REG_GPU_KERNEL_ONE(SpaceToBatch, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + SpaceToBatchGpuKernel, uint8_t) +MS_REG_GPU_KERNEL_ONE(SpaceToBatch, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + SpaceToBatchGpuKernel, uint16_t) +MS_REG_GPU_KERNEL_ONE(SpaceToBatch, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + SpaceToBatchGpuKernel, uint32_t) +MS_REG_GPU_KERNEL_ONE(SpaceToBatch, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + SpaceToBatchGpuKernel, uint64_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/spacetobatch_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/spacetobatch_gpu_kernel.h new file mode 100644 index 00000000000..7211a755bd5 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/spacetobatch_gpu_kernel.h @@ -0,0 +1,170 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPACETOBATCH_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPACETOBATCH_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/spacetobatch_impl.cuh" + +namespace mindspore { +namespace kernel { +constexpr size_t SHAPE_SIZE = 4; +constexpr size_t PADDING_SHAPE_0 = 2; +constexpr size_t PADDING_SHAPE_1 = 2; + +template +class SpaceToBatchGpuKernel : public GpuKernel { + public: + SpaceToBatchGpuKernel() { ResetResource(); } + ~SpaceToBatchGpuKernel() = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + + size_t size = input_size_ / sizeof(T); + + CalSpaceToBatch(size, input, in_, ih_, iw_, ic_, on_, oh_, ow_, oc_, paddings_[0][0], paddings_[0][1], + paddings_[1][0], paddings_[1][1], block_size_, output, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + input_size_ = sizeof(T); + for (size_t idx = 0; idx < input_shape_.size(); ++idx) { + input_size_ *= input_shape_[idx]; + } + in_ = input_shape_[0]; + ic_ = input_shape_[1]; + ih_ = input_shape_[2]; + iw_ = input_shape_[3]; + + on_ = in_ * block_size_ * block_size_; + oc_ = ic_; + oh_ = (ih_ + paddings_[0][0] + paddings_[0][1]) / block_size_; + ow_ = (iw_ + paddings_[1][0] + paddings_[1][1]) / block_size_; + output_size_ = on_ * oc_ * oh_ * ow_ * sizeof(T); + InitSizeLists(); + return true; + } + void ResetResource() noexcept override { + in_ = 0; + ic_ = 0; + ih_ = 0; + iw_ = 0; + on_ = 0; + oc_ = 0; + oh_ = 0; + ow_ = 0; + input_size_list_.clear(); + output_size_list_.clear(); + paddings_.clear(); + input_shape_.clear(); + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + block_size_ = static_cast(GetAttr(kernel_node, "block_size")); + if (block_size_ < 1) { + MS_LOG(ERROR) << "block_size can not be less than 1."; + return false; + } + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "input_num is " << input_num << ", but BatchToSpace needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "output_num is " << output_num << ", but BatchToSpace needs 1 output."; + return false; + } + + // check input_shape + auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); + if (input_shape.size() != SHAPE_SIZE) { + MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but BatchToSpace supports 4-D tensor."; + return false; + } + input_shape_.assign(input_shape.begin(), input_shape.end()); + // check paddings_ + paddings_ = GetAttr>>(kernel_node, "paddings"); + if (paddings_.size() != PADDING_SHAPE_0) { + MS_LOG(ERROR) << "paddings.size() in BatchToSpace needs 2."; + return false; + } + if (paddings_[0].size() != PADDING_SHAPE_1 || paddings_[1].size() != PADDING_SHAPE_1) { + MS_LOG(ERROR) << "paddings[i].size() in BatchToSpace needs 2."; + return false; + } else { + for (size_t idx_i = 0; idx_i < PADDING_SHAPE_0; ++idx_i) { + for (size_t idx_j = 0; idx_j < PADDING_SHAPE_1; ++idx_j) { + if (paddings_[idx_i][idx_j] < 0) { + MS_LOG(ERROR) << "the number in paddings can not be less than 0."; + return false; + } + } + auto tmp_shape = input_shape[idx_i + PADDING_SHAPE_1] + paddings_[idx_i][0] + paddings_[idx_i][1]; + if ((tmp_shape % block_size_) != 0) { + MS_LOG(ERROR) << "padded shape must be divisible by block_size"; + return false; + } + if ((tmp_shape / block_size_) == 0) { + MS_LOG(ERROR) << "padded shape can not be less than block_size"; + return false; + } + } + } + return true; + } + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + std::vector> paddings_; + std::vector input_shape_; + size_t block_size_; + size_t input_size_; + size_t output_size_; + size_t in_; + size_t ic_; + size_t ih_; + size_t iw_; + size_t on_; + size_t oc_; + size_t oh_; + size_t ow_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPACETOBATCH_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchtospace_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchtospace_impl.cu new file mode 100644 index 00000000000..fc5bf8b5969 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchtospace_impl.cu @@ -0,0 +1,133 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "batchtospace_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void BatchToSpace(const size_t size, const T *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t crop_up, const size_t crop_dn, + const size_t crop_lft, const size_t crop_rht, const size_t block_num, + T *output) { + size_t temp_stride = 0; + size_t temp_pos = 0; + size_t idx_on = 0; + size_t idx_oc = 0; + size_t idx_oh = 0; + size_t idx_ow = 0; + size_t idx_in = 0; + size_t input_pos = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; + pos += blockDim.x * gridDim.x) { + temp_stride = oc * oh * ow; + idx_on = pos / temp_stride; + temp_pos = pos % temp_stride; + + temp_stride /= oc; + idx_oc = temp_pos / temp_stride; + temp_pos = pos % temp_stride; + + temp_stride /= oh; + idx_oh = temp_pos / temp_stride; + temp_pos = pos % temp_stride; + + temp_stride /= ow; + idx_ow = temp_pos / temp_stride; + + idx_in = (((idx_oh + crop_up) % block_num) * block_num + ((idx_ow + crop_lft) % block_num)) * on + idx_on; + input_pos = idx_in * ic; + input_pos = (input_pos + idx_oc) * ih; + input_pos = (input_pos + ((idx_oh + crop_up) - (idx_in / (on * block_num))) / block_num) * iw; + input_pos = (input_pos + ((idx_ow + crop_lft) - ((idx_in / on) % block_num)) / block_num); + output[pos] = input[input_pos]; + } + return; +} + +template +void CalBatchToSpace(const size_t size, const T *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t crop_up, const size_t crop_dn, + const size_t crop_lft, const size_t crop_rht, const size_t block_num, + T *output, cudaStream_t cuda_stream) { + BatchToSpace<<>>( + size, input, in, ih, iw, ic, on, oh, ow, oc, crop_up, crop_dn, crop_lft, crop_rht, block_num, output); + return; +} + +template void CalBatchToSpace(const size_t size, const float *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t crop_up, const size_t crop_dn, + const size_t crop_lft, const size_t crop_rht, const size_t block_num, + float *output, cudaStream_t cuda_stream); +template void CalBatchToSpace(const size_t size, const half *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t crop_up, const size_t crop_dn, + const size_t crop_lft, const size_t crop_rht, const size_t block_num, + half *output, cudaStream_t cuda_stream); +template void CalBatchToSpace(const size_t size, const int *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t crop_up, const size_t crop_dn, + const size_t crop_lft, const size_t crop_rht, const size_t block_num, + int *output, cudaStream_t cuda_stream); +template void CalBatchToSpace(const size_t size, const int64_t *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t crop_up, const size_t crop_dn, + const size_t crop_lft, const size_t crop_rht, const size_t block_num, + int64_t *output, cudaStream_t cuda_stream); +template void CalBatchToSpace(const size_t size, const int16_t *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t crop_up, const size_t crop_dn, + const size_t crop_lft, const size_t crop_rht, const size_t block_num, + int16_t *output, cudaStream_t cuda_stream); +template void CalBatchToSpace(const size_t size, const int8_t *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t crop_up, const size_t crop_dn, + const size_t crop_lft, const size_t crop_rht, const size_t block_num, + int8_t *output, cudaStream_t cuda_stream); +template void CalBatchToSpace(const size_t size, const uint8_t *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t crop_up, const size_t crop_dn, + const size_t crop_lft, const size_t crop_rht, const size_t block_num, + uint8_t *output, cudaStream_t cuda_stream); +template void CalBatchToSpace(const size_t size, const uint16_t *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t crop_up, const size_t crop_dn, + const size_t crop_lft, const size_t crop_rht, const size_t block_num, + uint16_t *output, cudaStream_t cuda_stream); +template void CalBatchToSpace(const size_t size, const uint32_t *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t crop_up, const size_t crop_dn, + const size_t crop_lft, const size_t crop_rht, const size_t block_num, + uint32_t *output, cudaStream_t cuda_stream); +template void CalBatchToSpace(const size_t size, const uint64_t *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t crop_up, const size_t crop_dn, + const size_t crop_lft, const size_t crop_rht, const size_t block_num, + uint64_t *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchtospace_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchtospace_impl.cuh new file mode 100644 index 00000000000..cbf6a3976a6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchtospace_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHTOSPACE_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHTOSPACE_H_ +template +void CalBatchToSpace(const size_t size, const T *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t crop_up, const size_t crop_dn, + const size_t crop_lft, const size_t crop_rht, const size_t block_num, + T *output, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHTOSPACE_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/spacetobatch_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/spacetobatch_impl.cu new file mode 100644 index 00000000000..c770726f2b5 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/spacetobatch_impl.cu @@ -0,0 +1,134 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "spacetobatch_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void SpaceToBatch(const size_t size, const T *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t pad_up, const size_t pad_dn, + const size_t pad_lft, const size_t pad_rht, const size_t block_num, + T *output) { + size_t temp_stride = 0; + size_t temp_pos = 0; + size_t idx_in = 0; + size_t idx_ic = 0; + size_t idx_ih = 0; + size_t idx_iw = 0; + size_t idx_on = 0; + size_t output_pos = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; + pos += blockDim.x * gridDim.x) { + temp_stride = ic * ih * iw; + idx_in = pos / temp_stride; + temp_pos = pos % temp_stride; + + temp_stride /= ic; + idx_ic = temp_pos / temp_stride; + temp_pos = pos % temp_stride; + + temp_stride /= ih; + idx_ih = temp_pos / temp_stride; + temp_pos = pos % temp_stride; + + temp_stride /= iw; + idx_iw = temp_pos / temp_stride; + + idx_on = (((idx_ih + pad_up) % block_num) * block_num + ((idx_iw + pad_lft) % block_num)) * in + idx_in; + output_pos = idx_on * oc; + output_pos = (output_pos + idx_ic) * oh; + output_pos = (output_pos + ((idx_ih + pad_up) - (idx_on / (in * block_num))) / block_num) * ow; + output_pos = (output_pos + ((idx_iw + pad_lft) - ((idx_on / in) % block_num)) / block_num); + output[output_pos] = input[pos]; + } + return; +} + +template +void CalSpaceToBatch(const size_t size, const T *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t pad_up, const size_t pad_dn, + const size_t pad_lft, const size_t pad_rht, const size_t block_num, + T *output, cudaStream_t cuda_stream) { + cudaMemset(output, 0, on * oc * oh * ow * sizeof(T)); + SpaceToBatch<<>>( + size, input, in, ih, iw, ic, on, oh, ow, oc, pad_up, pad_dn, pad_lft, pad_rht, block_num, output); + return; +} + +template void CalSpaceToBatch(const size_t size, const float *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t pad_up, const size_t pad_dn, + const size_t pad_lft, const size_t pad_rht, const size_t block_num, + float *output, cudaStream_t cuda_stream); +template void CalSpaceToBatch(const size_t size, const half *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t pad_up, const size_t pad_dn, + const size_t pad_lft, const size_t pad_rht, const size_t block_num, + half *output, cudaStream_t cuda_stream); +template void CalSpaceToBatch(const size_t size, const int *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t pad_up, const size_t pad_dn, + const size_t pad_lft, const size_t pad_rht, const size_t block_num, + int *output, cudaStream_t cuda_stream); +template void CalSpaceToBatch(const size_t size, const int64_t *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t pad_up, const size_t pad_dn, + const size_t pad_lft, const size_t pad_rht, const size_t block_num, + int64_t *output, cudaStream_t cuda_stream); +template void CalSpaceToBatch(const size_t size, const int16_t *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t pad_up, const size_t pad_dn, + const size_t pad_lft, const size_t pad_rht, const size_t block_num, + int16_t *output, cudaStream_t cuda_stream); +template void CalSpaceToBatch(const size_t size, const int8_t *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t pad_up, const size_t pad_dn, + const size_t pad_lft, const size_t pad_rht, const size_t block_num, + int8_t *output, cudaStream_t cuda_stream); +template void CalSpaceToBatch(const size_t size, const uint8_t *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t pad_up, const size_t pad_dn, + const size_t pad_lft, const size_t pad_rht, const size_t block_num, + uint8_t *output, cudaStream_t cuda_stream); +template void CalSpaceToBatch(const size_t size, const uint16_t *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t pad_up, const size_t pad_dn, + const size_t pad_lft, const size_t pad_rht, const size_t block_num, + uint16_t *output, cudaStream_t cuda_stream); +template void CalSpaceToBatch(const size_t size, const uint32_t *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t pad_up, const size_t pad_dn, + const size_t pad_lft, const size_t pad_rht, const size_t block_num, + uint32_t *output, cudaStream_t cuda_stream); +template void CalSpaceToBatch(const size_t size, const uint64_t *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t pad_up, const size_t pad_dn, + const size_t pad_lft, const size_t pad_rht, const size_t block_num, + uint64_t *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/spacetobatch_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/spacetobatch_impl.cuh new file mode 100644 index 00000000000..93209f3235c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/spacetobatch_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPACETOBATCH_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPACETOBATCH_H_ +template +void CalSpaceToBatch(const size_t size, const T *input, const size_t in, + const size_t ih, const size_t iw, const size_t ic, + const size_t on, const size_t oh, const size_t ow, + const size_t oc, const size_t pad_up, const size_t pad_dn, + const size_t pad_lft, const size_t pad_rht, const size_t block_num, + T *output, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPACETOBATCH_H_ diff --git a/tests/st/ops/gpu/test_batchtospace_op.py b/tests/st/ops/gpu/test_batchtospace_op.py new file mode 100644 index 00000000000..0f33d7f3055 --- /dev/null +++ b/tests/st/ops/gpu/test_batchtospace_op.py @@ -0,0 +1,84 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.array_ops as P +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter + +class BatchToSpaceNet(nn.Cell): + def __init__(self, nptype, block_size=2, input_shape=(4, 1, 2, 2)): + super(BatchToSpaceNet, self).__init__() + self.BatchToSpace = P.BatchToSpace(block_size=block_size, crops=[[0, 0], [0, 0]]) + input_size = 1 + for i in input_shape: + input_size = input_size*i + data_np = np.arange(input_size).reshape(input_shape).astype(nptype) + self.x1 = Parameter(initializer(Tensor(data_np), input_shape), name='x1') + + + @ms_function + def construct(self): + y1 = self.BatchToSpace(self.x1) + return y1 + + +def BatchToSpace(nptype, block_size=2, input_shape=(4, 1, 2, 2)): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + input_size = 1 + for i in input_shape: + input_size = input_size*i + expect = np.array([[[[0, 4, 1, 5], + [8, 12, 9, 13], + [2, 6, 3, 7], + [10, 14, 11, 15]]]]).astype(nptype) + + dts = BatchToSpaceNet(nptype, block_size, input_shape) + output = dts() + + assert (output.asnumpy() == expect).all() + +def BatchToSpace_pynative(nptype, block_size=2, input_shape=(4, 1, 2, 2)): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + input_size = 1 + for i in input_shape: + input_size = input_size*i + expect = np.array([[[[0, 4, 1, 5], + [8, 12, 9, 13], + [2, 6, 3, 7], + [10, 14, 11, 15]]]]).astype(nptype) + + dts = P.BatchToSpace(block_size=block_size, crops=[[0, 0], [0, 0]]) + arr_input = Tensor(np.arange(input_size).reshape(input_shape).astype(nptype)) + output = dts(arr_input) + + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_batchtospace_graph_float32(): + BatchToSpace(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_batchtospace_graph_float16(): + BatchToSpace(np.float16) diff --git a/tests/st/ops/gpu/test_spacetobatch_op.py b/tests/st/ops/gpu/test_spacetobatch_op.py new file mode 100644 index 00000000000..a89751c0fdc --- /dev/null +++ b/tests/st/ops/gpu/test_spacetobatch_op.py @@ -0,0 +1,92 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.array_ops as P +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter + +class SpaceToBatchNet(nn.Cell): + def __init__(self, nptype, block_size=2, input_shape=(1, 1, 4, 4)): + super(SpaceToBatchNet, self).__init__() + self.SpaceToBatch = P.SpaceToBatch(block_size=block_size, paddings=[[0, 0], [0, 0]]) + input_size = 1 + for i in input_shape: + input_size = input_size*i + data_np = np.arange(input_size).reshape(input_shape).astype(nptype) + self.x1 = Parameter(initializer(Tensor(data_np), input_shape), name='x1') + + + @ms_function + def construct(self): + y1 = self.SpaceToBatch(self.x1) + return y1 + + +def SpaceToBatch(nptype, block_size=2, input_shape=(1, 1, 4, 4)): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + input_size = 1 + for i in input_shape: + input_size = input_size*i + expect = np.array([[[[0, 2], + [8, 10]]], + [[[1, 3], + [9, 11]]], + [[[4, 6], + [12, 14]]], + [[[5, 7], + [13, 15]]]]).astype(nptype) + + dts = SpaceToBatchNet(nptype, block_size, input_shape) + output = dts() + + assert (output.asnumpy() == expect).all() + +def SpaceToBatch_pynative(nptype, block_size=2, input_shape=(1, 1, 4, 4)): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + input_size = 1 + for i in input_shape: + input_size = input_size*i + expect = np.array([[[[0, 2], + [8, 10]]], + [[[1, 3], + [9, 11]]], + [[[4, 6], + [12, 14]]], + [[[5, 7], + [13, 15]]]]).astype(nptype) + + dts = P.SpaceToBatch(block_size=block_size, paddings=[[0, 0], [0, 0]]) + arr_input = Tensor(np.arange(input_size).reshape(input_shape).astype(nptype)) + output = dts(arr_input) + + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetobatch_graph_float32(): + SpaceToBatch(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetobatch_graph_float16(): + SpaceToBatch(np.float16)