optimize lamb
This commit is contained in:
parent
8167c4dbe4
commit
ab78aa86ee
|
@ -24,7 +24,7 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
// Lamb's inputs: param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, gradient, decay_flag
|
||||
// Lamb's inputs: param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, gradient
|
||||
constexpr size_t kParamIndex = 1;
|
||||
constexpr size_t kMIndex = 2;
|
||||
constexpr size_t kVIndex = 3;
|
||||
|
@ -35,10 +35,9 @@ constexpr size_t kEpsilonIndex = 7;
|
|||
constexpr size_t kWeightDecayIndex = 8;
|
||||
constexpr size_t kGlobalStepIndex = 9;
|
||||
constexpr size_t kGradientIndex = 10;
|
||||
constexpr size_t kDecayFlagIndex = 11;
|
||||
constexpr size_t kUMonadIndex = 12;
|
||||
constexpr size_t kLambInputNum = 11;
|
||||
constexpr size_t kLambInputNumWithUMonad = 12;
|
||||
constexpr size_t kUMonadIndex = 11;
|
||||
constexpr size_t kLambInputNum = 10;
|
||||
constexpr size_t kLambInputNumWithUMonad = 11;
|
||||
constexpr size_t kLambApplyOptimizerAssignOutputNum = 3;
|
||||
constexpr size_t kLambApplyOptimizerAssignUpdateIndex = 0;
|
||||
|
||||
|
@ -246,7 +245,7 @@ const AnfNodePtr LambFission::Process(const FuncGraphPtr &graph, const AnfNodePt
|
|||
auto new_global_step = CreateCastNode(graph, global_step_node, kNumberTypeFloat32);
|
||||
|
||||
// cast delay flag to float32
|
||||
auto weight_decay_flag = CreateCastNode(graph, ori_inputs[kDecayFlagIndex], kNumberTypeFloat32);
|
||||
auto weight_decay_flag = CreateValueNode(graph, 1.0);
|
||||
|
||||
auto num_one = CreateValueNode(graph, 1.0);
|
||||
// create 1-beta1
|
||||
|
|
|
@ -21,9 +21,9 @@ const int32_t kSqareNum = 2;
|
|||
|
||||
template <typename T>
|
||||
__global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
|
||||
const float *epsilon, const T *decay, const int32_t *global_step,
|
||||
const T *gradient, const bool *decay_flag, float *update, float *var_float,
|
||||
float *grad_float, float *g_hat_var) {
|
||||
const float *epsilon, const float *decay, const int32_t *global_step,
|
||||
const T *gradient, float *update, float *var_float, float *grad_float,
|
||||
float *g_hat_var) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
float next_m = (beta1[0] * m[i] + (1 - beta1[0]) * gradient[i]);
|
||||
float next_v = (beta2[0] * v[i] + (1 - beta2[0]) * pow(gradient[i], kSqareNum));
|
||||
|
@ -33,9 +33,7 @@ __global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v,
|
|||
grad_float[i] = gradient[i];
|
||||
g_hat_var[i] = (next_mm / sqrt(next_vv + epsilon[0]) + decay[0] * variable[i]);
|
||||
update[i] = next_mm / (sqrt(next_vv) - epsilon[0]);
|
||||
if (decay_flag[0]) {
|
||||
update[i] += decay[0] * variable[i];
|
||||
}
|
||||
update[i] += decay[0] * variable[i];
|
||||
m[i] = next_m;
|
||||
v[i] = next_v;
|
||||
}
|
||||
|
@ -43,13 +41,13 @@ __global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v,
|
|||
|
||||
template <>
|
||||
__global__ void ApplyLambEralyKernel(const size_t size, half *variable, half *m, half *v, const float *beta1,
|
||||
const float *beta2, const float *epsilon, const half *decay,
|
||||
const int32_t *global_step, const half *gradient, const bool *decay_flag,
|
||||
float *update, float *var_float, float *grad_float, float *g_hat_var) {
|
||||
const float *beta2, const float *epsilon, const float *decay,
|
||||
const int32_t *global_step, const half *gradient, float *update, float *var_float,
|
||||
float *grad_float, float *g_hat_var) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
float float_gradient = __half2float(gradient[i]);
|
||||
float float_var = __half2float(variable[i]);
|
||||
float float_decay = __half2float(decay[0]);
|
||||
float float_decay = decay[0];
|
||||
|
||||
float next_m = (beta1[0] * __half2float(m[i]) + (1 - beta1[0]) * float_gradient);
|
||||
float next_v = (beta2[0] * __half2float(v[i]) + (1 - beta2[0]) * pow(float_gradient, kSqareNum));
|
||||
|
@ -59,16 +57,14 @@ __global__ void ApplyLambEralyKernel(const size_t size, half *variable, half *m,
|
|||
grad_float[i] = float_gradient;
|
||||
g_hat_var[i] = next_mm / sqrt(next_vv + epsilon[0]) + float_decay * float_var;
|
||||
update[i] = next_mm / (sqrt(next_vv) - epsilon[0]);
|
||||
if (decay_flag[0]) {
|
||||
update[i] += float_decay * float_var;
|
||||
}
|
||||
update[i] += float_decay * float_var;
|
||||
m[i] = __float2half(next_m);
|
||||
v[i] = __float2half(next_v);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ApplyLambAfterNormKernel(const size_t size, T *variable, const T *lr, const float *update,
|
||||
__global__ void ApplyLambAfterNormKernel(const size_t size, T *variable, const float *lr, const float *update,
|
||||
const float *trust_ratio) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
variable[i] = variable[i] - trust_ratio[0] * lr[0] * update[i];
|
||||
|
@ -76,25 +72,24 @@ __global__ void ApplyLambAfterNormKernel(const size_t size, T *variable, const T
|
|||
}
|
||||
|
||||
template <>
|
||||
__global__ void ApplyLambAfterNormKernel(const size_t size, half *variable, const half *lr, const float *update,
|
||||
__global__ void ApplyLambAfterNormKernel(const size_t size, half *variable, const float *lr, const float *update,
|
||||
const float *trust_ratio) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
variable[i] = __float2half(__half2float(variable[i]) - trust_ratio[0] * __half2float(lr[0]) * update[i]);
|
||||
variable[i] = __float2half(__half2float(variable[i]) - trust_ratio[0] * lr[0] * update[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ApplyLambEraly(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
|
||||
const float *epsilon, const T *decay, const int32_t *global_step, const T *gradient,
|
||||
const bool *decay_flag, float *update, float *var_float, float *grad_float, float *g_hat_var,
|
||||
cudaStream_t cuda_stream) {
|
||||
const float *epsilon, const float *decay, const int32_t *global_step, const T *gradient,
|
||||
float *update, float *var_float, float *grad_float, float *g_hat_var, cudaStream_t cuda_stream) {
|
||||
ApplyLambEralyKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, m, v, beta1, beta2, epsilon,
|
||||
decay, global_step, gradient, decay_flag,
|
||||
update, var_float, grad_float, g_hat_var);
|
||||
decay, global_step, gradient, update,
|
||||
var_float, grad_float, g_hat_var);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr, const float *update,
|
||||
CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const float *lr, const float *update,
|
||||
const float *trust_ratio, cudaStream_t cuda_stream) {
|
||||
ApplyLambAfterNormKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, lr, update, trust_ratio);
|
||||
}
|
||||
|
@ -102,21 +97,20 @@ CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr,
|
|||
template CUDA_LIB_EXPORT void ApplyLambEraly<float>(const size_t size, float *variable, float *m, float *v,
|
||||
const float *beta1, const float *beta2, const float *epsilon,
|
||||
const float *decay, const int32_t *global_step,
|
||||
const float *gradient, const bool *decay_flag, float *update,
|
||||
float *w_square_ptr, float *g_square_ptr, float *g_hat_square_ptr,
|
||||
const float *gradient, float *update, float *w_square_ptr,
|
||||
float *g_square_ptr, float *g_hat_square_ptr,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyLambEraly<half>(const size_t size, half *variable, half *m, half *v,
|
||||
const float *beta1, const float *beta2, const float *epsilon,
|
||||
const half *decay, const int32_t *global_step, const half *gradient,
|
||||
const bool *decay_flag, float *update, float *w_square_ptr,
|
||||
float *g_square_ptr, float *g_hat_square_ptr,
|
||||
cudaStream_t cuda_stream);
|
||||
const float *decay, const int32_t *global_step, const half *gradient,
|
||||
float *update, float *w_square_ptr, float *g_square_ptr,
|
||||
float *g_hat_square_ptr, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyLambLater<float>(const size_t size, float *variable, const float *lr,
|
||||
const float *update, const float *trust_ratio,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void ApplyLambLater<half>(const size_t size, half *variable, const half *lr,
|
||||
template CUDA_LIB_EXPORT void ApplyLambLater<half>(const size_t size, half *variable, const float *lr,
|
||||
const float *update, const float *trust_ratio,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -19,12 +19,12 @@
|
|||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void ApplyLambEraly(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
|
||||
const float *epsilon, const T *decay, const int32_t *global_step, const T *gradient,
|
||||
const bool *decay_flag, float *update, float *var_float, float *grad_float,
|
||||
const float *epsilon, const float *decay, const int32_t *global_step,
|
||||
const T *gradient, float *update, float *var_float, float *grad_float,
|
||||
float *g_hat_var, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr, const float *update,
|
||||
CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const float *lr, const float *update,
|
||||
const float *trust_ratio, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LAMB_IMPL_CUH_
|
||||
|
|
|
@ -30,7 +30,6 @@ MS_REG_GPU_KERNEL_ONE(Lamb,
|
|||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
LambGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Lamb,
|
||||
|
@ -45,7 +44,6 @@ MS_REG_GPU_KERNEL_ONE(Lamb,
|
|||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
LambGpuKernelMod, half)
|
||||
} // namespace kernel
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t INPUT_NUM = 11;
|
||||
constexpr size_t INPUT_NUM = 10;
|
||||
constexpr size_t kArgMaxDim = 7;
|
||||
constexpr float ten = 10;
|
||||
|
||||
|
@ -43,7 +43,6 @@ constexpr size_t kEpsilonIndex = 6;
|
|||
constexpr size_t kWeightDecayIndex = 7;
|
||||
constexpr size_t kGlobalStepIndex = 8;
|
||||
constexpr size_t kGradIndex = 9;
|
||||
constexpr size_t kDecayFlagIndex = 10;
|
||||
|
||||
// workspaces param index
|
||||
constexpr size_t kUpdateIndex = 0;
|
||||
|
@ -72,21 +71,20 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
|
|||
T *variable = GetDeviceAddress<T>(inputs, kVarIndex);
|
||||
T *m = GetDeviceAddress<T>(inputs, kMIndex);
|
||||
T *v = GetDeviceAddress<T>(inputs, kVIndex);
|
||||
T *learning_rate = GetDeviceAddress<T>(inputs, kLearningRateIndex);
|
||||
float *learning_rate = GetDeviceAddress<float>(inputs, kLearningRateIndex);
|
||||
float *beta1 = GetDeviceAddress<float>(inputs, kBeta1Index);
|
||||
float *beta2 = GetDeviceAddress<float>(inputs, kBeta2Index);
|
||||
float *epsilon = GetDeviceAddress<float>(inputs, kEpsilonIndex);
|
||||
T *decay = GetDeviceAddress<T>(inputs, kWeightDecayIndex);
|
||||
float *decay = GetDeviceAddress<float>(inputs, kWeightDecayIndex);
|
||||
int32_t *global_step = GetDeviceAddress<int32_t>(inputs, kGlobalStepIndex);
|
||||
T *gradient = GetDeviceAddress<T>(inputs, kGradIndex);
|
||||
bool *decay_flag = GetDeviceAddress<bool>(inputs, kDecayFlagIndex);
|
||||
float *update = GetDeviceAddress<float>(workspaces, kUpdateIndex);
|
||||
float *var_float = GetDeviceAddress<float>(workspaces, kVarFloatIndex);
|
||||
float *grad_float = GetDeviceAddress<float>(workspaces, kGradFloatIndex);
|
||||
float *g_hat_var = GetDeviceAddress<float>(workspaces, kGHatValIndex);
|
||||
|
||||
ApplyLambEraly(inputs[0]->size / sizeof(T), variable, m, v, beta1, beta2, epsilon, decay, global_step, gradient,
|
||||
decay_flag, update, var_float, grad_float, g_hat_var, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
update, var_float, grad_float, g_hat_var, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
float trust_ratio{0};
|
||||
CalcTrustRatio(workspaces, var_float, grad_float, g_hat_var, stream_ptr, &trust_ratio);
|
||||
|
@ -176,7 +174,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
|
|||
input_size_list_.push_back(decay_size_);
|
||||
input_size_list_.push_back(global_step_size_);
|
||||
input_size_list_.push_back(gradient_size_);
|
||||
input_size_list_.push_back(decay_flag_size_);
|
||||
workspace_size_list_.push_back(update_size_);
|
||||
workspace_size_list_.push_back(var_float_size_);
|
||||
workspace_size_list_.push_back(grad_float_size_);
|
||||
|
@ -253,7 +250,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
|
|||
decay_size_ = sizeof(T);
|
||||
global_step_size_ = sizeof(int32_t);
|
||||
gradient_size_ = sizeof(T);
|
||||
decay_flag_size_ = sizeof(bool);
|
||||
update_size_ = sizeof(float);
|
||||
var_float_size_ = sizeof(float);
|
||||
grad_float_size_ = sizeof(float);
|
||||
|
@ -386,7 +382,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
|
|||
size_t decay_size_{0};
|
||||
size_t global_step_size_{0};
|
||||
size_t gradient_size_{0};
|
||||
size_t decay_flag_size_{0};
|
||||
size_t update_size_{0};
|
||||
size_t var_float_size_{0};
|
||||
size_t grad_float_size_{0};
|
||||
|
|
|
@ -35,15 +35,12 @@ TypePtr LambInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
|
|||
auto decay_type = input_args[kInputIndex7]->BuildType();
|
||||
auto global_step_type = input_args[kInputIndex8]->BuildType();
|
||||
auto grad_type = input_args[kInputIndex9]->BuildType();
|
||||
auto decay_flag_type = input_args[kInputIndex10]->BuildType();
|
||||
|
||||
std::map<std::string, TypePtr> type_dict;
|
||||
type_dict.emplace("var", var_type);
|
||||
type_dict.emplace("m", m_type);
|
||||
type_dict.emplace("v", v_type);
|
||||
type_dict.emplace("grad", grad_type);
|
||||
type_dict.emplace("lr", lr_type);
|
||||
type_dict.emplace("decay", decay_type);
|
||||
std::set<TypePtr> num_type = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
|
||||
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(type_dict, num_type, prim_name);
|
||||
|
@ -51,12 +48,12 @@ TypePtr LambInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
|
|||
type_dict1.emplace("beta1", beta1_type);
|
||||
type_dict1.emplace("beta2", beta2_type);
|
||||
type_dict1.emplace("epsilon", epsilon_type);
|
||||
std::set<TypePtr> float_set = {kFloat16, kFloat32};
|
||||
type_dict1.emplace("lr", lr_type);
|
||||
type_dict1.emplace("decay", decay_type);
|
||||
std::set<TypePtr> float_set = {kFloat32};
|
||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(type_dict1, float_set, prim_name, true);
|
||||
|
||||
std::set<TypePtr> bool_set = {kBool};
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("global_step", global_step_type, num_type, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("decay_flag", decay_flag_type, bool_set, prim_name);
|
||||
|
||||
return var_type;
|
||||
}
|
||||
|
@ -97,7 +94,7 @@ AbstractBasePtr LambInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
for (auto item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const int64_t kInputNum = 11;
|
||||
const int64_t kInputNum = 10;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
|
||||
auto infer_type = LambInferType(primitive, input_args);
|
||||
auto infer_shape = LambInferShape(primitive, input_args);
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
"""lamb"""
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
@ -26,11 +24,6 @@ from mindspore._checkparam import Rel
|
|||
from .optimizer import Optimizer
|
||||
from .optimizer import opt_init_args_register
|
||||
|
||||
from .. import layer
|
||||
|
||||
|
||||
num_one = Tensor(np.ones([1]), mstype.float32)
|
||||
|
||||
_lamb_opt = C.MultitypeFuncGraph("lamb_opt")
|
||||
|
||||
|
||||
|
@ -40,89 +33,6 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v
|
|||
"""
|
||||
Update parameters.
|
||||
|
||||
Args:
|
||||
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
|
||||
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
|
||||
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
|
||||
lr (Tensor): Learning rate.
|
||||
weight_decay (numbers.Number): Weight decay. Should be equal to or greater than 0.
|
||||
global_step (Tensor): Global step.
|
||||
param (Tensor): Parameters.
|
||||
m (Tensor): m value of parameters.
|
||||
v (Tensor): v value of parameters.
|
||||
gradient (Tensor): Gradient of parameters.
|
||||
decay_flag (bool): Specifies whether param update with weight decay.
|
||||
optim_filter(bool): Applies parameter update or not.
|
||||
|
||||
Returns:
|
||||
Tensor, the new value of v after updating.
|
||||
"""
|
||||
if optim_filter:
|
||||
op_mul = P.Mul()
|
||||
op_sqrt = P.Sqrt()
|
||||
op_rsqrt = P.Rsqrt()
|
||||
op_square = P.Square()
|
||||
op_cast = P.Cast()
|
||||
op_reshape = P.Reshape()
|
||||
op_shape = P.Shape()
|
||||
op_pow = P.Pow()
|
||||
op_norm = layer.Norm()
|
||||
op_select = P.Select()
|
||||
op_greater = P.Greater()
|
||||
op_fill = P.Fill()
|
||||
op_dtype = P.DType()
|
||||
|
||||
param_fp32 = op_cast(param, mstype.float32)
|
||||
m_fp32 = op_cast(m, mstype.float32)
|
||||
v_fp32 = op_cast(v, mstype.float32)
|
||||
gradient_fp32 = op_cast(gradient, mstype.float32)
|
||||
|
||||
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32)
|
||||
|
||||
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))
|
||||
|
||||
next_mm = next_m / (op_cast(num_one, mstype.float32)
|
||||
- op_pow(beta1, op_cast(global_step, mstype.float32)))
|
||||
next_vv = next_v / (op_cast(num_one, mstype.float32) -
|
||||
op_pow(beta2, op_cast(global_step, mstype.float32)))
|
||||
w_norm = op_norm(param_fp32)
|
||||
g_norm = op_norm(gradient_fp32)
|
||||
|
||||
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay * param_fp32)
|
||||
zeros = F.zeros_like(w_norm)
|
||||
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
|
||||
trust_ratio = op_select(
|
||||
op_greater(w_norm, zeros),
|
||||
op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
|
||||
ones)
|
||||
tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
|
||||
trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
|
||||
update = next_mm / (op_sqrt(next_vv) + eps)
|
||||
|
||||
if decay_flag:
|
||||
update = update + op_mul(weight_decay, param_fp32)
|
||||
|
||||
update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
|
||||
|
||||
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
|
||||
|
||||
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
|
||||
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
|
||||
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
|
||||
|
||||
return op_cast(next_param, F.dtype(param))
|
||||
return gradient
|
||||
|
||||
_lamb_opt_ascend = C.MultitypeFuncGraph("lamb_opt_ascend")
|
||||
|
||||
|
||||
@_lamb_opt_ascend.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Bool", "Bool")
|
||||
def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag,
|
||||
optim_filter):
|
||||
"""
|
||||
Update parameters function when device target is ascend.
|
||||
|
||||
Args:
|
||||
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
|
||||
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
|
||||
|
@ -142,9 +52,13 @@ def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, para
|
|||
"""
|
||||
if optim_filter:
|
||||
op_lamb = inner.Lamb()
|
||||
return op_lamb(param, m, v, lr, beta1, beta2, eps, weight_decay, global_step,
|
||||
gradient, decay_flag)
|
||||
return gradient
|
||||
if decay_flag:
|
||||
ret = op_lamb(param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, gradient)
|
||||
else:
|
||||
ret = op_lamb(param, m, v, lr, beta1, beta2, eps, 0.0, global_step, gradient)
|
||||
else:
|
||||
ret = gradient
|
||||
return ret
|
||||
|
||||
|
||||
def _check_param_value(beta1, beta2, eps, prim_name):
|
||||
|
@ -345,7 +259,7 @@ class Lamb(Optimizer):
|
|||
lr = self.get_lr()
|
||||
if not self.is_dynamic_lr_or_weight_decay():
|
||||
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
||||
lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt
|
||||
lamb_opt = _lamb_opt
|
||||
gradients = self.flatten_gradients(gradients)
|
||||
gradients = self.gradients_centralization(gradients)
|
||||
if self.is_group:
|
||||
|
|
|
@ -300,7 +300,6 @@ class Lamb(PrimitiveWithInfer):
|
|||
Default: 0.0.
|
||||
- **global_step** (Tensor) - Tensor to record current global step.
|
||||
- **gradient** (Tensor) - Gradient, has the same shape and data type as `var`.
|
||||
- **decay_flag** (bool) - Specifies whether param update with weight decay.
|
||||
Outputs:
|
||||
Tensor, the updated parameters.
|
||||
- **var** (Tensor) - The same shape and data type as `var`.
|
||||
|
@ -315,19 +314,19 @@ class Lamb(PrimitiveWithInfer):
|
|||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape,
|
||||
epsilon_shape, decay_shape, global_step_shape, gradient_shape, decay_flag_shape):
|
||||
epsilon_shape, decay_shape, global_step_shape, gradient_shape):
|
||||
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
|
||||
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
|
||||
validator.check("var_shape", var_shape, "gradient_shape", gradient_shape, Rel.EQ, self.name)
|
||||
return var_shape
|
||||
|
||||
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
|
||||
epsilon_dtype, decay_dtype, global_step_dtype, gradient_dtype, decay_flag_dtype):
|
||||
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "lr": lr_dtype, "grad": gradient_dtype,
|
||||
"decay": decay_dtype}
|
||||
epsilon_dtype, decay_dtype, global_step_dtype, gradient_dtype):
|
||||
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": gradient_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
|
||||
|
||||
args = {"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
|
||||
args = {"lr": lr_dtype, "decay": decay_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype,
|
||||
"epsilon": epsilon_dtype}
|
||||
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
|
||||
return var_dtype
|
||||
|
||||
|
|
|
@ -142,9 +142,8 @@ class MyLamb(nn.Cell):
|
|||
self.gradient = Parameter(gradient, name="grad")
|
||||
self.lamb = inner.Lamb()
|
||||
|
||||
def construct(self, beta1, beta2, eps, global_step, lr, weight_decay, decay_flag):
|
||||
return self.lamb(self.param, self.m, self.v, lr, beta1, beta2, eps, weight_decay, global_step, self.gradient,
|
||||
decay_flag)
|
||||
def construct(self, beta1, beta2, eps, global_step, lr, weight_decay):
|
||||
return self.lamb(self.param, self.m, self.v, lr, beta1, beta2, eps, weight_decay, global_step, self.gradient)
|
||||
|
||||
|
||||
def test_gpu_net():
|
||||
|
@ -154,7 +153,7 @@ def test_gpu_net():
|
|||
Expectation: get the same result when use new lamb kernel and old kernel
|
||||
"""
|
||||
my_lamb = MyLamb(param_val, m_val, v_val, grad_val)
|
||||
my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)
|
||||
my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val)
|
||||
|
||||
lamb_gpu_origin = LambGPUOrigin(param_val, m_val, v_val, grad_val)
|
||||
lamb_gpu_origin(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)
|
||||
|
@ -169,7 +168,7 @@ def test_ascend_net():
|
|||
Expectation: get the same result when use new lamb kernel and old kernel
|
||||
"""
|
||||
my_lamb = MyLamb(param_val, m_val, v_val, grad_val)
|
||||
my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)
|
||||
my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val)
|
||||
|
||||
lamb_ascend_origin = LambAscendOrigin(param_val, m_val, v_val, grad_val)
|
||||
lamb_ascend_origin(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)
|
||||
|
|
Loading…
Reference in New Issue