!11338 Support parameter broadcast in data parallel mode under PyNaitve

From: @zuochuanyong
Reviewed-by: @zhoufeng54,@chujinjin
Signed-off-by: @chujinjin
This commit is contained in:
mindspore-ci-bot 2021-01-18 15:11:13 +08:00 committed by Gitee
commit b988780fd7
2 changed files with 55 additions and 37 deletions

View File

@ -298,6 +298,49 @@ def _generate_pip_args(obj, *args, method="construct"):
return args_names, args_list
def _get_auto_split_param_names(parameter_layout_dict):
auto_split_params = {}
for key, value in parameter_layout_dict.items():
for dim in value[1]:
if dim != -1:
auto_split_params[key] = value
break
auto_split_param_names = (param_name for param_name in auto_split_params)
return auto_split_param_names
def _build_broadcast_graph(broadcast_params_dict, broadcast_phase):
"""Build broadcast graph."""
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell
if not broadcast_params_dict:
broadcast_params_dict = {}
broadcast_params = []
for param in broadcast_params_dict.values():
broadcast_params.append(Tensor(param.asnumpy()))
_broadcast_net = _BroadCastCell(broadcast_params)
_broadcast_net.phase = broadcast_phase
broadcasted_params = _broadcast_net()
for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
broadcast_params_dict[param_name].set_data(param)
def _parameter_broadcast(obj, auto_parallel_mode):
"""Parameter broadcast."""
auto_split_param_names = []
if auto_parallel_mode:
auto_split_param_names = _get_auto_split_param_names(obj.parameter_layout_dict)
broadcast_params_dict = obj.parameters_broadcast_dict()
if auto_split_param_names and broadcast_params_dict:
broadcast_params_dict = OrderedDict()
for param_name, param in obj.parameters_broadcast_dict().items():
if param_name not in auto_split_param_names:
broadcast_params_dict[param_name] = param
broadcast_phase = "_broadcast_subgraph"
_build_broadcast_graph(broadcast_params_dict, broadcast_phase)
class _PynativeExecutor:
"""
An pynative executor used to compile/manage/run graph.
@ -339,6 +382,10 @@ class _PynativeExecutor:
def leave_construct(self, cell):
self._executor.leave_construct(cell)
def parameter_broadcast(self, obj, phase, auto_parallel_mode):
if BROADCAST_PHASE not in phase and _get_parameter_broadcast():
_parameter_broadcast(obj, auto_parallel_mode)
def __call__(self, obj, *args, **kwargs):
args = args + tuple(kwargs.values())
return self._executor(obj, args, "")
@ -391,31 +438,6 @@ class _Executor:
def _build_data_graph(self, obj, phase):
self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
def _get_auto_split_param_names(self, parameter_layout_dict):
auto_split_params = {}
for key, value in parameter_layout_dict.items():
for dim in value[1]:
if dim != -1:
auto_split_params[key] = value
break
auto_split_param_names = (param_name for param_name in auto_split_params)
return auto_split_param_names
def _build_broadcast_graph(self, broadcast_params_dict, broadcast_phase):
"""Build broadcast graph."""
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell
if not broadcast_params_dict:
broadcast_params_dict = {}
broadcast_params = []
for param in broadcast_params_dict.values():
broadcast_params.append(Tensor(param.asnumpy()))
_broadcast_net = _BroadCastCell(broadcast_params)
_broadcast_net.phase = broadcast_phase
broadcasted_params = _broadcast_net()
for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
broadcast_params_dict[param_name].set_data(param)
def _set_dataset_mode(self, args_list):
"""set dataset mode."""
# decide whether to sink based on whether the inputs is virtual or args_list is ()
@ -500,18 +522,7 @@ class _Executor:
elif not enable_ge and "export" in phase:
self._build_data_graph(obj, phase)
elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
auto_split_param_names = []
if auto_parallel_mode:
auto_split_param_names = self._get_auto_split_param_names(obj.parameter_layout_dict)
broadcast_params_dict = obj.parameters_broadcast_dict()
if auto_split_param_names and broadcast_params_dict:
broadcast_params_dict = OrderedDict()
for param_name, param in obj.parameters_broadcast_dict().items():
if param_name not in auto_split_param_names:
broadcast_params_dict[param_name] = param
broadcast_phase = "_broadcast_subgraph"
self._build_broadcast_graph(broadcast_params_dict, broadcast_phase)
_parameter_broadcast(obj, auto_parallel_mode)
return phase, True

View File

@ -23,6 +23,7 @@ import numpy
from mindspore import log as logger
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
from mindspore.context import ParallelMode
from .. import context
from .._c_expression import init_pipeline, Cell_
from .._checkparam import Validator
@ -90,6 +91,7 @@ class Cell(Cell_):
self._parameter_layout_dict = {}
self._create_time = int(time.time() * 1e9)
self.phase_prefix = ""
self.parameter_broadcast_done = False
init_pipeline()
# call gc to release GE session resources used by non-used cell objects
@ -300,6 +302,11 @@ class Cell(Cell_):
out = self.compile_and_run(*inputs)
return out
if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
if not self.parameter_broadcast_done:
_pynative_exec.parameter_broadcast(self, self.phase, self._auto_parallel_mode)
self.parameter_broadcast_done = True
for item in inputs:
if isinstance(item, numpy.ndarray):
raise TypeError("cell inputs should not be numpy array.")