!1650 Support forward reduce scatter for Matmul

Merge pull request !1650 from yangzhenzhang/forward-reduce-scatter
This commit is contained in:
mindspore-ci-bot 2020-05-29 14:45:32 +08:00 committed by Gitee
commit 8316f736ea
6 changed files with 134 additions and 39 deletions

View File

@ -94,6 +94,17 @@ Status MatMulBase::GetAttrs() {
}
}
auto forward_reduce_scatter_iter = attrs_.find(FORWARD_REDUCE_SCATTER);
if (forward_reduce_scatter_iter != attrs_.end()) {
MS_EXCEPTION_IF_NULL(forward_reduce_scatter_iter->second);
if (forward_reduce_scatter_iter->second->isa<BoolImm>()) {
forward_reduce_scatter_ = forward_reduce_scatter_iter->second->cast<BoolImmPtr>()->value();
} else {
MS_LOG(ERROR) << name_ << " : The value of forward reduce scatter is not bool.";
return FAILED;
}
}
// infer inputs dimension size
if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) {
MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
@ -174,6 +185,13 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
}
}
if ((mat_a_dimension_ != 2 || mat_b_dimension_ != 2) && forward_reduce_scatter_) {
MS_LOG(WARNING) << name_
<< ": The dimension of mat a and mat b must be 2 in forward reduce scatter mode, "
"setting the forward reduce scatter mode to false here";
forward_reduce_scatter_ = false;
}
return SUCCESS;
}
@ -231,25 +249,16 @@ Status MatMulBase::InferForwardCommunication() {
return SUCCESS;
}
Operator op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
forward_op_.push_back(op);
MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name();
return SUCCESS;
}
// dev_matrix_shape: [a, b, c, d, e], then output strategy: [a, b, c, e];
Dimensions GetOutputStrategy(const Shape &dev_matrix_shape, int32_t repeated_calculation_num) {
Dimensions output_strategy = dev_matrix_shape;
if (repeated_calculation_num > 1) {
// move the first dimension(repeated_calc_num_)
(void)output_strategy.erase(output_strategy.begin());
Operator op;
if (forward_reduce_scatter_) {
op = CreateReduceScatterOp(REDUCE_OP_SUM, group_list[0].name());
} else {
op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
}
// delete the second-to-last element
(void)output_strategy.erase(output_strategy.begin() +
static_cast<different_type>(SECOND_FROM_END(output_strategy.size())));
return output_strategy;
forward_op_.push_back(op);
MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name();
return SUCCESS;
}
Status MatMulBase::InferTensorMap() {
@ -295,6 +304,23 @@ Status MatMulBase::InferTensorMap() {
mat_b_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(mat_b_tensor_map.size())), last_value);
}
if (forward_reduce_scatter_) {
if (dev_matrix_shape_.size() != 3) {
MS_LOG(WARNING) << name_
<< ": The dimension of dev matrix shape must be 3 in forward reduce scatter mode, "
"setting the forward reduce scatter mode to false here";
forward_reduce_scatter_ = false;
} else if (outputs_shape_[0][0] % (dev_matrix_shape_[0] * dev_matrix_shape_[1]) != 0) {
MS_LOG(WARNING) << name_
<< ": The first dimension of output should be split by dev_matrix[0]*dev_matrix[1] in "
"forward reduce scatter mode, setting the forward reduce scatter mode to false here";
forward_reduce_scatter_ = false;
} else {
// the forward reduce scatter only support that the dimension of output is 2
output_tensor_map = {1, 0};
}
}
inputs_tensor_map_.push_back(mat_a_tensor_map);
inputs_tensor_map_.push_back(mat_b_tensor_map);
outputs_tensor_map_.push_back(output_tensor_map);
@ -302,10 +328,21 @@ Status MatMulBase::InferTensorMap() {
}
Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) {
Shape output_dev_matrix_shape;
if (forward_reduce_scatter_) {
if (dev_matrix_shape_.size() != 3) {
MS_LOG(ERROR) << "The size of origin dev matrix shape must be 3 in forward reduce scatter mode";
return FAILED;
}
output_dev_matrix_shape = {dev_matrix_shape_[0] * dev_matrix_shape_[1], dev_matrix_shape_[2]};
} else {
output_dev_matrix_shape = dev_matrix_shape_;
}
TensorLayout mat_a_layout, mat_b_layout, output_layout;
if ((mat_a_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) ||
(mat_b_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) ||
(output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) {
(output_layout.InitFromVector(output_dev_matrix_shape, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) {
return FAILED;
}
@ -316,24 +353,6 @@ Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts
}
Status MatMulBase::InferTensorInfo() {
// infer tensor shape
Shape mat_a_shape = inputs_shape_.at(0);
Shape mat_b_shape = inputs_shape_.at(1);
Shape output_shape = outputs_shape_.at(0);
// infer slice shape
Shapes inputs_slice_shape, outputs_slice_shape;
Dimensions output_strategy = GetOutputStrategy(dev_matrix_shape_, repeated_calc_num_);
Strategys inputs_strategy = strategy_->GetInputDim();
Strategys outputs_strategy = {output_strategy};
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
return FAILED;
}
Shape mat_a_slice_shape = inputs_slice_shape.at(0);
Shape mat_b_slice_shape = inputs_slice_shape.at(1);
Shape output_slice_shape = outputs_slice_shape.at(0);
// infer tensor layout
TensorLayouts inputs_layout, outputs_layout;
if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) {
@ -343,9 +362,9 @@ Status MatMulBase::InferTensorInfo() {
TensorLayout mat_a_layout = inputs_layout.at(0);
TensorLayout mat_b_layout = inputs_layout.at(1);
TensorLayout output_layout = outputs_layout.at(0);
TensorInfo mat_a_tensor_info(mat_a_layout, mat_a_shape, mat_a_slice_shape);
TensorInfo mat_b_tensor_info(mat_b_layout, mat_b_shape, mat_b_slice_shape);
TensorInfo output_tensor_info(output_layout, output_shape, output_slice_shape);
TensorInfo mat_a_tensor_info(mat_a_layout);
TensorInfo mat_b_tensor_info(mat_b_layout);
TensorInfo output_tensor_info(output_layout);
inputs_tensor_info_.push_back(mat_a_tensor_info);
inputs_tensor_info_.push_back(mat_b_tensor_info);
@ -359,6 +378,11 @@ Status MatMulBase::Init(const StrategyPtr &strategy) {
return FAILED;
}
if (forward_reduce_scatter_) {
virtual_div_op_.clear();
MS_LOG(INFO) << "The forward reduce scatter mode does not involve repeated calculation, clear the virtual div op";
}
MS_LOG(INFO) << name_ << " : Init success.";
return SUCCESS;
}

View File

@ -61,6 +61,7 @@ class MatMulBase : public OperatorInfo {
bool transpose_a_ = false;
bool transpose_b_ = false;
bool forward_reduce_scatter_ = false;
size_t mat_a_dimension_ = 0;
size_t mat_b_dimension_ = 0;
};

View File

@ -208,6 +208,24 @@ Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &grou
return op;
}
Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group) {
OperatorName operator_name = REDUCE_SCATTER;
ValuePtr attr0_value = MakeValue(reduce_op); // ReduceOP.SUM
ValuePtr attr1_value = MakeValue(group); // group
Attr attr0 = std::make_pair(OP, attr0_value);
Attr attr1 = std::make_pair(GROUP, attr1_value);
OperatorAttrs operator_attrs;
operator_attrs.push_back(attr0);
operator_attrs.push_back(attr1);
OperatorParams operator_param;
OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
Operator op = std::make_pair(operator_name, operator_arg);
MS_LOG(INFO) << "Create reduce scatter op success, the reduce_op is " << reduce_op << ", the group is " << group;
return op;
}
// use for get tensor slice
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) {
Shape tensor_map = tensor_layout.tensor_map().array();

View File

@ -263,6 +263,7 @@ Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy);
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool);
Operator CreateVirtualDivOp(int32_t div_num);
Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group);
Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group);
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);
OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num);
int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map);

View File

@ -126,6 +126,7 @@ constexpr char MIRROR_OP[] = "mirror_op";
constexpr char FORWARD_OP[] = "forward_op";
constexpr char REDISTRIBUTION_OP[] = "redistribution_op";
constexpr char DARA_PARALLEL[] = "data_parallel";
constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter";
// Operator
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";

View File

@ -121,3 +121,53 @@ def test_two_matmul_repeated_calculation2():
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_forward_reduce_scatter():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.matmul.add_prim_attr("forward_reduce_scatter", True)
self.mul = P.Mul().set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.mul(out, b)
return out
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
context.set_context(save_graphs=True)
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([128, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_forward_reduce_scatter_transpose():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul(transpose_b=True).set_strategy(strategy1)
self.matmul.add_prim_attr("forward_reduce_scatter", True)
self.mul = P.Mul().set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.mul(out, b)
return out
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
context.set_context(save_graphs=True)
strategy1 = ((2, 4), (2, 4))
strategy2 = ((8, 2), (8, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([64, 32]), dtype=ms.float32)
b = Tensor(np.ones([128, 64]), dtype=ms.float32)
compile_net(net, x, y, b)