forked from mindspore-Ecosystem/mindspore
!18015 [clean code] nn-opt/wrap
Merge pull request !18015 from kingxian/master
This commit is contained in:
commit
465d7f84a5
|
@ -365,15 +365,16 @@ class Adam(Optimizer):
|
||||||
|
|
||||||
@Optimizer.target.setter
|
@Optimizer.target.setter
|
||||||
def target(self, value):
|
def target(self, value):
|
||||||
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
"""
|
||||||
optimizer operation."""
|
If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
||||||
|
optimizer operation."""
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
|
raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
|
||||||
|
|
||||||
if value not in ('CPU', 'Ascend', 'GPU'):
|
if value not in ('CPU', 'Ascend', 'GPU'):
|
||||||
raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value))
|
raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value))
|
||||||
|
|
||||||
if self._target == "CPU" and value in('Ascend', 'GPU'):
|
if self._target == "CPU" and value in ('Ascend', 'GPU'):
|
||||||
raise ValueError("In the CPU environment, target cannot be set to 'GPU' and 'Ascend'.")
|
raise ValueError("In the CPU environment, target cannot be set to 'GPU' and 'Ascend'.")
|
||||||
|
|
||||||
if self._target == "Ascend" and value == 'GPU':
|
if self._target == "Ascend" and value == 'GPU':
|
||||||
|
|
|
@ -234,8 +234,9 @@ class FTRL(Optimizer):
|
||||||
|
|
||||||
@Optimizer.target.setter
|
@Optimizer.target.setter
|
||||||
def target(self, value):
|
def target(self, value):
|
||||||
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
"""
|
||||||
optimizer operation."""
|
If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
||||||
|
optimizer operation."""
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
|
raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
|
||||||
|
|
||||||
|
|
|
@ -283,8 +283,9 @@ class LazyAdam(Optimizer):
|
||||||
|
|
||||||
@Optimizer.target.setter
|
@Optimizer.target.setter
|
||||||
def target(self, value):
|
def target(self, value):
|
||||||
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
"""
|
||||||
optimizer operation."""
|
If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
||||||
|
optimizer operation."""
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
|
raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
|
||||||
|
|
||||||
|
|
|
@ -236,14 +236,16 @@ class Optimizer(Cell):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def target(self):
|
def target(self):
|
||||||
"""The method is used to determine whether the parameter is updated on host or device. The input type is str
|
"""
|
||||||
and can only be 'CPU', 'Ascend' or 'GPU'."""
|
The method is used to determine whether the parameter is updated on host or device. The input type is str
|
||||||
|
and can only be 'CPU', 'Ascend' or 'GPU'."""
|
||||||
return self._target
|
return self._target
|
||||||
|
|
||||||
@target.setter
|
@target.setter
|
||||||
def target(self, value):
|
def target(self, value):
|
||||||
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
"""
|
||||||
optimizer operation."""
|
If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
||||||
|
optimizer operation."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def decay_weight(self, gradients):
|
def decay_weight(self, gradients):
|
||||||
|
|
|
@ -183,8 +183,9 @@ class ProximalAdagrad(Optimizer):
|
||||||
|
|
||||||
@Optimizer.target.setter
|
@Optimizer.target.setter
|
||||||
def target(self, value):
|
def target(self, value):
|
||||||
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
"""
|
||||||
optimizer operation."""
|
If the input value is set to "CPU", the parameters will be updated on the host using the Fused
|
||||||
|
optimizer operation."""
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
|
raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
|
||||||
|
|
||||||
|
|
|
@ -326,7 +326,8 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
||||||
return loss, cond, scaling_sens
|
return loss, cond, scaling_sens
|
||||||
|
|
||||||
def set_sense_scale(self, sens):
|
def set_sense_scale(self, sens):
|
||||||
"""If the user has set the sens in the training process and wants to reassign the value, he can call
|
"""
|
||||||
|
If the user has set the sens in the training process and wants to reassign the value, he can call
|
||||||
this function again to make modification, and sens needs to be of type Tensor."""
|
this function again to make modification, and sens needs to be of type Tensor."""
|
||||||
if self.scale_sense and isinstance(sens, Tensor):
|
if self.scale_sense and isinstance(sens, Tensor):
|
||||||
self.scale_sense.set_data(sens)
|
self.scale_sense.set_data(sens)
|
||||||
|
|
Loading…
Reference in New Issue