From a11287c33272f7f5bd868ccda62e276dc7a08f6b Mon Sep 17 00:00:00 2001 From: zhuyuxiao Date: Tue, 23 Mar 2021 17:09:34 +0800 Subject: [PATCH] adagrad: support ouput on gpu --- .../gpu/cuda_impl/adagrad_impl.cu | 46 +++---------------- .../gpu/cuda_impl/adagrad_impl.cuh | 2 - .../gpu/nn/adagrad_gpu_kernel.h | 18 ++++++-- tests/st/ops/gpu/test_adagrad_op.py | 4 +- 4 files changed, 22 insertions(+), 48 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cu index 15785e639c1..53fa79f0244 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cu @@ -32,16 +32,12 @@ __global__ void ApplyAdagradKernel(const size_t size, const S *learning_rate, const G *gradient, T *variable, - T *accumulation, - T *variable_out, - T *accumulation_out) { + T *accumulation) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { if (update_slots) { accumulation[i] += gradient[i] * gradient[i]; - accumulation_out[i] = accumulation[i]; } variable[i] -= learning_rate[0] * gradient[i] / SqrtFunc(accumulation[i]); - variable_out[i] = variable[i]; } } @@ -51,16 +47,12 @@ __global__ void ApplyAdagradKernel(const size_t size, const float *learning_rate, const half *gradient, half *variable, - half *accumulation, - half *variable_out, - half *accumulation_out) { + half *accumulation) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { if (update_slots) { accumulation[i] += gradient[i] * gradient[i]; - accumulation_out[i] = accumulation[i]; } variable[i] -= __float2half(learning_rate[0]) * gradient[i] / SqrtFunc(accumulation[i]); - variable_out[i] = variable[i]; } } @@ -70,16 +62,12 @@ __global__ void ApplyAdagradKernel(const size_t size, const float *learning_rate, const half *gradient, float *variable, - float *accumulation, - float *variable_out, - float *accumulation_out) { + float *accumulation) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { if (update_slots) { accumulation[i] += __half2float(gradient[i]) * __half2float(gradient[i]); - accumulation_out[i] = accumulation[i]; } variable[i] -= learning_rate[0] * __half2float(gradient[i]) / SqrtFunc(accumulation[i]); - variable_out[i] = variable[i]; } } @@ -89,16 +77,12 @@ __global__ void ApplyAdagradKernel(const size_t size, const half *learning_rate, const float *gradient, float *variable, - float *accumulation, - float *variable_out, - float *accumulation_out) { + float *accumulation) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { if (update_slots) { accumulation[i] += gradient[i] * gradient[i]; - accumulation_out[i] = accumulation[i]; } variable[i] -= __half2float(learning_rate[0]) * gradient[i] / SqrtFunc(accumulation[i]); - variable_out[i] = variable[i]; } } @@ -108,16 +92,12 @@ __global__ void ApplyAdagradKernel(const size_t size, const float *learning_rate, const float *gradient, half *variable, - half *accumulation, - half *variable_out, - half *accumulation_out) { + half *accumulation) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { if (update_slots) { accumulation[i] += __float2half(gradient[i]) * __float2half(gradient[i]); - accumulation_out[i] = accumulation[i]; } variable[i] -= __float2half(learning_rate[0]) * __float2half(gradient[i]) / SqrtFunc(accumulation[i]); - variable_out[i] = variable[i]; } } @@ -128,11 +108,9 @@ void ApplyAdagrad(const size_t size, const G *gradient, T *variable, T *accumulation, - T *variable_out, - T *accumulation_out, cudaStream_t cuda_stream) { ApplyAdagradKernel<<< GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>( - size, update_slots, learning_rate, gradient, variable, accumulation, variable_out, accumulation_out); + size, update_slots, learning_rate, gradient, variable, accumulation); } template void ApplyAdagrad(const size_t size, @@ -141,8 +119,6 @@ template void ApplyAdagrad(const size_t size, const float *gradient, float *variable, float *accumulation, - float *variable_out, - float *accumulation_out, cudaStream_t cuda_stream); template void ApplyAdagrad(const size_t size, @@ -151,8 +127,6 @@ template void ApplyAdagrad(const size_t size, const half *gradient, half *variable, half *accumulation, - half *variable_out, - half *accumulation_out, cudaStream_t cuda_stream); template void ApplyAdagrad(const size_t size, @@ -161,8 +135,6 @@ template void ApplyAdagrad(const size_t size, const half *gradient, half *variable, half *accumulation, - half *variable_out, - half *accumulation_out, cudaStream_t cuda_stream); template void ApplyAdagrad(const size_t size, @@ -171,8 +143,6 @@ template void ApplyAdagrad(const size_t size, const half *gradient, float *variable, float *accumulation, - float *variable_out, - float *accumulation_out, cudaStream_t cuda_stream); template void ApplyAdagrad(const size_t size, @@ -181,8 +151,6 @@ template void ApplyAdagrad(const size_t size, const float *gradient, float *variable, float *accumulation, - float *variable_out, - float *accumulation_out, cudaStream_t cuda_stream); template void ApplyAdagrad(const size_t size, @@ -191,6 +159,4 @@ template void ApplyAdagrad(const size_t size, const float *gradient, half *variable, half *accumulation, - half *variable_out, - half *accumulation_out, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh index dc6e02f1be9..79819a7cfef 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh @@ -25,8 +25,6 @@ void ApplyAdagrad(const size_t size, const G *gradient, T *variable, T *accumulation, - T *variable_out, - T *accumulation_out, cudaStream_t stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAGRAD_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h index bf1290beab5..e875e2fd029 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h @@ -45,7 +45,17 @@ class AdagradGpuKernel : public GpuKernel { T *variable_out = GetDeviceAddress(outputs, 0); T *accumulation_out = GetDeviceAddress(outputs, 1); ApplyAdagrad(inputs[0]->size / sizeof(T), update_slots, learning_rate, gradient, variable, accumulation, - variable_out, accumulation_out, reinterpret_cast(stream_ptr)); + reinterpret_cast(stream_ptr)); + + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(&variable_out[0], &variable[0], variable_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync output failed"); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(&accumulation_out[0], &accumulation[0], accumulation_size_, + cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync output failed"); + return true; } @@ -61,17 +71,17 @@ class AdagradGpuKernel : public GpuKernel { learning_rate_size_ = sizeof(S); gradient_size_ = sizeof(G); - auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); for (size_t i = 0; i < variable_shape.size(); i++) { variable_size_ *= variable_shape[i]; } - auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); for (size_t i = 0; i < accumulation_shape.size(); i++) { accumulation_size_ *= accumulation_shape[i]; } - auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); for (size_t i = 0; i < gradient_shape.size(); i++) { gradient_size_ *= gradient_shape[i]; } diff --git a/tests/st/ops/gpu/test_adagrad_op.py b/tests/st/ops/gpu/test_adagrad_op.py index 7153595f55d..ffe34aa96fd 100644 --- a/tests/st/ops/gpu/test_adagrad_op.py +++ b/tests/st/ops/gpu/test_adagrad_op.py @@ -36,8 +36,8 @@ class Net(nn.Cell): self.accum = Parameter(Tensor(accum_np), name="accum") def construct(self, lr, grad): - self.apply_adagrad(self.var, self.accum, lr, grad) - return self.var, self.accum + z = self.apply_adagrad(self.var, self.accum, lr, grad) + return z @pytest.mark.level0