!18015 [clean code] nn-opt/wrap

Merge pull request !18015 from kingxian/master
This commit is contained in:
i-robot 2021-06-09 09:40:50 +08:00 committed by Gitee
commit 465d7f84a5
6 changed files with 21 additions and 14 deletions

View File

@ -365,15 +365,16 @@ class Adam(Optimizer):
@Optimizer.target.setter @Optimizer.target.setter
def target(self, value): def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused """
optimizer operation.""" If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation."""
if not isinstance(value, str): if not isinstance(value, str):
raise TypeError("The value must be str type, but got value type is {}".format(type(value))) raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
if value not in ('CPU', 'Ascend', 'GPU'): if value not in ('CPU', 'Ascend', 'GPU'):
raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value)) raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value))
if self._target == "CPU" and value in('Ascend', 'GPU'): if self._target == "CPU" and value in ('Ascend', 'GPU'):
raise ValueError("In the CPU environment, target cannot be set to 'GPU' and 'Ascend'.") raise ValueError("In the CPU environment, target cannot be set to 'GPU' and 'Ascend'.")
if self._target == "Ascend" and value == 'GPU': if self._target == "Ascend" and value == 'GPU':

View File

@ -234,8 +234,9 @@ class FTRL(Optimizer):
@Optimizer.target.setter @Optimizer.target.setter
def target(self, value): def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused """
optimizer operation.""" If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation."""
if not isinstance(value, str): if not isinstance(value, str):
raise TypeError("The value must be str type, but got value type is {}".format(type(value))) raise TypeError("The value must be str type, but got value type is {}".format(type(value)))

View File

@ -283,8 +283,9 @@ class LazyAdam(Optimizer):
@Optimizer.target.setter @Optimizer.target.setter
def target(self, value): def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused """
optimizer operation.""" If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation."""
if not isinstance(value, str): if not isinstance(value, str):
raise TypeError("The value must be str type, but got value type is {}".format(type(value))) raise TypeError("The value must be str type, but got value type is {}".format(type(value)))

View File

@ -236,14 +236,16 @@ class Optimizer(Cell):
@property @property
def target(self): def target(self):
"""The method is used to determine whether the parameter is updated on host or device. The input type is str """
and can only be 'CPU', 'Ascend' or 'GPU'.""" The method is used to determine whether the parameter is updated on host or device. The input type is str
and can only be 'CPU', 'Ascend' or 'GPU'."""
return self._target return self._target
@target.setter @target.setter
def target(self, value): def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused """
optimizer operation.""" If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation."""
raise NotImplementedError raise NotImplementedError
def decay_weight(self, gradients): def decay_weight(self, gradients):

View File

@ -183,8 +183,9 @@ class ProximalAdagrad(Optimizer):
@Optimizer.target.setter @Optimizer.target.setter
def target(self, value): def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused """
optimizer operation.""" If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation."""
if not isinstance(value, str): if not isinstance(value, str):
raise TypeError("The value must be str type, but got value type is {}".format(type(value))) raise TypeError("The value must be str type, but got value type is {}".format(type(value)))

View File

@ -326,7 +326,8 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
return loss, cond, scaling_sens return loss, cond, scaling_sens
def set_sense_scale(self, sens): def set_sense_scale(self, sens):
"""If the user has set the sens in the training process and wants to reassign the value, he can call """
If the user has set the sens in the training process and wants to reassign the value, he can call
this function again to make modification, and sens needs to be of type Tensor.""" this function again to make modification, and sens needs to be of type Tensor."""
if self.scale_sense and isinstance(sens, Tensor): if self.scale_sense and isinstance(sens, Tensor):
self.scale_sense.set_data(sens) self.scale_sense.set_data(sens)