forked from mindspore-Ecosystem/mindspore
enable optimizer parallel
This commit is contained in:
parent
7976d77593
commit
39f08eb7dd
|
@ -100,7 +100,10 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
|
||||||
|
|
||||||
auto parallel_context = parallel::ParallelContext::GetInstance();
|
auto parallel_context = parallel::ParallelContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(parallel_context);
|
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||||
const auto &split_indices = parallel_context->GetAllReduceFusionSplitIndices(group);
|
std::vector<uint32_t> split_indices;
|
||||||
|
if (!parallel_context->enable_parallel_optimizer()) {
|
||||||
|
split_indices = parallel_context->GetAllReduceFusionSplitIndices(group);
|
||||||
|
}
|
||||||
|
|
||||||
size_t segments = 0;
|
size_t segments = 0;
|
||||||
if (split_indices.size() != 0) {
|
if (split_indices.size() != 0) {
|
||||||
|
|
|
@ -443,7 +443,7 @@ def _context():
|
||||||
|
|
||||||
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str,
|
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str,
|
||||||
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||||
strategy_ckpt_save_file=str, full_batch=bool)
|
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
|
||||||
def set_auto_parallel_context(**kwargs):
|
def set_auto_parallel_context(**kwargs):
|
||||||
"""
|
"""
|
||||||
Set auto parallel context.
|
Set auto parallel context.
|
||||||
|
@ -487,6 +487,9 @@ def set_auto_parallel_context(**kwargs):
|
||||||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
||||||
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
||||||
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
||||||
|
enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in
|
||||||
|
data parallel training in the benefit of time and memory saving.
|
||||||
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If input key is not attribute in auto parallel context.
|
ValueError: If input key is not attribute in auto parallel context.
|
||||||
|
@ -532,6 +535,7 @@ def reset_auto_parallel_context():
|
||||||
- parameter_broadcast: False.
|
- parameter_broadcast: False.
|
||||||
- strategy_ckpt_load_file: "".
|
- strategy_ckpt_load_file: "".
|
||||||
- strategy_ckpt_save_file: "".
|
- strategy_ckpt_save_file: "".
|
||||||
|
- enable_parallel_optimizer: False.
|
||||||
"""
|
"""
|
||||||
_reset_auto_parallel_context()
|
_reset_auto_parallel_context()
|
||||||
|
|
||||||
|
|
|
@ -28,8 +28,8 @@ from mindspore._checkparam import Validator as validator
|
||||||
from mindspore._checkparam import Rel
|
from mindspore._checkparam import Rel
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
|
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
|
||||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|
||||||
from mindspore.train.parallel_utils import ParallelMode
|
from mindspore.train.parallel_utils import ParallelMode
|
||||||
|
from mindspore import context
|
||||||
|
|
||||||
__all__ = ['Optimizer']
|
__all__ = ['Optimizer']
|
||||||
|
|
||||||
|
@ -157,13 +157,12 @@ class Optimizer(Cell):
|
||||||
self.param_length = len(self.parameters)
|
self.param_length = len(self.parameters)
|
||||||
self.map_ = C.Map()
|
self.map_ = C.Map()
|
||||||
|
|
||||||
use_parallel = auto_parallel_context().get_enable_parallel_optimizer()
|
use_parallel = context.get_auto_parallel_context("enable_parallel_optimizer")
|
||||||
self.use_parallel = use_parallel
|
self.use_parallel = use_parallel
|
||||||
if use_parallel:
|
if use_parallel:
|
||||||
if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]:
|
if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]:
|
||||||
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
|
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
|
||||||
if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL,
|
if _get_parallel_mode() != ParallelMode.DATA_PARALLEL:
|
||||||
ParallelMode.AUTO_PARALLEL]:
|
|
||||||
raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format
|
raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format
|
||||||
(_get_parallel_mode()))
|
(_get_parallel_mode()))
|
||||||
self.dev_num = _get_device_num()
|
self.dev_num = _get_device_num()
|
||||||
|
@ -175,6 +174,7 @@ class Optimizer(Cell):
|
||||||
self.param_names = []
|
self.param_names = []
|
||||||
for param in self.parameters:
|
for param in self.parameters:
|
||||||
self.param_names.append(param.name)
|
self.param_names.append(param.name)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.optim_filter = (True,) * self.param_length
|
self.optim_filter = (True,) * self.param_length
|
||||||
|
|
||||||
|
|
|
@ -13,107 +13,95 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""grad reducer cell for distributed training"""
|
"""grad reducer cell for distributed training"""
|
||||||
|
from mindspore import context
|
||||||
from mindspore.nn.cell import Cell
|
from mindspore.nn.cell import Cell
|
||||||
from mindspore.communication.management import GlobalComm, get_group_size
|
from mindspore.communication.management import GlobalComm, get_group_size
|
||||||
from mindspore.ops import functional as F, composite as C, operations as P
|
from mindspore.ops import functional as F, composite as C, operations as P
|
||||||
from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp, AllGather
|
from mindspore.ops.operations.comm_ops import AllReduce, AllGather
|
||||||
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
|
|
||||||
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
||||||
|
|
||||||
_all_reduce = AllReduce()
|
|
||||||
_all_gather = None
|
def _init_allreduce_operators(length):
|
||||||
|
""" initialize allreduce communication operators"""
|
||||||
|
is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer")
|
||||||
|
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()
|
||||||
|
if is_parallel_optimizer and split_indices:
|
||||||
|
group = 1
|
||||||
|
fusion = ()
|
||||||
|
for i in range(length):
|
||||||
|
fusion = fusion + (group,)
|
||||||
|
if split_indices[group - 1] <= i + 1:
|
||||||
|
if group >= len(split_indices):
|
||||||
|
continue
|
||||||
|
group = group + 1
|
||||||
|
index = tuple(range(1, length + 1))
|
||||||
|
else:
|
||||||
|
fusion = (1,) * length
|
||||||
|
index = (0,) * length
|
||||||
|
opt_list = ()
|
||||||
|
for i in range(length):
|
||||||
|
opt = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP)
|
||||||
|
opt.add_prim_attr('fusion', fusion[i])
|
||||||
|
opt.add_prim_attr('index', index[i])
|
||||||
|
opt_list = opt_list + (opt,)
|
||||||
|
return opt_list
|
||||||
|
|
||||||
|
|
||||||
def _init_optimizer_communication():
|
@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tensor", "Function")
|
||||||
global _all_reduce
|
def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduce):
|
||||||
global _all_gather
|
|
||||||
|
|
||||||
_all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
|
|
||||||
_all_reduce.add_prim_attr('fusion', 1)
|
|
||||||
_all_gather = AllGather(GlobalComm.WORLD_COMM_GROUP)
|
|
||||||
|
|
||||||
|
|
||||||
@reduce_opt.register("Function", "Number", "Bool", "Tensor")
|
|
||||||
def _tensors_allreduce_mean(mul, degree, allreduce_filter, grad):
|
|
||||||
"""
|
|
||||||
Apply mean and allreduce on gradient. Allreduce is a communication operation used for distributed deep learning.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mul (Primitive): Div operation.
|
|
||||||
degree (int): The mean coefficient.
|
|
||||||
allreduce_filter (bool): When it is true, allreduce would apply.
|
|
||||||
grad (Tensor): The gradient tensor before operation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor, the gradient tensor after operation.
|
|
||||||
"""
|
|
||||||
if allreduce_filter:
|
|
||||||
degree = F.scalar_cast(degree, F.dtype(grad))
|
|
||||||
grad = _all_reduce(grad)
|
|
||||||
cast_op = P.Cast()
|
|
||||||
return mul(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad)))
|
|
||||||
return grad
|
|
||||||
|
|
||||||
|
|
||||||
@reduce_opt.register("Function", "Number", "Bool", "Tuple")
|
|
||||||
def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad):
|
|
||||||
"""
|
|
||||||
Apply mean and allgather on gradient instead of allreduce for sparse feature.
|
|
||||||
Allgather is a communication operation used for distributed deep learning.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mul (Primitive): Div operation.
|
|
||||||
degree (int): The mean coefficient.
|
|
||||||
allreduce_filter (bool): When it is true, allgather would apply.
|
|
||||||
grad (Tuple): The indices, gradient tensor and tensor_shape before operation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple, include indices, the gradient tensor and tensor_shape after operation.
|
|
||||||
"""
|
|
||||||
if allreduce_filter:
|
|
||||||
indices = _all_gather(grad[0])
|
|
||||||
degree = F.scalar_cast(degree, F.dtype(grad[1]))
|
|
||||||
dout = _all_gather(grad[1])
|
|
||||||
cast_op = P.Cast()
|
|
||||||
dout = mul(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
|
|
||||||
grad = (indices, dout, grad[2])
|
|
||||||
return grad
|
|
||||||
|
|
||||||
|
|
||||||
@reduce_opt.register("Bool", "Tensor")
|
|
||||||
def _tensors_allreduce(allreduce_filter, grad):
|
|
||||||
"""
|
"""
|
||||||
Apply allreduce on gradient.
|
Apply allreduce on gradient.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
degree (int): The mean coefficient.
|
||||||
|
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
|
||||||
|
allgather (Primitive): The communication operator for sparse gradients.
|
||||||
allreduce_filter (bool): When it is true, allreduce would apply.
|
allreduce_filter (bool): When it is true, allreduce would apply.
|
||||||
grad (Tensor): The gradient tensor before operation.
|
grad (Tensor): The gradient tensor before operation.
|
||||||
|
allreduce (Primitive): The communication operator for gradients.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, the gradient tensor after operation.
|
Tensor, the gradient tensor after operation.
|
||||||
"""
|
"""
|
||||||
if allreduce_filter:
|
if allreduce_filter:
|
||||||
return _all_reduce(grad)
|
grad = allreduce(grad)
|
||||||
|
if mean:
|
||||||
|
degree = F.scalar_cast(degree, F.dtype(grad))
|
||||||
|
cast_op = P.Cast()
|
||||||
|
mul_op = P.Mul()
|
||||||
|
grad = mul_op(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad)))
|
||||||
|
return grad
|
||||||
return grad
|
return grad
|
||||||
|
|
||||||
|
|
||||||
@reduce_opt.register("Bool", "Tuple")
|
@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tuple", "Function")
|
||||||
def _tensors_allreduce_with_sparse(allreduce_filter, grad):
|
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, grad, allreduce):
|
||||||
"""
|
"""
|
||||||
Apply mean and allgather on gradient instead of allreduce for sparse feature.
|
Apply allgather on gradient instead of allreduce for sparse feature.
|
||||||
Allgather is a communication operation used for distributed deep learning.
|
Allgather is a communication operation used for distributed deep learning.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
degree (int): The mean coefficient.
|
||||||
|
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
|
||||||
|
allgather (Primitive): The communication operator for sparse gradients.
|
||||||
allreduce_filter (bool): When it is true, allgather would apply.
|
allreduce_filter (bool): When it is true, allgather would apply.
|
||||||
grad (Tuple): The indices, gradient tensor and tensor_shape before operation.
|
grad (tuple): The indices, gradient tensor and tensor_shape before operation.
|
||||||
|
allreduce (Primitive): The communication operator for gradients.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple, include indices, the gradient tensor and tensor_shape after operation.
|
Tuple, include indices, the gradient tensor and tensor_shape after operation.
|
||||||
"""
|
"""
|
||||||
if allreduce_filter:
|
if allreduce_filter:
|
||||||
indices = _all_gather(grad[0])
|
indices = allgather(grad[0])
|
||||||
dout = _all_gather(grad[1])
|
dout = allgather(grad[1])
|
||||||
|
if mean:
|
||||||
|
degree = F.scalar_cast(degree, F.dtype(grad[1]))
|
||||||
|
cast_op = P.Cast()
|
||||||
|
mul_op = P.Mul()
|
||||||
|
dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
|
||||||
grad = (indices, dout, grad[2])
|
grad = (indices, dout, grad[2])
|
||||||
return grad
|
return grad
|
||||||
|
|
||||||
|
@ -259,7 +247,6 @@ class DistributedGradReducer(Cell):
|
||||||
def __init__(self, parameters, mean=True, degree=None):
|
def __init__(self, parameters, mean=True, degree=None):
|
||||||
super(DistributedGradReducer, self).__init__(auto_prefix=False)
|
super(DistributedGradReducer, self).__init__(auto_prefix=False)
|
||||||
self.map_ = C.Map()
|
self.map_ = C.Map()
|
||||||
self.mul = P.Mul()
|
|
||||||
if degree is None:
|
if degree is None:
|
||||||
self.degree = get_group_size()
|
self.degree = get_group_size()
|
||||||
else:
|
else:
|
||||||
|
@ -268,7 +255,8 @@ class DistributedGradReducer(Cell):
|
||||||
self.degree = degree
|
self.degree = degree
|
||||||
self.mean = mean
|
self.mean = mean
|
||||||
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
|
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
|
||||||
_init_optimizer_communication()
|
self.opt_list = _init_allreduce_operators(len(parameters))
|
||||||
|
self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP)
|
||||||
|
|
||||||
def construct(self, grads):
|
def construct(self, grads):
|
||||||
"""
|
"""
|
||||||
|
@ -284,11 +272,8 @@ class DistributedGradReducer(Cell):
|
||||||
"""
|
"""
|
||||||
datatypes = self.map_(F.partial(_get_datatype), grads)
|
datatypes = self.map_(F.partial(_get_datatype), grads)
|
||||||
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
|
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
|
||||||
|
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
|
||||||
if self.mean:
|
self.allreduce_filter, grads, self.opt_list)
|
||||||
new_grad = self.map_(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads)
|
|
||||||
else:
|
|
||||||
new_grad = self.map_(F.partial(reduce_opt), self.allreduce_filter, grads)
|
|
||||||
|
|
||||||
new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad)
|
new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad)
|
||||||
return new_grad
|
return new_grad
|
||||||
|
|
|
@ -513,7 +513,7 @@ def _set_auto_parallel_context(**kwargs):
|
||||||
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
|
||||||
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
|
||||||
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
full_batch (bool): Whether to load the whole batch on each device. Default: False.
|
||||||
enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False.
|
enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If input key is not attribute in auto parallel context.
|
ValueError: If input key is not attribute in auto parallel context.
|
||||||
|
|
|
@ -22,7 +22,6 @@ from mindspore.common.api import _executor
|
||||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||||
from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR, Lamb
|
from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR, Lamb
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,8 +53,7 @@ class Net(nn.Cell):
|
||||||
|
|
||||||
def test_AdamWeightDecayDynamicLR():
|
def test_AdamWeightDecayDynamicLR():
|
||||||
""" test_AdamWeightDecayDynamicLR """
|
""" test_AdamWeightDecayDynamicLR """
|
||||||
auto_parallel_context().set_enable_parallel_optimizer(True)
|
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
|
||||||
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2)
|
|
||||||
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||||
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||||
net = Net()
|
net = Net()
|
||||||
|
@ -70,8 +68,7 @@ def test_AdamWeightDecayDynamicLR():
|
||||||
|
|
||||||
def test_AdamWeightDecay():
|
def test_AdamWeightDecay():
|
||||||
""" test_AdamWeightDecayDynamicLR """
|
""" test_AdamWeightDecayDynamicLR """
|
||||||
auto_parallel_context().set_enable_parallel_optimizer(True)
|
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
|
||||||
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2)
|
|
||||||
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||||
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||||
net = Net()
|
net = Net()
|
||||||
|
@ -86,8 +83,7 @@ def test_AdamWeightDecay():
|
||||||
|
|
||||||
def test_lamb_compile():
|
def test_lamb_compile():
|
||||||
""" test_Lamb_compile """
|
""" test_Lamb_compile """
|
||||||
auto_parallel_context().set_enable_parallel_optimizer(True)
|
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
|
||||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2)
|
|
||||||
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||||
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||||
net = Net()
|
net = Net()
|
||||||
|
@ -102,7 +98,7 @@ def test_lamb_compile():
|
||||||
|
|
||||||
def test_edge_case():
|
def test_edge_case():
|
||||||
""" test_edge_case """
|
""" test_edge_case """
|
||||||
auto_parallel_context().set_enable_parallel_optimizer(True)
|
context.set_auto_parallel_context(enable_parallel_optimizer=True)
|
||||||
net = Net()
|
net = Net()
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
context.set_auto_parallel_context(parallel_mode="stand_alone")
|
context.set_auto_parallel_context(parallel_mode="stand_alone")
|
||||||
|
|
|
@ -81,8 +81,8 @@ def test_set_auto_parallel_context():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
set_algo_parameters(tensor_slice_align_size=1025)
|
set_algo_parameters(tensor_slice_align_size=1025)
|
||||||
|
|
||||||
auto_parallel_context().set_enable_parallel_optimizer(True)
|
context.set_auto_parallel_context(enable_parallel_optimizer=True)
|
||||||
assert auto_parallel_context().get_enable_parallel_optimizer() is True
|
assert context.get_auto_parallel_context("enable_parallel_optimizer")
|
||||||
assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
|
assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue