forked from mindspore-Ecosystem/mindspore
commit
071366a5f8
|
@ -18,16 +18,21 @@
|
|||
import types
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from .tensor import Tensor as MsTensor
|
||||
from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_
|
||||
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
|
||||
from .tensor import Tensor as MsTensor
|
||||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor
|
||||
from ..parallel._ps_context import _is_role_pserver
|
||||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \
|
||||
_get_parameter_broadcast
|
||||
|
||||
# store ms_function class compiled pipeline cache
|
||||
ms_compile_cache = {}
|
||||
|
||||
BROADCAST_PHASE = "_broadcast_"
|
||||
|
||||
|
||||
def _convert_function_arguments(fn, *args):
|
||||
"""
|
||||
|
@ -362,6 +367,27 @@ class _Executor:
|
|||
def _build_data_graph(self, obj, phase):
|
||||
self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
|
||||
|
||||
def _get_auto_split_param_names(self, parameter_layout_dict):
|
||||
auto_split_params = {}
|
||||
for key, value in parameter_layout_dict.items():
|
||||
for dim in value[1]:
|
||||
if dim != -1:
|
||||
auto_split_params[key] = value
|
||||
break
|
||||
auto_split_param_names = (param_name for param_name in auto_split_params)
|
||||
return auto_split_param_names
|
||||
|
||||
def _build_broadcast_graph(self, obj, broadcast_params, broadcast_phase):
|
||||
"""Build broadcast graph."""
|
||||
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell
|
||||
|
||||
_broadcast_net = _BroadCastCell(broadcast_params)
|
||||
_broadcast_net.phase = broadcast_phase
|
||||
broadcasted_params = _broadcast_net()
|
||||
parameters_broadcast_dict = obj.parameters_broadcast_dict()
|
||||
for param_name, param in zip(parameters_broadcast_dict, broadcasted_params):
|
||||
parameters_broadcast_dict[param_name].set_data(param)
|
||||
|
||||
def _set_dataset_mode(self, args_list):
|
||||
"""set dataset mode."""
|
||||
# decide whether to sink based on whether the inputs is virtual or args_list is ()
|
||||
|
@ -444,6 +470,15 @@ class _Executor:
|
|||
_exec_init_graph(obj, init_phase)
|
||||
elif not enable_ge and "export" in phase:
|
||||
self._build_data_graph(obj, phase)
|
||||
elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
|
||||
auto_split_param_names = []
|
||||
if auto_parallel_mode:
|
||||
auto_split_param_names = self._get_auto_split_param_names(obj.parameter_layout_dict)
|
||||
broadcast_params = [param for param_name, param in obj.parameters_broadcast_dict().items() if
|
||||
param_name not in auto_split_param_names]
|
||||
broadcast_phase = "broadcast_subgraph" + "." + str(obj.create_time)
|
||||
self._build_broadcast_graph(obj, broadcast_params, broadcast_phase)
|
||||
self.compile_cache[phase] = broadcast_phase
|
||||
|
||||
return phase, True
|
||||
|
||||
|
|
|
@ -377,9 +377,7 @@ def set_auto_parallel_context(**kwargs):
|
|||
- recursive_programming: Recursive programming search mode.
|
||||
|
||||
- dynamic_programming: Dynamic programming search mode.
|
||||
parameter_broadcast (bool): A developing feature. Whether to broadcast parameters before training.
|
||||
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
|
||||
broadcast. Default: False.
|
||||
parameter_broadcast (bool): Whether to broadcast parameters before training. Default: False.
|
||||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
||||
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
||||
full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter
|
||||
|
|
|
@ -25,6 +25,40 @@ from ...ops.operations.comm_ops import _VirtualDataset
|
|||
from ..cell import Cell
|
||||
from .grad_reducer import DistributedGradReducer
|
||||
|
||||
_get_datatype = C.MultitypeFuncGraph("_get_datatype")
|
||||
|
||||
|
||||
@_get_datatype.register("Tensor")
|
||||
def _tensors_get_datatype(param):
|
||||
"""
|
||||
Acquire parameter datatype.
|
||||
|
||||
Args:
|
||||
param (Tensor): The parameter before operation.
|
||||
|
||||
Returns:
|
||||
mstype, the datatype of parameter.
|
||||
"""
|
||||
return F.dtype(param)
|
||||
|
||||
|
||||
_cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
|
||||
|
||||
|
||||
@_cast_datatype.register("TypeType", "Tensor")
|
||||
def _tensors_cast_datatype(datatype, param):
|
||||
"""
|
||||
Cast gradient to datatype.
|
||||
|
||||
Args:
|
||||
datatype (mstype): the destination datatype of parameter.
|
||||
param (Tensor): The parameter before operation.
|
||||
|
||||
Returns:
|
||||
Tensor, the parameter after operation.
|
||||
"""
|
||||
return F.cast(param, datatype)
|
||||
|
||||
|
||||
class WithLossCell(Cell):
|
||||
r"""
|
||||
|
@ -175,6 +209,7 @@ class TrainOneStepCell(Cell):
|
|||
>>> loss_net = MyWithLossCell(net, loss_fn)
|
||||
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
|
@ -314,7 +349,6 @@ class WithEvalCell(Cell):
|
|||
self._loss_fn = loss_fn
|
||||
self.add_cast_fp32 = add_cast_fp32
|
||||
|
||||
|
||||
def construct(self, data, label):
|
||||
outputs = self._network(data)
|
||||
if self.add_cast_fp32:
|
||||
|
@ -354,3 +388,25 @@ class ParameterUpdate(Cell):
|
|||
def construct(self, x):
|
||||
F.assign(self._param, x)
|
||||
return x
|
||||
|
||||
|
||||
class _BroadCastCell(Cell):
|
||||
"""
|
||||
Broadcast the parameters from device 0 to other devices.
|
||||
|
||||
Args:
|
||||
params (list): The parameters of Net.
|
||||
"""
|
||||
|
||||
def __init__(self, params):
|
||||
super(_BroadCastCell, self).__init__()
|
||||
self.map_ = C.Map()
|
||||
self.params = tuple(params)
|
||||
self.broadcast = P.Broadcast(0)
|
||||
|
||||
def construct(self):
|
||||
datatypes = self.map_(F.partial(_get_datatype), self.params)
|
||||
params = self.map_(F.partial(_cast_datatype, mstype.float32), self.params)
|
||||
params = self.broadcast(params)
|
||||
new_params = self.map_(F.partial(_cast_datatype), datatypes, params)
|
||||
return new_params
|
||||
|
|
|
@ -207,9 +207,6 @@ class _AutoParallelContext:
|
|||
parameter_broadcast (bool): Parameter broadcast or not.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if parameter_broadcast is True and context.get_context("enable_ge") is False:
|
||||
raise RuntimeError("Parameter broadcast is a developing feature. For now we suggest to"
|
||||
" use mindspore.common.set_seed() to share parameters among devices.")
|
||||
self._context_handle.set_parameter_broadcast(parameter_broadcast)
|
||||
|
||||
def get_parameter_broadcast(self):
|
||||
|
|
Loading…
Reference in New Issue