forked from mindspore-Ecosystem/mindspore
!5696 [Auto parallel] Move 'multi-subgraphs' interface to internal
Merge pull request !5696 from Xiaoda/20-moving-multi-graph-interface-internal
This commit is contained in:
commit
9018737e99
|
@ -17,7 +17,5 @@ This interface is ONLY used in Auto-parallel procedure.
|
|||
"""
|
||||
from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \
|
||||
set_algo_parameters
|
||||
from ._cost_model_context import set_multi_subgraphs, get_multi_subgraphs
|
||||
|
||||
__all__ = ["set_multi_subgraphs", "get_multi_subgraphs",
|
||||
"get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
|
||||
__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
|
||||
|
|
|
@ -589,7 +589,7 @@ def reset_cost_model_context():
|
|||
"""Reset cost model context attributes."""
|
||||
cost_model_context().reset_cost_model()
|
||||
|
||||
def set_multi_subgraphs(multi_subgraph=True):
|
||||
def _set_multi_subgraphs(multi_subgraph=True):
|
||||
"""
|
||||
Set the flag of ANF graph containing multiple subgraphs.
|
||||
|
||||
|
@ -598,7 +598,7 @@ def set_multi_subgraphs(multi_subgraph=True):
|
|||
"""
|
||||
cost_model_context().set_multi_subgraphs(multi_subgraph)
|
||||
|
||||
def get_multi_subgraphs():
|
||||
def _get_multi_subgraphs():
|
||||
"""
|
||||
Get the flag of ANF graph containing multiple subgraphs.
|
||||
"""
|
||||
|
|
|
@ -32,6 +32,7 @@ from .. import nn
|
|||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from ..context import ParallelMode
|
||||
from ..parallel._utils import _need_to_full, _to_full_tensor
|
||||
from ..parallel._cost_model_context import _set_multi_subgraphs
|
||||
from ..common import dtype as mstype
|
||||
from .dataset_helper import DatasetHelper
|
||||
from . import amp
|
||||
|
@ -166,6 +167,9 @@ class Model:
|
|||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network.set_auto_parallel()
|
||||
if self._optimizer is None:
|
||||
# In this case, multiple optimizer(s) is supposed to be included in 'self._network'
|
||||
_set_multi_subgraphs()
|
||||
return network
|
||||
|
||||
def _build_eval_network(self, metrics, eval_network, eval_indexes):
|
||||
|
@ -190,6 +194,9 @@ class Model:
|
|||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
if self._optimizer:
|
||||
self._eval_network = _VirtualDatasetCell(self._eval_network)
|
||||
if self._optimizer is None:
|
||||
# In this case, multiple optimizer(s) is supposed to be included in 'self._network'
|
||||
_set_multi_subgraphs()
|
||||
self._eval_network.set_auto_parallel()
|
||||
|
||||
def _build_predict_network(self):
|
||||
|
@ -197,6 +204,7 @@ class Model:
|
|||
self._predict_network = self._network
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
self._predict_network = _VirtualDatasetCell(self._network)
|
||||
_set_multi_subgraphs()
|
||||
self._predict_network.set_auto_parallel()
|
||||
|
||||
def _clear_metrics(self):
|
||||
|
|
|
@ -22,7 +22,6 @@ from mindspore import Model, context
|
|||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import get_rank, get_group_size, init
|
||||
from mindspore.parallel import set_multi_subgraphs
|
||||
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
|
||||
|
||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||
|
@ -145,7 +144,6 @@ if __name__ == "__main__":
|
|||
device_target=wide_deep_config.device_target, save_graphs=True)
|
||||
context.set_context(variable_memory_max_size="24GB")
|
||||
context.set_context(enable_sparse=True)
|
||||
set_multi_subgraphs()
|
||||
init()
|
||||
if wide_deep_config.host_device_mix == 1:
|
||||
context.set_auto_parallel_context(
|
||||
|
|
|
@ -21,7 +21,6 @@ from mindspore import Model, context
|
|||
from mindspore.train.callback import TimeMonitor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import get_rank, get_group_size, init
|
||||
from mindspore.parallel import set_multi_subgraphs
|
||||
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
|
||||
|
||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||
|
@ -33,7 +32,6 @@ from src.config import WideDeepConfig
|
|||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
|
||||
set_multi_subgraphs()
|
||||
init()
|
||||
|
||||
|
||||
|
|
|
@ -17,13 +17,13 @@ import numpy as np
|
|||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter, ParameterTuple
|
||||
from mindspore import context
|
||||
from mindspore import context, Model
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn.optim import Adam, FTRL
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.parallel import set_multi_subgraphs
|
||||
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
|
||||
from mindspore.parallel._utils import _reset_op_id as reset_op_id
|
||||
|
||||
|
||||
|
@ -103,7 +103,7 @@ class TrainStepWarp(nn.Cell):
|
|||
|
||||
|
||||
def test_double_subgraphs():
|
||||
set_multi_subgraphs()
|
||||
_set_multi_subgraphs()
|
||||
context.set_context(save_graphs=True)
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
|
@ -120,3 +120,50 @@ def test_double_subgraphs():
|
|||
'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]}
|
||||
assert strategies == expected_strategies
|
||||
|
||||
class DatasetLenet():
|
||||
def __init__(self, predict, label, length=3):
|
||||
self.predict = predict
|
||||
self.label = label
|
||||
self.index = 0
|
||||
self.length = length
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= self.length:
|
||||
raise StopIteration
|
||||
self.index += 1
|
||||
return self.predict
|
||||
|
||||
def reset(self):
|
||||
self.index = 0
|
||||
|
||||
def get_dataset_size(self):
|
||||
return 32
|
||||
|
||||
def get_repeat_count(self):
|
||||
return 1
|
||||
|
||||
def create_tuple_iterator(self):
|
||||
return self
|
||||
|
||||
def test_double_subgraphs_train():
|
||||
context.set_context(save_graphs=True)
|
||||
context.set_auto_parallel_context(device_num=1, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net = TrainStepWarp(NetWithLoss(Net()))
|
||||
|
||||
batch_ids = np.ones([8, 8, 8, 8]).astype(np.int32)
|
||||
ds_train = DatasetLenet(Tensor(batch_ids), None)
|
||||
model = Model(net)
|
||||
model.train(1, ds_train, dataset_sink_mode=False)
|
||||
strategies = _executor._get_strategy(net)
|
||||
expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op3': [[1, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/net-Net/ReLU-op4': [[1, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/net-Net/Mul-op5': [[1, 1, 1, 1], [1, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/net-Net/Mul-op6': [[1, 1, 1, 1], [1, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/net-Net/Cast-op1': [[1, 1, 1, 1]],
|
||||
'Default/network-NetWithLoss/ReduceSum-op7': [[1, 1, 1, 1]]}
|
||||
assert strategies == expected_strategies
|
||||
|
|
Loading…
Reference in New Issue