fix nasnet built in function

This commit is contained in:
panfengfeng 2021-06-14 03:54:20 +08:00
parent af16989394
commit 9ef713119b
1 changed files with 4 additions and 4 deletions

View File

@ -15,6 +15,7 @@
"""NASNet-A-Mobile model definition"""
import numpy as np
from mindspore import context
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.nn.loss.loss import Loss
@ -24,7 +25,6 @@ import mindspore.ops.composite as C
import mindspore.common.dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
GRADIENT_CLIP_TYPE = 1
@ -917,12 +917,12 @@ class NASNetAMobileTrainOneStepWithClipGradient(nn.Cell):
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
parallel_mode = _get_parallel_mode()
parallel_mode = context.get_auto_parallel_context("parallel_mode")
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
mean = context.get_auto_parallel_context("gradients_mean")
degree = context.get_auto_parallel_context("device_num")
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, *inputs):