!13404 update applyAdagrad

From: @zyx5256
Reviewed-by: @wuxuejian,@liangchenghui
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-03-17 09:41:18 +08:00 committed by Gitee
commit 8e617629fc
4 changed files with 222 additions and 57 deletions

View File

@ -26,45 +26,171 @@ __device__ __forceinline__ half SqrtFunc(half input) {
return hsqrt(input); return hsqrt(input);
} }
template <typename T> template <typename T, typename S, typename G>
__global__ void ApplyAdagradKernel(const size_t size, __global__ void ApplyAdagradKernel(const size_t size,
const bool update_slots, const bool update_slots,
const T *learning_rate, const S *learning_rate,
const T *gradient, const G *gradient,
T *variable, T *variable,
T *accumulation) { T *accumulation,
T *variable_out,
T *accumulation_out) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
if (update_slots) { if (update_slots) {
accumulation[i] += gradient[i] * gradient[i]; accumulation[i] += gradient[i] * gradient[i];
accumulation_out[i] = accumulation[i];
} }
variable[i] -= learning_rate[0] * gradient[i] / SqrtFunc(accumulation[i]); variable[i] -= learning_rate[0] * gradient[i] / SqrtFunc(accumulation[i]);
variable_out[i] = variable[i];
} }
} }
template <typename T> template <>
void ApplyAdagrad(const size_t size, __global__ void ApplyAdagradKernel(const size_t size,
const bool update_slots, const bool update_slots,
const T *learning_rate, const float *learning_rate,
const T *gradient, const half *gradient,
T *variable, half *variable,
T *accumulation, half *accumulation,
cudaStream_t cuda_stream) { half *variable_out,
ApplyAdagradKernel<<< GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>( half *accumulation_out) {
size, update_slots, learning_rate, gradient, variable, accumulation); for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
if (update_slots) {
accumulation[i] += gradient[i] * gradient[i];
accumulation_out[i] = accumulation[i];
}
variable[i] -= __float2half(learning_rate[0]) * gradient[i] / SqrtFunc(accumulation[i]);
variable_out[i] = variable[i];
}
} }
template void ApplyAdagrad<float>(const size_t size, template <>
__global__ void ApplyAdagradKernel(const size_t size,
const bool update_slots,
const float *learning_rate,
const half *gradient,
float *variable,
float *accumulation,
float *variable_out,
float *accumulation_out) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
if (update_slots) {
accumulation[i] += __half2float(gradient[i]) * __half2float(gradient[i]);
accumulation_out[i] = accumulation[i];
}
variable[i] -= learning_rate[0] * __half2float(gradient[i]) / SqrtFunc(accumulation[i]);
variable_out[i] = variable[i];
}
}
template <>
__global__ void ApplyAdagradKernel(const size_t size,
const bool update_slots,
const half *learning_rate,
const float *gradient,
float *variable,
float *accumulation,
float *variable_out,
float *accumulation_out) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
if (update_slots) {
accumulation[i] += gradient[i] * gradient[i];
accumulation_out[i] = accumulation[i];
}
variable[i] -= __half2float(learning_rate[0]) * gradient[i] / SqrtFunc(accumulation[i]);
variable_out[i] = variable[i];
}
}
template <>
__global__ void ApplyAdagradKernel(const size_t size,
const bool update_slots,
const float *learning_rate,
const float *gradient,
half *variable,
half *accumulation,
half *variable_out,
half *accumulation_out) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
if (update_slots) {
accumulation[i] += __float2half(gradient[i]) * __float2half(gradient[i]);
accumulation_out[i] = accumulation[i];
}
variable[i] -= __float2half(learning_rate[0]) * __float2half(gradient[i]) / SqrtFunc(accumulation[i]);
variable_out[i] = variable[i];
}
}
template <typename T, typename S, typename G>
void ApplyAdagrad(const size_t size,
const bool update_slots,
const S *learning_rate,
const G *gradient,
T *variable,
T *accumulation,
T *variable_out,
T *accumulation_out,
cudaStream_t cuda_stream) {
ApplyAdagradKernel<<< GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
size, update_slots, learning_rate, gradient, variable, accumulation, variable_out, accumulation_out);
}
template void ApplyAdagrad<float, float, float>(const size_t size,
const bool update_slots, const bool update_slots,
const float *learning_rate, const float *learning_rate,
const float *gradient, const float *gradient,
float *variable, float *variable,
float *accumulation, float *accumulation,
float *variable_out,
float *accumulation_out,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void ApplyAdagrad<half>(const size_t size, template void ApplyAdagrad<half, half, half>(const size_t size,
const bool update_slots, const bool update_slots,
const half *learning_rate, const half *learning_rate,
const half *gradient, const half *gradient,
half *variable, half *variable,
half *accumulation, half *accumulation,
half *variable_out,
half *accumulation_out,
cudaStream_t cuda_stream);
template void ApplyAdagrad<half, float, half>(const size_t size,
const bool update_slots,
const float *learning_rate,
const half *gradient,
half *variable,
half *accumulation,
half *variable_out,
half *accumulation_out,
cudaStream_t cuda_stream);
template void ApplyAdagrad<float, float, half>(const size_t size,
const bool update_slots,
const float *learning_rate,
const half *gradient,
float *variable,
float *accumulation,
float *variable_out,
float *accumulation_out,
cudaStream_t cuda_stream);
template void ApplyAdagrad<float, half, float>(const size_t size,
const bool update_slots,
const half *learning_rate,
const float *gradient,
float *variable,
float *accumulation,
float *variable_out,
float *accumulation_out,
cudaStream_t cuda_stream);
template void ApplyAdagrad<half, float, float>(const size_t size,
const bool update_slots,
const float *learning_rate,
const float *gradient,
half *variable,
half *accumulation,
half *variable_out,
half *accumulation_out,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);

View File

@ -18,13 +18,15 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAGRAD_IMPL_H_ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAGRAD_IMPL_H_
#include "runtime/device/gpu/cuda_common.h" #include "runtime/device/gpu/cuda_common.h"
template <typename T> template <typename T, typename S, typename G>
void ApplyAdagrad(const size_t size, void ApplyAdagrad(const size_t size,
const bool update_slots, const bool update_slots,
const T *learning_rate, const S *learning_rate,
const T *gradient, const G *gradient,
T *variable, T *variable,
T *accumulation, T *accumulation,
T *variable_out,
T *accumulation_out,
cudaStream_t stream); cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAGRAD_IMPL_H_ #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAGRAD_IMPL_H_

View File

@ -18,23 +18,59 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_ONE(ApplyAdagrad, MS_REG_GPU_KERNEL_THREE(ApplyAdagrad,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
AdagradGpuKernel, float) AdagradGpuKernel, float, float, float)
MS_REG_GPU_KERNEL_ONE(ApplyAdagrad, MS_REG_GPU_KERNEL_THREE(ApplyAdagrad,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16), .AddOutputAttr(kNumberTypeFloat16),
AdagradGpuKernel, half) AdagradGpuKernel, half, half, half)
MS_REG_GPU_KERNEL_THREE(ApplyAdagrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
AdagradGpuKernel, half, float, half)
MS_REG_GPU_KERNEL_THREE(ApplyAdagrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
AdagradGpuKernel, float, float, half)
MS_REG_GPU_KERNEL_THREE(ApplyAdagrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
AdagradGpuKernel, float, half, float)
MS_REG_GPU_KERNEL_THREE(ApplyAdagrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
AdagradGpuKernel, half, float, float)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -24,7 +24,7 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
template <typename T> template <typename T, typename S, typename G>
class AdagradGpuKernel : public GpuKernel { class AdagradGpuKernel : public GpuKernel {
public: public:
AdagradGpuKernel() AdagradGpuKernel()
@ -36,6 +36,19 @@ class AdagradGpuKernel : public GpuKernel {
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> & /*workspace*/,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *variable = GetDeviceAddress<T>(inputs, 0);
T *accumulation = GetDeviceAddress<T>(inputs, 1);
S *learning_rate = GetDeviceAddress<S>(inputs, 2);
G *gradient = GetDeviceAddress<G>(inputs, 3);
T *variable_out = GetDeviceAddress<T>(outputs, 0);
T *accumulation_out = GetDeviceAddress<T>(outputs, 1);
ApplyAdagrad(inputs[0]->size / sizeof(T), update_slots, learning_rate, gradient, variable, accumulation,
variable_out, accumulation_out, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
update_slots = AnfAlgo::GetNodeAttr<bool>(kernel_node, "update_slots"); update_slots = AnfAlgo::GetNodeAttr<bool>(kernel_node, "update_slots");
@ -45,47 +58,35 @@ class AdagradGpuKernel : public GpuKernel {
} }
variable_size_ = sizeof(T); variable_size_ = sizeof(T);
accumulation_size_ = sizeof(T); accumulation_size_ = sizeof(T);
learning_rate_size_ = sizeof(T); learning_rate_size_ = sizeof(S);
gradient_size_ = sizeof(T); gradient_size_ = sizeof(G);
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
for (size_t i = 0; i < variable_shape.size(); i++) { for (size_t i = 0; i < variable_shape.size(); i++) {
variable_size_ *= variable_shape[i]; variable_size_ *= variable_shape[i];
} }
auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
for (size_t i = 0; i < accumulation_shape.size(); i++) { for (size_t i = 0; i < accumulation_shape.size(); i++) {
accumulation_size_ *= accumulation_shape[i]; accumulation_size_ *= accumulation_shape[i];
} }
auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
for (size_t i = 0; i < gradient_shape.size(); i++) { for (size_t i = 0; i < gradient_shape.size(); i++) {
gradient_size_ *= gradient_shape[i]; gradient_size_ *= gradient_shape[i];
} }
InitSizeLists(); InitSizeLists();
return true; return true;
} }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
void *stream_ptr) override {
T *variable = GetDeviceAddress<T>(inputs, 0);
T *accumulation = GetDeviceAddress<T>(inputs, 1);
T *learning_rate = GetDeviceAddress<T>(inputs, 2);
T *gradient = GetDeviceAddress<T>(inputs, 3);
ApplyAdagrad(inputs[0]->size / sizeof(T), update_slots, learning_rate, gradient, variable, accumulation,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
protected: protected:
void InitSizeLists() override { void InitSizeLists() override {
input_size_list_.push_back(variable_size_); input_size_list_.push_back(variable_size_);
input_size_list_.push_back(accumulation_size_); input_size_list_.push_back(accumulation_size_);
input_size_list_.push_back(learning_rate_size_); input_size_list_.push_back(learning_rate_size_);
input_size_list_.push_back(gradient_size_); input_size_list_.push_back(gradient_size_);
output_size_list_.push_back(0); output_size_list_.push_back(variable_size_);
output_size_list_.push_back(0); output_size_list_.push_back(accumulation_size_);
} }
private: private: