forked from mindspore-Ecosystem/mindspore
remove internal interface in wide&deep
This commit is contained in:
parent
9dd4ab0e3e
commit
a9a8e323b2
|
@ -17,5 +17,7 @@ This interface is ONLY used in Auto-parallel procedure.
|
|||
"""
|
||||
from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \
|
||||
set_algo_parameters
|
||||
from ._cost_model_context import set_multi_subgraphs, get_multi_subgraphs
|
||||
|
||||
__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
|
||||
__all__ = ["set_multi_subgraphs", "get_multi_subgraphs",
|
||||
"get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
|
||||
|
|
|
@ -479,7 +479,6 @@ set_cost_model_context_func_map = {
|
|||
"costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold,
|
||||
"costmodel_communi_const": cost_model_context().set_costmodel_communi_const,
|
||||
"costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias,
|
||||
"multi_subgraphs": cost_model_context().set_multi_subgraphs,
|
||||
"run_phase": cost_model_context().set_run_phase,
|
||||
"costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm,
|
||||
"costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times,
|
||||
|
@ -501,7 +500,6 @@ get_cost_model_context_func_map = {
|
|||
"costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold,
|
||||
"costmodel_communi_const": cost_model_context().get_costmodel_communi_const,
|
||||
"costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias,
|
||||
"multi_subgraphs": cost_model_context().get_multi_subgraphs,
|
||||
"run_phase": cost_model_context().get_run_phase,
|
||||
"costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm,
|
||||
"costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times,
|
||||
|
@ -538,7 +536,6 @@ def set_cost_model_context(**kwargs):
|
|||
costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice.
|
||||
costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
|
||||
costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
|
||||
multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs.
|
||||
run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0.
|
||||
costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
|
||||
0: bypass allreduce fusion;
|
||||
|
@ -591,3 +588,18 @@ def get_cost_model_context(attr_key):
|
|||
def reset_cost_model_context():
|
||||
"""Reset cost model context attributes."""
|
||||
cost_model_context().reset_cost_model()
|
||||
|
||||
def set_multi_subgraphs(multi_subgraph=True):
|
||||
"""
|
||||
Set the flag of ANF graph containing multiple subgraphs.
|
||||
|
||||
Args:
|
||||
multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag.
|
||||
"""
|
||||
cost_model_context().set_multi_subgraphs(multi_subgraph)
|
||||
|
||||
def get_multi_subgraphs():
|
||||
"""
|
||||
Get the flag of ANF graph containing multiple subgraphs.
|
||||
"""
|
||||
cost_model_context().get_multi_subgraphs()
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""wide and deep model"""
|
||||
import numpy as np
|
||||
from mindspore import nn
|
||||
from mindspore import nn, context
|
||||
from mindspore import Parameter, ParameterTuple
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -22,10 +22,7 @@ from mindspore.ops import composite as C
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import Dropout
|
||||
from mindspore.nn.optim import Adam, FTRL, LazyAdam
|
||||
# from mindspore.nn.metrics import Metric
|
||||
from mindspore.common.initializer import Uniform, initializer
|
||||
# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.communication.management import get_group_size
|
||||
|
@ -142,7 +139,7 @@ class WideDeepModel(nn.Cell):
|
|||
self.batch_size = config.batch_size
|
||||
host_device_mix = bool(config.host_device_mix)
|
||||
parameter_server = bool(config.parameter_server)
|
||||
parallel_mode = _get_parallel_mode()
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
if is_auto_parallel:
|
||||
self.batch_size = self.batch_size * get_group_size()
|
||||
|
@ -259,7 +256,7 @@ class NetWithLossClass(nn.Cell):
|
|||
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
||||
host_device_mix = bool(config.host_device_mix)
|
||||
parameter_server = bool(config.parameter_server)
|
||||
parallel_mode = _get_parallel_mode()
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
self.no_l2loss = (is_auto_parallel if host_device_mix else parameter_server)
|
||||
self.network = network
|
||||
|
@ -312,7 +309,7 @@ class TrainStepWrap(nn.Cell):
|
|||
|
||||
def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False):
|
||||
super(TrainStepWrap, self).__init__()
|
||||
parallel_mode = _get_parallel_mode()
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
self.network = network
|
||||
self.network.set_train()
|
||||
|
@ -351,12 +348,11 @@ class TrainStepWrap(nn.Cell):
|
|||
self.reducer_flag = False
|
||||
self.grad_reducer_w = None
|
||||
self.grad_reducer_d = None
|
||||
parallel_mode = _get_parallel_mode()
|
||||
self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL,
|
||||
ParallelMode.HYBRID_PARALLEL)
|
||||
if self.reducer_flag:
|
||||
mean = _get_mirror_mean()
|
||||
degree = _get_device_num()
|
||||
mean = context.get_auto_parallel_context("mirror_mean")
|
||||
degree = context.get_auto_parallel_context("device_num")
|
||||
self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree)
|
||||
self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree)
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore import Model, context
|
|||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.train import ParallelMode
|
||||
from mindspore.communication.management import get_rank, get_group_size, init
|
||||
from mindspore.parallel import _cost_model_context as cost_model_context
|
||||
from mindspore.parallel import set_multi_subgraphs
|
||||
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
|
||||
|
||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||
|
@ -127,7 +127,7 @@ if __name__ == "__main__":
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True)
|
||||
context.set_context(variable_memory_max_size="24GB")
|
||||
context.set_context(enable_sparse=True)
|
||||
cost_model_context.set_cost_model_context(multi_subgraphs=True)
|
||||
set_multi_subgraphs()
|
||||
if wide_deep_config.device_target == "Ascend":
|
||||
init("hccl")
|
||||
elif wide_deep_config.device_target == "GPU":
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import nn
|
||||
from mindspore import nn, context
|
||||
from mindspore import Tensor, Parameter, ParameterTuple
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
|
@ -24,7 +24,6 @@ from mindspore.ops import operations as P
|
|||
from mindspore.nn import Dropout, Flatten
|
||||
from mindspore.nn.optim import Adam, FTRL
|
||||
from mindspore.common.initializer import Uniform, initializer
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
||||
|
@ -552,13 +551,13 @@ class TrainStepWrap(nn.Cell):
|
|||
self.reducer_flag = False
|
||||
self.grad_reducer_w = None
|
||||
self.grad_reducer_d = None
|
||||
parallel_mode = _get_parallel_mode()
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if parallel_mode in (ParallelMode.DATA_PARALLEL,
|
||||
ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_mirror_mean()
|
||||
degree = _get_device_num()
|
||||
mean = context.get_auto_parallel_context("mirror_mean")
|
||||
degree = context.get_auto_parallel_context("device_num")
|
||||
self.grad_reducer_w = DistributedGradReducer(
|
||||
self.optimizer_w.parameters, mean, degree)
|
||||
self.grad_reducer_d = DistributedGradReducer(
|
||||
|
|
|
@ -21,7 +21,7 @@ from mindspore import Model, context
|
|||
from mindspore.train.callback import TimeMonitor
|
||||
from mindspore.train import ParallelMode
|
||||
from mindspore.communication.management import get_rank, get_group_size, init
|
||||
from mindspore.parallel import _cost_model_context as cost_model_context
|
||||
from mindspore.parallel import set_multi_subgraphs
|
||||
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
|
||||
|
||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||
|
@ -33,7 +33,7 @@ from src.config import WideDeepConfig
|
|||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
|
||||
cost_model_context.set_cost_model_context(multi_subgraphs=True)
|
||||
set_multi_subgraphs()
|
||||
init()
|
||||
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.nn.optim import Adam, FTRL
|
|||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.parallel import _cost_model_context as cost_model_context
|
||||
from mindspore.parallel import set_multi_subgraphs
|
||||
from mindspore.parallel._utils import _reset_op_id as reset_op_id
|
||||
|
||||
|
||||
|
@ -103,7 +103,7 @@ class TrainStepWarp(nn.Cell):
|
|||
|
||||
|
||||
def test_double_subgraphs():
|
||||
cost_model_context.set_cost_model_context(multi_subgraphs=True)
|
||||
set_multi_subgraphs()
|
||||
context.set_context(save_graphs=True)
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
|
|
Loading…
Reference in New Issue