forked from mindspore-Ecosystem/mindspore
fixed LeakyReLU, Optimizer
This commit is contained in:
parent
aef9c4d838
commit
eb4571a67f
|
@ -250,7 +250,7 @@ class LeakyReLU(Cell):
|
|||
|
||||
def construct(self, x):
|
||||
alpha = P.Cast()(F.scalar_to_array(self.alpha), P.DType()(x))
|
||||
if self.alpha <= 1:
|
||||
if alpha <= 1:
|
||||
out = P.Maximum()(alpha * x, x)
|
||||
else:
|
||||
out = P.Minimum()(alpha * x, x)
|
||||
|
|
|
@ -45,8 +45,8 @@ class Embedding(Cell):
|
|||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The element of
|
||||
the Tensor should be integer and not larger than vocab_size. else the corresponding embedding vector is zero
|
||||
if larger than vocab_size.
|
||||
the Tensor should be integer and not larger than vocab_size. else the corresponding embedding vector is zero
|
||||
if larger than vocab_size.
|
||||
Outputs:
|
||||
Tensor of shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`.
|
||||
|
||||
|
|
|
@ -93,13 +93,13 @@ class Optimizer(Cell):
|
|||
|
||||
if isinstance(loss_scale, int):
|
||||
loss_scale = float(loss_scale)
|
||||
validator.check_value_type("loss_scale", loss_scale, [float], None)
|
||||
validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, None)
|
||||
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
|
||||
validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name)
|
||||
|
||||
if isinstance(weight_decay, int):
|
||||
weight_decay = float(weight_decay)
|
||||
validator.check_value_type("weight_decay", weight_decay, [float], None)
|
||||
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None)
|
||||
validator.check_value_type("weight_decay", weight_decay, [float], self.cls_name)
|
||||
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
|
||||
|
||||
self.is_group = False
|
||||
self.is_group_lr = False
|
||||
|
|
Loading…
Reference in New Issue