From 9ef713119bbd3c99916a7977cb7d43561f6856d3 Mon Sep 17 00:00:00 2001 From: panfengfeng Date: Mon, 14 Jun 2021 03:54:20 +0800 Subject: [PATCH] fix nasnet built in function --- model_zoo/official/cv/nasnet/src/nasnet_a_mobile.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/model_zoo/official/cv/nasnet/src/nasnet_a_mobile.py b/model_zoo/official/cv/nasnet/src/nasnet_a_mobile.py index 83242fef1e6..22fbac224e3 100755 --- a/model_zoo/official/cv/nasnet/src/nasnet_a_mobile.py +++ b/model_zoo/official/cv/nasnet/src/nasnet_a_mobile.py @@ -15,6 +15,7 @@ """NASNet-A-Mobile model definition""" import numpy as np +from mindspore import context from mindspore import Tensor import mindspore.nn as nn from mindspore.nn.loss.loss import Loss @@ -24,7 +25,6 @@ import mindspore.ops.composite as C import mindspore.common.dtype as mstype from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.context import ParallelMode -from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean GRADIENT_CLIP_TYPE = 1 @@ -917,12 +917,12 @@ class NASNetAMobileTrainOneStepWithClipGradient(nn.Cell): self.sens = sens self.reducer_flag = False self.grad_reducer = None - parallel_mode = _get_parallel_mode() + parallel_mode = context.get_auto_parallel_context("parallel_mode") if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): self.reducer_flag = True if self.reducer_flag: - mean = _get_gradients_mean() - degree = _get_device_num() + mean = context.get_auto_parallel_context("gradients_mean") + degree = context.get_auto_parallel_context("device_num") self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) def construct(self, *inputs):