!6144 add validation

Merge pull request !6144 from lijiaqi/add_validation
This commit is contained in:
mindspore-ci-bot 2020-09-14 18:49:29 +08:00 committed by Gitee
commit 2fb2228af3
1 changed files with 9 additions and 5 deletions

View File

@ -179,7 +179,7 @@ class TrainOneStepWithLossScaleCell(Cell):
network (Cell): The training network. The network only supports single output.
optimizer (Cell): Optimizer for updating the weights.
scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value
is Tensor type, Tensor with shape :math:`()`. Default: None.
is Tensor type, Tensor with shape :math:`()`.
Inputs:
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
@ -189,6 +189,7 @@ class TrainOneStepWithLossScaleCell(Cell):
- **loss** (Tensor) - Tensor with shape :math:`()`.
- **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool.
- **loss scaling value** (Tensor) - Tensor with shape :math:`()`
Examples:
>>> net_with_loss = Net()
@ -203,7 +204,7 @@ class TrainOneStepWithLossScaleCell(Cell):
>>> output = train_network(inputs, label, scaling_sens)
"""
def __init__(self, network, optimizer, scale_sense=None):
def __init__(self, network, optimizer, scale_sense):
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
@ -236,14 +237,15 @@ class TrainOneStepWithLossScaleCell(Cell):
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
self.scale_sense = None
self.loss_scaling_manager = None
if isinstance(scale_sense, Cell):
self.loss_scaling_manager = scale_sense
self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
name="scale_sense")
if isinstance(scale_sense, Tensor):
elif isinstance(scale_sense, Tensor):
self.scale_sense = Parameter(scale_sense, name='scale_sense')
else:
raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense)))
@C.add_flags(has_effect=True)
def construct(self, *inputs):
@ -293,4 +295,6 @@ class TrainOneStepWithLossScaleCell(Cell):
"""If the user has set the sens in the training process and wants to reassign the value, he can call
this function again to make modification, and sens needs to be of type Tensor."""
if self.scale_sense and isinstance(sens, Tensor):
self.self.scale_sense.set_data(sens)
self.scale_sense.set_data(sens)
else:
raise TypeError("The input type must be Tensor,but got {}".format(type(sens)))