!6584 GPU momentum supports use_nesterov.

Merge pull request !6584 from ZPaC/master-momentum-supports-use_nesterov
This commit is contained in:
mindspore-ci-bot 2020-09-20 11:20:37 +08:00 committed by Gitee
commit 2c1004eecd
3 changed files with 57 additions and 22 deletions

View File

@ -17,36 +17,60 @@
#include "momentum_impl.cuh"
template <typename T, typename S, typename G>
__global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T *accumulation, const S *learning_rate,
const G *gradient, const S *momentum) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
accumulation[i] = momentum[0] * accumulation[i] + gradient[i];
variable[i] -= learning_rate[0] * accumulation[i];
const G *gradient, const S *momentum, bool use_nesterov) {
if (use_nesterov) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
accumulation[i] = momentum[0] * accumulation[i] + gradient[i];
variable[i] -= gradient[i] * learning_rate[0] + accumulation[i] * momentum[0] * learning_rate[0];
}
} else {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
accumulation[i] = momentum[0] * accumulation[i] + gradient[i];
variable[i] -= learning_rate[0] * accumulation[i];
}
}
return;
}
template <>
__global__ void MomentumUpdateVariableKernel(const size_t size, half *variable, half *accumulation,
const float *learning_rate, const half *gradient, const float *momentum) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
accumulation[i] = __float2half(momentum[0]) * accumulation[i] + gradient[i];
variable[i] -= __float2half(learning_rate[0]) * accumulation[i];
const float *learning_rate, const half *gradient, const float *momentum,
bool use_nesterov) {
if (use_nesterov) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
accumulation[i] = __float2half(momentum[0]) * accumulation[i] + gradient[i];
variable[i] -= gradient[i] * __float2half(learning_rate[0]) +
accumulation[i] * __float2half(momentum[0]) * __float2half(learning_rate[0]);
}
} else {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
accumulation[i] = __float2half(momentum[0]) * accumulation[i] + gradient[i];
variable[i] -= __float2half(learning_rate[0]) * accumulation[i];
}
}
return;
}
template <>
__global__ void MomentumUpdateVariableKernel(const size_t size, float *variable, float *accumulation,
const float *learning_rate, const half *gradient, const float *momentum) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
accumulation[i] = momentum[0] * accumulation[i] + __half2float(gradient[i]);
variable[i] -= learning_rate[0] * accumulation[i];
const float *learning_rate, const half *gradient, const float *momentum,
bool use_nesterov) {
if (use_nesterov) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
accumulation[i] = momentum[0] * accumulation[i] + __half2float(gradient[i]);
variable[i] -= __half2float(gradient[i]) * learning_rate[0] + accumulation[i] * momentum[0] * learning_rate[0];
}
} else {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
accumulation[i] = momentum[0] * accumulation[i] + __half2float(gradient[i]);
variable[i] -= learning_rate[0] * accumulation[i];
}
}
return;
}
template <typename T, typename S, typename G>
void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient,
const S *momentum, cudaStream_t cuda_stream) {
MomentumUpdateVariableKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, accumulation,
learning_rate, gradient, momentum);
const S *momentum, bool use_nesterov, cudaStream_t cuda_stream) {
MomentumUpdateVariableKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
size, variable, accumulation, learning_rate, gradient, momentum, use_nesterov);
return;
}
@ -91,16 +115,20 @@ void FusedScaleMomentum(const size_t element_num, T *scale, T *variable, T *accu
template void MomentumUpdateVariable<float, float, float>(const size_t size, float *variable, float *accumulation,
const float *learning_rate, const float *gradient,
const float *momentum, cudaStream_t cuda_stream);
const float *momentum, bool use_nesterov,
cudaStream_t cuda_stream);
template void MomentumUpdateVariable<half, half, half>(const size_t size, half *variable, half *accumulation,
const half *learning_rate, const half *gradient,
const half *momentum, cudaStream_t cuda_stream);
const half *momentum, bool use_nesterov,
cudaStream_t cuda_stream);
template void MomentumUpdateVariable<half, float, half>(const size_t size, half *variable, half *accumulation,
const float *learning_rate, const half *gradient,
const float *momentum, cudaStream_t cuda_stream);
const float *momentum, bool use_nesterov,
cudaStream_t cuda_stream);
template void MomentumUpdateVariable<float, float, half>(const size_t size, float *variable, float *accumulation,
const float *learning_rate, const half *gradient,
const float *momentum, cudaStream_t cuda_stream);
const float *momentum, bool use_nesterov,
cudaStream_t cuda_stream);
template void FusedWeightDecayScaleMomentum(const size_t element_num, float *weight_decay, float *scale,
float *variable, float *accumulation, const float *learning_rate,

View File

@ -20,7 +20,7 @@
#include "runtime/device/gpu/cuda_common.h"
template <typename T, typename S, typename G>
void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient,
const S *momentum, cudaStream_t cuda_stream);
const S *momentum, bool use_nesterov, cudaStream_t cuda_stream);
template <typename T, typename S>
void FusedWeightDecayScaleMomentum(const size_t element_num, T *weight_decay, T *scale, T *variable, T *accumulation,
const T *learning_rate, const S *gradient, const T *momentum,

View File

@ -27,7 +27,12 @@ template <typename T, typename S, typename G>
class MomentumGpuKernel : public GpuKernel {
public:
MomentumGpuKernel()
: variable_size_(0), accumulation_size_(0), learning_rate_size_(0), gradient_size_(0), momentum_size_(0) {}
: use_nesterov_(false),
variable_size_(0),
accumulation_size_(0),
learning_rate_size_(0),
gradient_size_(0),
momentum_size_(0) {}
~MomentumGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
@ -41,7 +46,7 @@ class MomentumGpuKernel : public GpuKernel {
G *gradient = GetDeviceAddress<G>(inputs, 3);
S *momentum = GetDeviceAddress<S>(inputs, 4);
MomentumUpdateVariable(inputs[0]->size / sizeof(T), variable, accumulation, learning_rate, gradient, momentum,
reinterpret_cast<cudaStream_t>(stream_ptr));
use_nesterov_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
@ -50,6 +55,7 @@ class MomentumGpuKernel : public GpuKernel {
MS_LOG(ERROR) << "Input number is " << input_num << ", but momentum needs 5 inputs.";
return false;
}
use_nesterov_ = GetAttr<bool>(kernel_node, "use_nesterov");
variable_size_ = sizeof(T);
accumulation_size_ = sizeof(T);
@ -84,6 +90,7 @@ class MomentumGpuKernel : public GpuKernel {
}
private:
bool use_nesterov_;
size_t variable_size_;
size_t accumulation_size_;
size_t learning_rate_size_;