From 8cc5767d0c4bf6a665e02951ea6d7452afcaeea9 Mon Sep 17 00:00:00 2001 From: Jiaqi Date: Mon, 14 Sep 2020 10:18:31 +0800 Subject: [PATCH] add validation --- mindspore/nn/wrap/loss_scale.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index 1dfc91743c8..dce621a7650 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -179,7 +179,7 @@ class TrainOneStepWithLossScaleCell(Cell): network (Cell): The training network. The network only supports single output. optimizer (Cell): Optimizer for updating the weights. scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value - is Tensor type, Tensor with shape :math:`()`. Default: None. + is Tensor type, Tensor with shape :math:`()`. Inputs: - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. @@ -189,6 +189,7 @@ class TrainOneStepWithLossScaleCell(Cell): - **loss** (Tensor) - Tensor with shape :math:`()`. - **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool. + - **loss scaling value** (Tensor) - Tensor with shape :math:`()` Examples: >>> net_with_loss = Net() @@ -203,7 +204,7 @@ class TrainOneStepWithLossScaleCell(Cell): >>> output = train_network(inputs, label, scaling_sens) """ - def __init__(self, network, optimizer, scale_sense=None): + def __init__(self, network, optimizer, scale_sense): super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) self.network = network self.network.set_grad() @@ -236,14 +237,15 @@ class TrainOneStepWithLossScaleCell(Cell): self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE - self.scale_sense = None self.loss_scaling_manager = None if isinstance(scale_sense, Cell): self.loss_scaling_manager = scale_sense self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), name="scale_sense") - if isinstance(scale_sense, Tensor): + elif isinstance(scale_sense, Tensor): self.scale_sense = Parameter(scale_sense, name='scale_sense') + else: + raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense))) @C.add_flags(has_effect=True) def construct(self, *inputs): @@ -293,4 +295,6 @@ class TrainOneStepWithLossScaleCell(Cell): """If the user has set the sens in the training process and wants to reassign the value, he can call this function again to make modification, and sens needs to be of type Tensor.""" if self.scale_sense and isinstance(sens, Tensor): - self.self.scale_sense.set_data(sens) + self.scale_sense.set_data(sens) + else: + raise TypeError("The input type must be Tensor,but got {}".format(type(sens)))