forked from mindspore-Ecosystem/mindspore
!18015 [clean code] nn-opt/wrap
Merge pull request !18015 from kingxian/master
This commit is contained in:
commit
465d7f84a5
|
@ -365,15 +365,16 @@ class Adam(Optimizer):
|
|||
|
||||
@Optimizer.target.setter
|
||||
def target(self, value):
|
||||
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
||||
optimizer operation."""
|
||||
"""
|
||||
If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
||||
optimizer operation."""
|
||||
if not isinstance(value, str):
|
||||
raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
|
||||
|
||||
if value not in ('CPU', 'Ascend', 'GPU'):
|
||||
raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value))
|
||||
|
||||
if self._target == "CPU" and value in('Ascend', 'GPU'):
|
||||
if self._target == "CPU" and value in ('Ascend', 'GPU'):
|
||||
raise ValueError("In the CPU environment, target cannot be set to 'GPU' and 'Ascend'.")
|
||||
|
||||
if self._target == "Ascend" and value == 'GPU':
|
||||
|
|
|
@ -234,8 +234,9 @@ class FTRL(Optimizer):
|
|||
|
||||
@Optimizer.target.setter
|
||||
def target(self, value):
|
||||
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
||||
optimizer operation."""
|
||||
"""
|
||||
If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
||||
optimizer operation."""
|
||||
if not isinstance(value, str):
|
||||
raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
|
||||
|
||||
|
|
|
@ -283,8 +283,9 @@ class LazyAdam(Optimizer):
|
|||
|
||||
@Optimizer.target.setter
|
||||
def target(self, value):
|
||||
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
||||
optimizer operation."""
|
||||
"""
|
||||
If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
||||
optimizer operation."""
|
||||
if not isinstance(value, str):
|
||||
raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
|
||||
|
||||
|
|
|
@ -236,14 +236,16 @@ class Optimizer(Cell):
|
|||
|
||||
@property
|
||||
def target(self):
|
||||
"""The method is used to determine whether the parameter is updated on host or device. The input type is str
|
||||
and can only be 'CPU', 'Ascend' or 'GPU'."""
|
||||
"""
|
||||
The method is used to determine whether the parameter is updated on host or device. The input type is str
|
||||
and can only be 'CPU', 'Ascend' or 'GPU'."""
|
||||
return self._target
|
||||
|
||||
@target.setter
|
||||
def target(self, value):
|
||||
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
||||
optimizer operation."""
|
||||
"""
|
||||
If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
||||
optimizer operation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def decay_weight(self, gradients):
|
||||
|
|
|
@ -183,8 +183,9 @@ class ProximalAdagrad(Optimizer):
|
|||
|
||||
@Optimizer.target.setter
|
||||
def target(self, value):
|
||||
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
||||
optimizer operation."""
|
||||
"""
|
||||
If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
||||
optimizer operation."""
|
||||
if not isinstance(value, str):
|
||||
raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
|
||||
|
||||
|
|
|
@ -326,7 +326,8 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|||
return loss, cond, scaling_sens
|
||||
|
||||
def set_sense_scale(self, sens):
|
||||
"""If the user has set the sens in the training process and wants to reassign the value, he can call
|
||||
"""
|
||||
If the user has set the sens in the training process and wants to reassign the value, he can call
|
||||
this function again to make modification, and sens needs to be of type Tensor."""
|
||||
if self.scale_sense and isinstance(sens, Tensor):
|
||||
self.scale_sense.set_data(sens)
|
||||
|
|
Loading…
Reference in New Issue