forked from mindspore-Ecosystem/mindspore
!11338 Support parameter broadcast in data parallel mode under PyNaitve
From: @zuochuanyong Reviewed-by: @zhoufeng54,@chujinjin Signed-off-by: @chujinjin
This commit is contained in:
commit
b988780fd7
|
@ -298,6 +298,49 @@ def _generate_pip_args(obj, *args, method="construct"):
|
|||
return args_names, args_list
|
||||
|
||||
|
||||
def _get_auto_split_param_names(parameter_layout_dict):
|
||||
auto_split_params = {}
|
||||
for key, value in parameter_layout_dict.items():
|
||||
for dim in value[1]:
|
||||
if dim != -1:
|
||||
auto_split_params[key] = value
|
||||
break
|
||||
auto_split_param_names = (param_name for param_name in auto_split_params)
|
||||
return auto_split_param_names
|
||||
|
||||
|
||||
def _build_broadcast_graph(broadcast_params_dict, broadcast_phase):
|
||||
"""Build broadcast graph."""
|
||||
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell
|
||||
|
||||
if not broadcast_params_dict:
|
||||
broadcast_params_dict = {}
|
||||
broadcast_params = []
|
||||
for param in broadcast_params_dict.values():
|
||||
broadcast_params.append(Tensor(param.asnumpy()))
|
||||
_broadcast_net = _BroadCastCell(broadcast_params)
|
||||
_broadcast_net.phase = broadcast_phase
|
||||
broadcasted_params = _broadcast_net()
|
||||
for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
|
||||
broadcast_params_dict[param_name].set_data(param)
|
||||
|
||||
|
||||
def _parameter_broadcast(obj, auto_parallel_mode):
|
||||
"""Parameter broadcast."""
|
||||
auto_split_param_names = []
|
||||
if auto_parallel_mode:
|
||||
auto_split_param_names = _get_auto_split_param_names(obj.parameter_layout_dict)
|
||||
|
||||
broadcast_params_dict = obj.parameters_broadcast_dict()
|
||||
if auto_split_param_names and broadcast_params_dict:
|
||||
broadcast_params_dict = OrderedDict()
|
||||
for param_name, param in obj.parameters_broadcast_dict().items():
|
||||
if param_name not in auto_split_param_names:
|
||||
broadcast_params_dict[param_name] = param
|
||||
broadcast_phase = "_broadcast_subgraph"
|
||||
_build_broadcast_graph(broadcast_params_dict, broadcast_phase)
|
||||
|
||||
|
||||
class _PynativeExecutor:
|
||||
"""
|
||||
An pynative executor used to compile/manage/run graph.
|
||||
|
@ -339,6 +382,10 @@ class _PynativeExecutor:
|
|||
def leave_construct(self, cell):
|
||||
self._executor.leave_construct(cell)
|
||||
|
||||
def parameter_broadcast(self, obj, phase, auto_parallel_mode):
|
||||
if BROADCAST_PHASE not in phase and _get_parameter_broadcast():
|
||||
_parameter_broadcast(obj, auto_parallel_mode)
|
||||
|
||||
def __call__(self, obj, *args, **kwargs):
|
||||
args = args + tuple(kwargs.values())
|
||||
return self._executor(obj, args, "")
|
||||
|
@ -391,31 +438,6 @@ class _Executor:
|
|||
def _build_data_graph(self, obj, phase):
|
||||
self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
|
||||
|
||||
def _get_auto_split_param_names(self, parameter_layout_dict):
|
||||
auto_split_params = {}
|
||||
for key, value in parameter_layout_dict.items():
|
||||
for dim in value[1]:
|
||||
if dim != -1:
|
||||
auto_split_params[key] = value
|
||||
break
|
||||
auto_split_param_names = (param_name for param_name in auto_split_params)
|
||||
return auto_split_param_names
|
||||
|
||||
def _build_broadcast_graph(self, broadcast_params_dict, broadcast_phase):
|
||||
"""Build broadcast graph."""
|
||||
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell
|
||||
|
||||
if not broadcast_params_dict:
|
||||
broadcast_params_dict = {}
|
||||
broadcast_params = []
|
||||
for param in broadcast_params_dict.values():
|
||||
broadcast_params.append(Tensor(param.asnumpy()))
|
||||
_broadcast_net = _BroadCastCell(broadcast_params)
|
||||
_broadcast_net.phase = broadcast_phase
|
||||
broadcasted_params = _broadcast_net()
|
||||
for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
|
||||
broadcast_params_dict[param_name].set_data(param)
|
||||
|
||||
def _set_dataset_mode(self, args_list):
|
||||
"""set dataset mode."""
|
||||
# decide whether to sink based on whether the inputs is virtual or args_list is ()
|
||||
|
@ -500,18 +522,7 @@ class _Executor:
|
|||
elif not enable_ge and "export" in phase:
|
||||
self._build_data_graph(obj, phase)
|
||||
elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
|
||||
auto_split_param_names = []
|
||||
if auto_parallel_mode:
|
||||
auto_split_param_names = self._get_auto_split_param_names(obj.parameter_layout_dict)
|
||||
|
||||
broadcast_params_dict = obj.parameters_broadcast_dict()
|
||||
if auto_split_param_names and broadcast_params_dict:
|
||||
broadcast_params_dict = OrderedDict()
|
||||
for param_name, param in obj.parameters_broadcast_dict().items():
|
||||
if param_name not in auto_split_param_names:
|
||||
broadcast_params_dict[param_name] = param
|
||||
broadcast_phase = "_broadcast_subgraph"
|
||||
self._build_broadcast_graph(broadcast_params_dict, broadcast_phase)
|
||||
_parameter_broadcast(obj, auto_parallel_mode)
|
||||
|
||||
return phase, True
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ import numpy
|
|||
|
||||
from mindspore import log as logger
|
||||
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
||||
from mindspore.context import ParallelMode
|
||||
from .. import context
|
||||
from .._c_expression import init_pipeline, Cell_
|
||||
from .._checkparam import Validator
|
||||
|
@ -90,6 +91,7 @@ class Cell(Cell_):
|
|||
self._parameter_layout_dict = {}
|
||||
self._create_time = int(time.time() * 1e9)
|
||||
self.phase_prefix = ""
|
||||
self.parameter_broadcast_done = False
|
||||
init_pipeline()
|
||||
|
||||
# call gc to release GE session resources used by non-used cell objects
|
||||
|
@ -300,6 +302,11 @@ class Cell(Cell_):
|
|||
out = self.compile_and_run(*inputs)
|
||||
return out
|
||||
|
||||
if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
|
||||
if not self.parameter_broadcast_done:
|
||||
_pynative_exec.parameter_broadcast(self, self.phase, self._auto_parallel_mode)
|
||||
self.parameter_broadcast_done = True
|
||||
|
||||
for item in inputs:
|
||||
if isinstance(item, numpy.ndarray):
|
||||
raise TypeError("cell inputs should not be numpy array.")
|
||||
|
|
Loading…
Reference in New Issue