diff --git a/mindspore/parallel/__init__.py b/mindspore/parallel/__init__.py index 170418fc929..79d8e67a8dc 100644 --- a/mindspore/parallel/__init__.py +++ b/mindspore/parallel/__init__.py @@ -17,7 +17,5 @@ This interface is ONLY used in Auto-parallel procedure. """ from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \ set_algo_parameters -from ._cost_model_context import set_multi_subgraphs, get_multi_subgraphs -__all__ = ["set_multi_subgraphs", "get_multi_subgraphs", - "get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"] +__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"] diff --git a/mindspore/parallel/_cost_model_context.py b/mindspore/parallel/_cost_model_context.py index 3b278caaffb..549efb20054 100644 --- a/mindspore/parallel/_cost_model_context.py +++ b/mindspore/parallel/_cost_model_context.py @@ -589,7 +589,7 @@ def reset_cost_model_context(): """Reset cost model context attributes.""" cost_model_context().reset_cost_model() -def set_multi_subgraphs(multi_subgraph=True): +def _set_multi_subgraphs(multi_subgraph=True): """ Set the flag of ANF graph containing multiple subgraphs. @@ -598,7 +598,7 @@ def set_multi_subgraphs(multi_subgraph=True): """ cost_model_context().set_multi_subgraphs(multi_subgraph) -def get_multi_subgraphs(): +def _get_multi_subgraphs(): """ Get the flag of ANF graph containing multiple subgraphs. """ diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 545b0330dcf..0a73f5b30f8 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -32,6 +32,7 @@ from .. import nn from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from ..context import ParallelMode from ..parallel._utils import _need_to_full, _to_full_tensor +from ..parallel._cost_model_context import _set_multi_subgraphs from ..common import dtype as mstype from .dataset_helper import DatasetHelper from . import amp @@ -166,6 +167,9 @@ class Model: if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): network.set_auto_parallel() + if self._optimizer is None: + # In this case, multiple optimizer(s) is supposed to be included in 'self._network' + _set_multi_subgraphs() return network def _build_eval_network(self, metrics, eval_network, eval_indexes): @@ -190,6 +194,9 @@ class Model: if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): if self._optimizer: self._eval_network = _VirtualDatasetCell(self._eval_network) + if self._optimizer is None: + # In this case, multiple optimizer(s) is supposed to be included in 'self._network' + _set_multi_subgraphs() self._eval_network.set_auto_parallel() def _build_predict_network(self): @@ -197,6 +204,7 @@ class Model: self._predict_network = self._network if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): self._predict_network = _VirtualDatasetCell(self._network) + _set_multi_subgraphs() self._predict_network.set_auto_parallel() def _clear_metrics(self): diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py index 2661043d5ce..f20015c8079 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py @@ -22,7 +22,6 @@ from mindspore import Model, context from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.context import ParallelMode from mindspore.communication.management import get_rank, get_group_size, init -from mindspore.parallel import set_multi_subgraphs from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel @@ -145,7 +144,6 @@ if __name__ == "__main__": device_target=wide_deep_config.device_target, save_graphs=True) context.set_context(variable_memory_max_size="24GB") context.set_context(enable_sparse=True) - set_multi_subgraphs() init() if wide_deep_config.host_device_mix == 1: context.set_auto_parallel_context( diff --git a/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/train_and_test_multinpu_ci.py b/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/train_and_test_multinpu_ci.py index 22cb8bb40e2..857b9579c3f 100644 --- a/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/train_and_test_multinpu_ci.py +++ b/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/train_and_test_multinpu_ci.py @@ -21,7 +21,6 @@ from mindspore import Model, context from mindspore.train.callback import TimeMonitor from mindspore.context import ParallelMode from mindspore.communication.management import get_rank, get_group_size, init -from mindspore.parallel import set_multi_subgraphs from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel @@ -33,7 +32,6 @@ from src.config import WideDeepConfig sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True) -set_multi_subgraphs() init() diff --git a/tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py b/tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py index 70443858aa5..93e38263201 100644 --- a/tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py +++ b/tests/ut/python/parallel/test_auto_parallel_double_subgraphs.py @@ -17,13 +17,13 @@ import numpy as np import mindspore as ms import mindspore.nn as nn from mindspore import Tensor, Parameter, ParameterTuple -from mindspore import context +from mindspore import context, Model from mindspore.common.api import _executor from mindspore.nn.optim import Adam, FTRL from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops import operations as P -from mindspore.parallel import set_multi_subgraphs +from mindspore.parallel._cost_model_context import _set_multi_subgraphs from mindspore.parallel._utils import _reset_op_id as reset_op_id @@ -103,7 +103,7 @@ class TrainStepWarp(nn.Cell): def test_double_subgraphs(): - set_multi_subgraphs() + _set_multi_subgraphs() context.set_context(save_graphs=True) context.set_auto_parallel_context(device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="auto_parallel") @@ -120,3 +120,50 @@ def test_double_subgraphs(): 'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]], 'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]} assert strategies == expected_strategies + +class DatasetLenet(): + def __init__(self, predict, label, length=3): + self.predict = predict + self.label = label + self.index = 0 + self.length = length + + def __iter__(self): + return self + + def __next__(self): + if self.index >= self.length: + raise StopIteration + self.index += 1 + return self.predict + + def reset(self): + self.index = 0 + + def get_dataset_size(self): + return 32 + + def get_repeat_count(self): + return 1 + + def create_tuple_iterator(self): + return self + +def test_double_subgraphs_train(): + context.set_context(save_graphs=True) + context.set_auto_parallel_context(device_num=1, global_rank=0) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + net = TrainStepWarp(NetWithLoss(Net())) + + batch_ids = np.ones([8, 8, 8, 8]).astype(np.int32) + ds_train = DatasetLenet(Tensor(batch_ids), None) + model = Model(net) + model.train(1, ds_train, dataset_sink_mode=False) + strategies = _executor._get_strategy(net) + expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op3': [[1, 1, 1, 1]], + 'Default/network-NetWithLoss/net-Net/ReLU-op4': [[1, 1, 1, 1]], + 'Default/network-NetWithLoss/net-Net/Mul-op5': [[1, 1, 1, 1], [1, 1, 1, 1]], + 'Default/network-NetWithLoss/net-Net/Mul-op6': [[1, 1, 1, 1], [1, 1, 1, 1]], + 'Default/network-NetWithLoss/net-Net/Cast-op1': [[1, 1, 1, 1]], + 'Default/network-NetWithLoss/ReduceSum-op7': [[1, 1, 1, 1]]} + assert strategies == expected_strategies