From e6f9806cfb6b82eda0b43dddb7499ba575171cc3 Mon Sep 17 00:00:00 2001 From: jinyaohui Date: Wed, 18 Nov 2020 09:32:27 +0800 Subject: [PATCH] add broadcast --- mindspore/common/api.py | 39 ++++++++++++- mindspore/context.py | 4 +- mindspore/nn/wrap/cell_wrapper.py | 58 +++++++++++++++++++- mindspore/parallel/_auto_parallel_context.py | 3 - 4 files changed, 95 insertions(+), 9 deletions(-) diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 16c4a10a4db..0d7578fd5e0 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -18,16 +18,21 @@ import types from collections import OrderedDict from functools import wraps + from mindspore import context from mindspore import log as logger +from .tensor import Tensor as MsTensor from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_ from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend -from .tensor import Tensor as MsTensor -from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor from ..parallel._ps_context import _is_role_pserver +from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \ + _get_parameter_broadcast + # store ms_function class compiled pipeline cache ms_compile_cache = {} +BROADCAST_PHASE = "_broadcast_" + def _convert_function_arguments(fn, *args): """ @@ -362,6 +367,27 @@ class _Executor: def _build_data_graph(self, obj, phase): self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict()) + def _get_auto_split_param_names(self, parameter_layout_dict): + auto_split_params = {} + for key, value in parameter_layout_dict.items(): + for dim in value[1]: + if dim != -1: + auto_split_params[key] = value + break + auto_split_param_names = (param_name for param_name in auto_split_params) + return auto_split_param_names + + def _build_broadcast_graph(self, obj, broadcast_params, broadcast_phase): + """Build broadcast graph.""" + from mindspore.nn.wrap.cell_wrapper import _BroadCastCell + + _broadcast_net = _BroadCastCell(broadcast_params) + _broadcast_net.phase = broadcast_phase + broadcasted_params = _broadcast_net() + parameters_broadcast_dict = obj.parameters_broadcast_dict() + for param_name, param in zip(parameters_broadcast_dict, broadcasted_params): + parameters_broadcast_dict[param_name].set_data(param) + def _set_dataset_mode(self, args_list): """set dataset mode.""" # decide whether to sink based on whether the inputs is virtual or args_list is () @@ -444,6 +470,15 @@ class _Executor: _exec_init_graph(obj, init_phase) elif not enable_ge and "export" in phase: self._build_data_graph(obj, phase) + elif BROADCAST_PHASE not in phase and _get_parameter_broadcast(): + auto_split_param_names = [] + if auto_parallel_mode: + auto_split_param_names = self._get_auto_split_param_names(obj.parameter_layout_dict) + broadcast_params = [param for param_name, param in obj.parameters_broadcast_dict().items() if + param_name not in auto_split_param_names] + broadcast_phase = "broadcast_subgraph" + "." + str(obj.create_time) + self._build_broadcast_graph(obj, broadcast_params, broadcast_phase) + self.compile_cache[phase] = broadcast_phase return phase, True diff --git a/mindspore/context.py b/mindspore/context.py index 44daf6d4646..e9fac987cbb 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -377,9 +377,7 @@ def set_auto_parallel_context(**kwargs): - recursive_programming: Recursive programming search mode. - dynamic_programming: Dynamic programming search mode. - parameter_broadcast (bool): A developing feature. Whether to broadcast parameters before training. - "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter - broadcast. Default: False. + parameter_broadcast (bool): Whether to broadcast parameters before training. Default: False. strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index bdf60aa9534..3cd33932309 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -25,6 +25,40 @@ from ...ops.operations.comm_ops import _VirtualDataset from ..cell import Cell from .grad_reducer import DistributedGradReducer +_get_datatype = C.MultitypeFuncGraph("_get_datatype") + + +@_get_datatype.register("Tensor") +def _tensors_get_datatype(param): + """ + Acquire parameter datatype. + + Args: + param (Tensor): The parameter before operation. + + Returns: + mstype, the datatype of parameter. + """ + return F.dtype(param) + + +_cast_datatype = C.MultitypeFuncGraph("_cast_datatype") + + +@_cast_datatype.register("TypeType", "Tensor") +def _tensors_cast_datatype(datatype, param): + """ + Cast gradient to datatype. + + Args: + datatype (mstype): the destination datatype of parameter. + param (Tensor): The parameter before operation. + + Returns: + Tensor, the parameter after operation. + """ + return F.cast(param, datatype) + class WithLossCell(Cell): r""" @@ -175,6 +209,7 @@ class TrainOneStepCell(Cell): >>> loss_net = MyWithLossCell(net, loss_fn) >>> train_net = nn.TrainOneStepCell(loss_net, optim) """ + def __init__(self, network, optimizer, sens=1.0): super(TrainOneStepCell, self).__init__(auto_prefix=False) self.network = network @@ -314,7 +349,6 @@ class WithEvalCell(Cell): self._loss_fn = loss_fn self.add_cast_fp32 = add_cast_fp32 - def construct(self, data, label): outputs = self._network(data) if self.add_cast_fp32: @@ -354,3 +388,25 @@ class ParameterUpdate(Cell): def construct(self, x): F.assign(self._param, x) return x + + +class _BroadCastCell(Cell): + """ + Broadcast the parameters from device 0 to other devices. + + Args: + params (list): The parameters of Net. + """ + + def __init__(self, params): + super(_BroadCastCell, self).__init__() + self.map_ = C.Map() + self.params = tuple(params) + self.broadcast = P.Broadcast(0) + + def construct(self): + datatypes = self.map_(F.partial(_get_datatype), self.params) + params = self.map_(F.partial(_cast_datatype, mstype.float32), self.params) + params = self.broadcast(params) + new_params = self.map_(F.partial(_cast_datatype), datatypes, params) + return new_params diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 2185cf50a28..24cb8702af4 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -207,9 +207,6 @@ class _AutoParallelContext: parameter_broadcast (bool): Parameter broadcast or not. """ self.check_context_handle() - if parameter_broadcast is True and context.get_context("enable_ge") is False: - raise RuntimeError("Parameter broadcast is a developing feature. For now we suggest to" - " use mindspore.common.set_seed() to share parameters among devices.") self._context_handle.set_parameter_broadcast(parameter_broadcast) def get_parameter_broadcast(self):