forked from mindspore-Ecosystem/mindspore
add validation
This commit is contained in:
parent
2f14c40934
commit
8cc5767d0c
|
@ -179,7 +179,7 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
network (Cell): The training network. The network only supports single output.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value
|
||||
is Tensor type, Tensor with shape :math:`()`. Default: None.
|
||||
is Tensor type, Tensor with shape :math:`()`.
|
||||
|
||||
Inputs:
|
||||
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
||||
|
@ -189,6 +189,7 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
|
||||
- **loss** (Tensor) - Tensor with shape :math:`()`.
|
||||
- **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool.
|
||||
- **loss scaling value** (Tensor) - Tensor with shape :math:`()`
|
||||
|
||||
Examples:
|
||||
>>> net_with_loss = Net()
|
||||
|
@ -203,7 +204,7 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
>>> output = train_network(inputs, label, scaling_sens)
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_sense=None):
|
||||
def __init__(self, network, optimizer, scale_sense):
|
||||
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
|
@ -236,14 +237,15 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
|
||||
|
||||
self.scale_sense = None
|
||||
self.loss_scaling_manager = None
|
||||
if isinstance(scale_sense, Cell):
|
||||
self.loss_scaling_manager = scale_sense
|
||||
self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
|
||||
name="scale_sense")
|
||||
if isinstance(scale_sense, Tensor):
|
||||
elif isinstance(scale_sense, Tensor):
|
||||
self.scale_sense = Parameter(scale_sense, name='scale_sense')
|
||||
else:
|
||||
raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense)))
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self, *inputs):
|
||||
|
@ -293,4 +295,6 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
"""If the user has set the sens in the training process and wants to reassign the value, he can call
|
||||
this function again to make modification, and sens needs to be of type Tensor."""
|
||||
if self.scale_sense and isinstance(sens, Tensor):
|
||||
self.self.scale_sense.set_data(sens)
|
||||
self.scale_sense.set_data(sens)
|
||||
else:
|
||||
raise TypeError("The input type must be Tensor,but got {}".format(type(sens)))
|
||||
|
|
Loading…
Reference in New Issue