forked from mindspore-Ecosystem/mindspore
!24541 fix the case that the parameters is inseparable and correct the targets of allreduce for adasum
Merge pull request !24541 from jinjiali-kali/adasum_1008
This commit is contained in:
commit
92ae56d043
|
@ -22,6 +22,7 @@ from mindspore.ops import composite as C
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations._inner_ops import Send, Receive
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
__all__ = ["AdaSum"]
|
||||
|
@ -52,39 +53,64 @@ def _receive_before_send(send_part, send, recv):
|
|||
return F.depend(receive_ok, send(send_part))
|
||||
|
||||
|
||||
def _send_recv_res(recv_part, local_part, allreduce):
|
||||
def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisibility, allreduce_node_num):
|
||||
"""send result and receive result."""
|
||||
recv_part = P.Squeeze()(recv_part)
|
||||
local_part = F.depend(local_part, recv_part)
|
||||
eps = 1e-12
|
||||
value_0 = P.ReduceSum()(local_part * recv_part) + eps
|
||||
value_1 = P.ReduceSum()(local_part * local_part) + eps
|
||||
value_2 = P.ReduceSum()(recv_part * recv_part) + eps
|
||||
value_0 = allreduce(value_0)
|
||||
value_1 = allreduce(value_1)
|
||||
value_2 = allreduce(value_2)
|
||||
res = (1 - (value_0 / (2 * value_1))) * local_part + (1 - (value_0 / (2 * value_2))) * recv_part
|
||||
if parameter_divisibility:
|
||||
recv_part = P.Squeeze()(recv_part)
|
||||
if F.shape(recv_part) is None:
|
||||
recv_part = Tensor([recv_part])
|
||||
local_part = F.depend(local_part, recv_part)
|
||||
eps = 1e-12
|
||||
value_0 = P.ReduceSum()(local_part * recv_part) + eps
|
||||
if left_send:
|
||||
value_1 = P.ReduceSum()(local_part * local_part) + eps
|
||||
value_2 = P.ReduceSum()(recv_part * recv_part) + eps
|
||||
else:
|
||||
value_1 = P.ReduceSum()(recv_part * recv_part) + eps
|
||||
value_2 = P.ReduceSum()(local_part * local_part) + eps
|
||||
value_0 = allreduce(value_0)
|
||||
value_1 = F.depend(allreduce(value_1), value_0)
|
||||
value_2 = F.depend(allreduce(value_2), value_1)
|
||||
if left_send:
|
||||
res = (1 - (value_0 / (2 * value_1))) * local_part + (1 - (value_0 / (2 * value_2))) * recv_part
|
||||
else:
|
||||
res = (1 - (value_0 / (2 * value_1))) * recv_part + (1 - (value_0 / (2 * value_2))) * local_part
|
||||
else:
|
||||
res = allreduce(local_part)
|
||||
res /= allreduce_node_num
|
||||
return res
|
||||
|
||||
|
||||
_adasum_opt_forward = C.MultitypeFuncGraph("adasum_opt_forward")
|
||||
|
||||
|
||||
@_adasum_opt_forward.register("Bool", "Function", "Function", "Function", "Tensor")
|
||||
def _adasum_opt_forward_process(left_send, allreduce, send, recv, delta_w):
|
||||
@_adasum_opt_forward.register("Bool", "Function", "Bool", "Int64", "Function", "Function", "Tensor")
|
||||
def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, allreduce_node_num, send, recv, delta_w):
|
||||
"""adasum optimizer process."""
|
||||
delta_w = P.Squeeze()(delta_w)
|
||||
ori_len = F.shape(delta_w)[0]
|
||||
divide_len = ori_len / 2
|
||||
left_part = delta_w[:divide_len]
|
||||
right_part = delta_w[divide_len:]
|
||||
if parameter_divisibility:
|
||||
delta_w = P.Squeeze()(delta_w)
|
||||
ori_len = F.shape(delta_w)[0]
|
||||
divide_len = ori_len / 2
|
||||
left_part = delta_w[:divide_len]
|
||||
right_part = delta_w[divide_len:]
|
||||
else:
|
||||
left_part = delta_w
|
||||
right_part = delta_w
|
||||
|
||||
if left_send:
|
||||
recv_part = _send_before_receive(left_part, send, recv)
|
||||
update_delta_w = _send_recv_res(recv_part, right_part, allreduce)
|
||||
if parameter_divisibility:
|
||||
recv_part = _send_before_receive(left_part, send, recv)
|
||||
else:
|
||||
recv_part = right_part
|
||||
update_delta_w = _send_recv_res(left_send, recv_part, right_part, allreduce, parameter_divisibility,
|
||||
allreduce_node_num)
|
||||
else:
|
||||
recv_part = _receive_before_send(right_part, send, recv)
|
||||
update_delta_w = _send_recv_res(recv_part, left_part, allreduce)
|
||||
if parameter_divisibility:
|
||||
recv_part = _receive_before_send(right_part, send, recv)
|
||||
else:
|
||||
recv_part = left_part
|
||||
update_delta_w = _send_recv_res(left_send, recv_part, left_part, allreduce, parameter_divisibility,
|
||||
allreduce_node_num)
|
||||
|
||||
return update_delta_w
|
||||
|
||||
|
@ -92,18 +118,29 @@ def _adasum_opt_forward_process(left_send, allreduce, send, recv, delta_w):
|
|||
_adasum_opt_rollback = C.MultitypeFuncGraph("adasum_opt_rollback")
|
||||
|
||||
|
||||
@_adasum_opt_rollback.register("Bool", "Tensor", "Function", "Function")
|
||||
def _adasum_opt_rollback_process(left_send, delta_w, send, recv):
|
||||
@_adasum_opt_rollback.register("Bool", "Bool", "Tensor", "Function", "Function")
|
||||
def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, send, recv):
|
||||
"""adasum optimizer rollback process."""
|
||||
if left_send:
|
||||
recv_part = _send_before_receive(delta_w, send, recv)
|
||||
if parameter_divisibility:
|
||||
if left_send:
|
||||
recv_part = _send_before_receive(delta_w, send, recv)
|
||||
else:
|
||||
recv_part = _receive_before_send(delta_w, send, recv)
|
||||
|
||||
recv_part = P.Squeeze()(recv_part)
|
||||
if F.shape(recv_part) is None:
|
||||
recv_part = Tensor([recv_part])
|
||||
if F.shape(delta_w) is None:
|
||||
delta_w = Tensor([delta_w])
|
||||
recv_part = P.Reshape()(recv_part, (-1,))
|
||||
delta_w = P.Reshape()(delta_w, (-1,))
|
||||
|
||||
if left_send:
|
||||
res = P.Concat()((recv_part, delta_w))
|
||||
else:
|
||||
res = P.Concat()((delta_w, recv_part))
|
||||
else:
|
||||
recv_part = _receive_before_send(delta_w, send, recv)
|
||||
recv_part = P.Squeeze()(recv_part)
|
||||
if left_send:
|
||||
res = P.Concat()((recv_part, delta_w))
|
||||
else:
|
||||
res = P.Concat()((delta_w, recv_part))
|
||||
res = delta_w
|
||||
return res
|
||||
|
||||
|
||||
|
@ -144,6 +181,8 @@ class AdaSum(Cell):
|
|||
self.recv_list_rollback = []
|
||||
self.allreduce_list = []
|
||||
self.broadcast_list = []
|
||||
self.parameter_divisibility_list = []
|
||||
self.allreduce_node_num_list = []
|
||||
last_delta_weights = []
|
||||
group_start_rank = (self.rank // self.device_number) * self.device_number
|
||||
|
||||
|
@ -152,21 +191,31 @@ class AdaSum(Cell):
|
|||
sr_target = self.rank
|
||||
if (sr_target // current_group) % 2 == 0:
|
||||
dest_target = sr_target + current_group
|
||||
group_name = "adasum_" + str(step) + "_" + str(sr_target)
|
||||
create_group(group_name, [sr_target, dest_target])
|
||||
self.send_node.append(True)
|
||||
else:
|
||||
dest_target = sr_target - current_group
|
||||
group_name = "adasum_" + str(step) + "_" + str(dest_target)
|
||||
create_group(group_name, [dest_target, sr_target])
|
||||
self.send_node.append(False)
|
||||
|
||||
neighbor_ids = []
|
||||
group_name_last = 0
|
||||
for index in range(2 ** (step + 1)):
|
||||
node_rank = self.rank // self.device_number
|
||||
double_d = 2 ** (step + 1)
|
||||
neighbor_id = (node_rank // double_d * double_d + index) * self.device_number + \
|
||||
self.rank % self.device_number
|
||||
neighbor_ids.append(neighbor_id)
|
||||
group_name_last += neighbor_id
|
||||
group_name = "adasum_" + str(step) + "_" + str(group_name_last)
|
||||
create_group(group_name, neighbor_ids)
|
||||
|
||||
send_left = []
|
||||
send_right = []
|
||||
recv_left = []
|
||||
recv_right = []
|
||||
left_delta_weights, right_delta_weights = \
|
||||
allreduce_node_num = ()
|
||||
left_delta_weights, right_delta_weights, delta_weights_divisibility = \
|
||||
self._get_delta_weights_info(last_delta_weights)
|
||||
self.parameter_divisibility_list.append(delta_weights_divisibility)
|
||||
weights_index = 0
|
||||
fusion_id = (step + 1) * 3
|
||||
for shape, dtype in left_delta_weights:
|
||||
|
@ -209,6 +258,13 @@ class AdaSum(Cell):
|
|||
server_all_reduce.add_prim_attr("fusion", fusion_id + 2)
|
||||
self.allreduce_list.append(server_all_reduce)
|
||||
|
||||
for param_divisibility in delta_weights_divisibility:
|
||||
if param_divisibility:
|
||||
allreduce_node_num += (0,)
|
||||
else:
|
||||
allreduce_node_num += (2 ** (step + 1),)
|
||||
self.allreduce_node_num_list.append(allreduce_node_num)
|
||||
|
||||
broadcast_group = [x for x in range(group_start_rank, group_start_rank + self.device_number)]
|
||||
broadcast_group_name = "broadcast_group_" + str(group_start_rank)
|
||||
create_group(broadcast_group_name, broadcast_group)
|
||||
|
@ -227,17 +283,21 @@ class AdaSum(Cell):
|
|||
half_delta_weights.append((new_shape, parameter.dtype))
|
||||
left_delta_weights = []
|
||||
right_delta_weights = []
|
||||
delta_weights_divisibility = ()
|
||||
for shape, dtype in half_delta_weights:
|
||||
left_shape = copy.deepcopy(shape)
|
||||
right_shape = copy.deepcopy(shape)
|
||||
divisibility_flag = False
|
||||
for i in range(len(shape)):
|
||||
if shape[i] > 1:
|
||||
left_shape[i] = int(shape[i] // 2)
|
||||
right_shape[i] = shape[i] - int(shape[i] // 2)
|
||||
divisibility_flag = True
|
||||
break
|
||||
left_delta_weights.append((left_shape, dtype))
|
||||
right_delta_weights.append((right_shape, dtype))
|
||||
return left_delta_weights, right_delta_weights
|
||||
delta_weights_divisibility += (divisibility_flag,)
|
||||
return left_delta_weights, right_delta_weights, delta_weights_divisibility
|
||||
|
||||
def _hash(self, step, target, weights_index):
|
||||
target = "tag" + str(step) + str(target) + str(weights_index)
|
||||
|
@ -248,15 +308,16 @@ class AdaSum(Cell):
|
|||
def construct(self, delta_weights, parameters, old_parameters):
|
||||
forward_weights = [delta_weights]
|
||||
for i in range(self.calc_times):
|
||||
process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i], self.allreduce_list[i]),\
|
||||
process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i], self.allreduce_list[i]),
|
||||
self.parameter_divisibility_list[i], self.allreduce_node_num_list[i],
|
||||
self.send_list_forward[i], self.recv_list_forward[i], forward_weights[-1])
|
||||
forward_weights.append(process_weights)
|
||||
for i in range(self.calc_times):
|
||||
j = self.calc_times - i - 1
|
||||
process_weights = self.hyper_map(F.partial(_adasum_opt_rollback, self.send_node[i]), forward_weights[j + 1],
|
||||
self.send_list_rollback[i], self.recv_list_rollback[i])
|
||||
process_weights = self.hyper_map(F.partial(_adasum_opt_rollback, self.send_node[j]),
|
||||
self.parameter_divisibility_list[j], forward_weights[j + 1],
|
||||
self.send_list_rollback[j], self.recv_list_rollback[j])
|
||||
forward_weights[j] = process_weights
|
||||
adasum_parameters = self.hyper_map(F.partial(_update_parameters), delta_weights, forward_weights[0],\
|
||||
adasum_parameters = self.hyper_map(F.partial(_update_parameters), delta_weights, forward_weights[0],
|
||||
parameters, old_parameters)
|
||||
return adasum_parameters
|
||||
|
|
@ -180,7 +180,7 @@ class BoostTrainOneStepCell(TrainOneStepCell):
|
|||
self.end = [(x + 1) * parameter_rank_number for x in range(_device_number)]
|
||||
self.end[-1] = len(self.weights)
|
||||
|
||||
current_weights = self.weights[self.start[self.server_rank] : self.end[self.server_rank]]
|
||||
current_weights = self.weights[self.start[self.server_rank]: self.end[self.server_rank]]
|
||||
self.grad_clone = ParameterTuple(current_weights).clone(prefix="delta_weight")
|
||||
self.adasum = AdaSum(_rank, _device_number, group_number, self.grad_clone)
|
||||
|
||||
|
@ -191,7 +191,6 @@ class BoostTrainOneStepCell(TrainOneStepCell):
|
|||
create_group(server_group_name, group_list[current_index])
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree, group=server_group_name)
|
||||
|
||||
|
||||
def construct(self, *inputs):
|
||||
if self.freeze:
|
||||
loss = self.gradient_freeze_process(*inputs)
|
||||
|
@ -245,14 +244,14 @@ class BoostTrainOneStepCell(TrainOneStepCell):
|
|||
def adasum_process(self, loss, grads):
|
||||
"""adasum algorithm process."""
|
||||
loss = F.depend(loss, self.optimizer(grads))
|
||||
rank_weights = self.weights[self.start[self.server_rank] : self.end[self.server_rank]]
|
||||
rank_weights = self.weights[self.start[self.server_rank]: self.end[self.server_rank]]
|
||||
grad_clone = F.depend(self.grad_clone, loss)
|
||||
delta_w = self.hyper_map(F.partial(_get_delta_weight), rank_weights, grad_clone)
|
||||
adasum_res = self.adasum(delta_w, rank_weights, grad_clone)
|
||||
sync_tensor = F.depend(self.sync_tensor, adasum_res)
|
||||
sync_flag = self.adasum.sync_barrier(sync_tensor)
|
||||
for i in range(self.device_number):
|
||||
weight_tuple = self.weights[self.start[i] : self.end[i]]
|
||||
weight_tuple = self.weights[self.start[i]: self.end[i]]
|
||||
node_rank = F.depend(weight_tuple, sync_flag)
|
||||
update_weights = self.adasum.broadcast_list[i](node_rank)
|
||||
if i == self.server_rank:
|
||||
|
|
Loading…
Reference in New Issue