add validation

This commit is contained in:
Jiaqi 2020-09-14 10:18:31 +08:00
parent 2f14c40934
commit 8cc5767d0c
1 changed files with 9 additions and 5 deletions

View File

@ -179,7 +179,7 @@ class TrainOneStepWithLossScaleCell(Cell):
network (Cell): The training network. The network only supports single output.
optimizer (Cell): Optimizer for updating the weights.
scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value
is Tensor type, Tensor with shape :math:`()`. Default: None.
is Tensor type, Tensor with shape :math:`()`.
Inputs:
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
@ -189,6 +189,7 @@ class TrainOneStepWithLossScaleCell(Cell):
- **loss** (Tensor) - Tensor with shape :math:`()`.
- **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool.
- **loss scaling value** (Tensor) - Tensor with shape :math:`()`
Examples:
>>> net_with_loss = Net()
@ -203,7 +204,7 @@ class TrainOneStepWithLossScaleCell(Cell):
>>> output = train_network(inputs, label, scaling_sens)
"""
def __init__(self, network, optimizer, scale_sense=None):
def __init__(self, network, optimizer, scale_sense):
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
@ -236,14 +237,15 @@ class TrainOneStepWithLossScaleCell(Cell):
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
self.scale_sense = None
self.loss_scaling_manager = None
if isinstance(scale_sense, Cell):
self.loss_scaling_manager = scale_sense
self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
name="scale_sense")
if isinstance(scale_sense, Tensor):
elif isinstance(scale_sense, Tensor):
self.scale_sense = Parameter(scale_sense, name='scale_sense')
else:
raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense)))
@C.add_flags(has_effect=True)
def construct(self, *inputs):
@ -293,4 +295,6 @@ class TrainOneStepWithLossScaleCell(Cell):
"""If the user has set the sens in the training process and wants to reassign the value, he can call
this function again to make modification, and sens needs to be of type Tensor."""
if self.scale_sense and isinstance(sens, Tensor):
self.self.scale_sense.set_data(sens)
self.scale_sense.set_data(sens)
else:
raise TypeError("The input type must be Tensor,but got {}".format(type(sens)))