optimize lamb

This commit is contained in:
wangchangheng 2022-04-19 16:48:56 +08:00
parent 8167c4dbe4
commit ab78aa86ee
9 changed files with 56 additions and 161 deletions

View File

@ -24,7 +24,7 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
// Lamb's inputs: param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, gradient, decay_flag // Lamb's inputs: param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, gradient
constexpr size_t kParamIndex = 1; constexpr size_t kParamIndex = 1;
constexpr size_t kMIndex = 2; constexpr size_t kMIndex = 2;
constexpr size_t kVIndex = 3; constexpr size_t kVIndex = 3;
@ -35,10 +35,9 @@ constexpr size_t kEpsilonIndex = 7;
constexpr size_t kWeightDecayIndex = 8; constexpr size_t kWeightDecayIndex = 8;
constexpr size_t kGlobalStepIndex = 9; constexpr size_t kGlobalStepIndex = 9;
constexpr size_t kGradientIndex = 10; constexpr size_t kGradientIndex = 10;
constexpr size_t kDecayFlagIndex = 11; constexpr size_t kUMonadIndex = 11;
constexpr size_t kUMonadIndex = 12; constexpr size_t kLambInputNum = 10;
constexpr size_t kLambInputNum = 11; constexpr size_t kLambInputNumWithUMonad = 11;
constexpr size_t kLambInputNumWithUMonad = 12;
constexpr size_t kLambApplyOptimizerAssignOutputNum = 3; constexpr size_t kLambApplyOptimizerAssignOutputNum = 3;
constexpr size_t kLambApplyOptimizerAssignUpdateIndex = 0; constexpr size_t kLambApplyOptimizerAssignUpdateIndex = 0;
@ -246,7 +245,7 @@ const AnfNodePtr LambFission::Process(const FuncGraphPtr &graph, const AnfNodePt
auto new_global_step = CreateCastNode(graph, global_step_node, kNumberTypeFloat32); auto new_global_step = CreateCastNode(graph, global_step_node, kNumberTypeFloat32);
// cast delay flag to float32 // cast delay flag to float32
auto weight_decay_flag = CreateCastNode(graph, ori_inputs[kDecayFlagIndex], kNumberTypeFloat32); auto weight_decay_flag = CreateValueNode(graph, 1.0);
auto num_one = CreateValueNode(graph, 1.0); auto num_one = CreateValueNode(graph, 1.0);
// create 1-beta1 // create 1-beta1

View File

@ -21,9 +21,9 @@ const int32_t kSqareNum = 2;
template <typename T> template <typename T>
__global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2, __global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
const float *epsilon, const T *decay, const int32_t *global_step, const float *epsilon, const float *decay, const int32_t *global_step,
const T *gradient, const bool *decay_flag, float *update, float *var_float, const T *gradient, float *update, float *var_float, float *grad_float,
float *grad_float, float *g_hat_var) { float *g_hat_var) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
float next_m = (beta1[0] * m[i] + (1 - beta1[0]) * gradient[i]); float next_m = (beta1[0] * m[i] + (1 - beta1[0]) * gradient[i]);
float next_v = (beta2[0] * v[i] + (1 - beta2[0]) * pow(gradient[i], kSqareNum)); float next_v = (beta2[0] * v[i] + (1 - beta2[0]) * pow(gradient[i], kSqareNum));
@ -33,9 +33,7 @@ __global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v,
grad_float[i] = gradient[i]; grad_float[i] = gradient[i];
g_hat_var[i] = (next_mm / sqrt(next_vv + epsilon[0]) + decay[0] * variable[i]); g_hat_var[i] = (next_mm / sqrt(next_vv + epsilon[0]) + decay[0] * variable[i]);
update[i] = next_mm / (sqrt(next_vv) - epsilon[0]); update[i] = next_mm / (sqrt(next_vv) - epsilon[0]);
if (decay_flag[0]) { update[i] += decay[0] * variable[i];
update[i] += decay[0] * variable[i];
}
m[i] = next_m; m[i] = next_m;
v[i] = next_v; v[i] = next_v;
} }
@ -43,13 +41,13 @@ __global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v,
template <> template <>
__global__ void ApplyLambEralyKernel(const size_t size, half *variable, half *m, half *v, const float *beta1, __global__ void ApplyLambEralyKernel(const size_t size, half *variable, half *m, half *v, const float *beta1,
const float *beta2, const float *epsilon, const half *decay, const float *beta2, const float *epsilon, const float *decay,
const int32_t *global_step, const half *gradient, const bool *decay_flag, const int32_t *global_step, const half *gradient, float *update, float *var_float,
float *update, float *var_float, float *grad_float, float *g_hat_var) { float *grad_float, float *g_hat_var) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
float float_gradient = __half2float(gradient[i]); float float_gradient = __half2float(gradient[i]);
float float_var = __half2float(variable[i]); float float_var = __half2float(variable[i]);
float float_decay = __half2float(decay[0]); float float_decay = decay[0];
float next_m = (beta1[0] * __half2float(m[i]) + (1 - beta1[0]) * float_gradient); float next_m = (beta1[0] * __half2float(m[i]) + (1 - beta1[0]) * float_gradient);
float next_v = (beta2[0] * __half2float(v[i]) + (1 - beta2[0]) * pow(float_gradient, kSqareNum)); float next_v = (beta2[0] * __half2float(v[i]) + (1 - beta2[0]) * pow(float_gradient, kSqareNum));
@ -59,16 +57,14 @@ __global__ void ApplyLambEralyKernel(const size_t size, half *variable, half *m,
grad_float[i] = float_gradient; grad_float[i] = float_gradient;
g_hat_var[i] = next_mm / sqrt(next_vv + epsilon[0]) + float_decay * float_var; g_hat_var[i] = next_mm / sqrt(next_vv + epsilon[0]) + float_decay * float_var;
update[i] = next_mm / (sqrt(next_vv) - epsilon[0]); update[i] = next_mm / (sqrt(next_vv) - epsilon[0]);
if (decay_flag[0]) { update[i] += float_decay * float_var;
update[i] += float_decay * float_var;
}
m[i] = __float2half(next_m); m[i] = __float2half(next_m);
v[i] = __float2half(next_v); v[i] = __float2half(next_v);
} }
} }
template <typename T> template <typename T>
__global__ void ApplyLambAfterNormKernel(const size_t size, T *variable, const T *lr, const float *update, __global__ void ApplyLambAfterNormKernel(const size_t size, T *variable, const float *lr, const float *update,
const float *trust_ratio) { const float *trust_ratio) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
variable[i] = variable[i] - trust_ratio[0] * lr[0] * update[i]; variable[i] = variable[i] - trust_ratio[0] * lr[0] * update[i];
@ -76,25 +72,24 @@ __global__ void ApplyLambAfterNormKernel(const size_t size, T *variable, const T
} }
template <> template <>
__global__ void ApplyLambAfterNormKernel(const size_t size, half *variable, const half *lr, const float *update, __global__ void ApplyLambAfterNormKernel(const size_t size, half *variable, const float *lr, const float *update,
const float *trust_ratio) { const float *trust_ratio) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
variable[i] = __float2half(__half2float(variable[i]) - trust_ratio[0] * __half2float(lr[0]) * update[i]); variable[i] = __float2half(__half2float(variable[i]) - trust_ratio[0] * lr[0] * update[i]);
} }
} }
template <typename T> template <typename T>
void ApplyLambEraly(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2, void ApplyLambEraly(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
const float *epsilon, const T *decay, const int32_t *global_step, const T *gradient, const float *epsilon, const float *decay, const int32_t *global_step, const T *gradient,
const bool *decay_flag, float *update, float *var_float, float *grad_float, float *g_hat_var, float *update, float *var_float, float *grad_float, float *g_hat_var, cudaStream_t cuda_stream) {
cudaStream_t cuda_stream) {
ApplyLambEralyKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, m, v, beta1, beta2, epsilon, ApplyLambEralyKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, m, v, beta1, beta2, epsilon,
decay, global_step, gradient, decay_flag, decay, global_step, gradient, update,
update, var_float, grad_float, g_hat_var); var_float, grad_float, g_hat_var);
} }
template <typename T> template <typename T>
CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr, const float *update, CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const float *lr, const float *update,
const float *trust_ratio, cudaStream_t cuda_stream) { const float *trust_ratio, cudaStream_t cuda_stream) {
ApplyLambAfterNormKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, lr, update, trust_ratio); ApplyLambAfterNormKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, lr, update, trust_ratio);
} }
@ -102,21 +97,20 @@ CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr,
template CUDA_LIB_EXPORT void ApplyLambEraly<float>(const size_t size, float *variable, float *m, float *v, template CUDA_LIB_EXPORT void ApplyLambEraly<float>(const size_t size, float *variable, float *m, float *v,
const float *beta1, const float *beta2, const float *epsilon, const float *beta1, const float *beta2, const float *epsilon,
const float *decay, const int32_t *global_step, const float *decay, const int32_t *global_step,
const float *gradient, const bool *decay_flag, float *update, const float *gradient, float *update, float *w_square_ptr,
float *w_square_ptr, float *g_square_ptr, float *g_hat_square_ptr, float *g_square_ptr, float *g_hat_square_ptr,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyLambEraly<half>(const size_t size, half *variable, half *m, half *v, template CUDA_LIB_EXPORT void ApplyLambEraly<half>(const size_t size, half *variable, half *m, half *v,
const float *beta1, const float *beta2, const float *epsilon, const float *beta1, const float *beta2, const float *epsilon,
const half *decay, const int32_t *global_step, const half *gradient, const float *decay, const int32_t *global_step, const half *gradient,
const bool *decay_flag, float *update, float *w_square_ptr, float *update, float *w_square_ptr, float *g_square_ptr,
float *g_square_ptr, float *g_hat_square_ptr, float *g_hat_square_ptr, cudaStream_t cuda_stream);
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyLambLater<float>(const size_t size, float *variable, const float *lr, template CUDA_LIB_EXPORT void ApplyLambLater<float>(const size_t size, float *variable, const float *lr,
const float *update, const float *trust_ratio, const float *update, const float *trust_ratio,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyLambLater<half>(const size_t size, half *variable, const half *lr, template CUDA_LIB_EXPORT void ApplyLambLater<half>(const size_t size, half *variable, const float *lr,
const float *update, const float *trust_ratio, const float *update, const float *trust_ratio,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);

View File

@ -19,12 +19,12 @@
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
template <typename T> template <typename T>
CUDA_LIB_EXPORT void ApplyLambEraly(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2, CUDA_LIB_EXPORT void ApplyLambEraly(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
const float *epsilon, const T *decay, const int32_t *global_step, const T *gradient, const float *epsilon, const float *decay, const int32_t *global_step,
const bool *decay_flag, float *update, float *var_float, float *grad_float, const T *gradient, float *update, float *var_float, float *grad_float,
float *g_hat_var, cudaStream_t cuda_stream); float *g_hat_var, cudaStream_t cuda_stream);
template <typename T> template <typename T>
CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr, const float *update, CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const float *lr, const float *update,
const float *trust_ratio, cudaStream_t cuda_stream); const float *trust_ratio, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LAMB_IMPL_CUH_ #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LAMB_IMPL_CUH_

View File

@ -30,7 +30,6 @@ MS_REG_GPU_KERNEL_ONE(Lamb,
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
LambGpuKernelMod, float) LambGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(Lamb, MS_REG_GPU_KERNEL_ONE(Lamb,
@ -45,7 +44,6 @@ MS_REG_GPU_KERNEL_ONE(Lamb,
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeFloat16), .AddOutputAttr(kNumberTypeFloat16),
LambGpuKernelMod, half) LambGpuKernelMod, half)
} // namespace kernel } // namespace kernel

View File

@ -28,7 +28,7 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
constexpr size_t INPUT_NUM = 11; constexpr size_t INPUT_NUM = 10;
constexpr size_t kArgMaxDim = 7; constexpr size_t kArgMaxDim = 7;
constexpr float ten = 10; constexpr float ten = 10;
@ -43,7 +43,6 @@ constexpr size_t kEpsilonIndex = 6;
constexpr size_t kWeightDecayIndex = 7; constexpr size_t kWeightDecayIndex = 7;
constexpr size_t kGlobalStepIndex = 8; constexpr size_t kGlobalStepIndex = 8;
constexpr size_t kGradIndex = 9; constexpr size_t kGradIndex = 9;
constexpr size_t kDecayFlagIndex = 10;
// workspaces param index // workspaces param index
constexpr size_t kUpdateIndex = 0; constexpr size_t kUpdateIndex = 0;
@ -72,21 +71,20 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
T *variable = GetDeviceAddress<T>(inputs, kVarIndex); T *variable = GetDeviceAddress<T>(inputs, kVarIndex);
T *m = GetDeviceAddress<T>(inputs, kMIndex); T *m = GetDeviceAddress<T>(inputs, kMIndex);
T *v = GetDeviceAddress<T>(inputs, kVIndex); T *v = GetDeviceAddress<T>(inputs, kVIndex);
T *learning_rate = GetDeviceAddress<T>(inputs, kLearningRateIndex); float *learning_rate = GetDeviceAddress<float>(inputs, kLearningRateIndex);
float *beta1 = GetDeviceAddress<float>(inputs, kBeta1Index); float *beta1 = GetDeviceAddress<float>(inputs, kBeta1Index);
float *beta2 = GetDeviceAddress<float>(inputs, kBeta2Index); float *beta2 = GetDeviceAddress<float>(inputs, kBeta2Index);
float *epsilon = GetDeviceAddress<float>(inputs, kEpsilonIndex); float *epsilon = GetDeviceAddress<float>(inputs, kEpsilonIndex);
T *decay = GetDeviceAddress<T>(inputs, kWeightDecayIndex); float *decay = GetDeviceAddress<float>(inputs, kWeightDecayIndex);
int32_t *global_step = GetDeviceAddress<int32_t>(inputs, kGlobalStepIndex); int32_t *global_step = GetDeviceAddress<int32_t>(inputs, kGlobalStepIndex);
T *gradient = GetDeviceAddress<T>(inputs, kGradIndex); T *gradient = GetDeviceAddress<T>(inputs, kGradIndex);
bool *decay_flag = GetDeviceAddress<bool>(inputs, kDecayFlagIndex);
float *update = GetDeviceAddress<float>(workspaces, kUpdateIndex); float *update = GetDeviceAddress<float>(workspaces, kUpdateIndex);
float *var_float = GetDeviceAddress<float>(workspaces, kVarFloatIndex); float *var_float = GetDeviceAddress<float>(workspaces, kVarFloatIndex);
float *grad_float = GetDeviceAddress<float>(workspaces, kGradFloatIndex); float *grad_float = GetDeviceAddress<float>(workspaces, kGradFloatIndex);
float *g_hat_var = GetDeviceAddress<float>(workspaces, kGHatValIndex); float *g_hat_var = GetDeviceAddress<float>(workspaces, kGHatValIndex);
ApplyLambEraly(inputs[0]->size / sizeof(T), variable, m, v, beta1, beta2, epsilon, decay, global_step, gradient, ApplyLambEraly(inputs[0]->size / sizeof(T), variable, m, v, beta1, beta2, epsilon, decay, global_step, gradient,
decay_flag, update, var_float, grad_float, g_hat_var, reinterpret_cast<cudaStream_t>(stream_ptr)); update, var_float, grad_float, g_hat_var, reinterpret_cast<cudaStream_t>(stream_ptr));
float trust_ratio{0}; float trust_ratio{0};
CalcTrustRatio(workspaces, var_float, grad_float, g_hat_var, stream_ptr, &trust_ratio); CalcTrustRatio(workspaces, var_float, grad_float, g_hat_var, stream_ptr, &trust_ratio);
@ -176,7 +174,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
input_size_list_.push_back(decay_size_); input_size_list_.push_back(decay_size_);
input_size_list_.push_back(global_step_size_); input_size_list_.push_back(global_step_size_);
input_size_list_.push_back(gradient_size_); input_size_list_.push_back(gradient_size_);
input_size_list_.push_back(decay_flag_size_);
workspace_size_list_.push_back(update_size_); workspace_size_list_.push_back(update_size_);
workspace_size_list_.push_back(var_float_size_); workspace_size_list_.push_back(var_float_size_);
workspace_size_list_.push_back(grad_float_size_); workspace_size_list_.push_back(grad_float_size_);
@ -253,7 +250,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
decay_size_ = sizeof(T); decay_size_ = sizeof(T);
global_step_size_ = sizeof(int32_t); global_step_size_ = sizeof(int32_t);
gradient_size_ = sizeof(T); gradient_size_ = sizeof(T);
decay_flag_size_ = sizeof(bool);
update_size_ = sizeof(float); update_size_ = sizeof(float);
var_float_size_ = sizeof(float); var_float_size_ = sizeof(float);
grad_float_size_ = sizeof(float); grad_float_size_ = sizeof(float);
@ -386,7 +382,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
size_t decay_size_{0}; size_t decay_size_{0};
size_t global_step_size_{0}; size_t global_step_size_{0};
size_t gradient_size_{0}; size_t gradient_size_{0};
size_t decay_flag_size_{0};
size_t update_size_{0}; size_t update_size_{0};
size_t var_float_size_{0}; size_t var_float_size_{0};
size_t grad_float_size_{0}; size_t grad_float_size_{0};

View File

@ -35,15 +35,12 @@ TypePtr LambInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
auto decay_type = input_args[kInputIndex7]->BuildType(); auto decay_type = input_args[kInputIndex7]->BuildType();
auto global_step_type = input_args[kInputIndex8]->BuildType(); auto global_step_type = input_args[kInputIndex8]->BuildType();
auto grad_type = input_args[kInputIndex9]->BuildType(); auto grad_type = input_args[kInputIndex9]->BuildType();
auto decay_flag_type = input_args[kInputIndex10]->BuildType();
std::map<std::string, TypePtr> type_dict; std::map<std::string, TypePtr> type_dict;
type_dict.emplace("var", var_type); type_dict.emplace("var", var_type);
type_dict.emplace("m", m_type); type_dict.emplace("m", m_type);
type_dict.emplace("v", v_type); type_dict.emplace("v", v_type);
type_dict.emplace("grad", grad_type); type_dict.emplace("grad", grad_type);
type_dict.emplace("lr", lr_type);
type_dict.emplace("decay", decay_type);
std::set<TypePtr> num_type = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, std::set<TypePtr> num_type = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128}; kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
(void)CheckAndConvertUtils::CheckTensorTypeSame(type_dict, num_type, prim_name); (void)CheckAndConvertUtils::CheckTensorTypeSame(type_dict, num_type, prim_name);
@ -51,12 +48,12 @@ TypePtr LambInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
type_dict1.emplace("beta1", beta1_type); type_dict1.emplace("beta1", beta1_type);
type_dict1.emplace("beta2", beta2_type); type_dict1.emplace("beta2", beta2_type);
type_dict1.emplace("epsilon", epsilon_type); type_dict1.emplace("epsilon", epsilon_type);
std::set<TypePtr> float_set = {kFloat16, kFloat32}; type_dict1.emplace("lr", lr_type);
type_dict1.emplace("decay", decay_type);
std::set<TypePtr> float_set = {kFloat32};
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(type_dict1, float_set, prim_name, true); (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(type_dict1, float_set, prim_name, true);
std::set<TypePtr> bool_set = {kBool};
(void)CheckAndConvertUtils::CheckTypeValid("global_step", global_step_type, num_type, prim_name); (void)CheckAndConvertUtils::CheckTypeValid("global_step", global_step_type, num_type, prim_name);
(void)CheckAndConvertUtils::CheckTypeValid("decay_flag", decay_flag_type, bool_set, prim_name);
return var_type; return var_type;
} }
@ -97,7 +94,7 @@ AbstractBasePtr LambInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
for (auto item : input_args) { for (auto item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
const int64_t kInputNum = 11; const int64_t kInputNum = 10;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name); CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
auto infer_type = LambInferType(primitive, input_args); auto infer_type = LambInferType(primitive, input_args);
auto infer_shape = LambInferShape(primitive, input_args); auto infer_shape = LambInferShape(primitive, input_args);

View File

@ -15,8 +15,6 @@
"""lamb""" """lamb"""
import numpy as np import numpy as np
from mindspore import context from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops.operations import _inner_ops as inner from mindspore.ops.operations import _inner_ops as inner
@ -26,11 +24,6 @@ from mindspore._checkparam import Rel
from .optimizer import Optimizer from .optimizer import Optimizer
from .optimizer import opt_init_args_register from .optimizer import opt_init_args_register
from .. import layer
num_one = Tensor(np.ones([1]), mstype.float32)
_lamb_opt = C.MultitypeFuncGraph("lamb_opt") _lamb_opt = C.MultitypeFuncGraph("lamb_opt")
@ -40,89 +33,6 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v
""" """
Update parameters. Update parameters.
Args:
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
lr (Tensor): Learning rate.
weight_decay (numbers.Number): Weight decay. Should be equal to or greater than 0.
global_step (Tensor): Global step.
param (Tensor): Parameters.
m (Tensor): m value of parameters.
v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters.
decay_flag (bool): Specifies whether param update with weight decay.
optim_filter(bool): Applies parameter update or not.
Returns:
Tensor, the new value of v after updating.
"""
if optim_filter:
op_mul = P.Mul()
op_sqrt = P.Sqrt()
op_rsqrt = P.Rsqrt()
op_square = P.Square()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
op_pow = P.Pow()
op_norm = layer.Norm()
op_select = P.Select()
op_greater = P.Greater()
op_fill = P.Fill()
op_dtype = P.DType()
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32)
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))
next_mm = next_m / (op_cast(num_one, mstype.float32)
- op_pow(beta1, op_cast(global_step, mstype.float32)))
next_vv = next_v / (op_cast(num_one, mstype.float32) -
op_pow(beta2, op_cast(global_step, mstype.float32)))
w_norm = op_norm(param_fp32)
g_norm = op_norm(gradient_fp32)
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay * param_fp32)
zeros = F.zeros_like(w_norm)
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
trust_ratio = op_select(
op_greater(w_norm, zeros),
op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
ones)
tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
update = next_mm / (op_sqrt(next_vv) + eps)
if decay_flag:
update = update + op_mul(weight_decay, param_fp32)
update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
return op_cast(next_param, F.dtype(param))
return gradient
_lamb_opt_ascend = C.MultitypeFuncGraph("lamb_opt_ascend")
@_lamb_opt_ascend.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag,
optim_filter):
"""
Update parameters function when device target is ascend.
Args: Args:
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
@ -142,9 +52,13 @@ def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, para
""" """
if optim_filter: if optim_filter:
op_lamb = inner.Lamb() op_lamb = inner.Lamb()
return op_lamb(param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, if decay_flag:
gradient, decay_flag) ret = op_lamb(param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, gradient)
return gradient else:
ret = op_lamb(param, m, v, lr, beta1, beta2, eps, 0.0, global_step, gradient)
else:
ret = gradient
return ret
def _check_param_value(beta1, beta2, eps, prim_name): def _check_param_value(beta1, beta2, eps, prim_name):
@ -345,7 +259,7 @@ class Lamb(Optimizer):
lr = self.get_lr() lr = self.get_lr()
if not self.is_dynamic_lr_or_weight_decay(): if not self.is_dynamic_lr_or_weight_decay():
self.assignadd(self.global_step, self.global_step_increase_tensor) self.assignadd(self.global_step, self.global_step_increase_tensor)
lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt lamb_opt = _lamb_opt
gradients = self.flatten_gradients(gradients) gradients = self.flatten_gradients(gradients)
gradients = self.gradients_centralization(gradients) gradients = self.gradients_centralization(gradients)
if self.is_group: if self.is_group:

View File

@ -300,7 +300,6 @@ class Lamb(PrimitiveWithInfer):
Default: 0.0. Default: 0.0.
- **global_step** (Tensor) - Tensor to record current global step. - **global_step** (Tensor) - Tensor to record current global step.
- **gradient** (Tensor) - Gradient, has the same shape and data type as `var`. - **gradient** (Tensor) - Gradient, has the same shape and data type as `var`.
- **decay_flag** (bool) - Specifies whether param update with weight decay.
Outputs: Outputs:
Tensor, the updated parameters. Tensor, the updated parameters.
- **var** (Tensor) - The same shape and data type as `var`. - **var** (Tensor) - The same shape and data type as `var`.
@ -315,19 +314,19 @@ class Lamb(PrimitiveWithInfer):
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape, def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape,
epsilon_shape, decay_shape, global_step_shape, gradient_shape, decay_flag_shape): epsilon_shape, decay_shape, global_step_shape, gradient_shape):
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "gradient_shape", gradient_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "gradient_shape", gradient_shape, Rel.EQ, self.name)
return var_shape return var_shape
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype, def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
epsilon_dtype, decay_dtype, global_step_dtype, gradient_dtype, decay_flag_dtype): epsilon_dtype, decay_dtype, global_step_dtype, gradient_dtype):
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "lr": lr_dtype, "grad": gradient_dtype, args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": gradient_dtype}
"decay": decay_dtype}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
args = {"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} args = {"lr": lr_dtype, "decay": decay_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype,
"epsilon": epsilon_dtype}
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True) validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
return var_dtype return var_dtype

View File

@ -142,9 +142,8 @@ class MyLamb(nn.Cell):
self.gradient = Parameter(gradient, name="grad") self.gradient = Parameter(gradient, name="grad")
self.lamb = inner.Lamb() self.lamb = inner.Lamb()
def construct(self, beta1, beta2, eps, global_step, lr, weight_decay, decay_flag): def construct(self, beta1, beta2, eps, global_step, lr, weight_decay):
return self.lamb(self.param, self.m, self.v, lr, beta1, beta2, eps, weight_decay, global_step, self.gradient, return self.lamb(self.param, self.m, self.v, lr, beta1, beta2, eps, weight_decay, global_step, self.gradient)
decay_flag)
def test_gpu_net(): def test_gpu_net():
@ -154,7 +153,7 @@ def test_gpu_net():
Expectation: get the same result when use new lamb kernel and old kernel Expectation: get the same result when use new lamb kernel and old kernel
""" """
my_lamb = MyLamb(param_val, m_val, v_val, grad_val) my_lamb = MyLamb(param_val, m_val, v_val, grad_val)
my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True) my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val)
lamb_gpu_origin = LambGPUOrigin(param_val, m_val, v_val, grad_val) lamb_gpu_origin = LambGPUOrigin(param_val, m_val, v_val, grad_val)
lamb_gpu_origin(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True) lamb_gpu_origin(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)
@ -169,7 +168,7 @@ def test_ascend_net():
Expectation: get the same result when use new lamb kernel and old kernel Expectation: get the same result when use new lamb kernel and old kernel
""" """
my_lamb = MyLamb(param_val, m_val, v_val, grad_val) my_lamb = MyLamb(param_val, m_val, v_val, grad_val)
my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True) my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val)
lamb_ascend_origin = LambAscendOrigin(param_val, m_val, v_val, grad_val) lamb_ascend_origin = LambAscendOrigin(param_val, m_val, v_val, grad_val)
lamb_ascend_origin(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True) lamb_ascend_origin(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)