forked from mindspore-Ecosystem/mindspore
commit
63e2098f83
|
@ -103,7 +103,7 @@ bool AdamWeightDecayCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|||
}
|
||||
if (inputs[VAR]->size != inputs[M]->size || inputs[VAR]->size != inputs[V]->size ||
|
||||
inputs[VAR]->size != inputs[GRAD]->size) {
|
||||
MS_LOG(EXCEPTION) << "Error input data size!";
|
||||
MS_LOG(EXCEPTION) << "Var, m, v, grad input data size must be same!";
|
||||
}
|
||||
if (inputs[LR]->size != kSizeFloat32 || inputs[BETA1]->size != kSizeFloat32 || inputs[BETA2]->size != kSizeFloat32 ||
|
||||
inputs[EPSILON]->size != kSizeFloat32 || inputs[DECAY]->size != kSizeFloat32) {
|
||||
|
|
|
@ -75,7 +75,7 @@ MS_REG_CPU_KERNEL(FusedCastAdamWeightDecay,
|
|||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedCastAdamWeightDecayCPUKernel)
|
||||
|
|
|
@ -35,7 +35,7 @@ adam_weight_decay_op_info = CpuRegOp("AdamWeightDecay") \
|
|||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -481,7 +481,7 @@ class FusedCastAdamWeightDecay(PrimitiveWithInfer):
|
|||
|
||||
:math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
|
||||
`gradient`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
|
||||
:math:`\lr` represents `learning_rate`, :math:`w` represents `var`, :math:`decay` represents `weight_decay`,
|
||||
:math:`lr` represents `learning_rate`, :math:`w` represents `var`, :math:`decay` represents `weight_decay`,
|
||||
:math:`\epsilon` represents `epsilon`.
|
||||
|
||||
Args:
|
||||
|
@ -547,12 +547,13 @@ class FusedCastAdamWeightDecay(PrimitiveWithInfer):
|
|||
return var_shape, m_shape, v_shape
|
||||
|
||||
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
|
||||
epsilon_dtype, decay, grad_dtype):
|
||||
epsilon_dtype, decay_dtype, grad_dtype):
|
||||
args = {"m": m_dtype, "v": v_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"var": var_dtype}, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"grad": grad_dtype}, [mstype.float16], self.name)
|
||||
|
||||
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype, "decay": decay}
|
||||
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype,
|
||||
"decay": decay_dtype}
|
||||
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
|
||||
return var_dtype, m_dtype, v_dtype
|
||||
|
|
|
@ -4573,7 +4573,7 @@ class AdamWeightDecay(PrimitiveWithInfer):
|
|||
|
||||
:math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
|
||||
`gradient`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
|
||||
:math:`\lr` represents `learning_rate`, :math:`w` represents `var`, :math:`decay` represents `weight_decay`,
|
||||
:math:`lr` represents `learning_rate`, :math:`w` represents `var`, :math:`decay` represents `weight_decay`,
|
||||
:math:`\epsilon` represents `epsilon`.
|
||||
|
||||
Args:
|
||||
|
@ -4643,11 +4643,12 @@ class AdamWeightDecay(PrimitiveWithInfer):
|
|||
return var_shape, m_shape, v_shape
|
||||
|
||||
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
|
||||
epsilon_dtype, decay, grad_dtype):
|
||||
epsilon_dtype, decay_dtype, grad_dtype):
|
||||
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
|
||||
|
||||
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype, "decay": decay}
|
||||
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype,
|
||||
"decay": decay_dtype}
|
||||
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
|
||||
return var_dtype, m_dtype, v_dtype
|
||||
|
||||
|
|
|
@ -30,11 +30,6 @@ _adam_opt = C.MultitypeFuncGraph("adam_opt")
|
|||
_scaler_one = Tensor(1, mstype.int32)
|
||||
_scaler_ten = Tensor(10, mstype.float32)
|
||||
|
||||
op_assign = P.Assign()
|
||||
op_assign.add_prim_attr("primitive_target", "CPU")
|
||||
op_cast = P.Cast()
|
||||
op_cast.add_prim_attr("primitive_target", "CPU")
|
||||
|
||||
|
||||
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
|
||||
"Tensor", "Bool", "Bool")
|
||||
|
|
|
@ -33,7 +33,8 @@ class NetAdamWeightDecay(nn.Cell):
|
|||
self.batch_size = 1
|
||||
self.reshape = P.Reshape()
|
||||
weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01)
|
||||
self.fc1 = Dense(16, 10, weight_init=weight)
|
||||
bias = Tensor(np.zeros(10).astype(np.float32))
|
||||
self.fc1 = Dense(16, 10, weight_init=weight, bias_init=bias)
|
||||
|
||||
def construct(self, input_x):
|
||||
output = self.reshape(input_x, (self.batch_size, -1))
|
||||
|
@ -47,18 +48,15 @@ class NetAdamWeightDecay(nn.Cell):
|
|||
def test_adam_weight_decay():
|
||||
epoch = 3
|
||||
net = NetAdamWeightDecay()
|
||||
optimizer = AdamWeightDecayOp(filter(lambda x: x.requires_grad,
|
||||
net.get_parameters()), learning_rate=0.01)
|
||||
optimizer = AdamWeightDecayOp(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=0.01)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = TrainOneStepCell(
|
||||
net_with_criterion, optimizer)
|
||||
train_network = TrainOneStepCell(net_with_criterion, optimizer)
|
||||
train_network.set_train()
|
||||
|
||||
losses1 = []
|
||||
for _ in range(epoch):
|
||||
data = Tensor(np.arange(0, 16).reshape(
|
||||
1, 1, 4, 4).astype(np.float32) * 0.01)
|
||||
data = Tensor(np.arange(0, 16).reshape(1, 1, 4, 4).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.array([0]).astype(np.int32))
|
||||
loss = train_network(data, label)
|
||||
losses1.append(loss.asnumpy())
|
||||
|
|
Loading…
Reference in New Issue