optimize lamb
This commit is contained in:
parent
8167c4dbe4
commit
ab78aa86ee
|
@ -24,7 +24,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
// Lamb's inputs: param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, gradient, decay_flag
|
// Lamb's inputs: param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, gradient
|
||||||
constexpr size_t kParamIndex = 1;
|
constexpr size_t kParamIndex = 1;
|
||||||
constexpr size_t kMIndex = 2;
|
constexpr size_t kMIndex = 2;
|
||||||
constexpr size_t kVIndex = 3;
|
constexpr size_t kVIndex = 3;
|
||||||
|
@ -35,10 +35,9 @@ constexpr size_t kEpsilonIndex = 7;
|
||||||
constexpr size_t kWeightDecayIndex = 8;
|
constexpr size_t kWeightDecayIndex = 8;
|
||||||
constexpr size_t kGlobalStepIndex = 9;
|
constexpr size_t kGlobalStepIndex = 9;
|
||||||
constexpr size_t kGradientIndex = 10;
|
constexpr size_t kGradientIndex = 10;
|
||||||
constexpr size_t kDecayFlagIndex = 11;
|
constexpr size_t kUMonadIndex = 11;
|
||||||
constexpr size_t kUMonadIndex = 12;
|
constexpr size_t kLambInputNum = 10;
|
||||||
constexpr size_t kLambInputNum = 11;
|
constexpr size_t kLambInputNumWithUMonad = 11;
|
||||||
constexpr size_t kLambInputNumWithUMonad = 12;
|
|
||||||
constexpr size_t kLambApplyOptimizerAssignOutputNum = 3;
|
constexpr size_t kLambApplyOptimizerAssignOutputNum = 3;
|
||||||
constexpr size_t kLambApplyOptimizerAssignUpdateIndex = 0;
|
constexpr size_t kLambApplyOptimizerAssignUpdateIndex = 0;
|
||||||
|
|
||||||
|
@ -246,7 +245,7 @@ const AnfNodePtr LambFission::Process(const FuncGraphPtr &graph, const AnfNodePt
|
||||||
auto new_global_step = CreateCastNode(graph, global_step_node, kNumberTypeFloat32);
|
auto new_global_step = CreateCastNode(graph, global_step_node, kNumberTypeFloat32);
|
||||||
|
|
||||||
// cast delay flag to float32
|
// cast delay flag to float32
|
||||||
auto weight_decay_flag = CreateCastNode(graph, ori_inputs[kDecayFlagIndex], kNumberTypeFloat32);
|
auto weight_decay_flag = CreateValueNode(graph, 1.0);
|
||||||
|
|
||||||
auto num_one = CreateValueNode(graph, 1.0);
|
auto num_one = CreateValueNode(graph, 1.0);
|
||||||
// create 1-beta1
|
// create 1-beta1
|
||||||
|
|
|
@ -21,9 +21,9 @@ const int32_t kSqareNum = 2;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
|
__global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
|
||||||
const float *epsilon, const T *decay, const int32_t *global_step,
|
const float *epsilon, const float *decay, const int32_t *global_step,
|
||||||
const T *gradient, const bool *decay_flag, float *update, float *var_float,
|
const T *gradient, float *update, float *var_float, float *grad_float,
|
||||||
float *grad_float, float *g_hat_var) {
|
float *g_hat_var) {
|
||||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||||
float next_m = (beta1[0] * m[i] + (1 - beta1[0]) * gradient[i]);
|
float next_m = (beta1[0] * m[i] + (1 - beta1[0]) * gradient[i]);
|
||||||
float next_v = (beta2[0] * v[i] + (1 - beta2[0]) * pow(gradient[i], kSqareNum));
|
float next_v = (beta2[0] * v[i] + (1 - beta2[0]) * pow(gradient[i], kSqareNum));
|
||||||
|
@ -33,9 +33,7 @@ __global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v,
|
||||||
grad_float[i] = gradient[i];
|
grad_float[i] = gradient[i];
|
||||||
g_hat_var[i] = (next_mm / sqrt(next_vv + epsilon[0]) + decay[0] * variable[i]);
|
g_hat_var[i] = (next_mm / sqrt(next_vv + epsilon[0]) + decay[0] * variable[i]);
|
||||||
update[i] = next_mm / (sqrt(next_vv) - epsilon[0]);
|
update[i] = next_mm / (sqrt(next_vv) - epsilon[0]);
|
||||||
if (decay_flag[0]) {
|
update[i] += decay[0] * variable[i];
|
||||||
update[i] += decay[0] * variable[i];
|
|
||||||
}
|
|
||||||
m[i] = next_m;
|
m[i] = next_m;
|
||||||
v[i] = next_v;
|
v[i] = next_v;
|
||||||
}
|
}
|
||||||
|
@ -43,13 +41,13 @@ __global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v,
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__global__ void ApplyLambEralyKernel(const size_t size, half *variable, half *m, half *v, const float *beta1,
|
__global__ void ApplyLambEralyKernel(const size_t size, half *variable, half *m, half *v, const float *beta1,
|
||||||
const float *beta2, const float *epsilon, const half *decay,
|
const float *beta2, const float *epsilon, const float *decay,
|
||||||
const int32_t *global_step, const half *gradient, const bool *decay_flag,
|
const int32_t *global_step, const half *gradient, float *update, float *var_float,
|
||||||
float *update, float *var_float, float *grad_float, float *g_hat_var) {
|
float *grad_float, float *g_hat_var) {
|
||||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||||
float float_gradient = __half2float(gradient[i]);
|
float float_gradient = __half2float(gradient[i]);
|
||||||
float float_var = __half2float(variable[i]);
|
float float_var = __half2float(variable[i]);
|
||||||
float float_decay = __half2float(decay[0]);
|
float float_decay = decay[0];
|
||||||
|
|
||||||
float next_m = (beta1[0] * __half2float(m[i]) + (1 - beta1[0]) * float_gradient);
|
float next_m = (beta1[0] * __half2float(m[i]) + (1 - beta1[0]) * float_gradient);
|
||||||
float next_v = (beta2[0] * __half2float(v[i]) + (1 - beta2[0]) * pow(float_gradient, kSqareNum));
|
float next_v = (beta2[0] * __half2float(v[i]) + (1 - beta2[0]) * pow(float_gradient, kSqareNum));
|
||||||
|
@ -59,16 +57,14 @@ __global__ void ApplyLambEralyKernel(const size_t size, half *variable, half *m,
|
||||||
grad_float[i] = float_gradient;
|
grad_float[i] = float_gradient;
|
||||||
g_hat_var[i] = next_mm / sqrt(next_vv + epsilon[0]) + float_decay * float_var;
|
g_hat_var[i] = next_mm / sqrt(next_vv + epsilon[0]) + float_decay * float_var;
|
||||||
update[i] = next_mm / (sqrt(next_vv) - epsilon[0]);
|
update[i] = next_mm / (sqrt(next_vv) - epsilon[0]);
|
||||||
if (decay_flag[0]) {
|
update[i] += float_decay * float_var;
|
||||||
update[i] += float_decay * float_var;
|
|
||||||
}
|
|
||||||
m[i] = __float2half(next_m);
|
m[i] = __float2half(next_m);
|
||||||
v[i] = __float2half(next_v);
|
v[i] = __float2half(next_v);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void ApplyLambAfterNormKernel(const size_t size, T *variable, const T *lr, const float *update,
|
__global__ void ApplyLambAfterNormKernel(const size_t size, T *variable, const float *lr, const float *update,
|
||||||
const float *trust_ratio) {
|
const float *trust_ratio) {
|
||||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||||
variable[i] = variable[i] - trust_ratio[0] * lr[0] * update[i];
|
variable[i] = variable[i] - trust_ratio[0] * lr[0] * update[i];
|
||||||
|
@ -76,25 +72,24 @@ __global__ void ApplyLambAfterNormKernel(const size_t size, T *variable, const T
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__global__ void ApplyLambAfterNormKernel(const size_t size, half *variable, const half *lr, const float *update,
|
__global__ void ApplyLambAfterNormKernel(const size_t size, half *variable, const float *lr, const float *update,
|
||||||
const float *trust_ratio) {
|
const float *trust_ratio) {
|
||||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||||
variable[i] = __float2half(__half2float(variable[i]) - trust_ratio[0] * __half2float(lr[0]) * update[i]);
|
variable[i] = __float2half(__half2float(variable[i]) - trust_ratio[0] * lr[0] * update[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void ApplyLambEraly(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
|
void ApplyLambEraly(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
|
||||||
const float *epsilon, const T *decay, const int32_t *global_step, const T *gradient,
|
const float *epsilon, const float *decay, const int32_t *global_step, const T *gradient,
|
||||||
const bool *decay_flag, float *update, float *var_float, float *grad_float, float *g_hat_var,
|
float *update, float *var_float, float *grad_float, float *g_hat_var, cudaStream_t cuda_stream) {
|
||||||
cudaStream_t cuda_stream) {
|
|
||||||
ApplyLambEralyKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, m, v, beta1, beta2, epsilon,
|
ApplyLambEralyKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, m, v, beta1, beta2, epsilon,
|
||||||
decay, global_step, gradient, decay_flag,
|
decay, global_step, gradient, update,
|
||||||
update, var_float, grad_float, g_hat_var);
|
var_float, grad_float, g_hat_var);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr, const float *update,
|
CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const float *lr, const float *update,
|
||||||
const float *trust_ratio, cudaStream_t cuda_stream) {
|
const float *trust_ratio, cudaStream_t cuda_stream) {
|
||||||
ApplyLambAfterNormKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, lr, update, trust_ratio);
|
ApplyLambAfterNormKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, lr, update, trust_ratio);
|
||||||
}
|
}
|
||||||
|
@ -102,21 +97,20 @@ CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr,
|
||||||
template CUDA_LIB_EXPORT void ApplyLambEraly<float>(const size_t size, float *variable, float *m, float *v,
|
template CUDA_LIB_EXPORT void ApplyLambEraly<float>(const size_t size, float *variable, float *m, float *v,
|
||||||
const float *beta1, const float *beta2, const float *epsilon,
|
const float *beta1, const float *beta2, const float *epsilon,
|
||||||
const float *decay, const int32_t *global_step,
|
const float *decay, const int32_t *global_step,
|
||||||
const float *gradient, const bool *decay_flag, float *update,
|
const float *gradient, float *update, float *w_square_ptr,
|
||||||
float *w_square_ptr, float *g_square_ptr, float *g_hat_square_ptr,
|
float *g_square_ptr, float *g_hat_square_ptr,
|
||||||
cudaStream_t cuda_stream);
|
cudaStream_t cuda_stream);
|
||||||
|
|
||||||
template CUDA_LIB_EXPORT void ApplyLambEraly<half>(const size_t size, half *variable, half *m, half *v,
|
template CUDA_LIB_EXPORT void ApplyLambEraly<half>(const size_t size, half *variable, half *m, half *v,
|
||||||
const float *beta1, const float *beta2, const float *epsilon,
|
const float *beta1, const float *beta2, const float *epsilon,
|
||||||
const half *decay, const int32_t *global_step, const half *gradient,
|
const float *decay, const int32_t *global_step, const half *gradient,
|
||||||
const bool *decay_flag, float *update, float *w_square_ptr,
|
float *update, float *w_square_ptr, float *g_square_ptr,
|
||||||
float *g_square_ptr, float *g_hat_square_ptr,
|
float *g_hat_square_ptr, cudaStream_t cuda_stream);
|
||||||
cudaStream_t cuda_stream);
|
|
||||||
|
|
||||||
template CUDA_LIB_EXPORT void ApplyLambLater<float>(const size_t size, float *variable, const float *lr,
|
template CUDA_LIB_EXPORT void ApplyLambLater<float>(const size_t size, float *variable, const float *lr,
|
||||||
const float *update, const float *trust_ratio,
|
const float *update, const float *trust_ratio,
|
||||||
cudaStream_t cuda_stream);
|
cudaStream_t cuda_stream);
|
||||||
|
|
||||||
template CUDA_LIB_EXPORT void ApplyLambLater<half>(const size_t size, half *variable, const half *lr,
|
template CUDA_LIB_EXPORT void ApplyLambLater<half>(const size_t size, half *variable, const float *lr,
|
||||||
const float *update, const float *trust_ratio,
|
const float *update, const float *trust_ratio,
|
||||||
cudaStream_t cuda_stream);
|
cudaStream_t cuda_stream);
|
||||||
|
|
|
@ -19,12 +19,12 @@
|
||||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||||
template <typename T>
|
template <typename T>
|
||||||
CUDA_LIB_EXPORT void ApplyLambEraly(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
|
CUDA_LIB_EXPORT void ApplyLambEraly(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
|
||||||
const float *epsilon, const T *decay, const int32_t *global_step, const T *gradient,
|
const float *epsilon, const float *decay, const int32_t *global_step,
|
||||||
const bool *decay_flag, float *update, float *var_float, float *grad_float,
|
const T *gradient, float *update, float *var_float, float *grad_float,
|
||||||
float *g_hat_var, cudaStream_t cuda_stream);
|
float *g_hat_var, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr, const float *update,
|
CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const float *lr, const float *update,
|
||||||
const float *trust_ratio, cudaStream_t cuda_stream);
|
const float *trust_ratio, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LAMB_IMPL_CUH_
|
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LAMB_IMPL_CUH_
|
||||||
|
|
|
@ -30,7 +30,6 @@ MS_REG_GPU_KERNEL_ONE(Lamb,
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
.AddInputAttr(kNumberTypeBool)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
LambGpuKernelMod, float)
|
LambGpuKernelMod, float)
|
||||||
MS_REG_GPU_KERNEL_ONE(Lamb,
|
MS_REG_GPU_KERNEL_ONE(Lamb,
|
||||||
|
@ -45,7 +44,6 @@ MS_REG_GPU_KERNEL_ONE(Lamb,
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
.AddInputAttr(kNumberTypeBool)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat16),
|
.AddOutputAttr(kNumberTypeFloat16),
|
||||||
LambGpuKernelMod, half)
|
LambGpuKernelMod, half)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
|
|
|
@ -28,7 +28,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
constexpr size_t INPUT_NUM = 11;
|
constexpr size_t INPUT_NUM = 10;
|
||||||
constexpr size_t kArgMaxDim = 7;
|
constexpr size_t kArgMaxDim = 7;
|
||||||
constexpr float ten = 10;
|
constexpr float ten = 10;
|
||||||
|
|
||||||
|
@ -43,7 +43,6 @@ constexpr size_t kEpsilonIndex = 6;
|
||||||
constexpr size_t kWeightDecayIndex = 7;
|
constexpr size_t kWeightDecayIndex = 7;
|
||||||
constexpr size_t kGlobalStepIndex = 8;
|
constexpr size_t kGlobalStepIndex = 8;
|
||||||
constexpr size_t kGradIndex = 9;
|
constexpr size_t kGradIndex = 9;
|
||||||
constexpr size_t kDecayFlagIndex = 10;
|
|
||||||
|
|
||||||
// workspaces param index
|
// workspaces param index
|
||||||
constexpr size_t kUpdateIndex = 0;
|
constexpr size_t kUpdateIndex = 0;
|
||||||
|
@ -72,21 +71,20 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
|
||||||
T *variable = GetDeviceAddress<T>(inputs, kVarIndex);
|
T *variable = GetDeviceAddress<T>(inputs, kVarIndex);
|
||||||
T *m = GetDeviceAddress<T>(inputs, kMIndex);
|
T *m = GetDeviceAddress<T>(inputs, kMIndex);
|
||||||
T *v = GetDeviceAddress<T>(inputs, kVIndex);
|
T *v = GetDeviceAddress<T>(inputs, kVIndex);
|
||||||
T *learning_rate = GetDeviceAddress<T>(inputs, kLearningRateIndex);
|
float *learning_rate = GetDeviceAddress<float>(inputs, kLearningRateIndex);
|
||||||
float *beta1 = GetDeviceAddress<float>(inputs, kBeta1Index);
|
float *beta1 = GetDeviceAddress<float>(inputs, kBeta1Index);
|
||||||
float *beta2 = GetDeviceAddress<float>(inputs, kBeta2Index);
|
float *beta2 = GetDeviceAddress<float>(inputs, kBeta2Index);
|
||||||
float *epsilon = GetDeviceAddress<float>(inputs, kEpsilonIndex);
|
float *epsilon = GetDeviceAddress<float>(inputs, kEpsilonIndex);
|
||||||
T *decay = GetDeviceAddress<T>(inputs, kWeightDecayIndex);
|
float *decay = GetDeviceAddress<float>(inputs, kWeightDecayIndex);
|
||||||
int32_t *global_step = GetDeviceAddress<int32_t>(inputs, kGlobalStepIndex);
|
int32_t *global_step = GetDeviceAddress<int32_t>(inputs, kGlobalStepIndex);
|
||||||
T *gradient = GetDeviceAddress<T>(inputs, kGradIndex);
|
T *gradient = GetDeviceAddress<T>(inputs, kGradIndex);
|
||||||
bool *decay_flag = GetDeviceAddress<bool>(inputs, kDecayFlagIndex);
|
|
||||||
float *update = GetDeviceAddress<float>(workspaces, kUpdateIndex);
|
float *update = GetDeviceAddress<float>(workspaces, kUpdateIndex);
|
||||||
float *var_float = GetDeviceAddress<float>(workspaces, kVarFloatIndex);
|
float *var_float = GetDeviceAddress<float>(workspaces, kVarFloatIndex);
|
||||||
float *grad_float = GetDeviceAddress<float>(workspaces, kGradFloatIndex);
|
float *grad_float = GetDeviceAddress<float>(workspaces, kGradFloatIndex);
|
||||||
float *g_hat_var = GetDeviceAddress<float>(workspaces, kGHatValIndex);
|
float *g_hat_var = GetDeviceAddress<float>(workspaces, kGHatValIndex);
|
||||||
|
|
||||||
ApplyLambEraly(inputs[0]->size / sizeof(T), variable, m, v, beta1, beta2, epsilon, decay, global_step, gradient,
|
ApplyLambEraly(inputs[0]->size / sizeof(T), variable, m, v, beta1, beta2, epsilon, decay, global_step, gradient,
|
||||||
decay_flag, update, var_float, grad_float, g_hat_var, reinterpret_cast<cudaStream_t>(stream_ptr));
|
update, var_float, grad_float, g_hat_var, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
|
||||||
float trust_ratio{0};
|
float trust_ratio{0};
|
||||||
CalcTrustRatio(workspaces, var_float, grad_float, g_hat_var, stream_ptr, &trust_ratio);
|
CalcTrustRatio(workspaces, var_float, grad_float, g_hat_var, stream_ptr, &trust_ratio);
|
||||||
|
@ -176,7 +174,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
|
||||||
input_size_list_.push_back(decay_size_);
|
input_size_list_.push_back(decay_size_);
|
||||||
input_size_list_.push_back(global_step_size_);
|
input_size_list_.push_back(global_step_size_);
|
||||||
input_size_list_.push_back(gradient_size_);
|
input_size_list_.push_back(gradient_size_);
|
||||||
input_size_list_.push_back(decay_flag_size_);
|
|
||||||
workspace_size_list_.push_back(update_size_);
|
workspace_size_list_.push_back(update_size_);
|
||||||
workspace_size_list_.push_back(var_float_size_);
|
workspace_size_list_.push_back(var_float_size_);
|
||||||
workspace_size_list_.push_back(grad_float_size_);
|
workspace_size_list_.push_back(grad_float_size_);
|
||||||
|
@ -253,7 +250,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
|
||||||
decay_size_ = sizeof(T);
|
decay_size_ = sizeof(T);
|
||||||
global_step_size_ = sizeof(int32_t);
|
global_step_size_ = sizeof(int32_t);
|
||||||
gradient_size_ = sizeof(T);
|
gradient_size_ = sizeof(T);
|
||||||
decay_flag_size_ = sizeof(bool);
|
|
||||||
update_size_ = sizeof(float);
|
update_size_ = sizeof(float);
|
||||||
var_float_size_ = sizeof(float);
|
var_float_size_ = sizeof(float);
|
||||||
grad_float_size_ = sizeof(float);
|
grad_float_size_ = sizeof(float);
|
||||||
|
@ -386,7 +382,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
|
||||||
size_t decay_size_{0};
|
size_t decay_size_{0};
|
||||||
size_t global_step_size_{0};
|
size_t global_step_size_{0};
|
||||||
size_t gradient_size_{0};
|
size_t gradient_size_{0};
|
||||||
size_t decay_flag_size_{0};
|
|
||||||
size_t update_size_{0};
|
size_t update_size_{0};
|
||||||
size_t var_float_size_{0};
|
size_t var_float_size_{0};
|
||||||
size_t grad_float_size_{0};
|
size_t grad_float_size_{0};
|
||||||
|
|
|
@ -35,15 +35,12 @@ TypePtr LambInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
|
||||||
auto decay_type = input_args[kInputIndex7]->BuildType();
|
auto decay_type = input_args[kInputIndex7]->BuildType();
|
||||||
auto global_step_type = input_args[kInputIndex8]->BuildType();
|
auto global_step_type = input_args[kInputIndex8]->BuildType();
|
||||||
auto grad_type = input_args[kInputIndex9]->BuildType();
|
auto grad_type = input_args[kInputIndex9]->BuildType();
|
||||||
auto decay_flag_type = input_args[kInputIndex10]->BuildType();
|
|
||||||
|
|
||||||
std::map<std::string, TypePtr> type_dict;
|
std::map<std::string, TypePtr> type_dict;
|
||||||
type_dict.emplace("var", var_type);
|
type_dict.emplace("var", var_type);
|
||||||
type_dict.emplace("m", m_type);
|
type_dict.emplace("m", m_type);
|
||||||
type_dict.emplace("v", v_type);
|
type_dict.emplace("v", v_type);
|
||||||
type_dict.emplace("grad", grad_type);
|
type_dict.emplace("grad", grad_type);
|
||||||
type_dict.emplace("lr", lr_type);
|
|
||||||
type_dict.emplace("decay", decay_type);
|
|
||||||
std::set<TypePtr> num_type = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
|
std::set<TypePtr> num_type = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
|
||||||
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
|
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(type_dict, num_type, prim_name);
|
(void)CheckAndConvertUtils::CheckTensorTypeSame(type_dict, num_type, prim_name);
|
||||||
|
@ -51,12 +48,12 @@ TypePtr LambInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
|
||||||
type_dict1.emplace("beta1", beta1_type);
|
type_dict1.emplace("beta1", beta1_type);
|
||||||
type_dict1.emplace("beta2", beta2_type);
|
type_dict1.emplace("beta2", beta2_type);
|
||||||
type_dict1.emplace("epsilon", epsilon_type);
|
type_dict1.emplace("epsilon", epsilon_type);
|
||||||
std::set<TypePtr> float_set = {kFloat16, kFloat32};
|
type_dict1.emplace("lr", lr_type);
|
||||||
|
type_dict1.emplace("decay", decay_type);
|
||||||
|
std::set<TypePtr> float_set = {kFloat32};
|
||||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(type_dict1, float_set, prim_name, true);
|
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(type_dict1, float_set, prim_name, true);
|
||||||
|
|
||||||
std::set<TypePtr> bool_set = {kBool};
|
|
||||||
(void)CheckAndConvertUtils::CheckTypeValid("global_step", global_step_type, num_type, prim_name);
|
(void)CheckAndConvertUtils::CheckTypeValid("global_step", global_step_type, num_type, prim_name);
|
||||||
(void)CheckAndConvertUtils::CheckTypeValid("decay_flag", decay_flag_type, bool_set, prim_name);
|
|
||||||
|
|
||||||
return var_type;
|
return var_type;
|
||||||
}
|
}
|
||||||
|
@ -97,7 +94,7 @@ AbstractBasePtr LambInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
||||||
for (auto item : input_args) {
|
for (auto item : input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
}
|
}
|
||||||
const int64_t kInputNum = 11;
|
const int64_t kInputNum = 10;
|
||||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
|
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
|
||||||
auto infer_type = LambInferType(primitive, input_args);
|
auto infer_type = LambInferType(primitive, input_args);
|
||||||
auto infer_shape = LambInferShape(primitive, input_args);
|
auto infer_shape = LambInferShape(primitive, input_args);
|
||||||
|
|
|
@ -15,8 +15,6 @@
|
||||||
"""lamb"""
|
"""lamb"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.common import dtype as mstype
|
|
||||||
from mindspore.ops import operations as P
|
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops.operations import _inner_ops as inner
|
from mindspore.ops.operations import _inner_ops as inner
|
||||||
|
@ -26,11 +24,6 @@ from mindspore._checkparam import Rel
|
||||||
from .optimizer import Optimizer
|
from .optimizer import Optimizer
|
||||||
from .optimizer import opt_init_args_register
|
from .optimizer import opt_init_args_register
|
||||||
|
|
||||||
from .. import layer
|
|
||||||
|
|
||||||
|
|
||||||
num_one = Tensor(np.ones([1]), mstype.float32)
|
|
||||||
|
|
||||||
_lamb_opt = C.MultitypeFuncGraph("lamb_opt")
|
_lamb_opt = C.MultitypeFuncGraph("lamb_opt")
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,89 +33,6 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v
|
||||||
"""
|
"""
|
||||||
Update parameters.
|
Update parameters.
|
||||||
|
|
||||||
Args:
|
|
||||||
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
|
|
||||||
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
|
|
||||||
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
|
|
||||||
lr (Tensor): Learning rate.
|
|
||||||
weight_decay (numbers.Number): Weight decay. Should be equal to or greater than 0.
|
|
||||||
global_step (Tensor): Global step.
|
|
||||||
param (Tensor): Parameters.
|
|
||||||
m (Tensor): m value of parameters.
|
|
||||||
v (Tensor): v value of parameters.
|
|
||||||
gradient (Tensor): Gradient of parameters.
|
|
||||||
decay_flag (bool): Specifies whether param update with weight decay.
|
|
||||||
optim_filter(bool): Applies parameter update or not.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor, the new value of v after updating.
|
|
||||||
"""
|
|
||||||
if optim_filter:
|
|
||||||
op_mul = P.Mul()
|
|
||||||
op_sqrt = P.Sqrt()
|
|
||||||
op_rsqrt = P.Rsqrt()
|
|
||||||
op_square = P.Square()
|
|
||||||
op_cast = P.Cast()
|
|
||||||
op_reshape = P.Reshape()
|
|
||||||
op_shape = P.Shape()
|
|
||||||
op_pow = P.Pow()
|
|
||||||
op_norm = layer.Norm()
|
|
||||||
op_select = P.Select()
|
|
||||||
op_greater = P.Greater()
|
|
||||||
op_fill = P.Fill()
|
|
||||||
op_dtype = P.DType()
|
|
||||||
|
|
||||||
param_fp32 = op_cast(param, mstype.float32)
|
|
||||||
m_fp32 = op_cast(m, mstype.float32)
|
|
||||||
v_fp32 = op_cast(v, mstype.float32)
|
|
||||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
|
||||||
|
|
||||||
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32)
|
|
||||||
|
|
||||||
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))
|
|
||||||
|
|
||||||
next_mm = next_m / (op_cast(num_one, mstype.float32)
|
|
||||||
- op_pow(beta1, op_cast(global_step, mstype.float32)))
|
|
||||||
next_vv = next_v / (op_cast(num_one, mstype.float32) -
|
|
||||||
op_pow(beta2, op_cast(global_step, mstype.float32)))
|
|
||||||
w_norm = op_norm(param_fp32)
|
|
||||||
g_norm = op_norm(gradient_fp32)
|
|
||||||
|
|
||||||
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay * param_fp32)
|
|
||||||
zeros = F.zeros_like(w_norm)
|
|
||||||
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
|
|
||||||
trust_ratio = op_select(
|
|
||||||
op_greater(w_norm, zeros),
|
|
||||||
op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
|
|
||||||
ones)
|
|
||||||
tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
|
|
||||||
trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
|
|
||||||
update = next_mm / (op_sqrt(next_vv) + eps)
|
|
||||||
|
|
||||||
if decay_flag:
|
|
||||||
update = update + op_mul(weight_decay, param_fp32)
|
|
||||||
|
|
||||||
update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
|
|
||||||
|
|
||||||
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
|
||||||
|
|
||||||
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
|
|
||||||
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
|
|
||||||
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
|
|
||||||
|
|
||||||
return op_cast(next_param, F.dtype(param))
|
|
||||||
return gradient
|
|
||||||
|
|
||||||
_lamb_opt_ascend = C.MultitypeFuncGraph("lamb_opt_ascend")
|
|
||||||
|
|
||||||
|
|
||||||
@_lamb_opt_ascend.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
|
||||||
"Tensor", "Bool", "Bool")
|
|
||||||
def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag,
|
|
||||||
optim_filter):
|
|
||||||
"""
|
|
||||||
Update parameters function when device target is ascend.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
|
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
|
||||||
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
|
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
|
||||||
|
@ -142,9 +52,13 @@ def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, para
|
||||||
"""
|
"""
|
||||||
if optim_filter:
|
if optim_filter:
|
||||||
op_lamb = inner.Lamb()
|
op_lamb = inner.Lamb()
|
||||||
return op_lamb(param, m, v, lr, beta1, beta2, eps, weight_decay, global_step,
|
if decay_flag:
|
||||||
gradient, decay_flag)
|
ret = op_lamb(param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, gradient)
|
||||||
return gradient
|
else:
|
||||||
|
ret = op_lamb(param, m, v, lr, beta1, beta2, eps, 0.0, global_step, gradient)
|
||||||
|
else:
|
||||||
|
ret = gradient
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def _check_param_value(beta1, beta2, eps, prim_name):
|
def _check_param_value(beta1, beta2, eps, prim_name):
|
||||||
|
@ -345,7 +259,7 @@ class Lamb(Optimizer):
|
||||||
lr = self.get_lr()
|
lr = self.get_lr()
|
||||||
if not self.is_dynamic_lr_or_weight_decay():
|
if not self.is_dynamic_lr_or_weight_decay():
|
||||||
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
||||||
lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt
|
lamb_opt = _lamb_opt
|
||||||
gradients = self.flatten_gradients(gradients)
|
gradients = self.flatten_gradients(gradients)
|
||||||
gradients = self.gradients_centralization(gradients)
|
gradients = self.gradients_centralization(gradients)
|
||||||
if self.is_group:
|
if self.is_group:
|
||||||
|
|
|
@ -300,7 +300,6 @@ class Lamb(PrimitiveWithInfer):
|
||||||
Default: 0.0.
|
Default: 0.0.
|
||||||
- **global_step** (Tensor) - Tensor to record current global step.
|
- **global_step** (Tensor) - Tensor to record current global step.
|
||||||
- **gradient** (Tensor) - Gradient, has the same shape and data type as `var`.
|
- **gradient** (Tensor) - Gradient, has the same shape and data type as `var`.
|
||||||
- **decay_flag** (bool) - Specifies whether param update with weight decay.
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor, the updated parameters.
|
Tensor, the updated parameters.
|
||||||
- **var** (Tensor) - The same shape and data type as `var`.
|
- **var** (Tensor) - The same shape and data type as `var`.
|
||||||
|
@ -315,19 +314,19 @@ class Lamb(PrimitiveWithInfer):
|
||||||
self.add_prim_attr('side_effect_mem', True)
|
self.add_prim_attr('side_effect_mem', True)
|
||||||
|
|
||||||
def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape,
|
def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape,
|
||||||
epsilon_shape, decay_shape, global_step_shape, gradient_shape, decay_flag_shape):
|
epsilon_shape, decay_shape, global_step_shape, gradient_shape):
|
||||||
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
|
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
|
||||||
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
|
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
|
||||||
validator.check("var_shape", var_shape, "gradient_shape", gradient_shape, Rel.EQ, self.name)
|
validator.check("var_shape", var_shape, "gradient_shape", gradient_shape, Rel.EQ, self.name)
|
||||||
return var_shape
|
return var_shape
|
||||||
|
|
||||||
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
|
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
|
||||||
epsilon_dtype, decay_dtype, global_step_dtype, gradient_dtype, decay_flag_dtype):
|
epsilon_dtype, decay_dtype, global_step_dtype, gradient_dtype):
|
||||||
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "lr": lr_dtype, "grad": gradient_dtype,
|
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": gradient_dtype}
|
||||||
"decay": decay_dtype}
|
|
||||||
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
|
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
|
||||||
|
|
||||||
args = {"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
|
args = {"lr": lr_dtype, "decay": decay_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype,
|
||||||
|
"epsilon": epsilon_dtype}
|
||||||
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
|
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
|
||||||
return var_dtype
|
return var_dtype
|
||||||
|
|
||||||
|
|
|
@ -142,9 +142,8 @@ class MyLamb(nn.Cell):
|
||||||
self.gradient = Parameter(gradient, name="grad")
|
self.gradient = Parameter(gradient, name="grad")
|
||||||
self.lamb = inner.Lamb()
|
self.lamb = inner.Lamb()
|
||||||
|
|
||||||
def construct(self, beta1, beta2, eps, global_step, lr, weight_decay, decay_flag):
|
def construct(self, beta1, beta2, eps, global_step, lr, weight_decay):
|
||||||
return self.lamb(self.param, self.m, self.v, lr, beta1, beta2, eps, weight_decay, global_step, self.gradient,
|
return self.lamb(self.param, self.m, self.v, lr, beta1, beta2, eps, weight_decay, global_step, self.gradient)
|
||||||
decay_flag)
|
|
||||||
|
|
||||||
|
|
||||||
def test_gpu_net():
|
def test_gpu_net():
|
||||||
|
@ -154,7 +153,7 @@ def test_gpu_net():
|
||||||
Expectation: get the same result when use new lamb kernel and old kernel
|
Expectation: get the same result when use new lamb kernel and old kernel
|
||||||
"""
|
"""
|
||||||
my_lamb = MyLamb(param_val, m_val, v_val, grad_val)
|
my_lamb = MyLamb(param_val, m_val, v_val, grad_val)
|
||||||
my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)
|
my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val)
|
||||||
|
|
||||||
lamb_gpu_origin = LambGPUOrigin(param_val, m_val, v_val, grad_val)
|
lamb_gpu_origin = LambGPUOrigin(param_val, m_val, v_val, grad_val)
|
||||||
lamb_gpu_origin(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)
|
lamb_gpu_origin(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)
|
||||||
|
@ -169,7 +168,7 @@ def test_ascend_net():
|
||||||
Expectation: get the same result when use new lamb kernel and old kernel
|
Expectation: get the same result when use new lamb kernel and old kernel
|
||||||
"""
|
"""
|
||||||
my_lamb = MyLamb(param_val, m_val, v_val, grad_val)
|
my_lamb = MyLamb(param_val, m_val, v_val, grad_val)
|
||||||
my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)
|
my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val)
|
||||||
|
|
||||||
lamb_ascend_origin = LambAscendOrigin(param_val, m_val, v_val, grad_val)
|
lamb_ascend_origin = LambAscendOrigin(param_val, m_val, v_val, grad_val)
|
||||||
lamb_ascend_origin(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)
|
lamb_ascend_origin(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)
|
||||||
|
|
Loading…
Reference in New Issue