diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/depthtospace_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/depthtospace_gpu_kernel.cc new file mode 100644 index 00000000000..0f7467f3757 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/depthtospace_gpu_kernel.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/depthtospace_gpu_kernel.h" +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + DepthToSpaceFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + DepthToSpaceFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + DepthToSpaceFwdKernel, int) +MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + DepthToSpaceFwdKernel, int64_t) +MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + DepthToSpaceFwdKernel, int16_t) +MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + DepthToSpaceFwdKernel, int8_t) +MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + DepthToSpaceFwdKernel, uint8_t) +MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + DepthToSpaceFwdKernel, uint16_t) +MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + DepthToSpaceFwdKernel, uint32_t) +MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + DepthToSpaceFwdKernel, uint64_t) + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/depthtospace_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/depthtospace_gpu_kernel.h new file mode 100644 index 00000000000..8fbbf01869e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/depthtospace_gpu_kernel.h @@ -0,0 +1,146 @@ + +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEPTHTOSPACE_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEPTHTOSPACE_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/depthtospace_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class DepthToSpaceFwdKernel : public GpuKernel { + public: + DepthToSpaceFwdKernel() { ResetResource(); } + ~DepthToSpaceFwdKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + // get device buffer ptr + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + + // get device buffer shape ptr + size_t *input_shape = GetDeviceAddress(workspace, 0); + size_t *output_shape = GetDeviceAddress(workspace, 1); + + // buffer shape memcpy from host to device + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size1_, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape failed"); + + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(output_shape, &output_shape_[0], workspace_size2_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape failed"); + // get input size + size_t size = input_size_ / sizeof(T); + + // call cuda kernel + CalDepthToSpace(size, input, input_shape, output_shape, block_size_, output, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + block_size_ = static_cast(GetAttr(kernel_node, "block_size")); + // check input num and output num + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but DepthToSpace needs 1 input."; + return false; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", DepthToSpace needs 1 output."; + return false; + } + // check input_shape + auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); + shape_size_ = input_shape.size(); + if (shape_size_ != DEPTHTOSPACE_BUFFER_DIMENSION) { + MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but DepthToSpace supports 4-D tensor."; + } + // get input and out put information + input_size_ = 1; + for (size_t i = 0; i < shape_size_; i++) { + input_size_ *= input_shape[i]; + input_shape_.push_back(input_shape[i]); + } + input_size_ *= sizeof(T); + output_size_ = input_size_; + output_shape_.push_back(input_shape[0]); + output_shape_.push_back(input_shape[1] / block_size_ / block_size_); + output_shape_.push_back(input_shape[2] * block_size_); + output_shape_.push_back(input_shape[3] * block_size_); + // Private members Initialize + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + shape_size_ = 0; + input_size_ = 0; + output_size_ = 0; + block_size_ = 0; + workspace_size1_ = 0; + workspace_size2_ = 0; + + input_shape_.clear(); + output_shape_.clear(); + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + workspace_size1_ = shape_size_ * sizeof(size_t); + workspace_size2_ = shape_size_ * sizeof(size_t); + workspace_size_list_.push_back(workspace_size1_); + workspace_size_list_.push_back(workspace_size2_); + return; + } + + private: + std::vector input_shape_; + std::vector output_shape_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t shape_size_; + size_t input_size_; + size_t output_size_; + size_t block_size_; + size_t workspace_size1_; + size_t workspace_size2_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEPTHTOSPACE_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/spacetodepth_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/spacetodepth_gpu_kernel.cc new file mode 100644 index 00000000000..f5d88297024 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/spacetodepth_gpu_kernel.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/spacetodepth_gpu_kernel.h" +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SpaceToDepthFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SpaceToDepthFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SpaceToDepthFwdKernel, int) +MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + SpaceToDepthFwdKernel, int64_t) +MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + SpaceToDepthFwdKernel, int16_t) +MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + SpaceToDepthFwdKernel, int8_t) +MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + SpaceToDepthFwdKernel, uint8_t) +MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + SpaceToDepthFwdKernel, uint16_t) +MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + SpaceToDepthFwdKernel, uint32_t) +MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + SpaceToDepthFwdKernel, uint64_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/spacetodepth_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/spacetodepth_gpu_kernel.h new file mode 100644 index 00000000000..de0c9490969 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/spacetodepth_gpu_kernel.h @@ -0,0 +1,146 @@ + +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPACETODEPTH_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPACETODEPTH_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/spacetodepth_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SpaceToDepthFwdKernel : public GpuKernel { + public: + SpaceToDepthFwdKernel() { ResetResource(); } + ~SpaceToDepthFwdKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + // get device buffer ptr + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + + // get device buffer shape ptr + size_t *input_shape = GetDeviceAddress(workspace, 0); + size_t *output_shape = GetDeviceAddress(workspace, 1); + + // buffer shape memcpy from host to device + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size1_, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape failed"); + + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(output_shape, &output_shape_[0], workspace_size2_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape failed"); + // get input size + size_t size = input_size_ / sizeof(T); + + // call cuda kernel + CalSpaceToDepth(size, input, input_shape, output_shape, block_size_, output, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + block_size_ = static_cast(GetAttr(kernel_node, "block_size")); + // check input num and output num + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but SpaceToDepth needs 1 input."; + return false; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", SpaceToDepth needs 1 output."; + return false; + } + // check input_shape + auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); + shape_size_ = input_shape.size(); + if (shape_size_ != SPACETODEPTH_BUFFER_DIMENSION) { + MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but SpaceToDepth supports 4-D tensor."; + } + // get input and out put information + input_size_ = 1; + for (size_t i = 0; i < shape_size_; i++) { + input_size_ *= input_shape[i]; + input_shape_.push_back(input_shape[i]); + } + input_size_ *= sizeof(T); + output_size_ = input_size_; + output_shape_.push_back(input_shape[0]); + output_shape_.push_back(input_shape[1] * block_size_ * block_size_); + output_shape_.push_back(input_shape[2] / block_size_); + output_shape_.push_back(input_shape[3] / block_size_); + // Private members Initialize + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + shape_size_ = 0; + input_size_ = 0; + output_size_ = 0; + block_size_ = 0; + workspace_size1_ = 0; + workspace_size2_ = 0; + + input_shape_.clear(); + output_shape_.clear(); + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + workspace_size1_ = shape_size_ * sizeof(size_t); + workspace_size2_ = shape_size_ * sizeof(size_t); + workspace_size_list_.push_back(workspace_size1_); + workspace_size_list_.push_back(workspace_size2_); + return; + } + + private: + std::vector input_shape_; + std::vector output_shape_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t shape_size_; + size_t input_size_; + size_t output_size_; + size_t block_size_; + size_t workspace_size1_; + size_t workspace_size2_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPACETODEPTH_KERNEL__H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/depthtospace_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/depthtospace_impl.cu new file mode 100644 index 00000000000..0d6b2fb5a1f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/depthtospace_impl.cu @@ -0,0 +1,91 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "depthtospace_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void DepthToSpace(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape, + const size_t r, T *output) { + size_t temp_stride = 0; + size_t temp_pos = 0; + size_t input_pos = 0; + size_t input_pos_array[DEPTHTOSPACE_BUFFER_DIMENSION]; + size_t output_pos_array[DEPTHTOSPACE_BUFFER_DIMENSION]; + + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + temp_stride = output_shape[1] * output_shape[2] * output_shape[3]; + output_pos_array[0] = pos / temp_stride; + temp_pos = pos % temp_stride; + + for (size_t i = 1; i < DEPTHTOSPACE_BUFFER_DIMENSION; i++) { + temp_stride /= output_shape[i]; + output_pos_array[i] = temp_pos / temp_stride; + temp_pos %= temp_stride; + } + + input_pos_array[0] = output_pos_array[0]; + input_pos_array[1] = output_pos_array[1] * r * r + r * (output_pos_array[2] % r) + output_pos_array[3] % r; + input_pos_array[2] = output_pos_array[2] / r; + input_pos_array[3] = output_pos_array[3] / r; + + for (size_t i = 0; i < 3; ++i) { + input_pos += input_pos_array[i]; + input_pos *= input_shape[i + 1]; + } + input_pos += input_pos_array[3]; + + output[pos] = input[input_pos]; + } + return; +} + +template +void CalDepthToSpace(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape, + const size_t r, T *output, cudaStream_t cuda_stream) { + DepthToSpace<<>>(size, input, input_shape, output_shape, r, output); + return; +} + +template void CalDepthToSpace(const size_t size, const float *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, float *output, + cudaStream_t cuda_stream); +template void CalDepthToSpace(const size_t size, const half *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, half *output, cudaStream_t cuda_stream); +template void CalDepthToSpace(const size_t size, const int *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, int *output, cudaStream_t cuda_stream); +template void CalDepthToSpace(const size_t size, const int64_t *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, int64_t *output, + cudaStream_t cuda_stream); +template void CalDepthToSpace(const size_t size, const int16_t *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, int16_t *output, + cudaStream_t cuda_stream); +template void CalDepthToSpace(const size_t size, const int8_t *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, int8_t *output, + cudaStream_t cuda_stream); +template void CalDepthToSpace(const size_t size, const uint8_t *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, uint8_t *output, + cudaStream_t cuda_stream); +template void CalDepthToSpace(const size_t size, const uint16_t *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, uint16_t *output, + cudaStream_t cuda_stream); +template void CalDepthToSpace(const size_t size, const uint32_t *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, uint32_t *output, + cudaStream_t cuda_stream); +template void CalDepthToSpace(const size_t size, const uint64_t *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, uint64_t *output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/depthtospace_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/depthtospace_impl.cuh new file mode 100644 index 00000000000..4c727f973ef --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/depthtospace_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DEPTHTOSPACE_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DEPTHTOSPACE_H_ + +#define DEPTHTOSPACE_BUFFER_DIMENSION 4 +template +void CalDepthToSpace(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape, + const size_t r, T *output, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DEPTHTOSPACE_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/spacetodepth_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/spacetodepth_impl.cu new file mode 100644 index 00000000000..168022cc8e4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/spacetodepth_impl.cu @@ -0,0 +1,91 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "spacetodepth_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void SpaceToDepth(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape, + const size_t r, T *output) { + size_t temp_stride = 0; + size_t temp_pos = 0; + size_t output_pos = 0; + size_t input_pos_array[SPACETODEPTH_BUFFER_DIMENSION]; + size_t output_pos_array[SPACETODEPTH_BUFFER_DIMENSION]; + + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + temp_stride = input_shape[1] * input_shape[2] * input_shape[3]; + input_pos_array[0] = pos / temp_stride; + temp_pos = pos % temp_stride; + + for (size_t i = 1; i < SPACETODEPTH_BUFFER_DIMENSION; i++) { + temp_stride /= input_shape[i]; + input_pos_array[i] = temp_pos / temp_stride; + temp_pos %= temp_stride; + } + + output_pos_array[0] = input_pos_array[0]; + output_pos_array[1] = input_pos_array[1] * r * r + r * (input_pos_array[2] % r) + input_pos_array[3] % r; + output_pos_array[2] = input_pos_array[2] / r; + output_pos_array[3] = input_pos_array[3] / r; + + for (size_t i = 0; i < 3; ++i) { + output_pos += output_pos_array[i]; + output_pos *= output_shape[i + 1]; + } + output_pos += output_pos_array[3]; + + output[output_pos] = input[pos]; + } + return; +} + +template +void CalSpaceToDepth(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape, + const size_t r, T *output, cudaStream_t cuda_stream) { + SpaceToDepth<<>>(size, input, input_shape, output_shape, r, output); + return; +} + +template void CalSpaceToDepth(const size_t size, const float *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, float *output, + cudaStream_t cuda_stream); +template void CalSpaceToDepth(const size_t size, const half *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, half *output, cudaStream_t cuda_stream); +template void CalSpaceToDepth(const size_t size, const int *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, int *output, cudaStream_t cuda_stream); +template void CalSpaceToDepth(const size_t size, const int64_t *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, int64_t *output, + cudaStream_t cuda_stream); +template void CalSpaceToDepth(const size_t size, const int16_t *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, int16_t *output, + cudaStream_t cuda_stream); +template void CalSpaceToDepth(const size_t size, const int8_t *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, int8_t *output, + cudaStream_t cuda_stream); +template void CalSpaceToDepth(const size_t size, const uint8_t *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, uint8_t *output, + cudaStream_t cuda_stream); +template void CalSpaceToDepth(const size_t size, const uint16_t *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, uint16_t *output, + cudaStream_t cuda_stream); +template void CalSpaceToDepth(const size_t size, const uint32_t *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, uint32_t *output, + cudaStream_t cuda_stream); +template void CalSpaceToDepth(const size_t size, const uint64_t *input, const size_t *input_shape, + const size_t *output_shape, const size_t r, uint64_t *output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/spacetodepth_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/spacetodepth_impl.cuh new file mode 100644 index 00000000000..6d888654257 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/spacetodepth_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPACETODEPTH_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPACETODEPTH_H_ + +#define SPACETODEPTH_BUFFER_DIMENSION 4 +template +void CalSpaceToDepth(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape, + const size_t r, T *output, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPACETODEPTH_H_ diff --git a/tests/st/ops/gpu/test_depthtospace_op.py b/tests/st/ops/gpu/test_depthtospace_op.py new file mode 100644 index 00000000000..f0bdcb5d165 --- /dev/null +++ b/tests/st/ops/gpu/test_depthtospace_op.py @@ -0,0 +1,200 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.array_ops as P +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter + +def DepthToSpaceNumpy(arr, block_size): + ''' + DepthToSpace ops use numpy + ''' + tmpshape = arr.shape + newshape = [] + newshape.append(tmpshape[0]) + newshape.append(tmpshape[1]//block_size//block_size) + newshape.append(tmpshape[2]*block_size) + newshape.append(tmpshape[3]*block_size) + output = arr.reshape(newshape[0], newshape[1], block_size, block_size, tmpshape[2], tmpshape[3]) + output = np.transpose(output, (0, 1, 4, 2, 5, 3)) + output = output.reshape(newshape) + return output + +class DepthToSpaceNet(nn.Cell): + def __init__(self, nptype, block_size=2, input_shape=(1, 4, 3, 3)): + super(DepthToSpaceNet, self).__init__() + self.DepthToSpace = P.DepthToSpace(2) + input_size = 1 + for i in input_shape: + input_size = input_size*i + self.data_np = np.arange(input_size).reshape(input_shape).astype(nptype) + self.x = Parameter(initializer(Tensor(self.data_np), input_shape), name='x') + + @ms_function + def construct(self): + return self.DepthToSpace(self.x) + + +def DepthToSpace(nptype, block_size=2, input_shape=(1, 4, 3, 3)): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + input_size = 1 + for i in input_shape: + input_size = input_size*i + expect = np.arange(input_size).reshape(input_shape).astype(nptype) + expect = DepthToSpaceNumpy(expect, block_size) + + dts = DepthToSpaceNet(nptype, block_size, input_shape) + output = dts() + assert (output.asnumpy() == expect).all() + +def DepthToSpace_pynative(nptype, block_size=2, input_shape=(1, 4, 3, 3)): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + input_size = 1 + for i in input_shape: + input_size = input_size*i + expect = np.arange(input_size).reshape(input_shape).astype(nptype) + expect = DepthToSpaceNumpy(expect, block_size) + + dts = P.DepthToSpace(2) + arr_input = Tensor(np.arange(input_size).reshape(input_shape).astype(nptype)) + output = dts(arr_input) + + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_float32(): + DepthToSpace(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_float16(): + DepthToSpace(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_int32(): + DepthToSpace(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_int64(): + DepthToSpace(np.int64) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_int8(): + DepthToSpace(np.int8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_int16(): + DepthToSpace(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_uint8(): + DepthToSpace(np.uint8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_uint16(): + DepthToSpace(np.uint16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_uint32(): + DepthToSpace(np.uint32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_graph_uint64(): + DepthToSpace(np.uint64) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_pynative_float32(): + DepthToSpace_pynative(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_pynative_float16(): + DepthToSpace_pynative(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_pynative_int32(): + DepthToSpace_pynative(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_pynative_int64(): + DepthToSpace_pynative(np.int64) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_pynative_int8(): + DepthToSpace_pynative(np.int8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_pynative_int16(): + DepthToSpace_pynative(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_pynative_uint8(): + DepthToSpace_pynative(np.uint8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_pynative_uint16(): + DepthToSpace_pynative(np.uint16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_pynative_uint32(): + DepthToSpace_pynative(np.uint32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_depthtospace_pynative_uint64(): + DepthToSpace_pynative(np.uint64) diff --git a/tests/st/ops/gpu/test_spacetodepth_op.py b/tests/st/ops/gpu/test_spacetodepth_op.py new file mode 100644 index 00000000000..90001689e1a --- /dev/null +++ b/tests/st/ops/gpu/test_spacetodepth_op.py @@ -0,0 +1,210 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.array_ops as P +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter + +def DepthToSpaceNumpy(arr, block_size): + ''' + DepthToSpace ops use numpy + DepthToSpace ops is reverse ops to SpaceToDepth ops + therefore DepthToSpace's output can be SpaceToDepth's input + ''' + tmpshape = arr.shape + newshape = [] + newshape.append(tmpshape[0]) + newshape.append(tmpshape[1]//block_size//block_size) + newshape.append(tmpshape[2]*block_size) + newshape.append(tmpshape[3]*block_size) + output = arr.reshape(newshape[0], newshape[1], block_size, block_size, tmpshape[2], tmpshape[3]) + output = np.transpose(output, (0, 1, 4, 2, 5, 3)) + output = output.reshape(newshape) + return output + +class SpaceToDepthNet(nn.Cell): + def __init__(self, nptype, block_size=2, input_shape=(1, 4, 3, 3)): + super(SpaceToDepthNet, self).__init__() + self.SpaceToDepth = P.SpaceToDepth(block_size) + input_size = 1 + for i in input_shape: + input_size = input_size*i + + data_np = np.arange(input_size).reshape(input_shape).astype(nptype)# data_np shape is (N,C,H,W) + data_np = DepthToSpaceNumpy(data_np, block_size)#now data_np shape is (N,C/(block_size*block_size),H*block_size,W*block_size) + self.data_np = data_np + new_shape = [] + new_shape.append(input_shape[0]) + new_shape.append(input_shape[1]//(block_size*block_size)) + new_shape.append(input_shape[2]*block_size) + new_shape.append(input_shape[3]*block_size) + self.x = Parameter(initializer(Tensor(self.data_np), new_shape), name='x') + + @ms_function + def construct(self): + return self.SpaceToDepth(self.x) + + +def SpaceToDepth(nptype, block_size=2, input_shape=(1, 4, 3, 3)): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + input_size = 1 + for i in input_shape: + input_size = input_size*i + + expect = np.arange(input_size).reshape(input_shape).astype(nptype) + + std = SpaceToDepthNet(nptype, block_size, input_shape) + output = std() + assert (output.asnumpy() == expect).all() + +def SpaceToDepth_pynative(nptype, block_size=2, input_shape=(1, 4, 3, 3)): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + input_size = 1 + for i in input_shape: + input_size = input_size*i + expect = np.arange(input_size).reshape(input_shape).astype(nptype) + arrinput = DepthToSpaceNumpy(expect, block_size) + + std = P.SpaceToDepth(block_size) + arrinput = Tensor(arrinput) + output = std(arrinput) + + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_float32(): + SpaceToDepth(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_float16(): + SpaceToDepth(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_int32(): + SpaceToDepth(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_int64(): + SpaceToDepth(np.int64) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_int8(): + SpaceToDepth(np.int8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_int16(): + SpaceToDepth(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_uint8(): + SpaceToDepth(np.uint8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_uint16(): + SpaceToDepth(np.uint16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_uint32(): + SpaceToDepth(np.uint32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_graph_uint64(): + SpaceToDepth(np.uint64) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_pynative_float32(): + SpaceToDepth_pynative(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_pynative_float16(): + SpaceToDepth_pynative(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_pynative_int32(): + SpaceToDepth_pynative(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_pynative_int64(): + SpaceToDepth_pynative(np.int64) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_pynative_int8(): + SpaceToDepth_pynative(np.int8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_pynative_int16(): + SpaceToDepth_pynative(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_pynative_uint8(): + SpaceToDepth_pynative(np.uint8) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_pynative_uint16(): + SpaceToDepth_pynative(np.uint16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_pynative_uint32(): + SpaceToDepth_pynative(np.uint32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_spacetodepth_pynative_uint64(): + SpaceToDepth_pynative(np.uint64)