fix nasnet built in function

This commit is contained in:
panfengfeng 2021-06-14 03:54:20 +08:00
parent af16989394
commit 9ef713119b
1 changed files with 4 additions and 4 deletions

View File

@ -15,6 +15,7 @@
"""NASNet-A-Mobile model definition""" """NASNet-A-Mobile model definition"""
import numpy as np import numpy as np
from mindspore import context
from mindspore import Tensor from mindspore import Tensor
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import Loss
@ -24,7 +25,6 @@ import mindspore.ops.composite as C
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_TYPE = 1
@ -917,12 +917,12 @@ class NASNetAMobileTrainOneStepWithClipGradient(nn.Cell):
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
parallel_mode = _get_parallel_mode() parallel_mode = context.get_auto_parallel_context("parallel_mode")
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = _get_gradients_mean() mean = context.get_auto_parallel_context("gradients_mean")
degree = _get_device_num() degree = context.get_auto_parallel_context("device_num")
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, *inputs): def construct(self, *inputs):