forked from mindspore-Ecosystem/mindspore
Fixing problem issues including class slice example cannot run, adding an example for class SigmoidCrossEntropyWithLogits etc.
This commit is contained in:
parent
8f6b941a97
commit
4ba6f7884d
|
@ -81,8 +81,22 @@ class Optimizer(Cell):
|
|||
else:
|
||||
raise TypeError("Learning rate should be float, Tensor or Iterable.")
|
||||
|
||||
if isinstance(weight_decay, int):
|
||||
weight_decay = float(weight_decay)
|
||||
|
||||
if not isinstance(weight_decay, float):
|
||||
raise TypeError("weight_decay should be a float number!")
|
||||
|
||||
if isinstance(loss_scale, int):
|
||||
loss_scale = float(loss_scale)
|
||||
|
||||
if not isinstance(loss_scale, float):
|
||||
raise TypeError("loss_scale should be a float number!")
|
||||
|
||||
if loss_scale <= 0.0:
|
||||
raise ValueError("Loss scale should be greater than 0, but got {}".format(loss_scale))
|
||||
self.loss_scale = loss_scale
|
||||
|
||||
if weight_decay < 0.0:
|
||||
raise ValueError("Weight decay should be equal or greater than 0, but got {}".format(weight_decay))
|
||||
|
||||
|
|
|
@ -61,7 +61,8 @@ class SGD(Optimizer):
|
|||
dampening (float): A floating point value of dampening for momentum. Default: 0.
|
||||
weight_decay (float): Weight decay (L2 penalty). Default: 0.
|
||||
nesterov (bool): Enables the Nesterov momentum. Default: False.
|
||||
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
|
||||
loss_scale (float): A floating point value for the loss scale, which should be larger
|
||||
than 0.0. Default: 1.0.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
@ -83,9 +84,18 @@ class SGD(Optimizer):
|
|||
|
||||
super(SGD, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
|
||||
if not isinstance(momentum, float):
|
||||
raise TypeError("momentum should be float number!")
|
||||
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
|
||||
if not isinstance(dampening, float):
|
||||
raise TypeError("dampening should be float number")
|
||||
|
||||
if isinstance(dampening, int):
|
||||
dampening = float(dampening)
|
||||
|
||||
if dampening < 0.0:
|
||||
raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening))
|
||||
self.dampening = dampening
|
||||
|
|
|
@ -1008,6 +1008,7 @@ class Argmax(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_subclass("input_x", x_dtype, mstype.tensor)
|
||||
validator.check_typename('input_x', x_dtype, [mstype.float32, mstype.float16])
|
||||
return mstype.tensor_type(self.output_type)
|
||||
|
||||
|
||||
|
@ -1500,7 +1501,9 @@ class Slice(PrimitiveWithInfer):
|
|||
Tensor.
|
||||
|
||||
Examples:
|
||||
>>> data = Tensor(np.array([3,2,3]).astype(np.int32))
|
||||
>>> data = Tensor(np.array([[[1, 1, 1], [2, 2, 2]],
|
||||
>>> [[3, 3, 3], [4, 4, 4]],
|
||||
>>> [[5, 5, 5], [6, 6, 6]]]).astype(np.int32))
|
||||
>>> type = P.Slice()(data, (1, 0, 0), (1, 1, 3))
|
||||
"""
|
||||
|
||||
|
|
|
@ -1436,9 +1436,9 @@ class SGD(PrimitiveWithInfer):
|
|||
nesterov (bool): Enable Nesterov momentum. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **parameters** (Tensor) - Parameters to be updated.
|
||||
- **parameters** (Tensor) - Parameters to be updated. Their data type can be list or tuple.
|
||||
- **gradient** (Tensor) - Gradients.
|
||||
- **learning_rate** (Tensor) - Learning rate. e.g. Tensor(0.1, mindspore.float32).
|
||||
- **learning_rate** (Tensor) - Learning rate. Must be float value. e.g. Tensor(0.1, mindspore.float32).
|
||||
- **accum** (Tensor) - Accum(velocity) to be updated.
|
||||
- **momentum** (Tensor) - Momentum. e.g. Tensor(0.1, mindspore.float32).
|
||||
- **stat** (Tensor) - States to be updated with the same shape as gradient.
|
||||
|
@ -1449,6 +1449,7 @@ class SGD(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, dampening=0.0, weight_decay=0.0, nesterov=False):
|
||||
validator.check_type("nesterov", nesterov, [bool])
|
||||
self.init_prim_io_names(inputs=['parameters', 'gradient', 'learning_rate', 'accum', 'momentum', 'stat'],
|
||||
outputs=['output'])
|
||||
|
||||
|
|
Loading…
Reference in New Issue