!7776 add broadcast

From: @jinyaohui
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-18 18:57:37 +08:00 committed by Gitee
commit 071366a5f8
4 changed files with 95 additions and 9 deletions

View File

@ -18,16 +18,21 @@
import types
from collections import OrderedDict
from functools import wraps
from mindspore import context
from mindspore import log as logger
from .tensor import Tensor as MsTensor
from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
from .tensor import Tensor as MsTensor
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor
from ..parallel._ps_context import _is_role_pserver
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \
_get_parameter_broadcast
# store ms_function class compiled pipeline cache
ms_compile_cache = {}
BROADCAST_PHASE = "_broadcast_"
def _convert_function_arguments(fn, *args):
"""
@ -362,6 +367,27 @@ class _Executor:
def _build_data_graph(self, obj, phase):
self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
def _get_auto_split_param_names(self, parameter_layout_dict):
auto_split_params = {}
for key, value in parameter_layout_dict.items():
for dim in value[1]:
if dim != -1:
auto_split_params[key] = value
break
auto_split_param_names = (param_name for param_name in auto_split_params)
return auto_split_param_names
def _build_broadcast_graph(self, obj, broadcast_params, broadcast_phase):
"""Build broadcast graph."""
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell
_broadcast_net = _BroadCastCell(broadcast_params)
_broadcast_net.phase = broadcast_phase
broadcasted_params = _broadcast_net()
parameters_broadcast_dict = obj.parameters_broadcast_dict()
for param_name, param in zip(parameters_broadcast_dict, broadcasted_params):
parameters_broadcast_dict[param_name].set_data(param)
def _set_dataset_mode(self, args_list):
"""set dataset mode."""
# decide whether to sink based on whether the inputs is virtual or args_list is ()
@ -444,6 +470,15 @@ class _Executor:
_exec_init_graph(obj, init_phase)
elif not enable_ge and "export" in phase:
self._build_data_graph(obj, phase)
elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
auto_split_param_names = []
if auto_parallel_mode:
auto_split_param_names = self._get_auto_split_param_names(obj.parameter_layout_dict)
broadcast_params = [param for param_name, param in obj.parameters_broadcast_dict().items() if
param_name not in auto_split_param_names]
broadcast_phase = "broadcast_subgraph" + "." + str(obj.create_time)
self._build_broadcast_graph(obj, broadcast_params, broadcast_phase)
self.compile_cache[phase] = broadcast_phase
return phase, True

View File

@ -377,9 +377,7 @@ def set_auto_parallel_context(**kwargs):
- recursive_programming: Recursive programming search mode.
- dynamic_programming: Dynamic programming search mode.
parameter_broadcast (bool): A developing feature. Whether to broadcast parameters before training.
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
broadcast. Default: False.
parameter_broadcast (bool): Whether to broadcast parameters before training. Default: False.
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter

View File

@ -25,6 +25,40 @@ from ...ops.operations.comm_ops import _VirtualDataset
from ..cell import Cell
from .grad_reducer import DistributedGradReducer
_get_datatype = C.MultitypeFuncGraph("_get_datatype")
@_get_datatype.register("Tensor")
def _tensors_get_datatype(param):
"""
Acquire parameter datatype.
Args:
param (Tensor): The parameter before operation.
Returns:
mstype, the datatype of parameter.
"""
return F.dtype(param)
_cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
@_cast_datatype.register("TypeType", "Tensor")
def _tensors_cast_datatype(datatype, param):
"""
Cast gradient to datatype.
Args:
datatype (mstype): the destination datatype of parameter.
param (Tensor): The parameter before operation.
Returns:
Tensor, the parameter after operation.
"""
return F.cast(param, datatype)
class WithLossCell(Cell):
r"""
@ -175,6 +209,7 @@ class TrainOneStepCell(Cell):
>>> loss_net = MyWithLossCell(net, loss_fn)
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
"""
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
@ -314,7 +349,6 @@ class WithEvalCell(Cell):
self._loss_fn = loss_fn
self.add_cast_fp32 = add_cast_fp32
def construct(self, data, label):
outputs = self._network(data)
if self.add_cast_fp32:
@ -354,3 +388,25 @@ class ParameterUpdate(Cell):
def construct(self, x):
F.assign(self._param, x)
return x
class _BroadCastCell(Cell):
"""
Broadcast the parameters from device 0 to other devices.
Args:
params (list): The parameters of Net.
"""
def __init__(self, params):
super(_BroadCastCell, self).__init__()
self.map_ = C.Map()
self.params = tuple(params)
self.broadcast = P.Broadcast(0)
def construct(self):
datatypes = self.map_(F.partial(_get_datatype), self.params)
params = self.map_(F.partial(_cast_datatype, mstype.float32), self.params)
params = self.broadcast(params)
new_params = self.map_(F.partial(_cast_datatype), datatypes, params)
return new_params

View File

@ -207,9 +207,6 @@ class _AutoParallelContext:
parameter_broadcast (bool): Parameter broadcast or not.
"""
self.check_context_handle()
if parameter_broadcast is True and context.get_context("enable_ge") is False:
raise RuntimeError("Parameter broadcast is a developing feature. For now we suggest to"
" use mindspore.common.set_seed() to share parameters among devices.")
self._context_handle.set_parameter_broadcast(parameter_broadcast)
def get_parameter_broadcast(self):