forked from mindspore-Ecosystem/mindspore
!14219 add parallel mode validation for shard strategy
From: @gong_zi_yan Reviewed-by: @stsuteng,@kisnwang Signed-off-by: @stsuteng
This commit is contained in:
commit
197db11fe4
|
@ -17,7 +17,7 @@
|
|||
import inspect
|
||||
import copy
|
||||
from mindspore.common.api import _wrap_func
|
||||
from mindspore import context
|
||||
from mindspore import context, log as logger
|
||||
from .._c_expression import Primitive_, real_run_op, prim_type
|
||||
from .._checkparam import Validator
|
||||
from . import signature as sig
|
||||
|
@ -142,6 +142,10 @@ class Primitive(Primitive_):
|
|||
Args:
|
||||
strategy (tuple): Strategy describes the distributed parallel mode of the current primitive.
|
||||
"""
|
||||
if context.get_auto_parallel_context("parallel_mode") not in [context.ParallelMode.AUTO_PARALLEL,
|
||||
context.ParallelMode.SEMI_AUTO_PARALLEL]:
|
||||
logger.warning("Shard strategy is not valid in ", context.get_auto_parallel_context("parallel_mode"),
|
||||
" mode. Please use semi auto or auto parallel mode.")
|
||||
self.add_prim_attr("strategy", strategy)
|
||||
return self
|
||||
|
||||
|
|
|
@ -809,12 +809,15 @@ class Model:
|
|||
Dict, Parameter layout dictionary used for load distributed checkpoint
|
||||
|
||||
Examples:
|
||||
>>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on
|
||||
>>> # mindspore.cn.
|
||||
>>> import numpy as np
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore import Model, context, Tensor
|
||||
>>> from mindspore.context import ParallelMode
|
||||
>>>
|
||||
>>> context.set_context(mode=context.GRAPH_MODE)
|
||||
>>> init()
|
||||
>>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), ms.float32)
|
||||
>>> model = Model(Net())
|
||||
|
|
Loading…
Reference in New Issue