remove internal interface in wide&deep

This commit is contained in:
yao_yf 2020-08-24 10:41:31 +08:00
parent 9dd4ab0e3e
commit a9a8e323b2
7 changed files with 34 additions and 25 deletions
mindspore/parallel
model_zoo/official/recommend
wide_and_deep
wide_and_deep_multitable/src
tests
st/model_zoo_tests/wide_and_deep/python_file_for_ci
ut/python/parallel

View File

@ -17,5 +17,7 @@ This interface is ONLY used in Auto-parallel procedure.
"""
from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \
set_algo_parameters
from ._cost_model_context import set_multi_subgraphs, get_multi_subgraphs
__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
__all__ = ["set_multi_subgraphs", "get_multi_subgraphs",
"get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]

View File

@ -479,7 +479,6 @@ set_cost_model_context_func_map = {
"costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold,
"costmodel_communi_const": cost_model_context().set_costmodel_communi_const,
"costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias,
"multi_subgraphs": cost_model_context().set_multi_subgraphs,
"run_phase": cost_model_context().set_run_phase,
"costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm,
"costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times,
@ -501,7 +500,6 @@ get_cost_model_context_func_map = {
"costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold,
"costmodel_communi_const": cost_model_context().get_costmodel_communi_const,
"costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias,
"multi_subgraphs": cost_model_context().get_multi_subgraphs,
"run_phase": cost_model_context().get_run_phase,
"costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm,
"costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times,
@ -538,7 +536,6 @@ def set_cost_model_context(**kwargs):
costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice.
costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs.
run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0.
costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
0: bypass allreduce fusion;
@ -591,3 +588,18 @@ def get_cost_model_context(attr_key):
def reset_cost_model_context():
"""Reset cost model context attributes."""
cost_model_context().reset_cost_model()
def set_multi_subgraphs(multi_subgraph=True):
"""
Set the flag of ANF graph containing multiple subgraphs.
Args:
multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag.
"""
cost_model_context().set_multi_subgraphs(multi_subgraph)
def get_multi_subgraphs():
"""
Get the flag of ANF graph containing multiple subgraphs.
"""
cost_model_context().get_multi_subgraphs()

View File

@ -14,7 +14,7 @@
# ============================================================================
"""wide and deep model"""
import numpy as np
from mindspore import nn
from mindspore import nn, context
from mindspore import Parameter, ParameterTuple
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
@ -22,10 +22,7 @@ from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.nn import Dropout
from mindspore.nn.optim import Adam, FTRL, LazyAdam
# from mindspore.nn.metrics import Metric
from mindspore.common.initializer import Uniform, initializer
# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import get_group_size
@ -142,7 +139,7 @@ class WideDeepModel(nn.Cell):
self.batch_size = config.batch_size
host_device_mix = bool(config.host_device_mix)
parameter_server = bool(config.parameter_server)
parallel_mode = _get_parallel_mode()
parallel_mode = context.get_auto_parallel_context("parallel_mode")
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
if is_auto_parallel:
self.batch_size = self.batch_size * get_group_size()
@ -259,7 +256,7 @@ class NetWithLossClass(nn.Cell):
super(NetWithLossClass, self).__init__(auto_prefix=False)
host_device_mix = bool(config.host_device_mix)
parameter_server = bool(config.parameter_server)
parallel_mode = _get_parallel_mode()
parallel_mode = context.get_auto_parallel_context("parallel_mode")
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.no_l2loss = (is_auto_parallel if host_device_mix else parameter_server)
self.network = network
@ -312,7 +309,7 @@ class TrainStepWrap(nn.Cell):
def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False):
super(TrainStepWrap, self).__init__()
parallel_mode = _get_parallel_mode()
parallel_mode = context.get_auto_parallel_context("parallel_mode")
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.network = network
self.network.set_train()
@ -351,12 +348,11 @@ class TrainStepWrap(nn.Cell):
self.reducer_flag = False
self.grad_reducer_w = None
self.grad_reducer_d = None
parallel_mode = _get_parallel_mode()
self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL,
ParallelMode.HYBRID_PARALLEL)
if self.reducer_flag:
mean = _get_mirror_mean()
degree = _get_device_num()
mean = context.get_auto_parallel_context("mirror_mean")
degree = context.get_auto_parallel_context("device_num")
self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree)
self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree)

View File

@ -22,7 +22,7 @@ from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel import set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
@ -127,7 +127,7 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True)
context.set_context(variable_memory_max_size="24GB")
context.set_context(enable_sparse=True)
cost_model_context.set_cost_model_context(multi_subgraphs=True)
set_multi_subgraphs()
if wide_deep_config.device_target == "Ascend":
init("hccl")
elif wide_deep_config.device_target == "GPU":

View File

@ -16,7 +16,7 @@
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import nn
from mindspore import nn, context
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore.ops import functional as F
from mindspore.ops import composite as C
@ -24,7 +24,6 @@ from mindspore.ops import operations as P
from mindspore.nn import Dropout, Flatten
from mindspore.nn.optim import Adam, FTRL
from mindspore.common.initializer import Uniform, initializer
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
@ -552,13 +551,13 @@ class TrainStepWrap(nn.Cell):
self.reducer_flag = False
self.grad_reducer_w = None
self.grad_reducer_d = None
parallel_mode = _get_parallel_mode()
parallel_mode = context.get_auto_parallel_context("parallel_mode")
if parallel_mode in (ParallelMode.DATA_PARALLEL,
ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_mirror_mean()
degree = _get_device_num()
mean = context.get_auto_parallel_context("mirror_mean")
degree = context.get_auto_parallel_context("device_num")
self.grad_reducer_w = DistributedGradReducer(
self.optimizer_w.parameters, mean, degree)
self.grad_reducer_d = DistributedGradReducer(

View File

@ -21,7 +21,7 @@ from mindspore import Model, context
from mindspore.train.callback import TimeMonitor
from mindspore.train import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel import set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
@ -33,7 +33,7 @@ from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
cost_model_context.set_cost_model_context(multi_subgraphs=True)
set_multi_subgraphs()
init()

View File

@ -23,7 +23,7 @@ from mindspore.nn.optim import Adam, FTRL
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel import set_multi_subgraphs
from mindspore.parallel._utils import _reset_op_id as reset_op_id
@ -103,7 +103,7 @@ class TrainStepWarp(nn.Cell):
def test_double_subgraphs():
cost_model_context.set_cost_model_context(multi_subgraphs=True)
set_multi_subgraphs()
context.set_context(save_graphs=True)
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel")