forked from mindspore-Ecosystem/mindspore
!17491 replace internal interface form model ncf, centerface and mass
From: @zhouneng2 Reviewed-by: @oacjiewen,@c_34 Signed-off-by: @c_34
This commit is contained in:
commit
9cb53c0a5e
|
@ -30,7 +30,7 @@ from mindspore.nn.optim.sgd import SGD
|
|||
from mindspore import Tensor
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.callback import ModelCheckpoint, RunContext
|
||||
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
|
||||
from mindspore.train.callback import CheckpointConfig
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.profiler.profiling import Profiler
|
||||
from mindspore.common import set_seed
|
||||
|
@ -125,6 +125,16 @@ def convert_training_shape(args_):
|
|||
return training_shape
|
||||
|
||||
|
||||
class InternalCallbackParam(dict):
|
||||
"""Internal callback object's parameters."""
|
||||
|
||||
def __getattr__(self, para_name):
|
||||
return self[para_name]
|
||||
|
||||
def __setattr__(self, para_name, para_value):
|
||||
self[para_name] = para_value
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
|
@ -285,7 +295,7 @@ if __name__ == "__main__":
|
|||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=args.outputs_dir,
|
||||
prefix='{}'.format(args.rank))
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params = InternalCallbackParam()
|
||||
cb_params.train_network = network
|
||||
cb_params.epoch_num = ckpt_max_num
|
||||
cb_params.cur_epoch_num = 1
|
||||
|
|
|
@ -23,7 +23,6 @@ from mindspore.common.parameter import Parameter
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
||||
|
||||
from .transformer import Transformer
|
||||
from .grad_clip import GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE, ClipGradients
|
||||
|
@ -245,15 +244,15 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
self.reducer_flag = False
|
||||
self.all_reduce = P.AllReduce()
|
||||
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode not in ParallelMode.MODE_LIST:
|
||||
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
degree = context.get_auto_parallel_context("device_num")
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.clip_gradients = ClipGradients()
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Neural Collaborative Filtering Model"""
|
||||
from mindspore import nn
|
||||
from mindspore import nn, context
|
||||
from mindspore import Tensor, Parameter, ParameterTuple
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
|
@ -22,7 +22,6 @@ from mindspore.ops import operations as P
|
|||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
||||
|
@ -245,12 +244,12 @@ class TrainStepWrap(nn.Cell):
|
|||
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
parallel_mode = _get_parallel_mode()
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
degree = context.get_auto_parallel_context("device_num")
|
||||
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue