Merge pull request !22984 from zhaosida/zsd_pangu
This commit is contained in:
i-robot 2021-09-09 12:02:04 +00:00 committed by Gitee
commit 63e2098f83
7 changed files with 16 additions and 21 deletions

View File

@ -103,7 +103,7 @@ bool AdamWeightDecayCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
}
if (inputs[VAR]->size != inputs[M]->size || inputs[VAR]->size != inputs[V]->size ||
inputs[VAR]->size != inputs[GRAD]->size) {
MS_LOG(EXCEPTION) << "Error input data size!";
MS_LOG(EXCEPTION) << "Var, m, v, grad input data size must be same!";
}
if (inputs[LR]->size != kSizeFloat32 || inputs[BETA1]->size != kSizeFloat32 || inputs[BETA2]->size != kSizeFloat32 ||
inputs[EPSILON]->size != kSizeFloat32 || inputs[DECAY]->size != kSizeFloat32) {

View File

@ -75,7 +75,7 @@ MS_REG_CPU_KERNEL(FusedCastAdamWeightDecay,
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedCastAdamWeightDecayCPUKernel)

View File

@ -35,7 +35,7 @@ adam_weight_decay_op_info = CpuRegOp("AdamWeightDecay") \
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.get_op_info()

View File

@ -481,7 +481,7 @@ class FusedCastAdamWeightDecay(PrimitiveWithInfer):
:math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
`gradient`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
:math:`\lr` represents `learning_rate`, :math:`w` represents `var`, :math:`decay` represents `weight_decay`,
:math:`lr` represents `learning_rate`, :math:`w` represents `var`, :math:`decay` represents `weight_decay`,
:math:`\epsilon` represents `epsilon`.
Args:
@ -547,12 +547,13 @@ class FusedCastAdamWeightDecay(PrimitiveWithInfer):
return var_shape, m_shape, v_shape
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
epsilon_dtype, decay, grad_dtype):
epsilon_dtype, decay_dtype, grad_dtype):
args = {"m": m_dtype, "v": v_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same({"var": var_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_scalar_or_tensor_types_same({"grad": grad_dtype}, [mstype.float16], self.name)
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype, "decay": decay}
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype,
"decay": decay_dtype}
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
return var_dtype, m_dtype, v_dtype

View File

@ -4573,7 +4573,7 @@ class AdamWeightDecay(PrimitiveWithInfer):
:math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
`gradient`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
:math:`\lr` represents `learning_rate`, :math:`w` represents `var`, :math:`decay` represents `weight_decay`,
:math:`lr` represents `learning_rate`, :math:`w` represents `var`, :math:`decay` represents `weight_decay`,
:math:`\epsilon` represents `epsilon`.
Args:
@ -4643,11 +4643,12 @@ class AdamWeightDecay(PrimitiveWithInfer):
return var_shape, m_shape, v_shape
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
epsilon_dtype, decay, grad_dtype):
epsilon_dtype, decay_dtype, grad_dtype):
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype, "decay": decay}
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype,
"decay": decay_dtype}
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
return var_dtype, m_dtype, v_dtype

View File

@ -30,11 +30,6 @@ _adam_opt = C.MultitypeFuncGraph("adam_opt")
_scaler_one = Tensor(1, mstype.int32)
_scaler_ten = Tensor(10, mstype.float32)
op_assign = P.Assign()
op_assign.add_prim_attr("primitive_target", "CPU")
op_cast = P.Cast()
op_cast.add_prim_attr("primitive_target", "CPU")
@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")

View File

@ -33,7 +33,8 @@ class NetAdamWeightDecay(nn.Cell):
self.batch_size = 1
self.reshape = P.Reshape()
weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01)
self.fc1 = Dense(16, 10, weight_init=weight)
bias = Tensor(np.zeros(10).astype(np.float32))
self.fc1 = Dense(16, 10, weight_init=weight, bias_init=bias)
def construct(self, input_x):
output = self.reshape(input_x, (self.batch_size, -1))
@ -47,18 +48,15 @@ class NetAdamWeightDecay(nn.Cell):
def test_adam_weight_decay():
epoch = 3
net = NetAdamWeightDecay()
optimizer = AdamWeightDecayOp(filter(lambda x: x.requires_grad,
net.get_parameters()), learning_rate=0.01)
optimizer = AdamWeightDecayOp(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=0.01)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(
net_with_criterion, optimizer)
train_network = TrainOneStepCell(net_with_criterion, optimizer)
train_network.set_train()
losses1 = []
for _ in range(epoch):
data = Tensor(np.arange(0, 16).reshape(
1, 1, 4, 4).astype(np.float32) * 0.01)
data = Tensor(np.arange(0, 16).reshape(1, 1, 4, 4).astype(np.float32) * 0.01)
label = Tensor(np.array([0]).astype(np.int32))
loss = train_network(data, label)
losses1.append(loss.asnumpy())