fix optimizer parallel problems
This commit is contained in:
parent
7cb567ebbe
commit
98e2ee90de
|
@ -26,13 +26,13 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
namespace {
|
||||
std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) {
|
||||
auto parallel_context_instance = parallel::ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel_context_instance);
|
||||
if (parallel_context_instance->enable_parallel_optimizer()) {
|
||||
return kOpFormat_DEFAULT;
|
||||
}
|
||||
const std::set<std::string> kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0};
|
||||
auto op_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
auto parallel_context_instance = parallel::ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel_context_instance);
|
||||
if (parallel_context_instance->enable_parallel_optimizer() && op_name == kBroadcast) {
|
||||
return kOpFormat_DEFAULT;
|
||||
}
|
||||
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index);
|
||||
if (op_name != kReduceScatter && op_name != kAllGatherOpName) {
|
||||
return format;
|
||||
|
|
|
@ -65,6 +65,8 @@ void ParallelContext::Reset() {
|
|||
strategy_ckpt_load_file_ = "";
|
||||
strategy_ckpt_save_file_ = "";
|
||||
enable_parallel_optimizer_ = false;
|
||||
all_reduce_fusion_split_indices_.clear();
|
||||
all_reduce_fusion_split_sizes_.clear();
|
||||
}
|
||||
|
||||
void ParallelContext::set_device_num(int32_t device_num) {
|
||||
|
|
|
@ -371,5 +371,5 @@ class AdamWeightDecay(Optimizer):
|
|||
self.parameters, self.moments1, self.moments2,
|
||||
gradients, self.decay_flags, self.optim_filter)
|
||||
if self.use_parallel:
|
||||
optim_result = self.broadcast_params(optim_result)
|
||||
self.broadcast_params(optim_result)
|
||||
return optim_result
|
||||
|
|
|
@ -312,7 +312,7 @@ class Lamb(Optimizer):
|
|||
self.decay_flags, self.optim_filter)
|
||||
|
||||
if self.use_parallel:
|
||||
optim_result = self.broadcast_params(optim_result)
|
||||
self.broadcast_params(optim_result)
|
||||
|
||||
if not self.dynamic_lr:
|
||||
F.control_depend(lr, self.assignadd(self.global_step, 1))
|
||||
|
|
|
@ -466,7 +466,7 @@ class Optimizer(Cell):
|
|||
param_group.append(F.make_tuple())
|
||||
key_group.append(F.make_tuple())
|
||||
for i in range(self.param_length):
|
||||
param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],)
|
||||
param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (self.parameters[i],)
|
||||
key = P.MakeRefKey(self.param_names[i])()
|
||||
key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
|
||||
new_param_group = []
|
||||
|
@ -476,9 +476,9 @@ class Optimizer(Cell):
|
|||
new_param_group.append(next_params)
|
||||
for i in range(F.tuple_len(next_params)):
|
||||
F.assign(key_group[root][i], next_params[i])
|
||||
status = True
|
||||
status = F.control_depend(optim_result, new_param_group[0][0])
|
||||
for i in range(self.dev_num - 1):
|
||||
status = F.control_depend(new_param_group[i][0], new_param_group[i+1])
|
||||
status = F.depend(F.control_depend(new_param_group[i], new_param_group[i+1][0]), status)
|
||||
|
||||
return status
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ import mindspore.common.dtype as mstype
|
|||
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
||||
|
||||
|
||||
def _init_allreduce_operators(length):
|
||||
def _init_allreduce_operators(length, split_indices):
|
||||
""" initialize allreduce communication operators"""
|
||||
group = 1
|
||||
fusion = ()
|
||||
|
@ -318,7 +318,7 @@ class DistributedGradReducer(Cell):
|
|||
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()
|
||||
if is_parallel_optimizer and split_indices:
|
||||
self.split_fusion = True
|
||||
self.op_list = _init_allreduce_operators(len(parameters))
|
||||
self.op_list = _init_allreduce_operators(len(parameters), split_indices)
|
||||
else:
|
||||
self.split_fusion = False
|
||||
self.allreduce = AllReduce().add_prim_attr('fusion', 1)
|
||||
|
@ -344,10 +344,10 @@ class DistributedGradReducer(Cell):
|
|||
if self.split_fusion:
|
||||
if self.enable_parameter_server:
|
||||
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
|
||||
self.opt_list, self.allreduce_filter, grads, self.ps_parameters)
|
||||
self.op_list, self.allreduce_filter, grads, self.ps_parameters)
|
||||
else:
|
||||
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
|
||||
self.opt_list, self.allreduce_filter, grads)
|
||||
self.op_list, self.allreduce_filter, grads)
|
||||
else:
|
||||
if self.enable_parameter_server:
|
||||
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather,
|
||||
|
|
|
@ -16,8 +16,6 @@
|
|||
|
||||
import ctypes
|
||||
|
||||
from mindspore import log as logger
|
||||
|
||||
_MAX_GROUP_NAME_LEN = 127
|
||||
_HCCL_LIB = 'libhccl.so'
|
||||
|
||||
|
@ -25,8 +23,8 @@ _HCCL_LIB = 'libhccl.so'
|
|||
def _load_lib():
|
||||
try:
|
||||
hccl_lib = ctypes.CDLL(_HCCL_LIB)
|
||||
except RuntimeError:
|
||||
logger.error('Get hccl lib error')
|
||||
except Exception:
|
||||
raise RuntimeError('Get hccl lib error')
|
||||
|
||||
return hccl_lib
|
||||
|
||||
|
@ -69,8 +67,9 @@ def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
|
|||
try:
|
||||
lib_ctype = _load_lib()
|
||||
except RuntimeError:
|
||||
logger.error('Load HCCL lib failed')
|
||||
|
||||
import hccl_test.manage.api as hccl
|
||||
hccl.set_fusion_strategy_by_idx()
|
||||
return
|
||||
if isinstance(group, (str)):
|
||||
group_len = len(group)
|
||||
if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0):
|
||||
|
@ -126,7 +125,9 @@ def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
|
|||
try:
|
||||
lib_ctype = _load_lib()
|
||||
except RuntimeError:
|
||||
logger.error('Load HCCL lib failed')
|
||||
import hccl_test.manage.api as hccl
|
||||
hccl.set_fusion_strategy_by_size()
|
||||
return
|
||||
if isinstance(group, (str)):
|
||||
group_len = len(group)
|
||||
if group_len > _MAX_GROUP_NAME_LEN or group_len == 0:
|
||||
|
|
|
@ -86,3 +86,13 @@ def create_group(group, rank_size, rank_ids):
|
|||
# pylint: disable=unused-argument
|
||||
def destroy_group(group):
|
||||
pass
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def set_fusion_strategy_by_idx():
|
||||
pass
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def set_fusion_strategy_by_size():
|
||||
pass
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
|
|||
from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import context
|
||||
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""Net definition"""
|
||||
|
@ -64,6 +64,7 @@ def test_AdamWeightDecay():
|
|||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_executor.compile(train_network, inputs, label)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_lamb_compile():
|
||||
|
@ -79,8 +80,25 @@ def test_lamb_compile():
|
|||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_executor.compile(train_network, inputs, label)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_lamb_split_fusion():
|
||||
""" test_Lamb_split_fusion """
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([2, 4, 6, 8])
|
||||
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||
net = Net()
|
||||
net.set_train()
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_executor.compile(train_network, inputs, label)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
def test_edge_case():
|
||||
""" test_edge_case """
|
||||
context.set_auto_parallel_context(enable_parallel_optimizer=True)
|
||||
|
@ -93,3 +111,4 @@ def test_edge_case():
|
|||
with pytest.raises(RuntimeError):
|
||||
context.set_auto_parallel_context(device_num=16)
|
||||
Lamb(net.trainable_params(), learning_rate=0.1)
|
||||
context.reset_auto_parallel_context()
|
||||
|
|
Loading…
Reference in New Issue