!14219 add parallel mode validation for shard strategy

From: @gong_zi_yan
Reviewed-by: @stsuteng,@kisnwang
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2021-03-29 11:20:00 +08:00 committed by Gitee
commit 197db11fe4
2 changed files with 8 additions and 1 deletions

View File

@ -17,7 +17,7 @@
import inspect
import copy
from mindspore.common.api import _wrap_func
from mindspore import context
from mindspore import context, log as logger
from .._c_expression import Primitive_, real_run_op, prim_type
from .._checkparam import Validator
from . import signature as sig
@ -142,6 +142,10 @@ class Primitive(Primitive_):
Args:
strategy (tuple): Strategy describes the distributed parallel mode of the current primitive.
"""
if context.get_auto_parallel_context("parallel_mode") not in [context.ParallelMode.AUTO_PARALLEL,
context.ParallelMode.SEMI_AUTO_PARALLEL]:
logger.warning("Shard strategy is not valid in ", context.get_auto_parallel_context("parallel_mode"),
" mode. Please use semi auto or auto parallel mode.")
self.add_prim_attr("strategy", strategy)
return self

View File

@ -809,12 +809,15 @@ class Model:
Dict, Parameter layout dictionary used for load distributed checkpoint
Examples:
>>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on
>>> # mindspore.cn.
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Model, context, Tensor
>>> from mindspore.context import ParallelMode
>>>
>>> context.set_context(mode=context.GRAPH_MODE)
>>> init()
>>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), ms.float32)
>>> model = Model(Net())