forked from mindspore-Ecosystem/mindspore
!5696 [Auto parallel] Move 'multi-subgraphs' interface to internal
Merge pull request !5696 from Xiaoda/20-moving-multi-graph-interface-internal
This commit is contained in:
commit
9018737e99
|
@ -17,7 +17,5 @@ This interface is ONLY used in Auto-parallel procedure.
|
||||||
"""
|
"""
|
||||||
from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \
|
from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \
|
||||||
set_algo_parameters
|
set_algo_parameters
|
||||||
from ._cost_model_context import set_multi_subgraphs, get_multi_subgraphs
|
|
||||||
|
|
||||||
__all__ = ["set_multi_subgraphs", "get_multi_subgraphs",
|
__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
|
||||||
"get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
|
|
||||||
|
|
|
@ -589,7 +589,7 @@ def reset_cost_model_context():
|
||||||
"""Reset cost model context attributes."""
|
"""Reset cost model context attributes."""
|
||||||
cost_model_context().reset_cost_model()
|
cost_model_context().reset_cost_model()
|
||||||
|
|
||||||
def set_multi_subgraphs(multi_subgraph=True):
|
def _set_multi_subgraphs(multi_subgraph=True):
|
||||||
"""
|
"""
|
||||||
Set the flag of ANF graph containing multiple subgraphs.
|
Set the flag of ANF graph containing multiple subgraphs.
|
||||||
|
|
||||||
|
@ -598,7 +598,7 @@ def set_multi_subgraphs(multi_subgraph=True):
|
||||||
"""
|
"""
|
||||||
cost_model_context().set_multi_subgraphs(multi_subgraph)
|
cost_model_context().set_multi_subgraphs(multi_subgraph)
|
||||||
|
|
||||||
def get_multi_subgraphs():
|
def _get_multi_subgraphs():
|
||||||
"""
|
"""
|
||||||
Get the flag of ANF graph containing multiple subgraphs.
|
Get the flag of ANF graph containing multiple subgraphs.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -32,6 +32,7 @@ from .. import nn
|
||||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||||
from ..context import ParallelMode
|
from ..context import ParallelMode
|
||||||
from ..parallel._utils import _need_to_full, _to_full_tensor
|
from ..parallel._utils import _need_to_full, _to_full_tensor
|
||||||
|
from ..parallel._cost_model_context import _set_multi_subgraphs
|
||||||
from ..common import dtype as mstype
|
from ..common import dtype as mstype
|
||||||
from .dataset_helper import DatasetHelper
|
from .dataset_helper import DatasetHelper
|
||||||
from . import amp
|
from . import amp
|
||||||
|
@ -166,6 +167,9 @@ class Model:
|
||||||
|
|
||||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||||
network.set_auto_parallel()
|
network.set_auto_parallel()
|
||||||
|
if self._optimizer is None:
|
||||||
|
# In this case, multiple optimizer(s) is supposed to be included in 'self._network'
|
||||||
|
_set_multi_subgraphs()
|
||||||
return network
|
return network
|
||||||
|
|
||||||
def _build_eval_network(self, metrics, eval_network, eval_indexes):
|
def _build_eval_network(self, metrics, eval_network, eval_indexes):
|
||||||
|
@ -190,6 +194,9 @@ class Model:
|
||||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||||
if self._optimizer:
|
if self._optimizer:
|
||||||
self._eval_network = _VirtualDatasetCell(self._eval_network)
|
self._eval_network = _VirtualDatasetCell(self._eval_network)
|
||||||
|
if self._optimizer is None:
|
||||||
|
# In this case, multiple optimizer(s) is supposed to be included in 'self._network'
|
||||||
|
_set_multi_subgraphs()
|
||||||
self._eval_network.set_auto_parallel()
|
self._eval_network.set_auto_parallel()
|
||||||
|
|
||||||
def _build_predict_network(self):
|
def _build_predict_network(self):
|
||||||
|
@ -197,6 +204,7 @@ class Model:
|
||||||
self._predict_network = self._network
|
self._predict_network = self._network
|
||||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||||
self._predict_network = _VirtualDatasetCell(self._network)
|
self._predict_network = _VirtualDatasetCell(self._network)
|
||||||
|
_set_multi_subgraphs()
|
||||||
self._predict_network.set_auto_parallel()
|
self._predict_network.set_auto_parallel()
|
||||||
|
|
||||||
def _clear_metrics(self):
|
def _clear_metrics(self):
|
||||||
|
|
|
@ -22,7 +22,6 @@ from mindspore import Model, context
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.communication.management import get_rank, get_group_size, init
|
from mindspore.communication.management import get_rank, get_group_size, init
|
||||||
from mindspore.parallel import set_multi_subgraphs
|
|
||||||
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
|
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
|
||||||
|
|
||||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||||
|
@ -145,7 +144,6 @@ if __name__ == "__main__":
|
||||||
device_target=wide_deep_config.device_target, save_graphs=True)
|
device_target=wide_deep_config.device_target, save_graphs=True)
|
||||||
context.set_context(variable_memory_max_size="24GB")
|
context.set_context(variable_memory_max_size="24GB")
|
||||||
context.set_context(enable_sparse=True)
|
context.set_context(enable_sparse=True)
|
||||||
set_multi_subgraphs()
|
|
||||||
init()
|
init()
|
||||||
if wide_deep_config.host_device_mix == 1:
|
if wide_deep_config.host_device_mix == 1:
|
||||||
context.set_auto_parallel_context(
|
context.set_auto_parallel_context(
|
||||||
|
|
|
@ -21,7 +21,6 @@ from mindspore import Model, context
|
||||||
from mindspore.train.callback import TimeMonitor
|
from mindspore.train.callback import TimeMonitor
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.communication.management import get_rank, get_group_size, init
|
from mindspore.communication.management import get_rank, get_group_size, init
|
||||||
from mindspore.parallel import set_multi_subgraphs
|
|
||||||
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
|
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
|
||||||
|
|
||||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||||
|
@ -33,7 +32,6 @@ from src.config import WideDeepConfig
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
|
||||||
set_multi_subgraphs()
|
|
||||||
init()
|
init()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,13 +17,13 @@ import numpy as np
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import Tensor, Parameter, ParameterTuple
|
from mindspore import Tensor, Parameter, ParameterTuple
|
||||||
from mindspore import context
|
from mindspore import context, Model
|
||||||
from mindspore.common.api import _executor
|
from mindspore.common.api import _executor
|
||||||
from mindspore.nn.optim import Adam, FTRL
|
from mindspore.nn.optim import Adam, FTRL
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.parallel import set_multi_subgraphs
|
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
|
||||||
from mindspore.parallel._utils import _reset_op_id as reset_op_id
|
from mindspore.parallel._utils import _reset_op_id as reset_op_id
|
||||||
|
|
||||||
|
|
||||||
|
@ -103,7 +103,7 @@ class TrainStepWarp(nn.Cell):
|
||||||
|
|
||||||
|
|
||||||
def test_double_subgraphs():
|
def test_double_subgraphs():
|
||||||
set_multi_subgraphs()
|
_set_multi_subgraphs()
|
||||||
context.set_context(save_graphs=True)
|
context.set_context(save_graphs=True)
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||||
|
@ -120,3 +120,50 @@ def test_double_subgraphs():
|
||||||
'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]],
|
'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]],
|
||||||
'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]}
|
'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]}
|
||||||
assert strategies == expected_strategies
|
assert strategies == expected_strategies
|
||||||
|
|
||||||
|
class DatasetLenet():
|
||||||
|
def __init__(self, predict, label, length=3):
|
||||||
|
self.predict = predict
|
||||||
|
self.label = label
|
||||||
|
self.index = 0
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.index >= self.length:
|
||||||
|
raise StopIteration
|
||||||
|
self.index += 1
|
||||||
|
return self.predict
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.index = 0
|
||||||
|
|
||||||
|
def get_dataset_size(self):
|
||||||
|
return 32
|
||||||
|
|
||||||
|
def get_repeat_count(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def create_tuple_iterator(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def test_double_subgraphs_train():
|
||||||
|
context.set_context(save_graphs=True)
|
||||||
|
context.set_auto_parallel_context(device_num=1, global_rank=0)
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||||
|
net = TrainStepWarp(NetWithLoss(Net()))
|
||||||
|
|
||||||
|
batch_ids = np.ones([8, 8, 8, 8]).astype(np.int32)
|
||||||
|
ds_train = DatasetLenet(Tensor(batch_ids), None)
|
||||||
|
model = Model(net)
|
||||||
|
model.train(1, ds_train, dataset_sink_mode=False)
|
||||||
|
strategies = _executor._get_strategy(net)
|
||||||
|
expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op3': [[1, 1, 1, 1]],
|
||||||
|
'Default/network-NetWithLoss/net-Net/ReLU-op4': [[1, 1, 1, 1]],
|
||||||
|
'Default/network-NetWithLoss/net-Net/Mul-op5': [[1, 1, 1, 1], [1, 1, 1, 1]],
|
||||||
|
'Default/network-NetWithLoss/net-Net/Mul-op6': [[1, 1, 1, 1], [1, 1, 1, 1]],
|
||||||
|
'Default/network-NetWithLoss/net-Net/Cast-op1': [[1, 1, 1, 1]],
|
||||||
|
'Default/network-NetWithLoss/ReduceSum-op7': [[1, 1, 1, 1]]}
|
||||||
|
assert strategies == expected_strategies
|
||||||
|
|
Loading…
Reference in New Issue