fix thor optimizer interface
This commit is contained in:
parent
341851b931
commit
3349d4372b
|
@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
|
|||
from mindspore._checkparam import check_bool
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.parallel._utils import _get_device_num, _get_mirror_mean
|
||||
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
|
||||
from src.grad_reducer_thor import DistributedGradReducerThor
|
||||
|
||||
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")
|
||||
|
@ -85,7 +85,7 @@ class THOR_GPU(Optimizer):
|
|||
self.assign = P.Assign()
|
||||
self.mul = P.Mul()
|
||||
|
||||
mean = _get_mirror_mean()
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
|
||||
parameter_length = len(self.feature_map)
|
||||
|
@ -193,7 +193,7 @@ class THOR(Optimizer):
|
|||
1.0 / 196, 1.0 / 196, 1.0 / 196,
|
||||
1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49,
|
||||
1.0]
|
||||
mean = _get_mirror_mean()
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
parameter_length = len(self.feature_map)
|
||||
self.grad_reducer_Amax = DistributedGradReducerThor(parameter_length, ((27,), 2), mean, degree)
|
||||
|
|
Loading…
Reference in New Issue