adagrad: support ouput on gpu
This commit is contained in:
parent
a45f03eebf
commit
a11287c332
|
@ -32,16 +32,12 @@ __global__ void ApplyAdagradKernel(const size_t size,
|
|||
const S *learning_rate,
|
||||
const G *gradient,
|
||||
T *variable,
|
||||
T *accumulation,
|
||||
T *variable_out,
|
||||
T *accumulation_out) {
|
||||
T *accumulation) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
if (update_slots) {
|
||||
accumulation[i] += gradient[i] * gradient[i];
|
||||
accumulation_out[i] = accumulation[i];
|
||||
}
|
||||
variable[i] -= learning_rate[0] * gradient[i] / SqrtFunc(accumulation[i]);
|
||||
variable_out[i] = variable[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -51,16 +47,12 @@ __global__ void ApplyAdagradKernel(const size_t size,
|
|||
const float *learning_rate,
|
||||
const half *gradient,
|
||||
half *variable,
|
||||
half *accumulation,
|
||||
half *variable_out,
|
||||
half *accumulation_out) {
|
||||
half *accumulation) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
if (update_slots) {
|
||||
accumulation[i] += gradient[i] * gradient[i];
|
||||
accumulation_out[i] = accumulation[i];
|
||||
}
|
||||
variable[i] -= __float2half(learning_rate[0]) * gradient[i] / SqrtFunc(accumulation[i]);
|
||||
variable_out[i] = variable[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,16 +62,12 @@ __global__ void ApplyAdagradKernel(const size_t size,
|
|||
const float *learning_rate,
|
||||
const half *gradient,
|
||||
float *variable,
|
||||
float *accumulation,
|
||||
float *variable_out,
|
||||
float *accumulation_out) {
|
||||
float *accumulation) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
if (update_slots) {
|
||||
accumulation[i] += __half2float(gradient[i]) * __half2float(gradient[i]);
|
||||
accumulation_out[i] = accumulation[i];
|
||||
}
|
||||
variable[i] -= learning_rate[0] * __half2float(gradient[i]) / SqrtFunc(accumulation[i]);
|
||||
variable_out[i] = variable[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -89,16 +77,12 @@ __global__ void ApplyAdagradKernel(const size_t size,
|
|||
const half *learning_rate,
|
||||
const float *gradient,
|
||||
float *variable,
|
||||
float *accumulation,
|
||||
float *variable_out,
|
||||
float *accumulation_out) {
|
||||
float *accumulation) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
if (update_slots) {
|
||||
accumulation[i] += gradient[i] * gradient[i];
|
||||
accumulation_out[i] = accumulation[i];
|
||||
}
|
||||
variable[i] -= __half2float(learning_rate[0]) * gradient[i] / SqrtFunc(accumulation[i]);
|
||||
variable_out[i] = variable[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -108,16 +92,12 @@ __global__ void ApplyAdagradKernel(const size_t size,
|
|||
const float *learning_rate,
|
||||
const float *gradient,
|
||||
half *variable,
|
||||
half *accumulation,
|
||||
half *variable_out,
|
||||
half *accumulation_out) {
|
||||
half *accumulation) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
if (update_slots) {
|
||||
accumulation[i] += __float2half(gradient[i]) * __float2half(gradient[i]);
|
||||
accumulation_out[i] = accumulation[i];
|
||||
}
|
||||
variable[i] -= __float2half(learning_rate[0]) * __float2half(gradient[i]) / SqrtFunc(accumulation[i]);
|
||||
variable_out[i] = variable[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -128,11 +108,9 @@ void ApplyAdagrad(const size_t size,
|
|||
const G *gradient,
|
||||
T *variable,
|
||||
T *accumulation,
|
||||
T *variable_out,
|
||||
T *accumulation_out,
|
||||
cudaStream_t cuda_stream) {
|
||||
ApplyAdagradKernel<<< GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
|
||||
size, update_slots, learning_rate, gradient, variable, accumulation, variable_out, accumulation_out);
|
||||
size, update_slots, learning_rate, gradient, variable, accumulation);
|
||||
}
|
||||
|
||||
template void ApplyAdagrad<float, float, float>(const size_t size,
|
||||
|
@ -141,8 +119,6 @@ template void ApplyAdagrad<float, float, float>(const size_t size,
|
|||
const float *gradient,
|
||||
float *variable,
|
||||
float *accumulation,
|
||||
float *variable_out,
|
||||
float *accumulation_out,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void ApplyAdagrad<half, half, half>(const size_t size,
|
||||
|
@ -151,8 +127,6 @@ template void ApplyAdagrad<half, half, half>(const size_t size,
|
|||
const half *gradient,
|
||||
half *variable,
|
||||
half *accumulation,
|
||||
half *variable_out,
|
||||
half *accumulation_out,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void ApplyAdagrad<half, float, half>(const size_t size,
|
||||
|
@ -161,8 +135,6 @@ template void ApplyAdagrad<half, float, half>(const size_t size,
|
|||
const half *gradient,
|
||||
half *variable,
|
||||
half *accumulation,
|
||||
half *variable_out,
|
||||
half *accumulation_out,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void ApplyAdagrad<float, float, half>(const size_t size,
|
||||
|
@ -171,8 +143,6 @@ template void ApplyAdagrad<float, float, half>(const size_t size,
|
|||
const half *gradient,
|
||||
float *variable,
|
||||
float *accumulation,
|
||||
float *variable_out,
|
||||
float *accumulation_out,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void ApplyAdagrad<float, half, float>(const size_t size,
|
||||
|
@ -181,8 +151,6 @@ template void ApplyAdagrad<float, half, float>(const size_t size,
|
|||
const float *gradient,
|
||||
float *variable,
|
||||
float *accumulation,
|
||||
float *variable_out,
|
||||
float *accumulation_out,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void ApplyAdagrad<half, float, float>(const size_t size,
|
||||
|
@ -191,6 +159,4 @@ template void ApplyAdagrad<half, float, float>(const size_t size,
|
|||
const float *gradient,
|
||||
half *variable,
|
||||
half *accumulation,
|
||||
half *variable_out,
|
||||
half *accumulation_out,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -25,8 +25,6 @@ void ApplyAdagrad(const size_t size,
|
|||
const G *gradient,
|
||||
T *variable,
|
||||
T *accumulation,
|
||||
T *variable_out,
|
||||
T *accumulation_out,
|
||||
cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAGRAD_IMPL_H_
|
||||
|
|
|
@ -45,7 +45,17 @@ class AdagradGpuKernel : public GpuKernel {
|
|||
T *variable_out = GetDeviceAddress<T>(outputs, 0);
|
||||
T *accumulation_out = GetDeviceAddress<T>(outputs, 1);
|
||||
ApplyAdagrad(inputs[0]->size / sizeof(T), update_slots, learning_rate, gradient, variable, accumulation,
|
||||
variable_out, accumulation_out, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(&variable_out[0], &variable[0], variable_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(&accumulation_out[0], &accumulation[0], accumulation_size_,
|
||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -61,17 +71,17 @@ class AdagradGpuKernel : public GpuKernel {
|
|||
learning_rate_size_ = sizeof(S);
|
||||
gradient_size_ = sizeof(G);
|
||||
|
||||
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
|
||||
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < variable_shape.size(); i++) {
|
||||
variable_size_ *= variable_shape[i];
|
||||
}
|
||||
|
||||
auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
|
||||
auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
for (size_t i = 0; i < accumulation_shape.size(); i++) {
|
||||
accumulation_size_ *= accumulation_shape[i];
|
||||
}
|
||||
|
||||
auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
|
||||
for (size_t i = 0; i < gradient_shape.size(); i++) {
|
||||
gradient_size_ *= gradient_shape[i];
|
||||
}
|
||||
|
|
|
@ -36,8 +36,8 @@ class Net(nn.Cell):
|
|||
self.accum = Parameter(Tensor(accum_np), name="accum")
|
||||
|
||||
def construct(self, lr, grad):
|
||||
self.apply_adagrad(self.var, self.accum, lr, grad)
|
||||
return self.var, self.accum
|
||||
z = self.apply_adagrad(self.var, self.accum, lr, grad)
|
||||
return z
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
Loading…
Reference in New Issue