diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/lamb_fission.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/lamb_fission.cc index 707389a9db7..ca38f65a6f0 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/lamb_fission.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/lamb_fission.cc @@ -24,7 +24,7 @@ namespace mindspore { namespace opt { namespace { -// Lamb's inputs: param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, gradient, decay_flag +// Lamb's inputs: param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, gradient constexpr size_t kParamIndex = 1; constexpr size_t kMIndex = 2; constexpr size_t kVIndex = 3; @@ -35,10 +35,9 @@ constexpr size_t kEpsilonIndex = 7; constexpr size_t kWeightDecayIndex = 8; constexpr size_t kGlobalStepIndex = 9; constexpr size_t kGradientIndex = 10; -constexpr size_t kDecayFlagIndex = 11; -constexpr size_t kUMonadIndex = 12; -constexpr size_t kLambInputNum = 11; -constexpr size_t kLambInputNumWithUMonad = 12; +constexpr size_t kUMonadIndex = 11; +constexpr size_t kLambInputNum = 10; +constexpr size_t kLambInputNumWithUMonad = 11; constexpr size_t kLambApplyOptimizerAssignOutputNum = 3; constexpr size_t kLambApplyOptimizerAssignUpdateIndex = 0; @@ -246,7 +245,7 @@ const AnfNodePtr LambFission::Process(const FuncGraphPtr &graph, const AnfNodePt auto new_global_step = CreateCastNode(graph, global_step_node, kNumberTypeFloat32); // cast delay flag to float32 - auto weight_decay_flag = CreateCastNode(graph, ori_inputs[kDecayFlagIndex], kNumberTypeFloat32); + auto weight_decay_flag = CreateValueNode(graph, 1.0); auto num_one = CreateValueNode(graph, 1.0); // create 1-beta1 diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/lamb_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/lamb_impl.cu index f5a4641611b..09535f895e0 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/lamb_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/lamb_impl.cu @@ -21,9 +21,9 @@ const int32_t kSqareNum = 2; template __global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2, - const float *epsilon, const T *decay, const int32_t *global_step, - const T *gradient, const bool *decay_flag, float *update, float *var_float, - float *grad_float, float *g_hat_var) { + const float *epsilon, const float *decay, const int32_t *global_step, + const T *gradient, float *update, float *var_float, float *grad_float, + float *g_hat_var) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { float next_m = (beta1[0] * m[i] + (1 - beta1[0]) * gradient[i]); float next_v = (beta2[0] * v[i] + (1 - beta2[0]) * pow(gradient[i], kSqareNum)); @@ -33,9 +33,7 @@ __global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v, grad_float[i] = gradient[i]; g_hat_var[i] = (next_mm / sqrt(next_vv + epsilon[0]) + decay[0] * variable[i]); update[i] = next_mm / (sqrt(next_vv) - epsilon[0]); - if (decay_flag[0]) { - update[i] += decay[0] * variable[i]; - } + update[i] += decay[0] * variable[i]; m[i] = next_m; v[i] = next_v; } @@ -43,13 +41,13 @@ __global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v, template <> __global__ void ApplyLambEralyKernel(const size_t size, half *variable, half *m, half *v, const float *beta1, - const float *beta2, const float *epsilon, const half *decay, - const int32_t *global_step, const half *gradient, const bool *decay_flag, - float *update, float *var_float, float *grad_float, float *g_hat_var) { + const float *beta2, const float *epsilon, const float *decay, + const int32_t *global_step, const half *gradient, float *update, float *var_float, + float *grad_float, float *g_hat_var) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { float float_gradient = __half2float(gradient[i]); float float_var = __half2float(variable[i]); - float float_decay = __half2float(decay[0]); + float float_decay = decay[0]; float next_m = (beta1[0] * __half2float(m[i]) + (1 - beta1[0]) * float_gradient); float next_v = (beta2[0] * __half2float(v[i]) + (1 - beta2[0]) * pow(float_gradient, kSqareNum)); @@ -59,16 +57,14 @@ __global__ void ApplyLambEralyKernel(const size_t size, half *variable, half *m, grad_float[i] = float_gradient; g_hat_var[i] = next_mm / sqrt(next_vv + epsilon[0]) + float_decay * float_var; update[i] = next_mm / (sqrt(next_vv) - epsilon[0]); - if (decay_flag[0]) { - update[i] += float_decay * float_var; - } + update[i] += float_decay * float_var; m[i] = __float2half(next_m); v[i] = __float2half(next_v); } } template -__global__ void ApplyLambAfterNormKernel(const size_t size, T *variable, const T *lr, const float *update, +__global__ void ApplyLambAfterNormKernel(const size_t size, T *variable, const float *lr, const float *update, const float *trust_ratio) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { variable[i] = variable[i] - trust_ratio[0] * lr[0] * update[i]; @@ -76,25 +72,24 @@ __global__ void ApplyLambAfterNormKernel(const size_t size, T *variable, const T } template <> -__global__ void ApplyLambAfterNormKernel(const size_t size, half *variable, const half *lr, const float *update, +__global__ void ApplyLambAfterNormKernel(const size_t size, half *variable, const float *lr, const float *update, const float *trust_ratio) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { - variable[i] = __float2half(__half2float(variable[i]) - trust_ratio[0] * __half2float(lr[0]) * update[i]); + variable[i] = __float2half(__half2float(variable[i]) - trust_ratio[0] * lr[0] * update[i]); } } template void ApplyLambEraly(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2, - const float *epsilon, const T *decay, const int32_t *global_step, const T *gradient, - const bool *decay_flag, float *update, float *var_float, float *grad_float, float *g_hat_var, - cudaStream_t cuda_stream) { + const float *epsilon, const float *decay, const int32_t *global_step, const T *gradient, + float *update, float *var_float, float *grad_float, float *g_hat_var, cudaStream_t cuda_stream) { ApplyLambEralyKernel<<>>(size, variable, m, v, beta1, beta2, epsilon, - decay, global_step, gradient, decay_flag, - update, var_float, grad_float, g_hat_var); + decay, global_step, gradient, update, + var_float, grad_float, g_hat_var); } template -CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr, const float *update, +CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const float *lr, const float *update, const float *trust_ratio, cudaStream_t cuda_stream) { ApplyLambAfterNormKernel<<>>(size, variable, lr, update, trust_ratio); } @@ -102,21 +97,20 @@ CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr, template CUDA_LIB_EXPORT void ApplyLambEraly(const size_t size, float *variable, float *m, float *v, const float *beta1, const float *beta2, const float *epsilon, const float *decay, const int32_t *global_step, - const float *gradient, const bool *decay_flag, float *update, - float *w_square_ptr, float *g_square_ptr, float *g_hat_square_ptr, + const float *gradient, float *update, float *w_square_ptr, + float *g_square_ptr, float *g_hat_square_ptr, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void ApplyLambEraly(const size_t size, half *variable, half *m, half *v, const float *beta1, const float *beta2, const float *epsilon, - const half *decay, const int32_t *global_step, const half *gradient, - const bool *decay_flag, float *update, float *w_square_ptr, - float *g_square_ptr, float *g_hat_square_ptr, - cudaStream_t cuda_stream); + const float *decay, const int32_t *global_step, const half *gradient, + float *update, float *w_square_ptr, float *g_square_ptr, + float *g_hat_square_ptr, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, float *variable, const float *lr, const float *update, const float *trust_ratio, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, half *variable, const half *lr, +template CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, half *variable, const float *lr, const float *update, const float *trust_ratio, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/lamb_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/lamb_impl.cuh index ca0d76e0d48..0655cc3c24f 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/lamb_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/lamb_impl.cuh @@ -19,12 +19,12 @@ #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" template CUDA_LIB_EXPORT void ApplyLambEraly(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2, - const float *epsilon, const T *decay, const int32_t *global_step, const T *gradient, - const bool *decay_flag, float *update, float *var_float, float *grad_float, + const float *epsilon, const float *decay, const int32_t *global_step, + const T *gradient, float *update, float *var_float, float *grad_float, float *g_hat_var, cudaStream_t cuda_stream); template -CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr, const float *update, +CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const float *lr, const float *update, const float *trust_ratio, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LAMB_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/lamb_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/lamb_gpu_kernel.cc index 7c17a348810..84139b8a5b5 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/lamb_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/lamb_gpu_kernel.cc @@ -30,7 +30,6 @@ MS_REG_GPU_KERNEL_ONE(Lamb, .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeBool) .AddOutputAttr(kNumberTypeFloat32), LambGpuKernelMod, float) MS_REG_GPU_KERNEL_ONE(Lamb, @@ -45,7 +44,6 @@ MS_REG_GPU_KERNEL_ONE(Lamb, .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeBool) .AddOutputAttr(kNumberTypeFloat16), LambGpuKernelMod, half) } // namespace kernel diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/lamb_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/lamb_gpu_kernel.h index ae80691b06a..51d473e91ab 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/lamb_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/lamb_gpu_kernel.h @@ -28,7 +28,7 @@ namespace mindspore { namespace kernel { -constexpr size_t INPUT_NUM = 11; +constexpr size_t INPUT_NUM = 10; constexpr size_t kArgMaxDim = 7; constexpr float ten = 10; @@ -43,7 +43,6 @@ constexpr size_t kEpsilonIndex = 6; constexpr size_t kWeightDecayIndex = 7; constexpr size_t kGlobalStepIndex = 8; constexpr size_t kGradIndex = 9; -constexpr size_t kDecayFlagIndex = 10; // workspaces param index constexpr size_t kUpdateIndex = 0; @@ -72,21 +71,20 @@ class LambGpuKernelMod : public NativeGpuKernelMod { T *variable = GetDeviceAddress(inputs, kVarIndex); T *m = GetDeviceAddress(inputs, kMIndex); T *v = GetDeviceAddress(inputs, kVIndex); - T *learning_rate = GetDeviceAddress(inputs, kLearningRateIndex); + float *learning_rate = GetDeviceAddress(inputs, kLearningRateIndex); float *beta1 = GetDeviceAddress(inputs, kBeta1Index); float *beta2 = GetDeviceAddress(inputs, kBeta2Index); float *epsilon = GetDeviceAddress(inputs, kEpsilonIndex); - T *decay = GetDeviceAddress(inputs, kWeightDecayIndex); + float *decay = GetDeviceAddress(inputs, kWeightDecayIndex); int32_t *global_step = GetDeviceAddress(inputs, kGlobalStepIndex); T *gradient = GetDeviceAddress(inputs, kGradIndex); - bool *decay_flag = GetDeviceAddress(inputs, kDecayFlagIndex); float *update = GetDeviceAddress(workspaces, kUpdateIndex); float *var_float = GetDeviceAddress(workspaces, kVarFloatIndex); float *grad_float = GetDeviceAddress(workspaces, kGradFloatIndex); float *g_hat_var = GetDeviceAddress(workspaces, kGHatValIndex); ApplyLambEraly(inputs[0]->size / sizeof(T), variable, m, v, beta1, beta2, epsilon, decay, global_step, gradient, - decay_flag, update, var_float, grad_float, g_hat_var, reinterpret_cast(stream_ptr)); + update, var_float, grad_float, g_hat_var, reinterpret_cast(stream_ptr)); float trust_ratio{0}; CalcTrustRatio(workspaces, var_float, grad_float, g_hat_var, stream_ptr, &trust_ratio); @@ -176,7 +174,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod { input_size_list_.push_back(decay_size_); input_size_list_.push_back(global_step_size_); input_size_list_.push_back(gradient_size_); - input_size_list_.push_back(decay_flag_size_); workspace_size_list_.push_back(update_size_); workspace_size_list_.push_back(var_float_size_); workspace_size_list_.push_back(grad_float_size_); @@ -253,7 +250,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod { decay_size_ = sizeof(T); global_step_size_ = sizeof(int32_t); gradient_size_ = sizeof(T); - decay_flag_size_ = sizeof(bool); update_size_ = sizeof(float); var_float_size_ = sizeof(float); grad_float_size_ = sizeof(float); @@ -386,7 +382,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod { size_t decay_size_{0}; size_t global_step_size_{0}; size_t gradient_size_{0}; - size_t decay_flag_size_{0}; size_t update_size_{0}; size_t var_float_size_{0}; size_t grad_float_size_{0}; diff --git a/mindspore/core/ops/lamb.cc b/mindspore/core/ops/lamb.cc index 6d2551152b8..7bf28f054de 100644 --- a/mindspore/core/ops/lamb.cc +++ b/mindspore/core/ops/lamb.cc @@ -35,15 +35,12 @@ TypePtr LambInferType(const PrimitivePtr &primitive, const std::vectorBuildType(); auto global_step_type = input_args[kInputIndex8]->BuildType(); auto grad_type = input_args[kInputIndex9]->BuildType(); - auto decay_flag_type = input_args[kInputIndex10]->BuildType(); std::map type_dict; type_dict.emplace("var", var_type); type_dict.emplace("m", m_type); type_dict.emplace("v", v_type); type_dict.emplace("grad", grad_type); - type_dict.emplace("lr", lr_type); - type_dict.emplace("decay", decay_type); std::set num_type = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128}; (void)CheckAndConvertUtils::CheckTensorTypeSame(type_dict, num_type, prim_name); @@ -51,12 +48,12 @@ TypePtr LambInferType(const PrimitivePtr &primitive, const std::vector float_set = {kFloat16, kFloat32}; + type_dict1.emplace("lr", lr_type); + type_dict1.emplace("decay", decay_type); + std::set float_set = {kFloat32}; (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(type_dict1, float_set, prim_name, true); - std::set bool_set = {kBool}; (void)CheckAndConvertUtils::CheckTypeValid("global_step", global_step_type, num_type, prim_name); - (void)CheckAndConvertUtils::CheckTypeValid("decay_flag", decay_flag_type, bool_set, prim_name); return var_type; } @@ -97,7 +94,7 @@ AbstractBasePtr LambInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt for (auto item : input_args) { MS_EXCEPTION_IF_NULL(item); } - const int64_t kInputNum = 11; + const int64_t kInputNum = 10; CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name); auto infer_type = LambInferType(primitive, input_args); auto infer_shape = LambInferShape(primitive, input_args); diff --git a/mindspore/python/mindspore/nn/optim/lamb.py b/mindspore/python/mindspore/nn/optim/lamb.py index ce9725462d4..0a2825e6489 100755 --- a/mindspore/python/mindspore/nn/optim/lamb.py +++ b/mindspore/python/mindspore/nn/optim/lamb.py @@ -15,8 +15,6 @@ """lamb""" import numpy as np from mindspore import context -from mindspore.common import dtype as mstype -from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops.operations import _inner_ops as inner @@ -26,11 +24,6 @@ from mindspore._checkparam import Rel from .optimizer import Optimizer from .optimizer import opt_init_args_register -from .. import layer - - -num_one = Tensor(np.ones([1]), mstype.float32) - _lamb_opt = C.MultitypeFuncGraph("lamb_opt") @@ -40,89 +33,6 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v """ Update parameters. - Args: - beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). - beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). - eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. - lr (Tensor): Learning rate. - weight_decay (numbers.Number): Weight decay. Should be equal to or greater than 0. - global_step (Tensor): Global step. - param (Tensor): Parameters. - m (Tensor): m value of parameters. - v (Tensor): v value of parameters. - gradient (Tensor): Gradient of parameters. - decay_flag (bool): Specifies whether param update with weight decay. - optim_filter(bool): Applies parameter update or not. - - Returns: - Tensor, the new value of v after updating. - """ - if optim_filter: - op_mul = P.Mul() - op_sqrt = P.Sqrt() - op_rsqrt = P.Rsqrt() - op_square = P.Square() - op_cast = P.Cast() - op_reshape = P.Reshape() - op_shape = P.Shape() - op_pow = P.Pow() - op_norm = layer.Norm() - op_select = P.Select() - op_greater = P.Greater() - op_fill = P.Fill() - op_dtype = P.DType() - - param_fp32 = op_cast(param, mstype.float32) - m_fp32 = op_cast(m, mstype.float32) - v_fp32 = op_cast(v, mstype.float32) - gradient_fp32 = op_cast(gradient, mstype.float32) - - next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32) - - next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32)) - - next_mm = next_m / (op_cast(num_one, mstype.float32) - - op_pow(beta1, op_cast(global_step, mstype.float32))) - next_vv = next_v / (op_cast(num_one, mstype.float32) - - op_pow(beta2, op_cast(global_step, mstype.float32))) - w_norm = op_norm(param_fp32) - g_norm = op_norm(gradient_fp32) - - g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay * param_fp32) - zeros = F.zeros_like(w_norm) - ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) - trust_ratio = op_select( - op_greater(w_norm, zeros), - op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones), - ones) - tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0) - trust_ratio = C.clip_by_value(trust_ratio, zeros, tens) - update = next_mm / (op_sqrt(next_vv) + eps) - - if decay_flag: - update = update + op_mul(weight_decay, param_fp32) - - update_with_lr = op_mul(op_mul(trust_ratio, lr), update) - - next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) - - next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) - next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) - next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) - - return op_cast(next_param, F.dtype(param)) - return gradient - -_lamb_opt_ascend = C.MultitypeFuncGraph("lamb_opt_ascend") - - -@_lamb_opt_ascend.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", - "Tensor", "Bool", "Bool") -def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag, - optim_filter): - """ - Update parameters function when device target is ascend. - Args: beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). @@ -142,9 +52,13 @@ def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, para """ if optim_filter: op_lamb = inner.Lamb() - return op_lamb(param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, - gradient, decay_flag) - return gradient + if decay_flag: + ret = op_lamb(param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, gradient) + else: + ret = op_lamb(param, m, v, lr, beta1, beta2, eps, 0.0, global_step, gradient) + else: + ret = gradient + return ret def _check_param_value(beta1, beta2, eps, prim_name): @@ -345,7 +259,7 @@ class Lamb(Optimizer): lr = self.get_lr() if not self.is_dynamic_lr_or_weight_decay(): self.assignadd(self.global_step, self.global_step_increase_tensor) - lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt + lamb_opt = _lamb_opt gradients = self.flatten_gradients(gradients) gradients = self.gradients_centralization(gradients) if self.is_group: diff --git a/mindspore/python/mindspore/ops/operations/_inner_ops.py b/mindspore/python/mindspore/ops/operations/_inner_ops.py index 1516f4155ca..3dcb0273b84 100755 --- a/mindspore/python/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/python/mindspore/ops/operations/_inner_ops.py @@ -300,7 +300,6 @@ class Lamb(PrimitiveWithInfer): Default: 0.0. - **global_step** (Tensor) - Tensor to record current global step. - **gradient** (Tensor) - Gradient, has the same shape and data type as `var`. - - **decay_flag** (bool) - Specifies whether param update with weight decay. Outputs: Tensor, the updated parameters. - **var** (Tensor) - The same shape and data type as `var`. @@ -315,19 +314,19 @@ class Lamb(PrimitiveWithInfer): self.add_prim_attr('side_effect_mem', True) def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape, - epsilon_shape, decay_shape, global_step_shape, gradient_shape, decay_flag_shape): + epsilon_shape, decay_shape, global_step_shape, gradient_shape): validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "gradient_shape", gradient_shape, Rel.EQ, self.name) return var_shape def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype, - epsilon_dtype, decay_dtype, global_step_dtype, gradient_dtype, decay_flag_dtype): - args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "lr": lr_dtype, "grad": gradient_dtype, - "decay": decay_dtype} + epsilon_dtype, decay_dtype, global_step_dtype, gradient_dtype): + args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": gradient_dtype} validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) - args = {"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} + args = {"lr": lr_dtype, "decay": decay_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, + "epsilon": epsilon_dtype} validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True) return var_dtype diff --git a/tests/st/optimizer/test_lamb_op.py b/tests/st/optimizer/test_lamb_op.py index b188f12f963..53a49285c09 100644 --- a/tests/st/optimizer/test_lamb_op.py +++ b/tests/st/optimizer/test_lamb_op.py @@ -142,9 +142,8 @@ class MyLamb(nn.Cell): self.gradient = Parameter(gradient, name="grad") self.lamb = inner.Lamb() - def construct(self, beta1, beta2, eps, global_step, lr, weight_decay, decay_flag): - return self.lamb(self.param, self.m, self.v, lr, beta1, beta2, eps, weight_decay, global_step, self.gradient, - decay_flag) + def construct(self, beta1, beta2, eps, global_step, lr, weight_decay): + return self.lamb(self.param, self.m, self.v, lr, beta1, beta2, eps, weight_decay, global_step, self.gradient) def test_gpu_net(): @@ -154,7 +153,7 @@ def test_gpu_net(): Expectation: get the same result when use new lamb kernel and old kernel """ my_lamb = MyLamb(param_val, m_val, v_val, grad_val) - my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True) + my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val) lamb_gpu_origin = LambGPUOrigin(param_val, m_val, v_val, grad_val) lamb_gpu_origin(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True) @@ -169,7 +168,7 @@ def test_ascend_net(): Expectation: get the same result when use new lamb kernel and old kernel """ my_lamb = MyLamb(param_val, m_val, v_val, grad_val) - my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True) + my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val) lamb_ascend_origin = LambAscendOrigin(param_val, m_val, v_val, grad_val) lamb_ascend_origin(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)