fix optimizer parallel problems

This commit is contained in:
Ziyan 2020-07-24 11:20:50 +08:00
parent 7cb567ebbe
commit 98e2ee90de
9 changed files with 54 additions and 22 deletions

View File

@ -26,13 +26,13 @@ namespace mindspore {
namespace kernel { namespace kernel {
namespace { namespace {
std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) { std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) {
auto parallel_context_instance = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context_instance);
if (parallel_context_instance->enable_parallel_optimizer()) {
return kOpFormat_DEFAULT;
}
const std::set<std::string> kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0}; const std::set<std::string> kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0};
auto op_name = AnfAlgo::GetCNodeName(kernel_node); auto op_name = AnfAlgo::GetCNodeName(kernel_node);
auto parallel_context_instance = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context_instance);
if (parallel_context_instance->enable_parallel_optimizer() && op_name == kBroadcast) {
return kOpFormat_DEFAULT;
}
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index); auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index);
if (op_name != kReduceScatter && op_name != kAllGatherOpName) { if (op_name != kReduceScatter && op_name != kAllGatherOpName) {
return format; return format;

View File

@ -65,6 +65,8 @@ void ParallelContext::Reset() {
strategy_ckpt_load_file_ = ""; strategy_ckpt_load_file_ = "";
strategy_ckpt_save_file_ = ""; strategy_ckpt_save_file_ = "";
enable_parallel_optimizer_ = false; enable_parallel_optimizer_ = false;
all_reduce_fusion_split_indices_.clear();
all_reduce_fusion_split_sizes_.clear();
} }
void ParallelContext::set_device_num(int32_t device_num) { void ParallelContext::set_device_num(int32_t device_num) {

View File

@ -371,5 +371,5 @@ class AdamWeightDecay(Optimizer):
self.parameters, self.moments1, self.moments2, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter) gradients, self.decay_flags, self.optim_filter)
if self.use_parallel: if self.use_parallel:
optim_result = self.broadcast_params(optim_result) self.broadcast_params(optim_result)
return optim_result return optim_result

View File

@ -312,7 +312,7 @@ class Lamb(Optimizer):
self.decay_flags, self.optim_filter) self.decay_flags, self.optim_filter)
if self.use_parallel: if self.use_parallel:
optim_result = self.broadcast_params(optim_result) self.broadcast_params(optim_result)
if not self.dynamic_lr: if not self.dynamic_lr:
F.control_depend(lr, self.assignadd(self.global_step, 1)) F.control_depend(lr, self.assignadd(self.global_step, 1))

View File

@ -466,7 +466,7 @@ class Optimizer(Cell):
param_group.append(F.make_tuple()) param_group.append(F.make_tuple())
key_group.append(F.make_tuple()) key_group.append(F.make_tuple())
for i in range(self.param_length): for i in range(self.param_length):
param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],) param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (self.parameters[i],)
key = P.MakeRefKey(self.param_names[i])() key = P.MakeRefKey(self.param_names[i])()
key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,) key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
new_param_group = [] new_param_group = []
@ -476,9 +476,9 @@ class Optimizer(Cell):
new_param_group.append(next_params) new_param_group.append(next_params)
for i in range(F.tuple_len(next_params)): for i in range(F.tuple_len(next_params)):
F.assign(key_group[root][i], next_params[i]) F.assign(key_group[root][i], next_params[i])
status = True status = F.control_depend(optim_result, new_param_group[0][0])
for i in range(self.dev_num - 1): for i in range(self.dev_num - 1):
status = F.control_depend(new_param_group[i][0], new_param_group[i+1]) status = F.depend(F.control_depend(new_param_group[i], new_param_group[i+1][0]), status)
return status return status

View File

@ -25,7 +25,7 @@ import mindspore.common.dtype as mstype
reduce_opt = C.MultitypeFuncGraph("reduce_opt") reduce_opt = C.MultitypeFuncGraph("reduce_opt")
def _init_allreduce_operators(length): def _init_allreduce_operators(length, split_indices):
""" initialize allreduce communication operators""" """ initialize allreduce communication operators"""
group = 1 group = 1
fusion = () fusion = ()
@ -318,7 +318,7 @@ class DistributedGradReducer(Cell):
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices() split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()
if is_parallel_optimizer and split_indices: if is_parallel_optimizer and split_indices:
self.split_fusion = True self.split_fusion = True
self.op_list = _init_allreduce_operators(len(parameters)) self.op_list = _init_allreduce_operators(len(parameters), split_indices)
else: else:
self.split_fusion = False self.split_fusion = False
self.allreduce = AllReduce().add_prim_attr('fusion', 1) self.allreduce = AllReduce().add_prim_attr('fusion', 1)
@ -344,10 +344,10 @@ class DistributedGradReducer(Cell):
if self.split_fusion: if self.split_fusion:
if self.enable_parameter_server: if self.enable_parameter_server:
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
self.opt_list, self.allreduce_filter, grads, self.ps_parameters) self.op_list, self.allreduce_filter, grads, self.ps_parameters)
else: else:
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather), new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
self.opt_list, self.allreduce_filter, grads) self.op_list, self.allreduce_filter, grads)
else: else:
if self.enable_parameter_server: if self.enable_parameter_server:
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather, new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather,

View File

@ -16,8 +16,6 @@
import ctypes import ctypes
from mindspore import log as logger
_MAX_GROUP_NAME_LEN = 127 _MAX_GROUP_NAME_LEN = 127
_HCCL_LIB = 'libhccl.so' _HCCL_LIB = 'libhccl.so'
@ -25,8 +23,8 @@ _HCCL_LIB = 'libhccl.so'
def _load_lib(): def _load_lib():
try: try:
hccl_lib = ctypes.CDLL(_HCCL_LIB) hccl_lib = ctypes.CDLL(_HCCL_LIB)
except RuntimeError: except Exception:
logger.error('Get hccl lib error') raise RuntimeError('Get hccl lib error')
return hccl_lib return hccl_lib
@ -69,8 +67,9 @@ def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
try: try:
lib_ctype = _load_lib() lib_ctype = _load_lib()
except RuntimeError: except RuntimeError:
logger.error('Load HCCL lib failed') import hccl_test.manage.api as hccl
hccl.set_fusion_strategy_by_idx()
return
if isinstance(group, (str)): if isinstance(group, (str)):
group_len = len(group) group_len = len(group)
if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0): if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0):
@ -126,7 +125,9 @@ def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
try: try:
lib_ctype = _load_lib() lib_ctype = _load_lib()
except RuntimeError: except RuntimeError:
logger.error('Load HCCL lib failed') import hccl_test.manage.api as hccl
hccl.set_fusion_strategy_by_size()
return
if isinstance(group, (str)): if isinstance(group, (str)):
group_len = len(group) group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN or group_len == 0: if group_len > _MAX_GROUP_NAME_LEN or group_len == 0:

View File

@ -86,3 +86,13 @@ def create_group(group, rank_size, rank_ids):
# pylint: disable=unused-argument # pylint: disable=unused-argument
def destroy_group(group): def destroy_group(group):
pass pass
# pylint: disable=unused-argument
def set_fusion_strategy_by_idx():
pass
# pylint: disable=unused-argument
def set_fusion_strategy_by_size():
pass

View File

@ -23,7 +23,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore import context from mindspore import context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
class Net(nn.Cell): class Net(nn.Cell):
"""Net definition""" """Net definition"""
@ -64,6 +64,7 @@ def test_AdamWeightDecay():
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer) train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label) _executor.compile(train_network, inputs, label)
context.reset_auto_parallel_context()
def test_lamb_compile(): def test_lamb_compile():
@ -79,8 +80,25 @@ def test_lamb_compile():
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer) train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label) _executor.compile(train_network, inputs, label)
context.reset_auto_parallel_context()
def test_lamb_split_fusion():
""" test_Lamb_split_fusion """
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([2, 4, 6, 8])
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
context.reset_auto_parallel_context()
def test_edge_case(): def test_edge_case():
""" test_edge_case """ """ test_edge_case """
context.set_auto_parallel_context(enable_parallel_optimizer=True) context.set_auto_parallel_context(enable_parallel_optimizer=True)
@ -93,3 +111,4 @@ def test_edge_case():
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
context.set_auto_parallel_context(device_num=16) context.set_auto_parallel_context(device_num=16)
Lamb(net.trainable_params(), learning_rate=0.1) Lamb(net.trainable_params(), learning_rate=0.1)
context.reset_auto_parallel_context()