optimize lamb

This commit is contained in:
wangchangheng 2022-04-19 16:48:56 +08:00
parent 8167c4dbe4
commit ab78aa86ee
9 changed files with 56 additions and 161 deletions

View File

@ -24,7 +24,7 @@
namespace mindspore {
namespace opt {
namespace {
// Lamb's inputs: param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, gradient, decay_flag
// Lamb's inputs: param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, gradient
constexpr size_t kParamIndex = 1;
constexpr size_t kMIndex = 2;
constexpr size_t kVIndex = 3;
@ -35,10 +35,9 @@ constexpr size_t kEpsilonIndex = 7;
constexpr size_t kWeightDecayIndex = 8;
constexpr size_t kGlobalStepIndex = 9;
constexpr size_t kGradientIndex = 10;
constexpr size_t kDecayFlagIndex = 11;
constexpr size_t kUMonadIndex = 12;
constexpr size_t kLambInputNum = 11;
constexpr size_t kLambInputNumWithUMonad = 12;
constexpr size_t kUMonadIndex = 11;
constexpr size_t kLambInputNum = 10;
constexpr size_t kLambInputNumWithUMonad = 11;
constexpr size_t kLambApplyOptimizerAssignOutputNum = 3;
constexpr size_t kLambApplyOptimizerAssignUpdateIndex = 0;
@ -246,7 +245,7 @@ const AnfNodePtr LambFission::Process(const FuncGraphPtr &graph, const AnfNodePt
auto new_global_step = CreateCastNode(graph, global_step_node, kNumberTypeFloat32);
// cast delay flag to float32
auto weight_decay_flag = CreateCastNode(graph, ori_inputs[kDecayFlagIndex], kNumberTypeFloat32);
auto weight_decay_flag = CreateValueNode(graph, 1.0);
auto num_one = CreateValueNode(graph, 1.0);
// create 1-beta1

View File

@ -21,9 +21,9 @@ const int32_t kSqareNum = 2;
template <typename T>
__global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
const float *epsilon, const T *decay, const int32_t *global_step,
const T *gradient, const bool *decay_flag, float *update, float *var_float,
float *grad_float, float *g_hat_var) {
const float *epsilon, const float *decay, const int32_t *global_step,
const T *gradient, float *update, float *var_float, float *grad_float,
float *g_hat_var) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
float next_m = (beta1[0] * m[i] + (1 - beta1[0]) * gradient[i]);
float next_v = (beta2[0] * v[i] + (1 - beta2[0]) * pow(gradient[i], kSqareNum));
@ -33,9 +33,7 @@ __global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v,
grad_float[i] = gradient[i];
g_hat_var[i] = (next_mm / sqrt(next_vv + epsilon[0]) + decay[0] * variable[i]);
update[i] = next_mm / (sqrt(next_vv) - epsilon[0]);
if (decay_flag[0]) {
update[i] += decay[0] * variable[i];
}
m[i] = next_m;
v[i] = next_v;
}
@ -43,13 +41,13 @@ __global__ void ApplyLambEralyKernel(const size_t size, T *variable, T *m, T *v,
template <>
__global__ void ApplyLambEralyKernel(const size_t size, half *variable, half *m, half *v, const float *beta1,
const float *beta2, const float *epsilon, const half *decay,
const int32_t *global_step, const half *gradient, const bool *decay_flag,
float *update, float *var_float, float *grad_float, float *g_hat_var) {
const float *beta2, const float *epsilon, const float *decay,
const int32_t *global_step, const half *gradient, float *update, float *var_float,
float *grad_float, float *g_hat_var) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
float float_gradient = __half2float(gradient[i]);
float float_var = __half2float(variable[i]);
float float_decay = __half2float(decay[0]);
float float_decay = decay[0];
float next_m = (beta1[0] * __half2float(m[i]) + (1 - beta1[0]) * float_gradient);
float next_v = (beta2[0] * __half2float(v[i]) + (1 - beta2[0]) * pow(float_gradient, kSqareNum));
@ -59,16 +57,14 @@ __global__ void ApplyLambEralyKernel(const size_t size, half *variable, half *m,
grad_float[i] = float_gradient;
g_hat_var[i] = next_mm / sqrt(next_vv + epsilon[0]) + float_decay * float_var;
update[i] = next_mm / (sqrt(next_vv) - epsilon[0]);
if (decay_flag[0]) {
update[i] += float_decay * float_var;
}
m[i] = __float2half(next_m);
v[i] = __float2half(next_v);
}
}
template <typename T>
__global__ void ApplyLambAfterNormKernel(const size_t size, T *variable, const T *lr, const float *update,
__global__ void ApplyLambAfterNormKernel(const size_t size, T *variable, const float *lr, const float *update,
const float *trust_ratio) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
variable[i] = variable[i] - trust_ratio[0] * lr[0] * update[i];
@ -76,25 +72,24 @@ __global__ void ApplyLambAfterNormKernel(const size_t size, T *variable, const T
}
template <>
__global__ void ApplyLambAfterNormKernel(const size_t size, half *variable, const half *lr, const float *update,
__global__ void ApplyLambAfterNormKernel(const size_t size, half *variable, const float *lr, const float *update,
const float *trust_ratio) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
variable[i] = __float2half(__half2float(variable[i]) - trust_ratio[0] * __half2float(lr[0]) * update[i]);
variable[i] = __float2half(__half2float(variable[i]) - trust_ratio[0] * lr[0] * update[i]);
}
}
template <typename T>
void ApplyLambEraly(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
const float *epsilon, const T *decay, const int32_t *global_step, const T *gradient,
const bool *decay_flag, float *update, float *var_float, float *grad_float, float *g_hat_var,
cudaStream_t cuda_stream) {
const float *epsilon, const float *decay, const int32_t *global_step, const T *gradient,
float *update, float *var_float, float *grad_float, float *g_hat_var, cudaStream_t cuda_stream) {
ApplyLambEralyKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, m, v, beta1, beta2, epsilon,
decay, global_step, gradient, decay_flag,
update, var_float, grad_float, g_hat_var);
decay, global_step, gradient, update,
var_float, grad_float, g_hat_var);
}
template <typename T>
CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr, const float *update,
CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const float *lr, const float *update,
const float *trust_ratio, cudaStream_t cuda_stream) {
ApplyLambAfterNormKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, lr, update, trust_ratio);
}
@ -102,21 +97,20 @@ CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr,
template CUDA_LIB_EXPORT void ApplyLambEraly<float>(const size_t size, float *variable, float *m, float *v,
const float *beta1, const float *beta2, const float *epsilon,
const float *decay, const int32_t *global_step,
const float *gradient, const bool *decay_flag, float *update,
float *w_square_ptr, float *g_square_ptr, float *g_hat_square_ptr,
const float *gradient, float *update, float *w_square_ptr,
float *g_square_ptr, float *g_hat_square_ptr,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyLambEraly<half>(const size_t size, half *variable, half *m, half *v,
const float *beta1, const float *beta2, const float *epsilon,
const half *decay, const int32_t *global_step, const half *gradient,
const bool *decay_flag, float *update, float *w_square_ptr,
float *g_square_ptr, float *g_hat_square_ptr,
cudaStream_t cuda_stream);
const float *decay, const int32_t *global_step, const half *gradient,
float *update, float *w_square_ptr, float *g_square_ptr,
float *g_hat_square_ptr, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyLambLater<float>(const size_t size, float *variable, const float *lr,
const float *update, const float *trust_ratio,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void ApplyLambLater<half>(const size_t size, half *variable, const half *lr,
template CUDA_LIB_EXPORT void ApplyLambLater<half>(const size_t size, half *variable, const float *lr,
const float *update, const float *trust_ratio,
cudaStream_t cuda_stream);

View File

@ -19,12 +19,12 @@
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
template <typename T>
CUDA_LIB_EXPORT void ApplyLambEraly(const size_t size, T *variable, T *m, T *v, const float *beta1, const float *beta2,
const float *epsilon, const T *decay, const int32_t *global_step, const T *gradient,
const bool *decay_flag, float *update, float *var_float, float *grad_float,
const float *epsilon, const float *decay, const int32_t *global_step,
const T *gradient, float *update, float *var_float, float *grad_float,
float *g_hat_var, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const T *lr, const float *update,
CUDA_LIB_EXPORT void ApplyLambLater(const size_t size, T *variable, const float *lr, const float *update,
const float *trust_ratio, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LAMB_IMPL_CUH_

View File

@ -30,7 +30,6 @@ MS_REG_GPU_KERNEL_ONE(Lamb,
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeFloat32),
LambGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(Lamb,
@ -45,7 +44,6 @@ MS_REG_GPU_KERNEL_ONE(Lamb,
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeFloat16),
LambGpuKernelMod, half)
} // namespace kernel

View File

@ -28,7 +28,7 @@
namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUM = 11;
constexpr size_t INPUT_NUM = 10;
constexpr size_t kArgMaxDim = 7;
constexpr float ten = 10;
@ -43,7 +43,6 @@ constexpr size_t kEpsilonIndex = 6;
constexpr size_t kWeightDecayIndex = 7;
constexpr size_t kGlobalStepIndex = 8;
constexpr size_t kGradIndex = 9;
constexpr size_t kDecayFlagIndex = 10;
// workspaces param index
constexpr size_t kUpdateIndex = 0;
@ -72,21 +71,20 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
T *variable = GetDeviceAddress<T>(inputs, kVarIndex);
T *m = GetDeviceAddress<T>(inputs, kMIndex);
T *v = GetDeviceAddress<T>(inputs, kVIndex);
T *learning_rate = GetDeviceAddress<T>(inputs, kLearningRateIndex);
float *learning_rate = GetDeviceAddress<float>(inputs, kLearningRateIndex);
float *beta1 = GetDeviceAddress<float>(inputs, kBeta1Index);
float *beta2 = GetDeviceAddress<float>(inputs, kBeta2Index);
float *epsilon = GetDeviceAddress<float>(inputs, kEpsilonIndex);
T *decay = GetDeviceAddress<T>(inputs, kWeightDecayIndex);
float *decay = GetDeviceAddress<float>(inputs, kWeightDecayIndex);
int32_t *global_step = GetDeviceAddress<int32_t>(inputs, kGlobalStepIndex);
T *gradient = GetDeviceAddress<T>(inputs, kGradIndex);
bool *decay_flag = GetDeviceAddress<bool>(inputs, kDecayFlagIndex);
float *update = GetDeviceAddress<float>(workspaces, kUpdateIndex);
float *var_float = GetDeviceAddress<float>(workspaces, kVarFloatIndex);
float *grad_float = GetDeviceAddress<float>(workspaces, kGradFloatIndex);
float *g_hat_var = GetDeviceAddress<float>(workspaces, kGHatValIndex);
ApplyLambEraly(inputs[0]->size / sizeof(T), variable, m, v, beta1, beta2, epsilon, decay, global_step, gradient,
decay_flag, update, var_float, grad_float, g_hat_var, reinterpret_cast<cudaStream_t>(stream_ptr));
update, var_float, grad_float, g_hat_var, reinterpret_cast<cudaStream_t>(stream_ptr));
float trust_ratio{0};
CalcTrustRatio(workspaces, var_float, grad_float, g_hat_var, stream_ptr, &trust_ratio);
@ -176,7 +174,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
input_size_list_.push_back(decay_size_);
input_size_list_.push_back(global_step_size_);
input_size_list_.push_back(gradient_size_);
input_size_list_.push_back(decay_flag_size_);
workspace_size_list_.push_back(update_size_);
workspace_size_list_.push_back(var_float_size_);
workspace_size_list_.push_back(grad_float_size_);
@ -253,7 +250,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
decay_size_ = sizeof(T);
global_step_size_ = sizeof(int32_t);
gradient_size_ = sizeof(T);
decay_flag_size_ = sizeof(bool);
update_size_ = sizeof(float);
var_float_size_ = sizeof(float);
grad_float_size_ = sizeof(float);
@ -386,7 +382,6 @@ class LambGpuKernelMod : public NativeGpuKernelMod {
size_t decay_size_{0};
size_t global_step_size_{0};
size_t gradient_size_{0};
size_t decay_flag_size_{0};
size_t update_size_{0};
size_t var_float_size_{0};
size_t grad_float_size_{0};

View File

@ -35,15 +35,12 @@ TypePtr LambInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
auto decay_type = input_args[kInputIndex7]->BuildType();
auto global_step_type = input_args[kInputIndex8]->BuildType();
auto grad_type = input_args[kInputIndex9]->BuildType();
auto decay_flag_type = input_args[kInputIndex10]->BuildType();
std::map<std::string, TypePtr> type_dict;
type_dict.emplace("var", var_type);
type_dict.emplace("m", m_type);
type_dict.emplace("v", v_type);
type_dict.emplace("grad", grad_type);
type_dict.emplace("lr", lr_type);
type_dict.emplace("decay", decay_type);
std::set<TypePtr> num_type = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
(void)CheckAndConvertUtils::CheckTensorTypeSame(type_dict, num_type, prim_name);
@ -51,12 +48,12 @@ TypePtr LambInferType(const PrimitivePtr &primitive, const std::vector<AbstractB
type_dict1.emplace("beta1", beta1_type);
type_dict1.emplace("beta2", beta2_type);
type_dict1.emplace("epsilon", epsilon_type);
std::set<TypePtr> float_set = {kFloat16, kFloat32};
type_dict1.emplace("lr", lr_type);
type_dict1.emplace("decay", decay_type);
std::set<TypePtr> float_set = {kFloat32};
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(type_dict1, float_set, prim_name, true);
std::set<TypePtr> bool_set = {kBool};
(void)CheckAndConvertUtils::CheckTypeValid("global_step", global_step_type, num_type, prim_name);
(void)CheckAndConvertUtils::CheckTypeValid("decay_flag", decay_flag_type, bool_set, prim_name);
return var_type;
}
@ -97,7 +94,7 @@ AbstractBasePtr LambInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
for (auto item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const int64_t kInputNum = 11;
const int64_t kInputNum = 10;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
auto infer_type = LambInferType(primitive, input_args);
auto infer_shape = LambInferShape(primitive, input_args);

View File

@ -15,8 +15,6 @@
"""lamb"""
import numpy as np
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops.operations import _inner_ops as inner
@ -26,11 +24,6 @@ from mindspore._checkparam import Rel
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
from .. import layer
num_one = Tensor(np.ones([1]), mstype.float32)
_lamb_opt = C.MultitypeFuncGraph("lamb_opt")
@ -40,89 +33,6 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v
"""
Update parameters.
Args:
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
lr (Tensor): Learning rate.
weight_decay (numbers.Number): Weight decay. Should be equal to or greater than 0.
global_step (Tensor): Global step.
param (Tensor): Parameters.
m (Tensor): m value of parameters.
v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters.
decay_flag (bool): Specifies whether param update with weight decay.
optim_filter(bool): Applies parameter update or not.
Returns:
Tensor, the new value of v after updating.
"""
if optim_filter:
op_mul = P.Mul()
op_sqrt = P.Sqrt()
op_rsqrt = P.Rsqrt()
op_square = P.Square()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
op_pow = P.Pow()
op_norm = layer.Norm()
op_select = P.Select()
op_greater = P.Greater()
op_fill = P.Fill()
op_dtype = P.DType()
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32)
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))
next_mm = next_m / (op_cast(num_one, mstype.float32)
- op_pow(beta1, op_cast(global_step, mstype.float32)))
next_vv = next_v / (op_cast(num_one, mstype.float32) -
op_pow(beta2, op_cast(global_step, mstype.float32)))
w_norm = op_norm(param_fp32)
g_norm = op_norm(gradient_fp32)
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay * param_fp32)
zeros = F.zeros_like(w_norm)
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
trust_ratio = op_select(
op_greater(w_norm, zeros),
op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
ones)
tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
update = next_mm / (op_sqrt(next_vv) + eps)
if decay_flag:
update = update + op_mul(weight_decay, param_fp32)
update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
return op_cast(next_param, F.dtype(param))
return gradient
_lamb_opt_ascend = C.MultitypeFuncGraph("lamb_opt_ascend")
@_lamb_opt_ascend.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag,
optim_filter):
"""
Update parameters function when device target is ascend.
Args:
beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
@ -142,9 +52,13 @@ def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, para
"""
if optim_filter:
op_lamb = inner.Lamb()
return op_lamb(param, m, v, lr, beta1, beta2, eps, weight_decay, global_step,
gradient, decay_flag)
return gradient
if decay_flag:
ret = op_lamb(param, m, v, lr, beta1, beta2, eps, weight_decay, global_step, gradient)
else:
ret = op_lamb(param, m, v, lr, beta1, beta2, eps, 0.0, global_step, gradient)
else:
ret = gradient
return ret
def _check_param_value(beta1, beta2, eps, prim_name):
@ -345,7 +259,7 @@ class Lamb(Optimizer):
lr = self.get_lr()
if not self.is_dynamic_lr_or_weight_decay():
self.assignadd(self.global_step, self.global_step_increase_tensor)
lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt
lamb_opt = _lamb_opt
gradients = self.flatten_gradients(gradients)
gradients = self.gradients_centralization(gradients)
if self.is_group:

View File

@ -300,7 +300,6 @@ class Lamb(PrimitiveWithInfer):
Default: 0.0.
- **global_step** (Tensor) - Tensor to record current global step.
- **gradient** (Tensor) - Gradient, has the same shape and data type as `var`.
- **decay_flag** (bool) - Specifies whether param update with weight decay.
Outputs:
Tensor, the updated parameters.
- **var** (Tensor) - The same shape and data type as `var`.
@ -315,19 +314,19 @@ class Lamb(PrimitiveWithInfer):
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape,
epsilon_shape, decay_shape, global_step_shape, gradient_shape, decay_flag_shape):
epsilon_shape, decay_shape, global_step_shape, gradient_shape):
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "gradient_shape", gradient_shape, Rel.EQ, self.name)
return var_shape
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
epsilon_dtype, decay_dtype, global_step_dtype, gradient_dtype, decay_flag_dtype):
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "lr": lr_dtype, "grad": gradient_dtype,
"decay": decay_dtype}
epsilon_dtype, decay_dtype, global_step_dtype, gradient_dtype):
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": gradient_dtype}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
args = {"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
args = {"lr": lr_dtype, "decay": decay_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype,
"epsilon": epsilon_dtype}
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
return var_dtype

View File

@ -142,9 +142,8 @@ class MyLamb(nn.Cell):
self.gradient = Parameter(gradient, name="grad")
self.lamb = inner.Lamb()
def construct(self, beta1, beta2, eps, global_step, lr, weight_decay, decay_flag):
return self.lamb(self.param, self.m, self.v, lr, beta1, beta2, eps, weight_decay, global_step, self.gradient,
decay_flag)
def construct(self, beta1, beta2, eps, global_step, lr, weight_decay):
return self.lamb(self.param, self.m, self.v, lr, beta1, beta2, eps, weight_decay, global_step, self.gradient)
def test_gpu_net():
@ -154,7 +153,7 @@ def test_gpu_net():
Expectation: get the same result when use new lamb kernel and old kernel
"""
my_lamb = MyLamb(param_val, m_val, v_val, grad_val)
my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)
my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val)
lamb_gpu_origin = LambGPUOrigin(param_val, m_val, v_val, grad_val)
lamb_gpu_origin(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)
@ -169,7 +168,7 @@ def test_ascend_net():
Expectation: get the same result when use new lamb kernel and old kernel
"""
my_lamb = MyLamb(param_val, m_val, v_val, grad_val)
my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)
my_lamb(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val)
lamb_ascend_origin = LambAscendOrigin(param_val, m_val, v_val, grad_val)
lamb_ascend_origin(beta1_val, beta2_val, eps_val, global_step_val, lr_val, weight_decay_val, True)