remove 'multi-subgraphs' to internal

This commit is contained in:
Xiaoda Zhang 2020-09-03 14:37:50 +08:00
parent 90fa4c9d94
commit 42f1241270
6 changed files with 61 additions and 12 deletions

View File

@ -17,7 +17,5 @@ This interface is ONLY used in Auto-parallel procedure.
""" """
from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \ from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \
set_algo_parameters set_algo_parameters
from ._cost_model_context import set_multi_subgraphs, get_multi_subgraphs
__all__ = ["set_multi_subgraphs", "get_multi_subgraphs", __all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
"get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]

View File

@ -589,7 +589,7 @@ def reset_cost_model_context():
"""Reset cost model context attributes.""" """Reset cost model context attributes."""
cost_model_context().reset_cost_model() cost_model_context().reset_cost_model()
def set_multi_subgraphs(multi_subgraph=True): def _set_multi_subgraphs(multi_subgraph=True):
""" """
Set the flag of ANF graph containing multiple subgraphs. Set the flag of ANF graph containing multiple subgraphs.
@ -598,7 +598,7 @@ def set_multi_subgraphs(multi_subgraph=True):
""" """
cost_model_context().set_multi_subgraphs(multi_subgraph) cost_model_context().set_multi_subgraphs(multi_subgraph)
def get_multi_subgraphs(): def _get_multi_subgraphs():
""" """
Get the flag of ANF graph containing multiple subgraphs. Get the flag of ANF graph containing multiple subgraphs.
""" """

View File

@ -32,6 +32,7 @@ from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from ..context import ParallelMode from ..context import ParallelMode
from ..parallel._utils import _need_to_full, _to_full_tensor from ..parallel._utils import _need_to_full, _to_full_tensor
from ..parallel._cost_model_context import _set_multi_subgraphs
from ..common import dtype as mstype from ..common import dtype as mstype
from .dataset_helper import DatasetHelper from .dataset_helper import DatasetHelper
from . import amp from . import amp
@ -166,6 +167,9 @@ class Model:
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network.set_auto_parallel() network.set_auto_parallel()
if self._optimizer is None:
# In this case, multiple optimizer(s) is supposed to be included in 'self._network'
_set_multi_subgraphs()
return network return network
def _build_eval_network(self, metrics, eval_network, eval_indexes): def _build_eval_network(self, metrics, eval_network, eval_indexes):
@ -190,6 +194,9 @@ class Model:
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
if self._optimizer: if self._optimizer:
self._eval_network = _VirtualDatasetCell(self._eval_network) self._eval_network = _VirtualDatasetCell(self._eval_network)
if self._optimizer is None:
# In this case, multiple optimizer(s) is supposed to be included in 'self._network'
_set_multi_subgraphs()
self._eval_network.set_auto_parallel() self._eval_network.set_auto_parallel()
def _build_predict_network(self): def _build_predict_network(self):
@ -197,6 +204,7 @@ class Model:
self._predict_network = self._network self._predict_network = self._network
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self._predict_network = _VirtualDatasetCell(self._network) self._predict_network = _VirtualDatasetCell(self._network)
_set_multi_subgraphs()
self._predict_network.set_auto_parallel() self._predict_network.set_auto_parallel()
def _clear_metrics(self): def _clear_metrics(self):

View File

@ -22,7 +22,6 @@ from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.parallel import set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
@ -145,7 +144,6 @@ if __name__ == "__main__":
device_target=wide_deep_config.device_target, save_graphs=True) device_target=wide_deep_config.device_target, save_graphs=True)
context.set_context(variable_memory_max_size="24GB") context.set_context(variable_memory_max_size="24GB")
context.set_context(enable_sparse=True) context.set_context(enable_sparse=True)
set_multi_subgraphs()
init() init()
if wide_deep_config.host_device_mix == 1: if wide_deep_config.host_device_mix == 1:
context.set_auto_parallel_context( context.set_auto_parallel_context(

View File

@ -21,7 +21,6 @@ from mindspore import Model, context
from mindspore.train.callback import TimeMonitor from mindspore.train.callback import TimeMonitor
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.parallel import set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
@ -33,7 +32,6 @@ from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True) context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
set_multi_subgraphs()
init() init()

View File

@ -17,13 +17,13 @@ import numpy as np
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, Parameter, ParameterTuple from mindspore import Tensor, Parameter, ParameterTuple
from mindspore import context from mindspore import context, Model
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.nn.optim import Adam, FTRL from mindspore.nn.optim import Adam, FTRL
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.parallel import set_multi_subgraphs from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.parallel._utils import _reset_op_id as reset_op_id from mindspore.parallel._utils import _reset_op_id as reset_op_id
@ -103,7 +103,7 @@ class TrainStepWarp(nn.Cell):
def test_double_subgraphs(): def test_double_subgraphs():
set_multi_subgraphs() _set_multi_subgraphs()
context.set_context(save_graphs=True) context.set_context(save_graphs=True)
context.set_auto_parallel_context(device_num=8, global_rank=0) context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel") context.set_auto_parallel_context(parallel_mode="auto_parallel")
@ -120,3 +120,50 @@ def test_double_subgraphs():
'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]], 'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]],
'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]} 'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]}
assert strategies == expected_strategies assert strategies == expected_strategies
class DatasetLenet():
def __init__(self, predict, label, length=3):
self.predict = predict
self.label = label
self.index = 0
self.length = length
def __iter__(self):
return self
def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.predict
def reset(self):
self.index = 0
def get_dataset_size(self):
return 32
def get_repeat_count(self):
return 1
def create_tuple_iterator(self):
return self
def test_double_subgraphs_train():
context.set_context(save_graphs=True)
context.set_auto_parallel_context(device_num=1, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net = TrainStepWarp(NetWithLoss(Net()))
batch_ids = np.ones([8, 8, 8, 8]).astype(np.int32)
ds_train = DatasetLenet(Tensor(batch_ids), None)
model = Model(net)
model.train(1, ds_train, dataset_sink_mode=False)
strategies = _executor._get_strategy(net)
expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op3': [[1, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/ReLU-op4': [[1, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/Mul-op5': [[1, 1, 1, 1], [1, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/Mul-op6': [[1, 1, 1, 1], [1, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/Cast-op1': [[1, 1, 1, 1]],
'Default/network-NetWithLoss/ReduceSum-op7': [[1, 1, 1, 1]]}
assert strategies == expected_strategies