From c7b918f537a97b8ef58b57086d06b8c6a3b94d45 Mon Sep 17 00:00:00 2001 From: jinjiali Date: Mon, 11 Oct 2021 12:56:55 +0800 Subject: [PATCH] correct adasum --- mindspore/boost/adasum.py | 147 ++++++++++++++++++-------- mindspore/boost/boost_cell_wrapper.py | 7 +- 2 files changed, 107 insertions(+), 47 deletions(-) diff --git a/mindspore/boost/adasum.py b/mindspore/boost/adasum.py index d42bbce0573..5abc0ddcad9 100644 --- a/mindspore/boost/adasum.py +++ b/mindspore/boost/adasum.py @@ -22,6 +22,7 @@ from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops import operations as P from mindspore.ops.operations._inner_ops import Send, Receive +from mindspore.common.tensor import Tensor __all__ = ["AdaSum"] @@ -52,39 +53,64 @@ def _receive_before_send(send_part, send, recv): return F.depend(receive_ok, send(send_part)) -def _send_recv_res(recv_part, local_part, allreduce): +def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisibility, allreduce_node_num): """send result and receive result.""" - recv_part = P.Squeeze()(recv_part) - local_part = F.depend(local_part, recv_part) - eps = 1e-12 - value_0 = P.ReduceSum()(local_part * recv_part) + eps - value_1 = P.ReduceSum()(local_part * local_part) + eps - value_2 = P.ReduceSum()(recv_part * recv_part) + eps - value_0 = allreduce(value_0) - value_1 = allreduce(value_1) - value_2 = allreduce(value_2) - res = (1 - (value_0 / (2 * value_1))) * local_part + (1 - (value_0 / (2 * value_2))) * recv_part + if parameter_divisibility: + recv_part = P.Squeeze()(recv_part) + if F.shape(recv_part) is None: + recv_part = Tensor([recv_part]) + local_part = F.depend(local_part, recv_part) + eps = 1e-12 + value_0 = P.ReduceSum()(local_part * recv_part) + eps + if left_send: + value_1 = P.ReduceSum()(local_part * local_part) + eps + value_2 = P.ReduceSum()(recv_part * recv_part) + eps + else: + value_1 = P.ReduceSum()(recv_part * recv_part) + eps + value_2 = P.ReduceSum()(local_part * local_part) + eps + value_0 = allreduce(value_0) + value_1 = F.depend(allreduce(value_1), value_0) + value_2 = F.depend(allreduce(value_2), value_1) + if left_send: + res = (1 - (value_0 / (2 * value_1))) * local_part + (1 - (value_0 / (2 * value_2))) * recv_part + else: + res = (1 - (value_0 / (2 * value_1))) * recv_part + (1 - (value_0 / (2 * value_2))) * local_part + else: + res = allreduce(local_part) + res /= allreduce_node_num return res _adasum_opt_forward = C.MultitypeFuncGraph("adasum_opt_forward") -@_adasum_opt_forward.register("Bool", "Function", "Function", "Function", "Tensor") -def _adasum_opt_forward_process(left_send, allreduce, send, recv, delta_w): +@_adasum_opt_forward.register("Bool", "Function", "Bool", "Int64", "Function", "Function", "Tensor") +def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, allreduce_node_num, send, recv, delta_w): """adasum optimizer process.""" - delta_w = P.Squeeze()(delta_w) - ori_len = F.shape(delta_w)[0] - divide_len = ori_len / 2 - left_part = delta_w[:divide_len] - right_part = delta_w[divide_len:] + if parameter_divisibility: + delta_w = P.Squeeze()(delta_w) + ori_len = F.shape(delta_w)[0] + divide_len = ori_len / 2 + left_part = delta_w[:divide_len] + right_part = delta_w[divide_len:] + else: + left_part = delta_w + right_part = delta_w if left_send: - recv_part = _send_before_receive(left_part, send, recv) - update_delta_w = _send_recv_res(recv_part, right_part, allreduce) + if parameter_divisibility: + recv_part = _send_before_receive(left_part, send, recv) + else: + recv_part = right_part + update_delta_w = _send_recv_res(left_send, recv_part, right_part, allreduce, parameter_divisibility, + allreduce_node_num) else: - recv_part = _receive_before_send(right_part, send, recv) - update_delta_w = _send_recv_res(recv_part, left_part, allreduce) + if parameter_divisibility: + recv_part = _receive_before_send(right_part, send, recv) + else: + recv_part = left_part + update_delta_w = _send_recv_res(left_send, recv_part, left_part, allreduce, parameter_divisibility, + allreduce_node_num) return update_delta_w @@ -92,18 +118,29 @@ def _adasum_opt_forward_process(left_send, allreduce, send, recv, delta_w): _adasum_opt_rollback = C.MultitypeFuncGraph("adasum_opt_rollback") -@_adasum_opt_rollback.register("Bool", "Tensor", "Function", "Function") -def _adasum_opt_rollback_process(left_send, delta_w, send, recv): +@_adasum_opt_rollback.register("Bool", "Bool", "Tensor", "Function", "Function") +def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, send, recv): """adasum optimizer rollback process.""" - if left_send: - recv_part = _send_before_receive(delta_w, send, recv) + if parameter_divisibility: + if left_send: + recv_part = _send_before_receive(delta_w, send, recv) + else: + recv_part = _receive_before_send(delta_w, send, recv) + + recv_part = P.Squeeze()(recv_part) + if F.shape(recv_part) is None: + recv_part = Tensor([recv_part]) + if F.shape(delta_w) is None: + delta_w = Tensor([delta_w]) + recv_part = P.Reshape()(recv_part, (-1,)) + delta_w = P.Reshape()(delta_w, (-1,)) + + if left_send: + res = P.Concat()((recv_part, delta_w)) + else: + res = P.Concat()((delta_w, recv_part)) else: - recv_part = _receive_before_send(delta_w, send, recv) - recv_part = P.Squeeze()(recv_part) - if left_send: - res = P.Concat()((recv_part, delta_w)) - else: - res = P.Concat()((delta_w, recv_part)) + res = delta_w return res @@ -144,6 +181,8 @@ class AdaSum(Cell): self.recv_list_rollback = [] self.allreduce_list = [] self.broadcast_list = [] + self.parameter_divisibility_list = [] + self.allreduce_node_num_list = [] last_delta_weights = [] group_start_rank = (self.rank // self.device_number) * self.device_number @@ -152,21 +191,31 @@ class AdaSum(Cell): sr_target = self.rank if (sr_target // current_group) % 2 == 0: dest_target = sr_target + current_group - group_name = "adasum_" + str(step) + "_" + str(sr_target) - create_group(group_name, [sr_target, dest_target]) self.send_node.append(True) else: dest_target = sr_target - current_group - group_name = "adasum_" + str(step) + "_" + str(dest_target) - create_group(group_name, [dest_target, sr_target]) self.send_node.append(False) + neighbor_ids = [] + group_name_last = 0 + for index in range(2 ** (step + 1)): + node_rank = self.rank // self.device_number + double_d = 2 ** (step + 1) + neighbor_id = (node_rank // double_d * double_d + index) * self.device_number + \ + self.rank % self.device_number + neighbor_ids.append(neighbor_id) + group_name_last += neighbor_id + group_name = "adasum_" + str(step) + "_" + str(group_name_last) + create_group(group_name, neighbor_ids) + send_left = [] send_right = [] recv_left = [] recv_right = [] - left_delta_weights, right_delta_weights = \ + allreduce_node_num = () + left_delta_weights, right_delta_weights, delta_weights_divisibility = \ self._get_delta_weights_info(last_delta_weights) + self.parameter_divisibility_list.append(delta_weights_divisibility) weights_index = 0 fusion_id = (step + 1) * 3 for shape, dtype in left_delta_weights: @@ -209,6 +258,13 @@ class AdaSum(Cell): server_all_reduce.add_prim_attr("fusion", fusion_id + 2) self.allreduce_list.append(server_all_reduce) + for param_divisibility in delta_weights_divisibility: + if param_divisibility: + allreduce_node_num += (0,) + else: + allreduce_node_num += (2 ** (step + 1),) + self.allreduce_node_num_list.append(allreduce_node_num) + broadcast_group = [x for x in range(group_start_rank, group_start_rank + self.device_number)] broadcast_group_name = "broadcast_group_" + str(group_start_rank) create_group(broadcast_group_name, broadcast_group) @@ -227,17 +283,21 @@ class AdaSum(Cell): half_delta_weights.append((new_shape, parameter.dtype)) left_delta_weights = [] right_delta_weights = [] + delta_weights_divisibility = () for shape, dtype in half_delta_weights: left_shape = copy.deepcopy(shape) right_shape = copy.deepcopy(shape) + divisibility_flag = False for i in range(len(shape)): if shape[i] > 1: left_shape[i] = int(shape[i] // 2) right_shape[i] = shape[i] - int(shape[i] // 2) + divisibility_flag = True break left_delta_weights.append((left_shape, dtype)) right_delta_weights.append((right_shape, dtype)) - return left_delta_weights, right_delta_weights + delta_weights_divisibility += (divisibility_flag,) + return left_delta_weights, right_delta_weights, delta_weights_divisibility def _hash(self, step, target, weights_index): target = "tag" + str(step) + str(target) + str(weights_index) @@ -248,15 +308,16 @@ class AdaSum(Cell): def construct(self, delta_weights, parameters, old_parameters): forward_weights = [delta_weights] for i in range(self.calc_times): - process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i], self.allreduce_list[i]),\ + process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i], self.allreduce_list[i]), + self.parameter_divisibility_list[i], self.allreduce_node_num_list[i], self.send_list_forward[i], self.recv_list_forward[i], forward_weights[-1]) forward_weights.append(process_weights) for i in range(self.calc_times): j = self.calc_times - i - 1 - process_weights = self.hyper_map(F.partial(_adasum_opt_rollback, self.send_node[i]), forward_weights[j + 1], - self.send_list_rollback[i], self.recv_list_rollback[i]) + process_weights = self.hyper_map(F.partial(_adasum_opt_rollback, self.send_node[j]), + self.parameter_divisibility_list[j], forward_weights[j + 1], + self.send_list_rollback[j], self.recv_list_rollback[j]) forward_weights[j] = process_weights - adasum_parameters = self.hyper_map(F.partial(_update_parameters), delta_weights, forward_weights[0],\ + adasum_parameters = self.hyper_map(F.partial(_update_parameters), delta_weights, forward_weights[0], parameters, old_parameters) return adasum_parameters - \ No newline at end of file diff --git a/mindspore/boost/boost_cell_wrapper.py b/mindspore/boost/boost_cell_wrapper.py index 1a3ad4bf798..1089ad38ed9 100644 --- a/mindspore/boost/boost_cell_wrapper.py +++ b/mindspore/boost/boost_cell_wrapper.py @@ -180,7 +180,7 @@ class BoostTrainOneStepCell(TrainOneStepCell): self.end = [(x + 1) * parameter_rank_number for x in range(_device_number)] self.end[-1] = len(self.weights) - current_weights = self.weights[self.start[self.server_rank] : self.end[self.server_rank]] + current_weights = self.weights[self.start[self.server_rank]: self.end[self.server_rank]] self.grad_clone = ParameterTuple(current_weights).clone(prefix="delta_weight") self.adasum = AdaSum(_rank, _device_number, group_number, self.grad_clone) @@ -191,7 +191,6 @@ class BoostTrainOneStepCell(TrainOneStepCell): create_group(server_group_name, group_list[current_index]) self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree, group=server_group_name) - def construct(self, *inputs): if self.freeze: loss = self.gradient_freeze_process(*inputs) @@ -245,14 +244,14 @@ class BoostTrainOneStepCell(TrainOneStepCell): def adasum_process(self, loss, grads): """adasum algorithm process.""" loss = F.depend(loss, self.optimizer(grads)) - rank_weights = self.weights[self.start[self.server_rank] : self.end[self.server_rank]] + rank_weights = self.weights[self.start[self.server_rank]: self.end[self.server_rank]] grad_clone = F.depend(self.grad_clone, loss) delta_w = self.hyper_map(F.partial(_get_delta_weight), rank_weights, grad_clone) adasum_res = self.adasum(delta_w, rank_weights, grad_clone) sync_tensor = F.depend(self.sync_tensor, adasum_res) sync_flag = self.adasum.sync_barrier(sync_tensor) for i in range(self.device_number): - weight_tuple = self.weights[self.start[i] : self.end[i]] + weight_tuple = self.weights[self.start[i]: self.end[i]] node_rank = F.depend(weight_tuple, sync_flag) update_weights = self.adasum.broadcast_list[i](node_rank) if i == self.server_rank: