forked from mindspore-Ecosystem/mindspore
!6144 add validation
Merge pull request !6144 from lijiaqi/add_validation
This commit is contained in:
commit
2fb2228af3
|
@ -179,7 +179,7 @@ class TrainOneStepWithLossScaleCell(Cell):
|
||||||
network (Cell): The training network. The network only supports single output.
|
network (Cell): The training network. The network only supports single output.
|
||||||
optimizer (Cell): Optimizer for updating the weights.
|
optimizer (Cell): Optimizer for updating the weights.
|
||||||
scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value
|
scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value
|
||||||
is Tensor type, Tensor with shape :math:`()`. Default: None.
|
is Tensor type, Tensor with shape :math:`()`.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
||||||
|
@ -189,6 +189,7 @@ class TrainOneStepWithLossScaleCell(Cell):
|
||||||
|
|
||||||
- **loss** (Tensor) - Tensor with shape :math:`()`.
|
- **loss** (Tensor) - Tensor with shape :math:`()`.
|
||||||
- **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool.
|
- **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool.
|
||||||
|
- **loss scaling value** (Tensor) - Tensor with shape :math:`()`
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> net_with_loss = Net()
|
>>> net_with_loss = Net()
|
||||||
|
@ -203,7 +204,7 @@ class TrainOneStepWithLossScaleCell(Cell):
|
||||||
>>> output = train_network(inputs, label, scaling_sens)
|
>>> output = train_network(inputs, label, scaling_sens)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, network, optimizer, scale_sense=None):
|
def __init__(self, network, optimizer, scale_sense):
|
||||||
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||||
self.network = network
|
self.network = network
|
||||||
self.network.set_grad()
|
self.network.set_grad()
|
||||||
|
@ -236,14 +237,15 @@ class TrainOneStepWithLossScaleCell(Cell):
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||||
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
|
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
|
||||||
|
|
||||||
self.scale_sense = None
|
|
||||||
self.loss_scaling_manager = None
|
self.loss_scaling_manager = None
|
||||||
if isinstance(scale_sense, Cell):
|
if isinstance(scale_sense, Cell):
|
||||||
self.loss_scaling_manager = scale_sense
|
self.loss_scaling_manager = scale_sense
|
||||||
self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
|
self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
|
||||||
name="scale_sense")
|
name="scale_sense")
|
||||||
if isinstance(scale_sense, Tensor):
|
elif isinstance(scale_sense, Tensor):
|
||||||
self.scale_sense = Parameter(scale_sense, name='scale_sense')
|
self.scale_sense = Parameter(scale_sense, name='scale_sense')
|
||||||
|
else:
|
||||||
|
raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense)))
|
||||||
|
|
||||||
@C.add_flags(has_effect=True)
|
@C.add_flags(has_effect=True)
|
||||||
def construct(self, *inputs):
|
def construct(self, *inputs):
|
||||||
|
@ -293,4 +295,6 @@ class TrainOneStepWithLossScaleCell(Cell):
|
||||||
"""If the user has set the sens in the training process and wants to reassign the value, he can call
|
"""If the user has set the sens in the training process and wants to reassign the value, he can call
|
||||||
this function again to make modification, and sens needs to be of type Tensor."""
|
this function again to make modification, and sens needs to be of type Tensor."""
|
||||||
if self.scale_sense and isinstance(sens, Tensor):
|
if self.scale_sense and isinstance(sens, Tensor):
|
||||||
self.self.scale_sense.set_data(sens)
|
self.scale_sense.set_data(sens)
|
||||||
|
else:
|
||||||
|
raise TypeError("The input type must be Tensor,but got {}".format(type(sens)))
|
||||||
|
|
Loading…
Reference in New Issue