forked from mindspore-Ecosystem/mindspore
fix optimizer parallel problems
This commit is contained in:
parent
7cb567ebbe
commit
98e2ee90de
|
@ -26,13 +26,13 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace {
|
namespace {
|
||||||
std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) {
|
std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) {
|
||||||
auto parallel_context_instance = parallel::ParallelContext::GetInstance();
|
|
||||||
MS_EXCEPTION_IF_NULL(parallel_context_instance);
|
|
||||||
if (parallel_context_instance->enable_parallel_optimizer()) {
|
|
||||||
return kOpFormat_DEFAULT;
|
|
||||||
}
|
|
||||||
const std::set<std::string> kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0};
|
const std::set<std::string> kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0};
|
||||||
auto op_name = AnfAlgo::GetCNodeName(kernel_node);
|
auto op_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
|
auto parallel_context_instance = parallel::ParallelContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(parallel_context_instance);
|
||||||
|
if (parallel_context_instance->enable_parallel_optimizer() && op_name == kBroadcast) {
|
||||||
|
return kOpFormat_DEFAULT;
|
||||||
|
}
|
||||||
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index);
|
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index);
|
||||||
if (op_name != kReduceScatter && op_name != kAllGatherOpName) {
|
if (op_name != kReduceScatter && op_name != kAllGatherOpName) {
|
||||||
return format;
|
return format;
|
||||||
|
|
|
@ -65,6 +65,8 @@ void ParallelContext::Reset() {
|
||||||
strategy_ckpt_load_file_ = "";
|
strategy_ckpt_load_file_ = "";
|
||||||
strategy_ckpt_save_file_ = "";
|
strategy_ckpt_save_file_ = "";
|
||||||
enable_parallel_optimizer_ = false;
|
enable_parallel_optimizer_ = false;
|
||||||
|
all_reduce_fusion_split_indices_.clear();
|
||||||
|
all_reduce_fusion_split_sizes_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ParallelContext::set_device_num(int32_t device_num) {
|
void ParallelContext::set_device_num(int32_t device_num) {
|
||||||
|
|
|
@ -371,5 +371,5 @@ class AdamWeightDecay(Optimizer):
|
||||||
self.parameters, self.moments1, self.moments2,
|
self.parameters, self.moments1, self.moments2,
|
||||||
gradients, self.decay_flags, self.optim_filter)
|
gradients, self.decay_flags, self.optim_filter)
|
||||||
if self.use_parallel:
|
if self.use_parallel:
|
||||||
optim_result = self.broadcast_params(optim_result)
|
self.broadcast_params(optim_result)
|
||||||
return optim_result
|
return optim_result
|
||||||
|
|
|
@ -312,7 +312,7 @@ class Lamb(Optimizer):
|
||||||
self.decay_flags, self.optim_filter)
|
self.decay_flags, self.optim_filter)
|
||||||
|
|
||||||
if self.use_parallel:
|
if self.use_parallel:
|
||||||
optim_result = self.broadcast_params(optim_result)
|
self.broadcast_params(optim_result)
|
||||||
|
|
||||||
if not self.dynamic_lr:
|
if not self.dynamic_lr:
|
||||||
F.control_depend(lr, self.assignadd(self.global_step, 1))
|
F.control_depend(lr, self.assignadd(self.global_step, 1))
|
||||||
|
|
|
@ -466,7 +466,7 @@ class Optimizer(Cell):
|
||||||
param_group.append(F.make_tuple())
|
param_group.append(F.make_tuple())
|
||||||
key_group.append(F.make_tuple())
|
key_group.append(F.make_tuple())
|
||||||
for i in range(self.param_length):
|
for i in range(self.param_length):
|
||||||
param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],)
|
param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (self.parameters[i],)
|
||||||
key = P.MakeRefKey(self.param_names[i])()
|
key = P.MakeRefKey(self.param_names[i])()
|
||||||
key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
|
key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
|
||||||
new_param_group = []
|
new_param_group = []
|
||||||
|
@ -476,9 +476,9 @@ class Optimizer(Cell):
|
||||||
new_param_group.append(next_params)
|
new_param_group.append(next_params)
|
||||||
for i in range(F.tuple_len(next_params)):
|
for i in range(F.tuple_len(next_params)):
|
||||||
F.assign(key_group[root][i], next_params[i])
|
F.assign(key_group[root][i], next_params[i])
|
||||||
status = True
|
status = F.control_depend(optim_result, new_param_group[0][0])
|
||||||
for i in range(self.dev_num - 1):
|
for i in range(self.dev_num - 1):
|
||||||
status = F.control_depend(new_param_group[i][0], new_param_group[i+1])
|
status = F.depend(F.control_depend(new_param_group[i], new_param_group[i+1][0]), status)
|
||||||
|
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ import mindspore.common.dtype as mstype
|
||||||
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
||||||
|
|
||||||
|
|
||||||
def _init_allreduce_operators(length):
|
def _init_allreduce_operators(length, split_indices):
|
||||||
""" initialize allreduce communication operators"""
|
""" initialize allreduce communication operators"""
|
||||||
group = 1
|
group = 1
|
||||||
fusion = ()
|
fusion = ()
|
||||||
|
@ -318,7 +318,7 @@ class DistributedGradReducer(Cell):
|
||||||
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()
|
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()
|
||||||
if is_parallel_optimizer and split_indices:
|
if is_parallel_optimizer and split_indices:
|
||||||
self.split_fusion = True
|
self.split_fusion = True
|
||||||
self.op_list = _init_allreduce_operators(len(parameters))
|
self.op_list = _init_allreduce_operators(len(parameters), split_indices)
|
||||||
else:
|
else:
|
||||||
self.split_fusion = False
|
self.split_fusion = False
|
||||||
self.allreduce = AllReduce().add_prim_attr('fusion', 1)
|
self.allreduce = AllReduce().add_prim_attr('fusion', 1)
|
||||||
|
@ -344,10 +344,10 @@ class DistributedGradReducer(Cell):
|
||||||
if self.split_fusion:
|
if self.split_fusion:
|
||||||
if self.enable_parameter_server:
|
if self.enable_parameter_server:
|
||||||
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
|
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
|
||||||
self.opt_list, self.allreduce_filter, grads, self.ps_parameters)
|
self.op_list, self.allreduce_filter, grads, self.ps_parameters)
|
||||||
else:
|
else:
|
||||||
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
|
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
|
||||||
self.opt_list, self.allreduce_filter, grads)
|
self.op_list, self.allreduce_filter, grads)
|
||||||
else:
|
else:
|
||||||
if self.enable_parameter_server:
|
if self.enable_parameter_server:
|
||||||
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather,
|
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather,
|
||||||
|
|
|
@ -16,8 +16,6 @@
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
|
|
||||||
from mindspore import log as logger
|
|
||||||
|
|
||||||
_MAX_GROUP_NAME_LEN = 127
|
_MAX_GROUP_NAME_LEN = 127
|
||||||
_HCCL_LIB = 'libhccl.so'
|
_HCCL_LIB = 'libhccl.so'
|
||||||
|
|
||||||
|
@ -25,8 +23,8 @@ _HCCL_LIB = 'libhccl.so'
|
||||||
def _load_lib():
|
def _load_lib():
|
||||||
try:
|
try:
|
||||||
hccl_lib = ctypes.CDLL(_HCCL_LIB)
|
hccl_lib = ctypes.CDLL(_HCCL_LIB)
|
||||||
except RuntimeError:
|
except Exception:
|
||||||
logger.error('Get hccl lib error')
|
raise RuntimeError('Get hccl lib error')
|
||||||
|
|
||||||
return hccl_lib
|
return hccl_lib
|
||||||
|
|
||||||
|
@ -69,8 +67,9 @@ def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
|
||||||
try:
|
try:
|
||||||
lib_ctype = _load_lib()
|
lib_ctype = _load_lib()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
logger.error('Load HCCL lib failed')
|
import hccl_test.manage.api as hccl
|
||||||
|
hccl.set_fusion_strategy_by_idx()
|
||||||
|
return
|
||||||
if isinstance(group, (str)):
|
if isinstance(group, (str)):
|
||||||
group_len = len(group)
|
group_len = len(group)
|
||||||
if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0):
|
if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0):
|
||||||
|
@ -126,7 +125,9 @@ def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
|
||||||
try:
|
try:
|
||||||
lib_ctype = _load_lib()
|
lib_ctype = _load_lib()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
logger.error('Load HCCL lib failed')
|
import hccl_test.manage.api as hccl
|
||||||
|
hccl.set_fusion_strategy_by_size()
|
||||||
|
return
|
||||||
if isinstance(group, (str)):
|
if isinstance(group, (str)):
|
||||||
group_len = len(group)
|
group_len = len(group)
|
||||||
if group_len > _MAX_GROUP_NAME_LEN or group_len == 0:
|
if group_len > _MAX_GROUP_NAME_LEN or group_len == 0:
|
||||||
|
|
|
@ -86,3 +86,13 @@ def create_group(group, rank_size, rank_ids):
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def destroy_group(group):
|
def destroy_group(group):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def set_fusion_strategy_by_idx():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def set_fusion_strategy_by_size():
|
||||||
|
pass
|
||||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||||
from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb
|
from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
"""Net definition"""
|
"""Net definition"""
|
||||||
|
@ -64,6 +64,7 @@ def test_AdamWeightDecay():
|
||||||
net_with_loss = WithLossCell(net, loss)
|
net_with_loss = WithLossCell(net, loss)
|
||||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||||
_executor.compile(train_network, inputs, label)
|
_executor.compile(train_network, inputs, label)
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
|
||||||
|
|
||||||
def test_lamb_compile():
|
def test_lamb_compile():
|
||||||
|
@ -79,8 +80,25 @@ def test_lamb_compile():
|
||||||
net_with_loss = WithLossCell(net, loss)
|
net_with_loss = WithLossCell(net, loss)
|
||||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||||
_executor.compile(train_network, inputs, label)
|
_executor.compile(train_network, inputs, label)
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
|
||||||
|
|
||||||
|
def test_lamb_split_fusion():
|
||||||
|
""" test_Lamb_split_fusion """
|
||||||
|
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
|
||||||
|
auto_parallel_context().set_all_reduce_fusion_split_indices([2, 4, 6, 8])
|
||||||
|
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
||||||
|
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||||
|
net = Net()
|
||||||
|
net.set_train()
|
||||||
|
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
optimizer = Lamb(net.trainable_params(), learning_rate=0.1)
|
||||||
|
|
||||||
|
net_with_loss = WithLossCell(net, loss)
|
||||||
|
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||||
|
_executor.compile(train_network, inputs, label)
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
|
||||||
def test_edge_case():
|
def test_edge_case():
|
||||||
""" test_edge_case """
|
""" test_edge_case """
|
||||||
context.set_auto_parallel_context(enable_parallel_optimizer=True)
|
context.set_auto_parallel_context(enable_parallel_optimizer=True)
|
||||||
|
@ -93,3 +111,4 @@ def test_edge_case():
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
context.set_auto_parallel_context(device_num=16)
|
context.set_auto_parallel_context(device_num=16)
|
||||||
Lamb(net.trainable_params(), learning_rate=0.1)
|
Lamb(net.trainable_params(), learning_rate=0.1)
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
|
Loading…
Reference in New Issue