From 19bd830539752a1b1af0f37fa240e749030e4bdc Mon Sep 17 00:00:00 2001 From: yangzhenzhang <285824651@qq.com> Date: Fri, 29 May 2020 10:25:15 +0800 Subject: [PATCH] support forward reduce scatter for matmul --- .../ccsrc/parallel/ops_info/matmul_info.cc | 102 +++++++++++------- .../ccsrc/parallel/ops_info/matmul_info.h | 1 + .../ccsrc/parallel/ops_info/operator_info.cc | 18 ++++ .../ccsrc/parallel/ops_info/operator_info.h | 1 + mindspore/ccsrc/parallel/ops_info/ops_utils.h | 1 + tests/ut/python/parallel/test_two_matmul.py | 50 +++++++++ 6 files changed, 134 insertions(+), 39 deletions(-) diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc index 7752148b7d1..7d1ab8dc0fa 100644 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc @@ -94,6 +94,17 @@ Status MatMulBase::GetAttrs() { } } + auto forward_reduce_scatter_iter = attrs_.find(FORWARD_REDUCE_SCATTER); + if (forward_reduce_scatter_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(forward_reduce_scatter_iter->second); + if (forward_reduce_scatter_iter->second->isa()) { + forward_reduce_scatter_ = forward_reduce_scatter_iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << " : The value of forward reduce scatter is not bool."; + return FAILED; + } + } + // infer inputs dimension size if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) { MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; @@ -174,6 +185,13 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) { } } + if ((mat_a_dimension_ != 2 || mat_b_dimension_ != 2) && forward_reduce_scatter_) { + MS_LOG(WARNING) << name_ + << ": The dimension of mat a and mat b must be 2 in forward reduce scatter mode, " + "setting the forward reduce scatter mode to false here"; + forward_reduce_scatter_ = false; + } + return SUCCESS; } @@ -231,25 +249,16 @@ Status MatMulBase::InferForwardCommunication() { return SUCCESS; } - Operator op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name()); - forward_op_.push_back(op); - - MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name(); - return SUCCESS; -} - -// dev_matrix_shape: [a, b, c, d, e], then output strategy: [a, b, c, e]; -Dimensions GetOutputStrategy(const Shape &dev_matrix_shape, int32_t repeated_calculation_num) { - Dimensions output_strategy = dev_matrix_shape; - if (repeated_calculation_num > 1) { - // move the first dimension(repeated_calc_num_) - (void)output_strategy.erase(output_strategy.begin()); + Operator op; + if (forward_reduce_scatter_) { + op = CreateReduceScatterOp(REDUCE_OP_SUM, group_list[0].name()); + } else { + op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name()); } - // delete the second-to-last element - (void)output_strategy.erase(output_strategy.begin() + - static_cast(SECOND_FROM_END(output_strategy.size()))); - return output_strategy; + forward_op_.push_back(op); + MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name(); + return SUCCESS; } Status MatMulBase::InferTensorMap() { @@ -295,6 +304,23 @@ Status MatMulBase::InferTensorMap() { mat_b_tensor_map.begin() + static_cast(LAST_INDEX(mat_b_tensor_map.size())), last_value); } + if (forward_reduce_scatter_) { + if (dev_matrix_shape_.size() != 3) { + MS_LOG(WARNING) << name_ + << ": The dimension of dev matrix shape must be 3 in forward reduce scatter mode, " + "setting the forward reduce scatter mode to false here"; + forward_reduce_scatter_ = false; + } else if (outputs_shape_[0][0] % (dev_matrix_shape_[0] * dev_matrix_shape_[1]) != 0) { + MS_LOG(WARNING) << name_ + << ": The first dimension of output should be split by dev_matrix[0]*dev_matrix[1] in " + "forward reduce scatter mode, setting the forward reduce scatter mode to false here"; + forward_reduce_scatter_ = false; + } else { + // the forward reduce scatter only support that the dimension of output is 2 + output_tensor_map = {1, 0}; + } + } + inputs_tensor_map_.push_back(mat_a_tensor_map); inputs_tensor_map_.push_back(mat_b_tensor_map); outputs_tensor_map_.push_back(output_tensor_map); @@ -302,10 +328,21 @@ Status MatMulBase::InferTensorMap() { } Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { + Shape output_dev_matrix_shape; + if (forward_reduce_scatter_) { + if (dev_matrix_shape_.size() != 3) { + MS_LOG(ERROR) << "The size of origin dev matrix shape must be 3 in forward reduce scatter mode"; + return FAILED; + } + output_dev_matrix_shape = {dev_matrix_shape_[0] * dev_matrix_shape_[1], dev_matrix_shape_[2]}; + } else { + output_dev_matrix_shape = dev_matrix_shape_; + } + TensorLayout mat_a_layout, mat_b_layout, output_layout; if ((mat_a_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) || (mat_b_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) || - (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) { + (output_layout.InitFromVector(output_dev_matrix_shape, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) { return FAILED; } @@ -316,24 +353,6 @@ Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts } Status MatMulBase::InferTensorInfo() { - // infer tensor shape - Shape mat_a_shape = inputs_shape_.at(0); - Shape mat_b_shape = inputs_shape_.at(1); - Shape output_shape = outputs_shape_.at(0); - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Dimensions output_strategy = GetOutputStrategy(dev_matrix_shape_, repeated_calc_num_); - - Strategys inputs_strategy = strategy_->GetInputDim(); - Strategys outputs_strategy = {output_strategy}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape mat_a_slice_shape = inputs_slice_shape.at(0); - Shape mat_b_slice_shape = inputs_slice_shape.at(1); - Shape output_slice_shape = outputs_slice_shape.at(0); - // infer tensor layout TensorLayouts inputs_layout, outputs_layout; if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { @@ -343,9 +362,9 @@ Status MatMulBase::InferTensorInfo() { TensorLayout mat_a_layout = inputs_layout.at(0); TensorLayout mat_b_layout = inputs_layout.at(1); TensorLayout output_layout = outputs_layout.at(0); - TensorInfo mat_a_tensor_info(mat_a_layout, mat_a_shape, mat_a_slice_shape); - TensorInfo mat_b_tensor_info(mat_b_layout, mat_b_shape, mat_b_slice_shape); - TensorInfo output_tensor_info(output_layout, output_shape, output_slice_shape); + TensorInfo mat_a_tensor_info(mat_a_layout); + TensorInfo mat_b_tensor_info(mat_b_layout); + TensorInfo output_tensor_info(output_layout); inputs_tensor_info_.push_back(mat_a_tensor_info); inputs_tensor_info_.push_back(mat_b_tensor_info); @@ -359,6 +378,11 @@ Status MatMulBase::Init(const StrategyPtr &strategy) { return FAILED; } + if (forward_reduce_scatter_) { + virtual_div_op_.clear(); + MS_LOG(INFO) << "The forward reduce scatter mode does not involve repeated calculation, clear the virtual div op"; + } + MS_LOG(INFO) << name_ << " : Init success."; return SUCCESS; } diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.h b/mindspore/ccsrc/parallel/ops_info/matmul_info.h index 86a74f78f26..cb3e54a0489 100644 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.h +++ b/mindspore/ccsrc/parallel/ops_info/matmul_info.h @@ -61,6 +61,7 @@ class MatMulBase : public OperatorInfo { bool transpose_a_ = false; bool transpose_b_ = false; + bool forward_reduce_scatter_ = false; size_t mat_a_dimension_ = 0; size_t mat_b_dimension_ = 0; }; diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/parallel/ops_info/operator_info.cc index 6e09f7994f9..f9b294898cb 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.cc @@ -208,6 +208,24 @@ Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &grou return op; } +Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group) { + OperatorName operator_name = REDUCE_SCATTER; + ValuePtr attr0_value = MakeValue(reduce_op); // ReduceOP.SUM + ValuePtr attr1_value = MakeValue(group); // group + Attr attr0 = std::make_pair(OP, attr0_value); + Attr attr1 = std::make_pair(GROUP, attr1_value); + OperatorAttrs operator_attrs; + operator_attrs.push_back(attr0); + operator_attrs.push_back(attr1); + + OperatorParams operator_param; + OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); + + Operator op = std::make_pair(operator_name, operator_arg); + MS_LOG(INFO) << "Create reduce scatter op success, the reduce_op is " << reduce_op << ", the group is " << group; + return op; +} + // use for get tensor slice Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) { Shape tensor_map = tensor_layout.tensor_map().array(); diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.h b/mindspore/ccsrc/parallel/ops_info/operator_info.h index 6888a88f720..21041c3e94b 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.h @@ -263,6 +263,7 @@ Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy); Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool); Operator CreateVirtualDivOp(int32_t div_num); Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group); +Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group); Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index 44c504c2421..4da54a358d5 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -126,6 +126,7 @@ constexpr char MIRROR_OP[] = "mirror_op"; constexpr char FORWARD_OP[] = "forward_op"; constexpr char REDISTRIBUTION_OP[] = "redistribution_op"; constexpr char DARA_PARALLEL[] = "data_parallel"; +constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter"; // Operator constexpr char VIRTUAL_DIV[] = "_VirtualDiv"; diff --git a/tests/ut/python/parallel/test_two_matmul.py b/tests/ut/python/parallel/test_two_matmul.py index daee920a910..6489cc90a8d 100644 --- a/tests/ut/python/parallel/test_two_matmul.py +++ b/tests/ut/python/parallel/test_two_matmul.py @@ -121,3 +121,53 @@ def test_two_matmul_repeated_calculation2(): y = Tensor(np.ones([32, 64]), dtype=ms.float32) b = Tensor(np.ones([64, 64]), dtype=ms.float32) compile_net(net, x, y, b) + + +def test_matmul_forward_reduce_scatter(): + class Net(nn.Cell): + def __init__(self, strategy1, strategy2): + super().__init__() + self.matmul = P.MatMul().set_strategy(strategy1) + self.matmul.add_prim_attr("forward_reduce_scatter", True) + self.mul = P.Mul().set_strategy(strategy2) + + def construct(self, x, y, b): + out = self.matmul(x, y) + out = self.mul(out, b) + return out + + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + context.set_context(save_graphs=True) + strategy1 = ((2, 2), (2, 2)) + strategy2 = ((4, 2), (4, 2)) + net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) + + x = Tensor(np.ones([128, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 64]), dtype=ms.float32) + b = Tensor(np.ones([128, 64]), dtype=ms.float32) + compile_net(net, x, y, b) + + +def test_matmul_forward_reduce_scatter_transpose(): + class Net(nn.Cell): + def __init__(self, strategy1, strategy2): + super().__init__() + self.matmul = P.MatMul(transpose_b=True).set_strategy(strategy1) + self.matmul.add_prim_attr("forward_reduce_scatter", True) + self.mul = P.Mul().set_strategy(strategy2) + + def construct(self, x, y, b): + out = self.matmul(x, y) + out = self.mul(out, b) + return out + + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + context.set_context(save_graphs=True) + strategy1 = ((2, 4), (2, 4)) + strategy2 = ((8, 2), (8, 2)) + net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) + + x = Tensor(np.ones([128, 32]), dtype=ms.float32) + y = Tensor(np.ones([64, 32]), dtype=ms.float32) + b = Tensor(np.ones([128, 64]), dtype=ms.float32) + compile_net(net, x, y, b) \ No newline at end of file