enable optimizer parallel

This commit is contained in:
Ziyan 2020-07-14 20:44:16 +08:00
parent 7976d77593
commit 39f08eb7dd
7 changed files with 80 additions and 92 deletions

View File

@ -100,7 +100,10 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
auto parallel_context = parallel::ParallelContext::GetInstance(); auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context); MS_EXCEPTION_IF_NULL(parallel_context);
const auto &split_indices = parallel_context->GetAllReduceFusionSplitIndices(group); std::vector<uint32_t> split_indices;
if (!parallel_context->enable_parallel_optimizer()) {
split_indices = parallel_context->GetAllReduceFusionSplitIndices(group);
}
size_t segments = 0; size_t segments = 0;
if (split_indices.size() != 0) { if (split_indices.size() != 0) {

View File

@ -443,7 +443,7 @@ def _context():
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str, @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str,
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool) strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
def set_auto_parallel_context(**kwargs): def set_auto_parallel_context(**kwargs):
""" """
Set auto parallel context. Set auto parallel context.
@ -487,6 +487,9 @@ def set_auto_parallel_context(**kwargs):
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False. full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in
data parallel training in the benefit of time and memory saving.
Raises: Raises:
ValueError: If input key is not attribute in auto parallel context. ValueError: If input key is not attribute in auto parallel context.
@ -532,6 +535,7 @@ def reset_auto_parallel_context():
- parameter_broadcast: False. - parameter_broadcast: False.
- strategy_ckpt_load_file: "". - strategy_ckpt_load_file: "".
- strategy_ckpt_save_file: "". - strategy_ckpt_save_file: "".
- enable_parallel_optimizer: False.
""" """
_reset_auto_parallel_context() _reset_auto_parallel_context()

View File

@ -28,8 +28,8 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from mindspore import log as logger from mindspore import log as logger
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
from mindspore import context
__all__ = ['Optimizer'] __all__ = ['Optimizer']
@ -157,13 +157,12 @@ class Optimizer(Cell):
self.param_length = len(self.parameters) self.param_length = len(self.parameters)
self.map_ = C.Map() self.map_ = C.Map()
use_parallel = auto_parallel_context().get_enable_parallel_optimizer() use_parallel = context.get_auto_parallel_context("enable_parallel_optimizer")
self.use_parallel = use_parallel self.use_parallel = use_parallel
if use_parallel: if use_parallel:
if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]: if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]:
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name)) raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL, if _get_parallel_mode() != ParallelMode.DATA_PARALLEL:
ParallelMode.AUTO_PARALLEL]:
raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format
(_get_parallel_mode())) (_get_parallel_mode()))
self.dev_num = _get_device_num() self.dev_num = _get_device_num()
@ -175,6 +174,7 @@ class Optimizer(Cell):
self.param_names = [] self.param_names = []
for param in self.parameters: for param in self.parameters:
self.param_names.append(param.name) self.param_names.append(param.name)
else: else:
self.optim_filter = (True,) * self.param_length self.optim_filter = (True,) * self.param_length

View File

@ -13,107 +13,95 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""grad reducer cell for distributed training""" """grad reducer cell for distributed training"""
from mindspore import context
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.communication.management import GlobalComm, get_group_size from mindspore.communication.management import GlobalComm, get_group_size
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp, AllGather from mindspore.ops.operations.comm_ops import AllReduce, AllGather
from mindspore.parallel._auto_parallel_context import auto_parallel_context
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
reduce_opt = C.MultitypeFuncGraph("reduce_opt") reduce_opt = C.MultitypeFuncGraph("reduce_opt")
_all_reduce = AllReduce()
_all_gather = None def _init_allreduce_operators(length):
""" initialize allreduce communication operators"""
is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer")
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()
if is_parallel_optimizer and split_indices:
group = 1
fusion = ()
for i in range(length):
fusion = fusion + (group,)
if split_indices[group - 1] <= i + 1:
if group >= len(split_indices):
continue
group = group + 1
index = tuple(range(1, length + 1))
else:
fusion = (1,) * length
index = (0,) * length
opt_list = ()
for i in range(length):
opt = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP)
opt.add_prim_attr('fusion', fusion[i])
opt.add_prim_attr('index', index[i])
opt_list = opt_list + (opt,)
return opt_list
def _init_optimizer_communication(): @reduce_opt.register("Number", "Bool", "Function", "Bool", "Tensor", "Function")
global _all_reduce def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduce):
global _all_gather
_all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
_all_reduce.add_prim_attr('fusion', 1)
_all_gather = AllGather(GlobalComm.WORLD_COMM_GROUP)
@reduce_opt.register("Function", "Number", "Bool", "Tensor")
def _tensors_allreduce_mean(mul, degree, allreduce_filter, grad):
"""
Apply mean and allreduce on gradient. Allreduce is a communication operation used for distributed deep learning.
Args:
mul (Primitive): Div operation.
degree (int): The mean coefficient.
allreduce_filter (bool): When it is true, allreduce would apply.
grad (Tensor): The gradient tensor before operation.
Returns:
Tensor, the gradient tensor after operation.
"""
if allreduce_filter:
degree = F.scalar_cast(degree, F.dtype(grad))
grad = _all_reduce(grad)
cast_op = P.Cast()
return mul(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad)))
return grad
@reduce_opt.register("Function", "Number", "Bool", "Tuple")
def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad):
"""
Apply mean and allgather on gradient instead of allreduce for sparse feature.
Allgather is a communication operation used for distributed deep learning.
Args:
mul (Primitive): Div operation.
degree (int): The mean coefficient.
allreduce_filter (bool): When it is true, allgather would apply.
grad (Tuple): The indices, gradient tensor and tensor_shape before operation.
Returns:
Tuple, include indices, the gradient tensor and tensor_shape after operation.
"""
if allreduce_filter:
indices = _all_gather(grad[0])
degree = F.scalar_cast(degree, F.dtype(grad[1]))
dout = _all_gather(grad[1])
cast_op = P.Cast()
dout = mul(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
grad = (indices, dout, grad[2])
return grad
@reduce_opt.register("Bool", "Tensor")
def _tensors_allreduce(allreduce_filter, grad):
""" """
Apply allreduce on gradient. Apply allreduce on gradient.
Args: Args:
degree (int): The mean coefficient.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
allgather (Primitive): The communication operator for sparse gradients.
allreduce_filter (bool): When it is true, allreduce would apply. allreduce_filter (bool): When it is true, allreduce would apply.
grad (Tensor): The gradient tensor before operation. grad (Tensor): The gradient tensor before operation.
allreduce (Primitive): The communication operator for gradients.
Returns: Returns:
Tensor, the gradient tensor after operation. Tensor, the gradient tensor after operation.
""" """
if allreduce_filter: if allreduce_filter:
return _all_reduce(grad) grad = allreduce(grad)
if mean:
degree = F.scalar_cast(degree, F.dtype(grad))
cast_op = P.Cast()
mul_op = P.Mul()
grad = mul_op(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad)))
return grad
return grad return grad
@reduce_opt.register("Bool", "Tuple") @reduce_opt.register("Number", "Bool", "Function", "Bool", "Tuple", "Function")
def _tensors_allreduce_with_sparse(allreduce_filter, grad): def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, grad, allreduce):
""" """
Apply mean and allgather on gradient instead of allreduce for sparse feature. Apply allgather on gradient instead of allreduce for sparse feature.
Allgather is a communication operation used for distributed deep learning. Allgather is a communication operation used for distributed deep learning.
Args: Args:
degree (int): The mean coefficient.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
allgather (Primitive): The communication operator for sparse gradients.
allreduce_filter (bool): When it is true, allgather would apply. allreduce_filter (bool): When it is true, allgather would apply.
grad (Tuple): The indices, gradient tensor and tensor_shape before operation. grad (tuple): The indices, gradient tensor and tensor_shape before operation.
allreduce (Primitive): The communication operator for gradients.
Returns: Returns:
Tuple, include indices, the gradient tensor and tensor_shape after operation. Tuple, include indices, the gradient tensor and tensor_shape after operation.
""" """
if allreduce_filter: if allreduce_filter:
indices = _all_gather(grad[0]) indices = allgather(grad[0])
dout = _all_gather(grad[1]) dout = allgather(grad[1])
if mean:
degree = F.scalar_cast(degree, F.dtype(grad[1]))
cast_op = P.Cast()
mul_op = P.Mul()
dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
grad = (indices, dout, grad[2]) grad = (indices, dout, grad[2])
return grad return grad
@ -259,7 +247,6 @@ class DistributedGradReducer(Cell):
def __init__(self, parameters, mean=True, degree=None): def __init__(self, parameters, mean=True, degree=None):
super(DistributedGradReducer, self).__init__(auto_prefix=False) super(DistributedGradReducer, self).__init__(auto_prefix=False)
self.map_ = C.Map() self.map_ = C.Map()
self.mul = P.Mul()
if degree is None: if degree is None:
self.degree = get_group_size() self.degree = get_group_size()
else: else:
@ -268,7 +255,8 @@ class DistributedGradReducer(Cell):
self.degree = degree self.degree = degree
self.mean = mean self.mean = mean
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
_init_optimizer_communication() self.opt_list = _init_allreduce_operators(len(parameters))
self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP)
def construct(self, grads): def construct(self, grads):
""" """
@ -284,11 +272,8 @@ class DistributedGradReducer(Cell):
""" """
datatypes = self.map_(F.partial(_get_datatype), grads) datatypes = self.map_(F.partial(_get_datatype), grads)
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads) grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
if self.mean: self.allreduce_filter, grads, self.opt_list)
new_grad = self.map_(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads)
else:
new_grad = self.map_(F.partial(reduce_opt), self.allreduce_filter, grads)
new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad) new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad)
return new_grad return new_grad

View File

@ -513,7 +513,7 @@ def _set_auto_parallel_context(**kwargs):
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False. full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False. enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False.
Raises: Raises:
ValueError: If input key is not attribute in auto parallel context. ValueError: If input key is not attribute in auto parallel context.

View File

@ -22,7 +22,6 @@ from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR, Lamb from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR, Lamb
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore import context from mindspore import context
@ -54,8 +53,7 @@ class Net(nn.Cell):
def test_AdamWeightDecayDynamicLR(): def test_AdamWeightDecayDynamicLR():
""" test_AdamWeightDecayDynamicLR """ """ test_AdamWeightDecayDynamicLR """
auto_parallel_context().set_enable_parallel_optimizer(True) context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2)
inputs = Tensor(np.ones([32, 128]).astype(np.float32)) inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32)) label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = Net() net = Net()
@ -70,8 +68,7 @@ def test_AdamWeightDecayDynamicLR():
def test_AdamWeightDecay(): def test_AdamWeightDecay():
""" test_AdamWeightDecayDynamicLR """ """ test_AdamWeightDecayDynamicLR """
auto_parallel_context().set_enable_parallel_optimizer(True) context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2)
inputs = Tensor(np.ones([32, 128]).astype(np.float32)) inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32)) label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = Net() net = Net()
@ -86,8 +83,7 @@ def test_AdamWeightDecay():
def test_lamb_compile(): def test_lamb_compile():
""" test_Lamb_compile """ """ test_Lamb_compile """
auto_parallel_context().set_enable_parallel_optimizer(True) context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2)
inputs = Tensor(np.ones([32, 128]).astype(np.float32)) inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32)) label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = Net() net = Net()
@ -102,7 +98,7 @@ def test_lamb_compile():
def test_edge_case(): def test_edge_case():
""" test_edge_case """ """ test_edge_case """
auto_parallel_context().set_enable_parallel_optimizer(True) context.set_auto_parallel_context(enable_parallel_optimizer=True)
net = Net() net = Net()
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
context.set_auto_parallel_context(parallel_mode="stand_alone") context.set_auto_parallel_context(parallel_mode="stand_alone")

View File

@ -81,8 +81,8 @@ def test_set_auto_parallel_context():
with pytest.raises(ValueError): with pytest.raises(ValueError):
set_algo_parameters(tensor_slice_align_size=1025) set_algo_parameters(tensor_slice_align_size=1025)
auto_parallel_context().set_enable_parallel_optimizer(True) context.set_auto_parallel_context(enable_parallel_optimizer=True)
assert auto_parallel_context().get_enable_parallel_optimizer() is True assert context.get_auto_parallel_context("enable_parallel_optimizer")
assert not auto_parallel_context().get_all_reduce_fusion_split_indices() assert not auto_parallel_context().get_all_reduce_fusion_split_indices()