From 1240d57cea25fc5c0f8b4596c6eee9d1ba7499e3 Mon Sep 17 00:00:00 2001 From: zhuyuxiao Date: Tue, 16 Mar 2021 15:27:08 +0800 Subject: [PATCH] I3AP06: dtype and return value --- .../gpu/cuda_impl/adagrad_impl.cu | 158 ++++++++++++++++-- .../gpu/cuda_impl/adagrad_impl.cuh | 8 +- .../gpu/nn/adagrad_gpu_kernel.cc | 72 ++++++-- .../gpu/nn/adagrad_gpu_kernel.h | 41 ++--- 4 files changed, 222 insertions(+), 57 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cu index b1a0eb2514e..15785e639c1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cu @@ -26,45 +26,171 @@ __device__ __forceinline__ half SqrtFunc(half input) { return hsqrt(input); } -template +template __global__ void ApplyAdagradKernel(const size_t size, const bool update_slots, - const T *learning_rate, - const T *gradient, + const S *learning_rate, + const G *gradient, T *variable, - T *accumulation) { + T *accumulation, + T *variable_out, + T *accumulation_out) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { if (update_slots) { accumulation[i] += gradient[i] * gradient[i]; + accumulation_out[i] = accumulation[i]; } variable[i] -= learning_rate[0] * gradient[i] / SqrtFunc(accumulation[i]); + variable_out[i] = variable[i]; } } -template -void ApplyAdagrad(const size_t size, - const bool update_slots, - const T *learning_rate, - const T *gradient, - T *variable, - T *accumulation, - cudaStream_t cuda_stream) { - ApplyAdagradKernel<<< GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>( - size, update_slots, learning_rate, gradient, variable, accumulation); +template <> +__global__ void ApplyAdagradKernel(const size_t size, + const bool update_slots, + const float *learning_rate, + const half *gradient, + half *variable, + half *accumulation, + half *variable_out, + half *accumulation_out) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + if (update_slots) { + accumulation[i] += gradient[i] * gradient[i]; + accumulation_out[i] = accumulation[i]; + } + variable[i] -= __float2half(learning_rate[0]) * gradient[i] / SqrtFunc(accumulation[i]); + variable_out[i] = variable[i]; + } } -template void ApplyAdagrad(const size_t size, +template <> +__global__ void ApplyAdagradKernel(const size_t size, + const bool update_slots, + const float *learning_rate, + const half *gradient, + float *variable, + float *accumulation, + float *variable_out, + float *accumulation_out) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + if (update_slots) { + accumulation[i] += __half2float(gradient[i]) * __half2float(gradient[i]); + accumulation_out[i] = accumulation[i]; + } + variable[i] -= learning_rate[0] * __half2float(gradient[i]) / SqrtFunc(accumulation[i]); + variable_out[i] = variable[i]; + } +} + +template <> +__global__ void ApplyAdagradKernel(const size_t size, + const bool update_slots, + const half *learning_rate, + const float *gradient, + float *variable, + float *accumulation, + float *variable_out, + float *accumulation_out) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + if (update_slots) { + accumulation[i] += gradient[i] * gradient[i]; + accumulation_out[i] = accumulation[i]; + } + variable[i] -= __half2float(learning_rate[0]) * gradient[i] / SqrtFunc(accumulation[i]); + variable_out[i] = variable[i]; + } +} + +template <> +__global__ void ApplyAdagradKernel(const size_t size, + const bool update_slots, + const float *learning_rate, + const float *gradient, + half *variable, + half *accumulation, + half *variable_out, + half *accumulation_out) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + if (update_slots) { + accumulation[i] += __float2half(gradient[i]) * __float2half(gradient[i]); + accumulation_out[i] = accumulation[i]; + } + variable[i] -= __float2half(learning_rate[0]) * __float2half(gradient[i]) / SqrtFunc(accumulation[i]); + variable_out[i] = variable[i]; + } +} + +template +void ApplyAdagrad(const size_t size, + const bool update_slots, + const S *learning_rate, + const G *gradient, + T *variable, + T *accumulation, + T *variable_out, + T *accumulation_out, + cudaStream_t cuda_stream) { + ApplyAdagradKernel<<< GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>( + size, update_slots, learning_rate, gradient, variable, accumulation, variable_out, accumulation_out); +} + +template void ApplyAdagrad(const size_t size, const bool update_slots, const float *learning_rate, const float *gradient, float *variable, float *accumulation, + float *variable_out, + float *accumulation_out, cudaStream_t cuda_stream); -template void ApplyAdagrad(const size_t size, +template void ApplyAdagrad(const size_t size, const bool update_slots, const half *learning_rate, const half *gradient, half *variable, half *accumulation, + half *variable_out, + half *accumulation_out, + cudaStream_t cuda_stream); + +template void ApplyAdagrad(const size_t size, + const bool update_slots, + const float *learning_rate, + const half *gradient, + half *variable, + half *accumulation, + half *variable_out, + half *accumulation_out, + cudaStream_t cuda_stream); + +template void ApplyAdagrad(const size_t size, + const bool update_slots, + const float *learning_rate, + const half *gradient, + float *variable, + float *accumulation, + float *variable_out, + float *accumulation_out, + cudaStream_t cuda_stream); + +template void ApplyAdagrad(const size_t size, + const bool update_slots, + const half *learning_rate, + const float *gradient, + float *variable, + float *accumulation, + float *variable_out, + float *accumulation_out, + cudaStream_t cuda_stream); + +template void ApplyAdagrad(const size_t size, + const bool update_slots, + const float *learning_rate, + const float *gradient, + half *variable, + half *accumulation, + half *variable_out, + half *accumulation_out, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh index 3cfbd776e95..dc6e02f1be9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh @@ -18,13 +18,15 @@ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAGRAD_IMPL_H_ #include "runtime/device/gpu/cuda_common.h" -template +template void ApplyAdagrad(const size_t size, const bool update_slots, - const T *learning_rate, - const T *gradient, + const S *learning_rate, + const G *gradient, T *variable, T *accumulation, + T *variable_out, + T *accumulation_out, cudaStream_t stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAGRAD_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.cc index 25c459c14bc..5a4d27c8bd1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.cc @@ -18,23 +18,59 @@ namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_ONE(ApplyAdagrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - AdagradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(ApplyAdagrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - AdagradGpuKernel, half) +MS_REG_GPU_KERNEL_THREE(ApplyAdagrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + AdagradGpuKernel, float, float, float) +MS_REG_GPU_KERNEL_THREE(ApplyAdagrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + AdagradGpuKernel, half, half, half) +MS_REG_GPU_KERNEL_THREE(ApplyAdagrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + AdagradGpuKernel, half, float, half) +MS_REG_GPU_KERNEL_THREE(ApplyAdagrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + AdagradGpuKernel, float, float, half) +MS_REG_GPU_KERNEL_THREE(ApplyAdagrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + AdagradGpuKernel, float, half, float) +MS_REG_GPU_KERNEL_THREE(ApplyAdagrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + AdagradGpuKernel, half, float, float) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h index 7f9a87b5bb3..bf1290beab5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h @@ -24,7 +24,7 @@ namespace mindspore { namespace kernel { -template +template class AdagradGpuKernel : public GpuKernel { public: AdagradGpuKernel() @@ -36,6 +36,19 @@ class AdagradGpuKernel : public GpuKernel { const std::vector &GetOutputSizeList() const override { return output_size_list_; } const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs, void *stream_ptr) override { + T *variable = GetDeviceAddress(inputs, 0); + T *accumulation = GetDeviceAddress(inputs, 1); + S *learning_rate = GetDeviceAddress(inputs, 2); + G *gradient = GetDeviceAddress(inputs, 3); + T *variable_out = GetDeviceAddress(outputs, 0); + T *accumulation_out = GetDeviceAddress(outputs, 1); + ApplyAdagrad(inputs[0]->size / sizeof(T), update_slots, learning_rate, gradient, variable, accumulation, + variable_out, accumulation_out, reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); update_slots = AnfAlgo::GetNodeAttr(kernel_node, "update_slots"); @@ -45,47 +58,35 @@ class AdagradGpuKernel : public GpuKernel { } variable_size_ = sizeof(T); accumulation_size_ = sizeof(T); - learning_rate_size_ = sizeof(T); - gradient_size_ = sizeof(T); + learning_rate_size_ = sizeof(S); + gradient_size_ = sizeof(G); - auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); for (size_t i = 0; i < variable_shape.size(); i++) { variable_size_ *= variable_shape[i]; } - auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); for (size_t i = 0; i < accumulation_shape.size(); i++) { accumulation_size_ *= accumulation_shape[i]; } - auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); for (size_t i = 0; i < gradient_shape.size(); i++) { gradient_size_ *= gradient_shape[i]; } - InitSizeLists(); return true; } - bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, - void *stream_ptr) override { - T *variable = GetDeviceAddress(inputs, 0); - T *accumulation = GetDeviceAddress(inputs, 1); - T *learning_rate = GetDeviceAddress(inputs, 2); - T *gradient = GetDeviceAddress(inputs, 3); - ApplyAdagrad(inputs[0]->size / sizeof(T), update_slots, learning_rate, gradient, variable, accumulation, - reinterpret_cast(stream_ptr)); - return true; - } - protected: void InitSizeLists() override { input_size_list_.push_back(variable_size_); input_size_list_.push_back(accumulation_size_); input_size_list_.push_back(learning_rate_size_); input_size_list_.push_back(gradient_size_); - output_size_list_.push_back(0); - output_size_list_.push_back(0); + output_size_list_.push_back(variable_size_); + output_size_list_.push_back(accumulation_size_); } private: