forked from mindspore-Ecosystem/mindspore
enable optimizer parallel
This commit is contained in:
parent
7976d77593
commit
39f08eb7dd
|
@ -100,7 +100,10 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
|
|||
|
||||
auto parallel_context = parallel::ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
const auto &split_indices = parallel_context->GetAllReduceFusionSplitIndices(group);
|
||||
std::vector<uint32_t> split_indices;
|
||||
if (!parallel_context->enable_parallel_optimizer()) {
|
||||
split_indices = parallel_context->GetAllReduceFusionSplitIndices(group);
|
||||
}
|
||||
|
||||
size_t segments = 0;
|
||||
if (split_indices.size() != 0) {
|
||||
|
|
|
@ -443,7 +443,7 @@ def _context():
|
|||
|
||||
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str,
|
||||
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool)
|
||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
|
||||
def set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
Set auto parallel context.
|
||||
|
@ -487,6 +487,9 @@ def set_auto_parallel_context(**kwargs):
|
|||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
||||
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
||||
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
||||
enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in
|
||||
data parallel training in the benefit of time and memory saving.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
@ -532,6 +535,7 @@ def reset_auto_parallel_context():
|
|||
- parameter_broadcast: False.
|
||||
- strategy_ckpt_load_file: "".
|
||||
- strategy_ckpt_save_file: "".
|
||||
- enable_parallel_optimizer: False.
|
||||
"""
|
||||
_reset_auto_parallel_context()
|
||||
|
||||
|
|
|
@ -28,8 +28,8 @@ from mindspore._checkparam import Validator as validator
|
|||
from mindspore._checkparam import Rel
|
||||
from mindspore import log as logger
|
||||
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore import context
|
||||
|
||||
__all__ = ['Optimizer']
|
||||
|
||||
|
@ -157,13 +157,12 @@ class Optimizer(Cell):
|
|||
self.param_length = len(self.parameters)
|
||||
self.map_ = C.Map()
|
||||
|
||||
use_parallel = auto_parallel_context().get_enable_parallel_optimizer()
|
||||
use_parallel = context.get_auto_parallel_context("enable_parallel_optimizer")
|
||||
self.use_parallel = use_parallel
|
||||
if use_parallel:
|
||||
if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]:
|
||||
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
|
||||
if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL,
|
||||
ParallelMode.AUTO_PARALLEL]:
|
||||
if _get_parallel_mode() != ParallelMode.DATA_PARALLEL:
|
||||
raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format
|
||||
(_get_parallel_mode()))
|
||||
self.dev_num = _get_device_num()
|
||||
|
@ -175,6 +174,7 @@ class Optimizer(Cell):
|
|||
self.param_names = []
|
||||
for param in self.parameters:
|
||||
self.param_names.append(param.name)
|
||||
|
||||
else:
|
||||
self.optim_filter = (True,) * self.param_length
|
||||
|
||||
|
|
|
@ -13,107 +13,95 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""grad reducer cell for distributed training"""
|
||||
from mindspore import context
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.communication.management import GlobalComm, get_group_size
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp, AllGather
|
||||
from mindspore.ops.operations.comm_ops import AllReduce, AllGather
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
||||
|
||||
_all_reduce = AllReduce()
|
||||
_all_gather = None
|
||||
|
||||
def _init_allreduce_operators(length):
|
||||
""" initialize allreduce communication operators"""
|
||||
is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer")
|
||||
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()
|
||||
if is_parallel_optimizer and split_indices:
|
||||
group = 1
|
||||
fusion = ()
|
||||
for i in range(length):
|
||||
fusion = fusion + (group,)
|
||||
if split_indices[group - 1] <= i + 1:
|
||||
if group >= len(split_indices):
|
||||
continue
|
||||
group = group + 1
|
||||
index = tuple(range(1, length + 1))
|
||||
else:
|
||||
fusion = (1,) * length
|
||||
index = (0,) * length
|
||||
opt_list = ()
|
||||
for i in range(length):
|
||||
opt = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP)
|
||||
opt.add_prim_attr('fusion', fusion[i])
|
||||
opt.add_prim_attr('index', index[i])
|
||||
opt_list = opt_list + (opt,)
|
||||
return opt_list
|
||||
|
||||
|
||||
def _init_optimizer_communication():
|
||||
global _all_reduce
|
||||
global _all_gather
|
||||
|
||||
_all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
|
||||
_all_reduce.add_prim_attr('fusion', 1)
|
||||
_all_gather = AllGather(GlobalComm.WORLD_COMM_GROUP)
|
||||
|
||||
|
||||
@reduce_opt.register("Function", "Number", "Bool", "Tensor")
|
||||
def _tensors_allreduce_mean(mul, degree, allreduce_filter, grad):
|
||||
"""
|
||||
Apply mean and allreduce on gradient. Allreduce is a communication operation used for distributed deep learning.
|
||||
|
||||
Args:
|
||||
mul (Primitive): Div operation.
|
||||
degree (int): The mean coefficient.
|
||||
allreduce_filter (bool): When it is true, allreduce would apply.
|
||||
grad (Tensor): The gradient tensor before operation.
|
||||
|
||||
Returns:
|
||||
Tensor, the gradient tensor after operation.
|
||||
"""
|
||||
if allreduce_filter:
|
||||
degree = F.scalar_cast(degree, F.dtype(grad))
|
||||
grad = _all_reduce(grad)
|
||||
cast_op = P.Cast()
|
||||
return mul(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad)))
|
||||
return grad
|
||||
|
||||
|
||||
@reduce_opt.register("Function", "Number", "Bool", "Tuple")
|
||||
def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad):
|
||||
"""
|
||||
Apply mean and allgather on gradient instead of allreduce for sparse feature.
|
||||
Allgather is a communication operation used for distributed deep learning.
|
||||
|
||||
Args:
|
||||
mul (Primitive): Div operation.
|
||||
degree (int): The mean coefficient.
|
||||
allreduce_filter (bool): When it is true, allgather would apply.
|
||||
grad (Tuple): The indices, gradient tensor and tensor_shape before operation.
|
||||
|
||||
Returns:
|
||||
Tuple, include indices, the gradient tensor and tensor_shape after operation.
|
||||
"""
|
||||
if allreduce_filter:
|
||||
indices = _all_gather(grad[0])
|
||||
degree = F.scalar_cast(degree, F.dtype(grad[1]))
|
||||
dout = _all_gather(grad[1])
|
||||
cast_op = P.Cast()
|
||||
dout = mul(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
|
||||
grad = (indices, dout, grad[2])
|
||||
return grad
|
||||
|
||||
|
||||
@reduce_opt.register("Bool", "Tensor")
|
||||
def _tensors_allreduce(allreduce_filter, grad):
|
||||
@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tensor", "Function")
|
||||
def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduce):
|
||||
"""
|
||||
Apply allreduce on gradient.
|
||||
|
||||
Args:
|
||||
degree (int): The mean coefficient.
|
||||
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
|
||||
allgather (Primitive): The communication operator for sparse gradients.
|
||||
allreduce_filter (bool): When it is true, allreduce would apply.
|
||||
grad (Tensor): The gradient tensor before operation.
|
||||
allreduce (Primitive): The communication operator for gradients.
|
||||
|
||||
Returns:
|
||||
Tensor, the gradient tensor after operation.
|
||||
"""
|
||||
if allreduce_filter:
|
||||
return _all_reduce(grad)
|
||||
grad = allreduce(grad)
|
||||
if mean:
|
||||
degree = F.scalar_cast(degree, F.dtype(grad))
|
||||
cast_op = P.Cast()
|
||||
mul_op = P.Mul()
|
||||
grad = mul_op(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad)))
|
||||
return grad
|
||||
return grad
|
||||
|
||||
|
||||
@reduce_opt.register("Bool", "Tuple")
|
||||
def _tensors_allreduce_with_sparse(allreduce_filter, grad):
|
||||
@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tuple", "Function")
|
||||
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, grad, allreduce):
|
||||
"""
|
||||
Apply mean and allgather on gradient instead of allreduce for sparse feature.
|
||||
Apply allgather on gradient instead of allreduce for sparse feature.
|
||||
Allgather is a communication operation used for distributed deep learning.
|
||||
|
||||
Args:
|
||||
degree (int): The mean coefficient.
|
||||
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
|
||||
allgather (Primitive): The communication operator for sparse gradients.
|
||||
allreduce_filter (bool): When it is true, allgather would apply.
|
||||
grad (Tuple): The indices, gradient tensor and tensor_shape before operation.
|
||||
grad (tuple): The indices, gradient tensor and tensor_shape before operation.
|
||||
allreduce (Primitive): The communication operator for gradients.
|
||||
|
||||
Returns:
|
||||
Tuple, include indices, the gradient tensor and tensor_shape after operation.
|
||||
"""
|
||||
if allreduce_filter:
|
||||
indices = _all_gather(grad[0])
|
||||
dout = _all_gather(grad[1])
|
||||
indices = allgather(grad[0])
|
||||
dout = allgather(grad[1])
|
||||
if mean:
|
||||
degree = F.scalar_cast(degree, F.dtype(grad[1]))
|
||||
cast_op = P.Cast()
|
||||
mul_op = P.Mul()
|
||||
dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
|
||||
grad = (indices, dout, grad[2])
|
||||
return grad
|
||||
|
||||
|
@ -259,7 +247,6 @@ class DistributedGradReducer(Cell):
|
|||
def __init__(self, parameters, mean=True, degree=None):
|
||||
super(DistributedGradReducer, self).__init__(auto_prefix=False)
|
||||
self.map_ = C.Map()
|
||||
self.mul = P.Mul()
|
||||
if degree is None:
|
||||
self.degree = get_group_size()
|
||||
else:
|
||||
|
@ -268,7 +255,8 @@ class DistributedGradReducer(Cell):
|
|||
self.degree = degree
|
||||
self.mean = mean
|
||||
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
|
||||
_init_optimizer_communication()
|
||||
self.opt_list = _init_allreduce_operators(len(parameters))
|
||||
self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP)
|
||||
|
||||
def construct(self, grads):
|
||||
"""
|
||||
|
@ -284,11 +272,8 @@ class DistributedGradReducer(Cell):
|
|||
"""
|
||||
datatypes = self.map_(F.partial(_get_datatype), grads)
|
||||
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
|
||||
|
||||
if self.mean:
|
||||
new_grad = self.map_(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads)
|
||||
else:
|
||||
new_grad = self.map_(F.partial(reduce_opt), self.allreduce_filter, grads)
|
||||
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
|
||||
self.allreduce_filter, grads, self.opt_list)
|
||||
|
||||
new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad)
|
||||
return new_grad
|
||||
|
|
|
@ -513,7 +513,7 @@ def _set_auto_parallel_context(**kwargs):
|
|||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
||||
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
||||
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
||||
enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False.
|
||||
enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
|
|
@ -22,7 +22,6 @@ from mindspore.common.api import _executor
|
|||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR, Lamb
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore import context
|
||||
|
||||
|
||||
|
@ -54,8 +53,7 @@ class Net(nn.Cell):
|
|||
|
||||
def test_AdamWeightDecayDynamicLR():
|
||||
""" test_AdamWeightDecayDynamicLR """
|
||||
auto_parallel_context().set_enable_parallel_optimizer(True)
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2)
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
|
||||
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||
net = Net()
|
||||
|
@ -70,8 +68,7 @@ def test_AdamWeightDecayDynamicLR():
|
|||
|
||||
def test_AdamWeightDecay():
|
||||
""" test_AdamWeightDecayDynamicLR """
|
||||
auto_parallel_context().set_enable_parallel_optimizer(True)
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2)
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
|
||||
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||
net = Net()
|
||||
|
@ -86,8 +83,7 @@ def test_AdamWeightDecay():
|
|||
|
||||
def test_lamb_compile():
|
||||
""" test_Lamb_compile """
|
||||
auto_parallel_context().set_enable_parallel_optimizer(True)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2)
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
|
||||
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||
net = Net()
|
||||
|
@ -102,7 +98,7 @@ def test_lamb_compile():
|
|||
|
||||
def test_edge_case():
|
||||
""" test_edge_case """
|
||||
auto_parallel_context().set_enable_parallel_optimizer(True)
|
||||
context.set_auto_parallel_context(enable_parallel_optimizer=True)
|
||||
net = Net()
|
||||
with pytest.raises(RuntimeError):
|
||||
context.set_auto_parallel_context(parallel_mode="stand_alone")
|
||||
|
|
|
@ -81,8 +81,8 @@ def test_set_auto_parallel_context():
|
|||
with pytest.raises(ValueError):
|
||||
set_algo_parameters(tensor_slice_align_size=1025)
|
||||
|
||||
auto_parallel_context().set_enable_parallel_optimizer(True)
|
||||
assert auto_parallel_context().get_enable_parallel_optimizer() is True
|
||||
context.set_auto_parallel_context(enable_parallel_optimizer=True)
|
||||
assert context.get_auto_parallel_context("enable_parallel_optimizer")
|
||||
assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue