forked from mindspore-Ecosystem/mindspore
fix nasnet built in function
This commit is contained in:
parent
af16989394
commit
9ef713119b
|
@ -15,6 +15,7 @@
|
||||||
"""NASNet-A-Mobile model definition"""
|
"""NASNet-A-Mobile model definition"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import context
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.nn.loss.loss import Loss
|
from mindspore.nn.loss.loss import Loss
|
||||||
|
@ -24,7 +25,6 @@ import mindspore.ops.composite as C
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
|
||||||
|
|
||||||
|
|
||||||
GRADIENT_CLIP_TYPE = 1
|
GRADIENT_CLIP_TYPE = 1
|
||||||
|
@ -917,12 +917,12 @@ class NASNetAMobileTrainOneStepWithClipGradient(nn.Cell):
|
||||||
self.sens = sens
|
self.sens = sens
|
||||||
self.reducer_flag = False
|
self.reducer_flag = False
|
||||||
self.grad_reducer = None
|
self.grad_reducer = None
|
||||||
parallel_mode = _get_parallel_mode()
|
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||||
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||||
self.reducer_flag = True
|
self.reducer_flag = True
|
||||||
if self.reducer_flag:
|
if self.reducer_flag:
|
||||||
mean = _get_gradients_mean()
|
mean = context.get_auto_parallel_context("gradients_mean")
|
||||||
degree = _get_device_num()
|
degree = context.get_auto_parallel_context("device_num")
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||||
|
|
||||||
def construct(self, *inputs):
|
def construct(self, *inputs):
|
||||||
|
|
Loading…
Reference in New Issue