!17491 replace internal interface form model ncf, centerface and mass

From: @zhouneng2
Reviewed-by: @oacjiewen,@c_34
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-06-02 16:36:45 +08:00 committed by Gitee
commit 9cb53c0a5e
3 changed files with 19 additions and 11 deletions

View File

@ -30,7 +30,7 @@ from mindspore.nn.optim.sgd import SGD
from mindspore import Tensor
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint, RunContext
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
from mindspore.train.callback import CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.profiler.profiling import Profiler
from mindspore.common import set_seed
@ -125,6 +125,16 @@ def convert_training_shape(args_):
return training_shape
class InternalCallbackParam(dict):
"""Internal callback object's parameters."""
def __getattr__(self, para_name):
return self[para_name]
def __setattr__(self, para_name, para_value):
self[para_name] = para_value
if __name__ == "__main__":
# init distributed
if args.is_distributed:
@ -285,7 +295,7 @@ if __name__ == "__main__":
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=args.outputs_dir,
prefix='{}'.format(args.rank))
cb_params = _InternalCallbackParam()
cb_params = InternalCallbackParam()
cb_params.train_network = network
cb_params.epoch_num = ckpt_max_num
cb_params.cur_epoch_num = 1

View File

@ -23,7 +23,6 @@ from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from .transformer import Transformer
from .grad_clip import GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE, ClipGradients
@ -245,15 +244,15 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
self.reducer_flag = False
self.all_reduce = P.AllReduce()
self.parallel_mode = _get_parallel_mode()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode not in ParallelMode.MODE_LIST:
raise ValueError("Parallel mode does not support: ", self.parallel_mode)
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
mean = context.get_auto_parallel_context("gradients_mean")
degree = context.get_auto_parallel_context("device_num")
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.clip_gradients = ClipGradients()

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""Neural Collaborative Filtering Model"""
from mindspore import nn
from mindspore import nn, context
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore._checkparam import Validator as validator
from mindspore.nn.layer.activation import get_activation
@ -22,7 +22,6 @@ from mindspore.ops import operations as P
from mindspore.common.initializer import initializer
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
@ -245,12 +244,12 @@ class TrainStepWrap(nn.Cell):
self.reducer_flag = False
self.grad_reducer = None
parallel_mode = _get_parallel_mode()
parallel_mode = context.get_auto_parallel_context("parallel_mode")
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
mean = context.get_auto_parallel_context("gradients_mean")
degree = context.get_auto_parallel_context("device_num")
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)