From 99a995b4326a388dbfcd08eae31b9ec6298651c3 Mon Sep 17 00:00:00 2001 From: likesen Date: Wed, 23 Jun 2021 16:22:03 +0800 Subject: [PATCH] Implement UNet3d on GPU --- .../gpu/arrays/slice_gpu_kernel.h | 39 +++- .../gpu/cuda_impl/relu_grad_impl.cu | 26 ++- .../gpu/cuda_impl/relu_grad_impl.cuh | 6 +- .../gpu/cuda_impl/slice_impl.cu | 60 +++++- .../gpu/cuda_impl/slice_impl.cuh | 4 + .../gpu/nn/prelu_grad_kernel.cc | 38 ++++ .../gpu/nn/prelu_grad_kernel.h | 196 ++++++++++++++++++ model_zoo/official/cv/unet3d/README.md | 163 ++++++++++++--- .../official/cv/unet3d/default_config.yaml | 2 + model_zoo/official/cv/unet3d/eval.py | 9 +- .../scripts/run_distribute_train_gpu_fp16.sh | 59 ++++++ .../scripts/run_distribute_train_gpu_fp32.sh | 59 ++++++ .../scripts/run_standalone_eval_gpu_fp16.sh | 76 +++++++ .../scripts/run_standalone_eval_gpu_fp32.sh | 76 +++++++ .../scripts/run_standalone_train_gpu_fp16.sh | 60 ++++++ .../scripts/run_standalone_train_gpu_fp32.sh | 60 ++++++ .../cv/unet3d/src/model_utils/config.py | 2 +- .../official/cv/unet3d/src/unet3d_model.py | 42 ++++ model_zoo/official/cv/unet3d/train.py | 33 ++- tests/st/ops/gpu/test_prelu_grad_op.py | 61 ++++++ tests/st/ops/gpu/test_slice.py | 24 +++ 21 files changed, 1034 insertions(+), 61 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_kernel.h create mode 100644 model_zoo/official/cv/unet3d/scripts/run_distribute_train_gpu_fp16.sh create mode 100644 model_zoo/official/cv/unet3d/scripts/run_distribute_train_gpu_fp32.sh create mode 100644 model_zoo/official/cv/unet3d/scripts/run_standalone_eval_gpu_fp16.sh create mode 100644 model_zoo/official/cv/unet3d/scripts/run_standalone_eval_gpu_fp32.sh create mode 100644 model_zoo/official/cv/unet3d/scripts/run_standalone_train_gpu_fp16.sh create mode 100644 model_zoo/official/cv/unet3d/scripts/run_standalone_train_gpu_fp32.sh create mode 100644 tests/st/ops/gpu/test_prelu_grad_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h index 63f39986db2..c4cd841376f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h @@ -42,9 +42,9 @@ class SliceGpuFwdKernel : public GpuKernel { } T *input = GetDeviceAddress(inputs, 0); T *output = GetDeviceAddress(outputs, 0); - Slice4DKernel(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3], input_shape_[0], - input_shape_[1], input_shape_[2], input_shape_[3], input, output, - reinterpret_cast(stream_ptr)); + Slice5DKernel(begin_[0], begin_[1], begin_[2], begin_[3], begin_[4], size_[0], size_[1], size_[2], size_[3], + size_[4], input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], input_shape_[4], input, + output, reinterpret_cast(stream_ptr)); return true; } bool Init(const CNodePtr &kernel_node) override { @@ -53,28 +53,35 @@ class SliceGpuFwdKernel : public GpuKernel { } auto data_format = AnfAlgo::GetInputFormat(kernel_node, 0); auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - ShapeNdTo4d(input_shape, &input_shape_); + ShapeNdTo5d(input_shape, &input_shape_); - for (auto i = begin_.size(); i < 4; i++) { + for (auto i = begin_.size(); i < 5; i++) { (void)begin_.insert(begin_.begin(), 0); } - for (size_t i = size_.size(); i < 4; i++) { + for (size_t i = size_.size(); i < 5; i++) { (void)size_.insert(size_.begin(), 1); } - input_size_ = input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T); + input_size_ = input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * input_shape_[4] * sizeof(T); auto out_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); output_size_ = sizeof(T); for (size_t x : out_shape) { output_size_ = output_size_ * x; } - // transpose begin and size for NHWC data + // transpose begin and size for NHWC and NDHWC data if (data_format == "NHWC") { std::swap(begin_[1], begin_[3]); std::swap(begin_[1], begin_[2]); std::swap(size_[1], size_[3]); std::swap(size_[1], size_[2]); + } else if (data_format == "NDHWC") { + std::swap(begin_[1], begin_[4]); + std::swap(begin_[1], begin_[3]); + std::swap(begin_[1], begin_[2]); + std::swap(size_[1], size_[4]); + std::swap(size_[1], size_[3]); + std::swap(size_[1], size_[2]); } InitSizeLists(); return true; @@ -87,6 +94,18 @@ class SliceGpuFwdKernel : public GpuKernel { } private: + // expand Nd Shape to 5d (N in [0,5]) + void ShapeNdTo5d(const std::vector &src, std::vector *dst) { + if (src.size() > 5) { + MS_EXCEPTION(ValueError) << src.size() << "-D data is not supported!"; + } + dst->push_back(src.size() < 5 ? 1 : src[src.size() - 5]); + dst->push_back(src.size() < 4 ? 1 : src[src.size() - 4]); + dst->push_back(src.size() < 3 ? 1 : src[src.size() - 3]); + dst->push_back(src.size() < 2 ? 1 : src[src.size() - 2]); + dst->push_back(src.size() == 0 ? 1 : src[src.size() - 1]); + } + bool CheckParam(const CNodePtr &kernel_node) { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 1) { @@ -99,8 +118,8 @@ class SliceGpuFwdKernel : public GpuKernel { return false; } auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", but SliceGpuFwdKernel olny support 4d or lower."; + if (input_shape.size() > 5) { + MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", but SliceGpuFwdKernel olny support 5d or lower."; return false; } if (input_shape.size() == 0) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cu index 15d012a10d0..4c9af52e5da 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,26 @@ void CalReLUGrad(int size, T *dy, T *y, T *dx, cudaStream_t cuda_stream) { return; } +template +__global__ void PReluChannelSharedGradKernel(size_t size, T *dy_addr, T *x_addr, T *w_addr, T *dx_addr, T *dwc_addr) { + T zero = static_cast(0); + T w = w_addr[0]; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + T dy = dy_addr[pos]; + T x = x_addr[pos]; + dx_addr[pos] = x > zero ? dy : w * dy; + dwc_addr[pos] = x > zero ? zero : x * dy; + } +} + +template +void PReluChannelSharedGrad(size_t input_size, T *dy_addr, T *x_addr, T *w_addr, T *dx_addr, T *dwc_addr, + cudaStream_t cuda_stream) { + PReluChannelSharedGradKernel<<>>(input_size, dy_addr, x_addr, + w_addr, dx_addr, dwc_addr); + return; +} + template void CalReLUGrad(int size, double *dy, double *y, double *dx, cudaStream_t cuda_stream); template void CalReLUGrad(int size, float *dy, float *y, float *dx, cudaStream_t cuda_stream); template void CalReLUGrad(int size, half *dy, half *y, half *dx, cudaStream_t cuda_stream); @@ -38,3 +58,7 @@ template void CalReLUGrad(int size, int16_t *dy, int16_t *y, int16_t *dx, cudaSt template void CalReLUGrad(int size, int32_t *dy, int32_t *y, int32_t *dx, cudaStream_t cuda_stream); template void CalReLUGrad(int size, int64_t *dy, int64_t *y, int64_t *dx, cudaStream_t cuda_stream); template void CalReLUGrad(int size, uint8_t *dy, uint8_t *y, uint8_t *dx, cudaStream_t cuda_stream); +template void PReluChannelSharedGrad(size_t input_size, float *dy_addr, float *x_addr, float *w_addr, float *dx_addr, + float *dwc_addr, cudaStream_t cuda_stream); +template void PReluChannelSharedGrad(size_t input_size, half *dy_addr, half *x_addr, half *w_addr, half *dx_addr, + half *dwc_addr, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh index 1d1fbbde7c3..c55ce4d823b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,4 +20,8 @@ #include "runtime/device/gpu/cuda_common.h" template void CalReLUGrad(int input_size, T *dy, T *y, T *dx, cudaStream_t cuda_stream); + +template +void PReluChannelSharedGrad(size_t input_size, T *dy_addr, T *x_addr, T *w_addr, T *dx_addr, T *dwc_addr, + cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu index a31487f4025..2f362d6f3d3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu @@ -35,6 +35,25 @@ __global__ void Slice4D(const size_t s1, const size_t s2, const size_t s3, const } } +template +__global__ void Slice5D(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5, + const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5, + const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5, + const T *input, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4 * l5); + pos += blockDim.x * gridDim.x) { + size_t i = pos / (l2 * l3 * l4 * l5) % l1; + size_t j = pos / (l3 * l4 * l5) % l2; + size_t k = pos / (l4 * l5) % l3; + size_t o = pos / l5 % l4; + size_t q = pos % l5; + + size_t offset = + (i + s1) * (d2 * d3 * d4 * d5) + (j + s2) * (d3 * d4 * d5) + (k + s3) * (d4 * d5) + (o + s4) * d5 + (q + s5); + output[pos] = input[offset]; + } +} + template __global__ void Slice4DGrad(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, @@ -70,7 +89,13 @@ void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size Slice4D<<>>(s1, s2, s3, s4, l1, l2, l3, l4, d1, d2, d3, d4, input, output); } - +template +void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5, const size_t l1, + const size_t l2, const size_t l3, const size_t l4, const size_t l5, const size_t d1, const size_t d2, + const size_t d3, const size_t d4, const size_t d5, const T *input, T *output, cudaStream_t stream) { + Slice5D<<>>(s1, s2, s3, s4, s5, l1, l2, l3, l4, l5, d1, + d2, d3, d4, d5, input, output); +} template void CalSlice4DGrad(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, @@ -184,6 +209,39 @@ template void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, c const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, const bool *input, bool *output, cudaStream_t stream); +template void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5, + const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5, + const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5, + const double *input, double *output, cudaStream_t stream); +template void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5, + const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5, + const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5, + const float *input, float *output, cudaStream_t stream); +template void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5, + const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5, + const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5, + const half *input, half *output, cudaStream_t stream); +template void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5, + const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5, + const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5, + const int64_t *input, int64_t *output, cudaStream_t stream); +template void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5, + const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5, + const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5, + const int *input, int *output, cudaStream_t stream); +template void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5, + const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5, + const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5, + const short *input, short *output, cudaStream_t stream); // NOLINT +template void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5, + const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5, + const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5, + const unsigned char *input, unsigned char *output, cudaStream_t stream); +template void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5, + const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t l5, + const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5, + const bool *input, bool *output, cudaStream_t stream); + template void CalSlice4DGrad(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh index 27cb53cf04e..da2d8bc8548 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh @@ -26,6 +26,10 @@ void Slice4DKernel(const size_t s1, const size_t s2, const size_t s3, const size const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, const T *input, T *output, cudaStream_t stream); template +void Slice5DKernel(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t s5, const size_t l1, + const size_t l2, const size_t l3, const size_t l4, const size_t l5, const size_t d1, const size_t d2, + const size_t d3, const size_t d4, const size_t d5, const T *input, T *output, cudaStream_t stream); +template void CalSlice4DGrad(const size_t s1, const size_t s2, const size_t s3, const size_t s4, const size_t l1, const size_t l2, const size_t l3, const size_t l4, const size_t d1, const size_t d2, const size_t d3, const size_t d4, const T *dy, T *dx, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_kernel.cc new file mode 100644 index 00000000000..b7b1bb0cf1a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/prelu_grad_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(PReLUGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + PReLUGpuGradKernel, float) +MS_REG_GPU_KERNEL_ONE(PReLUGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + PReLUGpuGradKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_kernel.h new file mode 100644 index 00000000000..21f36285d76 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_kernel.h @@ -0,0 +1,196 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class PReLUGpuGradKernel : public GpuKernel { + public: + PReLUGpuGradKernel() + : data_format_(kOpFormat_NCDHW), + input_size_(0), + weight_size_(0), + reduce_workspace_size_(0), + spatial_count_(1), + is_null_input_(false), + channel_shared_(false), + channel_last_(false) {} + ~PReLUGpuGradKernel() override { DestroyResource(); } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *dy_addr = GetDeviceAddress(inputs, 0); + T *x_addr = GetDeviceAddress(inputs, 1); + T *w_addr = GetDeviceAddress(inputs, 2); + T *dx_addr = GetDeviceAddress(outputs, 0); + T *dw_addr = GetDeviceAddress(outputs, 1); + T *dw_collector_addr = GetDeviceAddress(workspace, 0); + T *reduce_workspace_addr = GetDeviceAddress(workspace, 1); + + PReluChannelSharedGrad(input_size_ / sizeof(T), dy_addr, x_addr, w_addr, dx_addr, dw_collector_addr, + reinterpret_cast(stream_ptr)); + + if (data_type_ == CUDNN_DATA_DOUBLE) { + T alpha = static_cast(1.0f); + T beta = static_cast(0.0f); + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr, + reduce_workspace_size_, &alpha, grad_weight_collector_descriptor_, dw_collector_addr, &beta, + grad_weight_descriptor_, dw_addr), + "cudnnReduceTensor failed."); + } else { + const float alphaf = static_cast(1.0f); + const float betaf = static_cast(0.0f); + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr, + reduce_workspace_size_, &alphaf, grad_weight_collector_descriptor_, dw_collector_addr, &betaf, + grad_weight_descriptor_, dw_addr), + "cudnnReduceTensor failed."); + } + return true; + } + + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateReduceTensorDescriptor(&reduce_tensor_descriptor_), + "cudnnCreateReduceTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&grad_weight_collector_descriptor_), + "cudnnCreateTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&grad_weight_descriptor_), + "cudnnCreateTensorDescriptor failed."); + } + + void DestroyResource() noexcept override { + CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyReduceTensorDescriptor(reduce_tensor_descriptor_), + "cudnnDestroyReduceTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(grad_weight_collector_descriptor_), + "cudnnDestroyTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(grad_weight_descriptor_), + "cudnnDestroyTensorDescriptor failed."); + } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + input_size_ = sizeof(T); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "PReLUGpuBwdKernel input is null."; + } + for (size_t i = 0; i < input_shape.size(); ++i) { + input_size_ *= input_shape[i]; + } + weight_size_ = sizeof(T); + auto weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); + is_null_input_ = CHECK_NULL_INPUT(weight_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "PReLUGpuBwdKernel input is null."; + } + for (auto dim : weight_shape) { + weight_size_ *= dim; + } + channel_shared_ = (weight_shape[0] == 1); + if (!channel_shared_) { + MS_LOG(WARNING) + << "PReLUGpuBwdKernel shares weight for all channels, but the given weight tensor has more than one element."; + } + + spatial_count_ = 1; + if (channel_last_) { + for (size_t i = 1; i < input_shape.size() - 1; ++i) { + spatial_count_ *= input_shape[i]; + } + } else { + for (size_t i = 2; i < input_shape.size(); ++i) { + spatial_count_ *= input_shape[i]; + } + } + + data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + int input_dim_length = input_shape.size(); + std::vector reduce_out_shape(input_dim_length, 1); + if (channel_last_) { + reduce_out_shape[input_dim_length - 1] = weight_shape[0]; + } else { + reduce_out_shape[1] = weight_shape[0]; + } + InitResource(); + CudnnSetTensorNdDescriptor(reduce_out_shape, grad_weight_descriptor_, data_type_, kernel_node_); + CudnnSetTensorNdDescriptor(input_shape, grad_weight_collector_descriptor_, data_type_, kernel_node_); + cudnnDataType_t comp_type = (data_type_ == CUDNN_DATA_DOUBLE) ? CUDNN_DATA_DOUBLE : CUDNN_DATA_FLOAT; + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnSetReduceTensorDescriptor(reduce_tensor_descriptor_, CUDNN_REDUCE_TENSOR_ADD, comp_type, + CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES), + "cudnnSetReduceTensorDescriptor failed"); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + input_size_list_.push_back(weight_size_); + output_size_list_.push_back(input_size_); + output_size_list_.push_back(weight_size_); + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnGetReductionWorkspaceSize(cudnn_handle_, reduce_tensor_descriptor_, grad_weight_collector_descriptor_, + grad_weight_descriptor_, &reduce_workspace_size_), + "cudnnGetReductionWorkspaceSize failed."); + workspace_size_list_.push_back(input_size_); + workspace_size_list_.push_back(reduce_workspace_size_); + } + + private: + cudnnHandle_t cudnn_handle_; + cudnnDataType_t data_type_; + cudnnReduceTensorDescriptor_t reduce_tensor_descriptor_; + cudnnTensorDescriptor_t grad_weight_collector_descriptor_; + cudnnTensorDescriptor_t grad_weight_descriptor_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + std::string data_format_ = kOpFormat_NCDHW; + size_t input_size_; + size_t weight_size_; + size_t reduce_workspace_size_; + size_t spatial_count_; + bool is_null_input_ = false; + bool channel_shared_ = false; + bool channel_last_ = false; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_KERNEL_H_ diff --git a/model_zoo/official/cv/unet3d/README.md b/model_zoo/official/cv/unet3d/README.md index 6c0abe091a0..8470791c8ca 100644 --- a/model_zoo/official/cv/unet3d/README.md +++ b/model_zoo/official/cv/unet3d/README.md @@ -11,10 +11,15 @@ - [Script Parameters](#script-parameters) - [Training Process](#training-process) - [Training](#training) - - [running on Ascend](#running-on-ascend) - - [Distributed Training](#distributed-training) + - [Training on Ascend](#training-on-ascend) + - [Training on GPU](#training-on-gpu) + - [Distributed Training](#distributed-training) + - [Distributed training on Ascend](#distributed-training-on-ascend) + - [Distributed training on GPU](#distributed-training-on-gpu) - [Evaluation Process](#evaluation-process) - [Evaluation](#evaluation) + - [Evaluating on Ascend](#training-on-ascend) + - [Evaluating on GPU](#training-on-gpu) - [Model Description](#model-description) - [Performance](#performance) - [Evaluation Performance](#evaluation-performance) @@ -36,16 +41,29 @@ Dataset used: [LUNA16](https://luna16.grand-challenge.org/) - Description: The data is to automatically detect the location of nodules from volumetric CT images. 888 CT scans from LIDC-IDRI database are provided. The complete dataset is divided into 10 subsets that should be used for the 10-fold cross-validation. All subsets are available as compressed zip files. -- Dataset size:888 - - Train:878 images - - Test:10 images +- Dataset size:887 + - Train:877 images + - Test:10 images(last 10 images in subset9 with lexicographical order) - Data format:zip - - Note:Data will be processed in convert_nifti.py + - Note:Data will be processed in convert_nifti.py, and one of them will be ignored during data processing. +- Data Content Structure + +```text + +. +└─LUNA16 + ├── train + │ ├── image // contains 877 image files + | ├── seg // contains 877 seg files + ├── val + │ ├── image // contains 10 image files + | ├── seg // contains 10 seg files +``` ## [Environment Requirements](#contents) -- Hardware(Ascend) - - Prepare hardware environment with Ascend processor. +- Hardware(Ascend or GPU) + - Prepare hardware environment with Ascend or GPU. - Framework - [MindSpore](https://www.mindspore.cn/install/en) - For more information, please check the resources below: @@ -79,6 +97,25 @@ bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH] # run evaluation example python eval.py --data_path=/path/to/data/ --checkpoint_file_path=/path/to/checkpoint/ > eval.log 2>&1 & +``` + +- Run on GPU + +```shell +# enter scripts directory +cd scripts +# run training example(fp32) +bash ./run_standalone_train_gpu_fp32.sh [TRAINING_DATA_PATH] +# run training example(fp16) +bash ./run_standalone_train_gpu_fp16.sh [TRAINING_DATA_PATH] +# run distributed training example(fp32) +bash ./run_distribute_train_gpu_fp32.sh [TRAINING_DATA_PATH] +# run distributed training example(fp16) +bash ./run_distribute_train_gpu_fp16.sh [TRAINING_DATA_PATH] +# run evaluation example(fp32) +bash ./run_standalone_eval_gpu_fp32.sh [VALIDATING_DATA_PATH] [CHECKPOINT_FILE_PATH] +# run evaluation example(fp16) +bash ./run_standalone_eval_gpu_fp16.sh [VALIDATING_DATA_PATH] [CHECKPOINT_FILE_PATH] ``` @@ -123,9 +160,15 @@ If you want to run in modelarts, please check the official documentation of [mod └─unet3d ├── README.md // descriptions about Unet3D ├── scripts - │ ├──run_disribute_train.sh // shell script for distributed on Ascend + │ ├──run_distribute_train.sh // shell script for distributed on Ascend │ ├──run_standalone_train.sh // shell script for standalone on Ascend │ ├──run_standalone_eval.sh // shell script for evaluation on Ascend + │ ├──run_distribute_train_gpu_fp32.sh // shell script for distributed on GPU fp32 + │ ├──run_distribute_train_gpu_fp16.sh // shell script for distributed on GPU fp16 + │ ├──run_standalone_train_gpu_fp32.sh // shell script for standalone on GPU fp32 + │ ├──run_standalone_train_gpu_fp16.sh // shell script for standalone on GPU fp16 + │ ├──run_standalone_eval_gpu_fp32.sh // shell script for evaluation on GPU fp32 + │ ├──run_standalone_eval_gpu_fp16.sh // shell script for evaluation on GPU fp16 ├── src │ ├──dataset.py // creating dataset │ ├──lr_schedule.py // learning rate scheduler @@ -177,7 +220,23 @@ Parameters for both training and evaluation can be set in config.py ### Training -#### running on Ascend +#### Training on GPU + +```shell +# enter scripts directory +cd scripts +# fp32 +bash ./run_standalone_train_gpu_fp32.sh /path_prefix/LUNA16/train +# fp16 +bash ./run_standalone_train_gpu_fp16.sh /path_prefix/LUNA16/train + +``` + +The python command above will run in the background, you can view the results through the file `train.log`. + +After training, you'll get some checkpoint files under the train_fp[32|16]/output/ckpt_0/ folder by default. + +#### Training on Ascend ```shell python train.py --data_path=/path/to/data/ > train.log 2>&1 & @@ -201,7 +260,25 @@ epoch time: 1180467.795 ms, per step time: 1380.664 ms ``` -#### Distributed Training +### Distributed Training + +#### Distributed training on GPU(8P) + +```shell +# enter scripts directory +cd scripts +# fp32 +bash ./run_distribute_train_gpu_fp32.sh /path_prefix/LUNA16/train +# fp16 +bash ./run_distribute_train_gpu_fp16.sh /path_prefix/LUNA16/train + +``` + +The above shell script will run distribute training in the background. You can view the results through the file `/train_parallel_fp[32|16]/train.log`. + +After training, you'll get some checkpoint files under the `train_parallel_fp[32|16]/output/ckpt_[X]/` folder by default. + +#### Distributed training on Ascend > Notes: > RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size. @@ -235,6 +312,24 @@ epoch time: 140476.520 ms, per step time: 1312.865 ms ### Evaluation +#### Evaluating on GPU + +```shell +# enter scripts directory +cd ./script +# fp32, 1gpu +bash ./run_standalone_eval_gpu_fp32.sh /path_prefix/LUNA16/val /path_prefix/train_fp32/output/ckpt_0/Unet3d-10_877.ckpt +# fp16, 1gpu +bash ./run_standalone_eval_gpu_fp16.sh /path_prefix/LUNA16/val /path_prefix/train_fp16/output/ckpt_0/Unet3d-10_877.ckpt +# fp32, 8gpu +bash ./run_standalone_eval_gpu_fp32.sh /path_prefix/LUNA16/val /path_prefix/train_parallel_fp32/output/ckpt_0/Unet3d-10_110.ckpt +# fp16, 8gpu +bash ./run_standalone_eval_gpu_fp16.sh /path_prefix/LUNA16/val /path_prefix/train_parallel_fp16/output/ckpt_0/Unet3d-10_110.ckpt + +``` + +#### Evaluating on Ascend + - evaluation on dataset when running on Ascend Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/unet3d/Unet3d-10_110.ckpt". @@ -259,33 +354,33 @@ eval average dice is 0.9502010010453671 #### Evaluation Performance -| Parameters | Ascend | -| ------------------- | --------------------------------------------------------- | -| Model Version | Unet3D | -| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | -| uploaded Date | 03/18/2021 (month/day/year) | -| MindSpore Version | 1.2.0 | -| Dataset | LUNA16 | -| Training Parameters | epoch = 10, batch_size = 1 | -| Optimizer | Adam | -| Loss Function | SoftmaxCrossEntropyWithLogits | -| Speed | 8pcs: 1795ms/step | -| Total time | 8pcs: 0.62hours | -| Parameters (M) | 34 | +| Parameters | Ascend | GPU | +| ------------------- | --------------------------------------------------------- | ---------------------------------------------------- | +| Model Version | Unet3D | Unet3D | +| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | Nvidia V100 SXM2; CPU 1.526GHz; 72cores; Memory 42G; OS Ubuntu16| +| uploaded Date | 03/18/2021 (month/day/year) | 05/21/2021(month/day/year) | +| MindSpore Version | 1.2.0 | 1.2.0 | +| Dataset | LUNA16 | LUNA16 | +| Training Parameters | epoch = 10, batch_size = 1 | epoch = 10, batch_size = 1 | +| Optimizer | Adam | Adam | +| Loss Function | SoftmaxCrossEntropyWithLogits | SoftmaxCrossEntropyWithLogits | +| Speed | 8pcs: 1795ms/step | 8pcs: 1883ms/step | +| Total time | 8pcs: 0.62hours | 8pcs: 0.66hours | +| Parameters (M) | 34 | 34 | | Scripts | [unet3d script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet3d) | #### Inference Performance -| Parameters | Ascend | -| ------------------- | --------------------------- | -| Model Version | Unet3D | -| Resource | Ascend 910; OS Euler2.8 | -| Uploaded Date | 03/18/2021 (month/day/year) | -| MindSpore Version | 1.2.0 | -| Dataset | LUNA16 | -| batch_size | 1 | -| Dice | dice = 0.9502 | -| Model for inference | 56M(.ckpt file) | +| Parameters | Ascend | GPU | +| ------------------- | --------------------------- | --------------------------- | +| Model Version | Unet3D | Unet3D | +| Resource | Ascend 910; OS Euler2.8 | Nvidia V100 SXM2; OS Ubuntu16| +| Uploaded Date | 03/18/2021 (month/day/year) | 05/21/2021 (month/day/year) | +| MindSpore Version | 1.2.0 | 1.2.0 | +| Dataset | LUNA16 | LUNA16 | +| batch_size | 1 | 1 | +| Dice | dice = 0.9502 | dice = 0.9601 | +| Model for inference | 56M(.ckpt file) | 56M(.ckpt file) | # [Description of Random Situation](#contents) diff --git a/model_zoo/official/cv/unet3d/default_config.yaml b/model_zoo/official/cv/unet3d/default_config.yaml index bcad19a6eca..45717e0f577 100644 --- a/model_zoo/official/cv/unet3d/default_config.yaml +++ b/model_zoo/official/cv/unet3d/default_config.yaml @@ -1,4 +1,5 @@ # Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_fp16_gpu: False enable_modelarts: False # Url for modelarts data_url: "" @@ -41,6 +42,7 @@ file_format: "" --- # Help description for each configuration enable_modelarts: 'Whether training on modelarts, default: False' +enable_fp16_gpu: 'Whether training on gpu with fp16, default: False' data_url: 'Dataset url for obs' train_url: 'Training output url for obs' checkpoint_url: 'The location of checkpoint for obs' diff --git a/model_zoo/official/cv/unet3d/eval.py b/model_zoo/official/cv/unet3d/eval.py index e49ad5c77fb..db41847d486 100644 --- a/model_zoo/official/cv/unet3d/eval.py +++ b/model_zoo/official/cv/unet3d/eval.py @@ -19,13 +19,13 @@ from mindspore import dtype as mstype from mindspore import Model, context, Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.dataset import create_dataset -from src.unet3d_model import UNet3d +from src.unet3d_model import UNet3d, UNet3d_ from src.utils import create_sliding_window, CalculateDice from src.model_utils.config import config from src.model_utils.moxing_adapter import moxing_wrapper device_id = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False, device_id=device_id) @moxing_wrapper() def test_net(data_path, ckpt_path): @@ -35,7 +35,10 @@ def test_net(data_path, ckpt_path): eval_data_size = eval_dataset.get_dataset_size() print("train dataset length is:", eval_data_size) - network = UNet3d() + if config.device_target == 'Ascend': + network = UNet3d() + else: + network = UNet3d_() network.set_train(False) param_dict = load_checkpoint(ckpt_path) load_param_into_net(network, param_dict) diff --git a/model_zoo/official/cv/unet3d/scripts/run_distribute_train_gpu_fp16.sh b/model_zoo/official/cv/unet3d/scripts/run_distribute_train_gpu_fp16.sh new file mode 100644 index 00000000000..7d5244c5fa3 --- /dev/null +++ b/model_zoo/official/cv/unet3d/scripts/run_distribute_train_gpu_fp16.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -ne 1 ] +then + echo "Usage: sh run_distribute_train_gpu.sh [DATA_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +echo $PATH1 +if [ ! -d $PATH1 ] +then + echo "error: IMAGE_PATH=$PATH1 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 + + +if [ -d "train_parallel_fp16" ]; +then + rm -rf ./train_parallel_fp16 +fi + +rm -rf ./train_parallel_fp16 +mkdir ./train_parallel_fp16 +cp ../*.py ./train_parallel_fp16 +cp *.sh ./train_parallel_fp16 +cp ../*.yaml ./train_parallel_fp16 +cp -r ../src ./train_parallel_fp16 +cd ./train_parallel_fp16 || exit +echo "start distributed training with $DEVICE_NUM GPUs." +env > env.log +mpirun --allow-run-as-root -n $DEVICE_NUM python train.py --run_distribute=True --data_path=$PATH1 --output_path './output' --device_target='GPU' --enable_fp16_gpu=True --checkpoint_path='./' > train.log 2>&1 & +cd .. diff --git a/model_zoo/official/cv/unet3d/scripts/run_distribute_train_gpu_fp32.sh b/model_zoo/official/cv/unet3d/scripts/run_distribute_train_gpu_fp32.sh new file mode 100644 index 00000000000..b9902c087ce --- /dev/null +++ b/model_zoo/official/cv/unet3d/scripts/run_distribute_train_gpu_fp32.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -ne 1 ] +then + echo "Usage: sh run_distribute_train_gpu.sh [DATA_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +echo $PATH1 +if [ ! -d $PATH1 ] +then + echo "error: IMAGE_PATH=$PATH1 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 + + +if [ -d "train_parallel_fp32" ]; +then + rm -rf ./train_parallel_fp32 +fi + +rm -rf ./train_parallel_fp32 +mkdir ./train_parallel_fp32 +cp ../*.py ./train_parallel_fp32 +cp *.sh ./train_parallel_fp32 +cp ../*.yaml ./train_parallel_fp32 +cp -r ../src ./train_parallel_fp32 +cd ./train_parallel_fp32 || exit +echo "start distributed training with $DEVICE_NUM GPUs." +env > env.log +mpirun --allow-run-as-root -n $DEVICE_NUM python train.py --run_distribute=True --data_path=$PATH1 --output_path './output' --device_target='GPU' --checkpoint_path='./' > train.log 2>&1 & +cd .. diff --git a/model_zoo/official/cv/unet3d/scripts/run_standalone_eval_gpu_fp16.sh b/model_zoo/official/cv/unet3d/scripts/run_standalone_eval_gpu_fp16.sh new file mode 100644 index 00000000000..4de101935cd --- /dev/null +++ b/model_zoo/official/cv/unet3d/scripts/run_standalone_eval_gpu_fp16.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "==============================================================================================================" + echo "Please run the script as: " + echo "bash scripts/run_standalone_eval_gpu_fp16.sh [DATA_PATH] [CHECKPOINT]" + echo "for example: bash run_standalone_eval_gpu_fp16.sh /path/to/data/ /path/to/checkpoint/" + echo "==============================================================================================================" +fi + +if [ $# != 2 ] +then + echo "Usage: sh run_standalone_eval_gpu_fp16.sh [DATA_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +PATH1=$(get_real_path $1) +CHECKPOINT_FILE_PATH=$(get_real_path $2) +echo $PATH1 +echo $CHECKPOINT_FILE_PATH + +if [ ! -d $PATH1 ] +then + echo "error: PATH1=$PATH1 is not a path" +exit 1 +fi + +if [ ! -f $CHECKPOINT_FILE_PATH ] +then + echo "error: CHECKPOINT_FILE_PATH=$CHECKPOINT_FILE_PATH is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export RANK_SIZE=$DEVICE_NUM +export DEVICE_ID=0 +export RANK_ID=0 + +if [ -d "eval_fp16" ]; +then + rm -rf ./eval_fp16 +fi + +mkdir ./eval_fp16 +cp ../*.py ./eval_fp16 +cp *.sh ./eval_fp16 +cp ../*.yaml ./eval_fp16 +cp -r ../src ./eval_fp16 +cd ./eval_fp16 || exit +echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}" +python eval.py --data_path=$PATH1 --checkpoint_file_path=$CHECKPOINT_FILE_PATH --device_target='GPU' > eval.log 2>&1 & +echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}" +cd .. diff --git a/model_zoo/official/cv/unet3d/scripts/run_standalone_eval_gpu_fp32.sh b/model_zoo/official/cv/unet3d/scripts/run_standalone_eval_gpu_fp32.sh new file mode 100644 index 00000000000..e5478c2fbf5 --- /dev/null +++ b/model_zoo/official/cv/unet3d/scripts/run_standalone_eval_gpu_fp32.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "==============================================================================================================" + echo "Please run the script as: " + echo "bash scripts/run_standalone_eval_gpu.sh [DATA_PATH] [CHECKPOINT]" + echo "for example: bash run_standalone_eval_gpu.sh /path/to/data/ /path/to/checkpoint/" + echo "==============================================================================================================" +fi + +if [ $# != 2 ] +then + echo "Usage: sh run_standalone_eval_gpu.sh [DATA_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +PATH1=$(get_real_path $1) +CHECKPOINT_FILE_PATH=$(get_real_path $2) +echo $PATH1 +echo $CHECKPOINT_FILE_PATH + +if [ ! -d $PATH1 ] +then + echo "error: PATH1=$PATH1 is not a path" +exit 1 +fi + +if [ ! -f $CHECKPOINT_FILE_PATH ] +then + echo "error: CHECKPOINT_FILE_PATH=$CHECKPOINT_FILE_PATH is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export RANK_SIZE=$DEVICE_NUM +export DEVICE_ID=0 +export RANK_ID=0 + +if [ -d "eval_fp32" ]; +then + rm -rf ./eval_fp32 +fi + +mkdir ./eval_fp32 +cp ../*.py ./eval_fp32 +cp *.sh ./eval_fp32 +cp ../*.yaml ./eval_fp32 +cp -r ../src ./eval_fp32 +cd ./eval_fp32 || exit +echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}" +python eval.py --data_path=$PATH1 --checkpoint_file_path=$CHECKPOINT_FILE_PATH --device_target='GPU' > eval.log 2>&1 & +echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}" +cd .. diff --git a/model_zoo/official/cv/unet3d/scripts/run_standalone_train_gpu_fp16.sh b/model_zoo/official/cv/unet3d/scripts/run_standalone_train_gpu_fp16.sh new file mode 100644 index 00000000000..e84212b21e3 --- /dev/null +++ b/model_zoo/official/cv/unet3d/scripts/run_standalone_train_gpu_fp16.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -ne 1 ] +then + echo "Usage: sh run_distribute_train_gpu_fp16.sh [DATA_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +echo $PATH1 +if [ ! -d $PATH1 ] +then + echo "error: IMAGE_PATH=$PATH1 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 + +if [ -d "train_fp16" ]; +then + rm -rf ./train_fp16 +fi + +rm -rf ./train_fp16 +mkdir ./train_fp16 +cp ../*.py ./train_fp16 +cp *.sh ./train_fp16 +cp ../*.yaml ./train_fp16 +cp -r ../src ./train_fp16 +cd ./train_fp16 || exit +echo "start training for device $DEVICE_ID" +env > env.log +python train.py --data_path=$PATH1 --output_path './output' --device_target='GPU' --checkpoint_path='./' --enable_fp16_gpu=True > train.log 2>&1 & +cd .. diff --git a/model_zoo/official/cv/unet3d/scripts/run_standalone_train_gpu_fp32.sh b/model_zoo/official/cv/unet3d/scripts/run_standalone_train_gpu_fp32.sh new file mode 100644 index 00000000000..edeed187855 --- /dev/null +++ b/model_zoo/official/cv/unet3d/scripts/run_standalone_train_gpu_fp32.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -ne 1 ] +then + echo "Usage: sh run_distribute_train_gpu_fp32.sh [DATA_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +echo $PATH1 +if [ ! -d $PATH1 ] +then + echo "error: IMAGE_PATH=$PATH1 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 + +if [ -d "train_fp32" ]; +then + rm -rf ./train_fp32 +fi + +rm -rf ./train_fp32 +mkdir ./train_fp32 +cp ../*.py ./train_fp32 +cp *.sh ./train_fp32 +cp ../*.yaml ./train_fp32 +cp -r ../src ./train_fp32 +cd ./train_fp32 || exit +echo "start training for device $DEVICE_ID" +env > env.log +python train.py --data_path=$PATH1 --output_path './output' --device_target='GPU' --checkpoint_path='./' > train.log 2>&1 & +cd .. \ No newline at end of file diff --git a/model_zoo/official/cv/unet3d/src/model_utils/config.py b/model_zoo/official/cv/unet3d/src/model_utils/config.py index 92136db1e0c..4d022dd3c51 100644 --- a/model_zoo/official/cv/unet3d/src/model_utils/config.py +++ b/model_zoo/official/cv/unet3d/src/model_utils/config.py @@ -117,9 +117,9 @@ def get_config(): help="Config file path") path_args, _ = parser.parse_known_args() default, helper = parse_yaml(path_args.config_path) - pprint(default) args = parse_cli_to_yaml(parser, default, helper, path_args.config_path) final_config = merge(args, default) + pprint(final_config) return Config(final_config) config = get_config() diff --git a/model_zoo/official/cv/unet3d/src/unet3d_model.py b/model_zoo/official/cv/unet3d/src/unet3d_model.py index 85bae388e7a..34a32182169 100644 --- a/model_zoo/official/cv/unet3d/src/unet3d_model.py +++ b/model_zoo/official/cv/unet3d/src/unet3d_model.py @@ -19,6 +19,48 @@ from mindspore.ops import operations as P from src.unet3d_parts import Down, Up from src.model_utils.config import config +class UNet3d_(nn.Cell): + """ + UNet3d_ support fp32 and fp16(amp) training on GPU. + """ + def __init__(self): + super(UNet3d_, self).__init__() + self.n_channels = config.in_channels + self.n_classes = config.num_classes + + # down + self.down1 = Down(in_channel=self.n_channels, out_channel=16, dtype=mstype.float32) + self.down2 = Down(in_channel=16, out_channel=32, dtype=mstype.float32) + self.down3 = Down(in_channel=32, out_channel=64, dtype=mstype.float32) + self.down4 = Down(in_channel=64, out_channel=128, dtype=mstype.float32) + self.down5 = Down(in_channel=128, out_channel=256, stride=1, kernel_size=(1, 1, 1), \ + dtype=mstype.float32) + + # up + self.up1 = Up(in_channel=256, down_in_channel=128, out_channel=64, \ + dtype=mstype.float32) + self.up2 = Up(in_channel=64, down_in_channel=64, out_channel=32, \ + dtype=mstype.float32) + self.up3 = Up(in_channel=32, down_in_channel=32, out_channel=16, \ + dtype=mstype.float32) + self.up4 = Up(in_channel=16, down_in_channel=16, out_channel=self.n_classes, \ + dtype=mstype.float32, is_output=True) + + + def construct(self, input_data): + x1 = self.down1(input_data) + x2 = self.down2(x1) + x3 = self.down3(x2) + x4 = self.down4(x3) + x5 = self.down5(x4) + + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + return x + + class UNet3d(nn.Cell): def __init__(self): super(UNet3d, self).__init__() diff --git a/model_zoo/official/cv/unet3d/train.py b/model_zoo/official/cv/unet3d/train.py index eb85dabcf88..01af49b1d45 100644 --- a/model_zoo/official/cv/unet3d/train.py +++ b/model_zoo/official/cv/unet3d/train.py @@ -19,20 +19,23 @@ import mindspore.nn as nn import mindspore.common.dtype as mstype from mindspore import Tensor, Model, context from mindspore.context import ParallelMode -from mindspore.communication.management import init +from mindspore.communication.management import init, get_rank, get_group_size from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor from src.dataset import create_dataset -from src.unet3d_model import UNet3d +from src.unet3d_model import UNet3d, UNet3d_ from src.lr_schedule import dynamic_lr from src.loss import SoftmaxCrossEntropyWithLogits from src.model_utils.config import config from src.model_utils.moxing_adapter import moxing_wrapper from src.model_utils.device_adapter import get_device_id, get_device_num -device_id = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, \ - device_id=device_id) +if config.device_target == 'Ascend': + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False, \ + device_id=device_id) +else: + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False) mindspore.set_seed(1) @moxing_wrapper() @@ -42,8 +45,12 @@ def train_net(data_path, seg_dir = data_path + "/seg/" if run_distribute: init() - rank_id = get_device_id() - rank_size = get_device_num() + if config.device_target == 'Ascend': + rank_id = get_device_id() + rank_size = get_device_num() + else: + rank_id = get_rank() + rank_size = get_group_size() parallel_mode = ParallelMode.DATA_PARALLEL context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=rank_size, @@ -56,7 +63,10 @@ def train_net(data_path, train_data_size = train_dataset.get_dataset_size() print("train dataset length is:", train_data_size) - network = UNet3d() + if config.device_target == 'Ascend': + network = UNet3d() + else: + network = UNet3d_() loss = SoftmaxCrossEntropyWithLogits() lr = Tensor(dynamic_lr(config, train_data_size), mstype.float32) @@ -64,7 +74,10 @@ def train_net(data_path, scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) network.set_train() - model = Model(network, loss_fn=loss, optimizer=optimizer, loss_scale_manager=scale_manager) + if config.device_target == 'GPU' and config.enable_fp16_gpu: + model = Model(network, loss_fn=loss, optimizer=optimizer, loss_scale_manager=scale_manager, amp_level='O2') + else: + model = Model(network, loss_fn=loss, optimizer=optimizer, loss_scale_manager=scale_manager) time_cb = TimeMonitor(data_size=train_data_size) loss_cb = LossMonitor() @@ -72,7 +85,7 @@ def train_net(data_path, keep_checkpoint_max=config.keep_checkpoint_max) ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path) ckpoint_cb = ModelCheckpoint(prefix='Unet3d', - directory=ckpt_save_dir+'./ckpt_{}/'.format(device_id), + directory=ckpt_save_dir+'./ckpt_{}/'.format(rank_id), config=ckpt_config) callbacks_list = [loss_cb, time_cb, ckpoint_cb] print("============== Starting Training ==============") diff --git a/tests/st/ops/gpu/test_prelu_grad_op.py b/tests/st/ops/gpu/test_prelu_grad_op.py new file mode 100644 index 00000000000..1442d730055 --- /dev/null +++ b/tests/st/ops/gpu/test_prelu_grad_op.py @@ -0,0 +1,61 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + + +class NetPReLUGrad(nn.Cell): + def __init__(self): + super(NetPReLUGrad, self).__init__() + self.prelu_grad = G.PReLUGrad() + + def construct(self, dout, x, w): + return self.prelu_grad(dout, x, w) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_prelu_grad_fp32_channel_shared(): + dout = Tensor(np.ones(shape=[2, 2, 2, 3]).astype(np.float32)) + x = Tensor(np.arange(-5, 19).reshape(2, 2, 2, 3).astype(np.float32)) + w = Tensor(np.array([-0.5]).astype(np.float32)) + expect_dx = np.array([[[[-0.5000, -0.5000, -0.5000], + [-0.5000, -0.5000, -0.5000]], + [[1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000]]], + [[[1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000]], + [[1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000]]]]).astype(np.float32) + expect_dw = np.array([-15.]).astype(np.float32) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + prelu_grad = NetPReLUGrad() + dx, dw = prelu_grad(dout, x, w) + assert (dx.asnumpy() == expect_dx).all() + assert (dw.asnumpy() == expect_dw).all() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + prelu_grad = NetPReLUGrad() + dx, dw = prelu_grad(dout, x, w) + assert (dx.asnumpy() == expect_dx).all() + assert (dw.asnumpy() == expect_dw).all() diff --git a/tests/st/ops/gpu/test_slice.py b/tests/st/ops/gpu/test_slice.py index aef14084b5f..10575e94326 100644 --- a/tests/st/ops/gpu/test_slice.py +++ b/tests/st/ops/gpu/test_slice.py @@ -69,6 +69,30 @@ def test_slice_4d(): assert (output_ms.asnumpy() == output_np).all() +class Slice5DNet(nn.Cell): + def __init__(self): + super(Slice5DNet, self).__init__() + self.slice = P.Slice() + + def construct(self, x): + return self.slice(x, (0, 11, 1, 2, 3), (32, 7, 14, 10, 221)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_slice_5d(): + x_np = np.random.randn(32, 32, 24, 224, 224).astype(np.float32) + output_np = x_np[:, 11:18, 1:15, 2:12, 3:224] + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x_ms = Tensor(x_np) + net = Slice5DNet() + output_ms = net(x_ms) + + assert (output_ms.asnumpy() == output_np).all() + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard