!1650 Support forward reduce scatter for Matmul
Merge pull request !1650 from yangzhenzhang/forward-reduce-scatter
This commit is contained in:
commit
8316f736ea
|
@ -94,6 +94,17 @@ Status MatMulBase::GetAttrs() {
|
|||
}
|
||||
}
|
||||
|
||||
auto forward_reduce_scatter_iter = attrs_.find(FORWARD_REDUCE_SCATTER);
|
||||
if (forward_reduce_scatter_iter != attrs_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(forward_reduce_scatter_iter->second);
|
||||
if (forward_reduce_scatter_iter->second->isa<BoolImm>()) {
|
||||
forward_reduce_scatter_ = forward_reduce_scatter_iter->second->cast<BoolImmPtr>()->value();
|
||||
} else {
|
||||
MS_LOG(ERROR) << name_ << " : The value of forward reduce scatter is not bool.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
// infer inputs dimension size
|
||||
if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) {
|
||||
MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
|
||||
|
@ -174,6 +185,13 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
|
|||
}
|
||||
}
|
||||
|
||||
if ((mat_a_dimension_ != 2 || mat_b_dimension_ != 2) && forward_reduce_scatter_) {
|
||||
MS_LOG(WARNING) << name_
|
||||
<< ": The dimension of mat a and mat b must be 2 in forward reduce scatter mode, "
|
||||
"setting the forward reduce scatter mode to false here";
|
||||
forward_reduce_scatter_ = false;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -231,25 +249,16 @@ Status MatMulBase::InferForwardCommunication() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Operator op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
|
||||
forward_op_.push_back(op);
|
||||
|
||||
MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name();
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
// dev_matrix_shape: [a, b, c, d, e], then output strategy: [a, b, c, e];
|
||||
Dimensions GetOutputStrategy(const Shape &dev_matrix_shape, int32_t repeated_calculation_num) {
|
||||
Dimensions output_strategy = dev_matrix_shape;
|
||||
if (repeated_calculation_num > 1) {
|
||||
// move the first dimension(repeated_calc_num_)
|
||||
(void)output_strategy.erase(output_strategy.begin());
|
||||
Operator op;
|
||||
if (forward_reduce_scatter_) {
|
||||
op = CreateReduceScatterOp(REDUCE_OP_SUM, group_list[0].name());
|
||||
} else {
|
||||
op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
|
||||
}
|
||||
|
||||
// delete the second-to-last element
|
||||
(void)output_strategy.erase(output_strategy.begin() +
|
||||
static_cast<different_type>(SECOND_FROM_END(output_strategy.size())));
|
||||
return output_strategy;
|
||||
forward_op_.push_back(op);
|
||||
MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name();
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MatMulBase::InferTensorMap() {
|
||||
|
@ -295,6 +304,23 @@ Status MatMulBase::InferTensorMap() {
|
|||
mat_b_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(mat_b_tensor_map.size())), last_value);
|
||||
}
|
||||
|
||||
if (forward_reduce_scatter_) {
|
||||
if (dev_matrix_shape_.size() != 3) {
|
||||
MS_LOG(WARNING) << name_
|
||||
<< ": The dimension of dev matrix shape must be 3 in forward reduce scatter mode, "
|
||||
"setting the forward reduce scatter mode to false here";
|
||||
forward_reduce_scatter_ = false;
|
||||
} else if (outputs_shape_[0][0] % (dev_matrix_shape_[0] * dev_matrix_shape_[1]) != 0) {
|
||||
MS_LOG(WARNING) << name_
|
||||
<< ": The first dimension of output should be split by dev_matrix[0]*dev_matrix[1] in "
|
||||
"forward reduce scatter mode, setting the forward reduce scatter mode to false here";
|
||||
forward_reduce_scatter_ = false;
|
||||
} else {
|
||||
// the forward reduce scatter only support that the dimension of output is 2
|
||||
output_tensor_map = {1, 0};
|
||||
}
|
||||
}
|
||||
|
||||
inputs_tensor_map_.push_back(mat_a_tensor_map);
|
||||
inputs_tensor_map_.push_back(mat_b_tensor_map);
|
||||
outputs_tensor_map_.push_back(output_tensor_map);
|
||||
|
@ -302,10 +328,21 @@ Status MatMulBase::InferTensorMap() {
|
|||
}
|
||||
|
||||
Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) {
|
||||
Shape output_dev_matrix_shape;
|
||||
if (forward_reduce_scatter_) {
|
||||
if (dev_matrix_shape_.size() != 3) {
|
||||
MS_LOG(ERROR) << "The size of origin dev matrix shape must be 3 in forward reduce scatter mode";
|
||||
return FAILED;
|
||||
}
|
||||
output_dev_matrix_shape = {dev_matrix_shape_[0] * dev_matrix_shape_[1], dev_matrix_shape_[2]};
|
||||
} else {
|
||||
output_dev_matrix_shape = dev_matrix_shape_;
|
||||
}
|
||||
|
||||
TensorLayout mat_a_layout, mat_b_layout, output_layout;
|
||||
if ((mat_a_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) ||
|
||||
(mat_b_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) ||
|
||||
(output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) {
|
||||
(output_layout.InitFromVector(output_dev_matrix_shape, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
@ -316,24 +353,6 @@ Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts
|
|||
}
|
||||
|
||||
Status MatMulBase::InferTensorInfo() {
|
||||
// infer tensor shape
|
||||
Shape mat_a_shape = inputs_shape_.at(0);
|
||||
Shape mat_b_shape = inputs_shape_.at(1);
|
||||
Shape output_shape = outputs_shape_.at(0);
|
||||
|
||||
// infer slice shape
|
||||
Shapes inputs_slice_shape, outputs_slice_shape;
|
||||
Dimensions output_strategy = GetOutputStrategy(dev_matrix_shape_, repeated_calc_num_);
|
||||
|
||||
Strategys inputs_strategy = strategy_->GetInputDim();
|
||||
Strategys outputs_strategy = {output_strategy};
|
||||
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
Shape mat_a_slice_shape = inputs_slice_shape.at(0);
|
||||
Shape mat_b_slice_shape = inputs_slice_shape.at(1);
|
||||
Shape output_slice_shape = outputs_slice_shape.at(0);
|
||||
|
||||
// infer tensor layout
|
||||
TensorLayouts inputs_layout, outputs_layout;
|
||||
if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) {
|
||||
|
@ -343,9 +362,9 @@ Status MatMulBase::InferTensorInfo() {
|
|||
TensorLayout mat_a_layout = inputs_layout.at(0);
|
||||
TensorLayout mat_b_layout = inputs_layout.at(1);
|
||||
TensorLayout output_layout = outputs_layout.at(0);
|
||||
TensorInfo mat_a_tensor_info(mat_a_layout, mat_a_shape, mat_a_slice_shape);
|
||||
TensorInfo mat_b_tensor_info(mat_b_layout, mat_b_shape, mat_b_slice_shape);
|
||||
TensorInfo output_tensor_info(output_layout, output_shape, output_slice_shape);
|
||||
TensorInfo mat_a_tensor_info(mat_a_layout);
|
||||
TensorInfo mat_b_tensor_info(mat_b_layout);
|
||||
TensorInfo output_tensor_info(output_layout);
|
||||
|
||||
inputs_tensor_info_.push_back(mat_a_tensor_info);
|
||||
inputs_tensor_info_.push_back(mat_b_tensor_info);
|
||||
|
@ -359,6 +378,11 @@ Status MatMulBase::Init(const StrategyPtr &strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
if (forward_reduce_scatter_) {
|
||||
virtual_div_op_.clear();
|
||||
MS_LOG(INFO) << "The forward reduce scatter mode does not involve repeated calculation, clear the virtual div op";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << " : Init success.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
|
|
@ -61,6 +61,7 @@ class MatMulBase : public OperatorInfo {
|
|||
|
||||
bool transpose_a_ = false;
|
||||
bool transpose_b_ = false;
|
||||
bool forward_reduce_scatter_ = false;
|
||||
size_t mat_a_dimension_ = 0;
|
||||
size_t mat_b_dimension_ = 0;
|
||||
};
|
||||
|
|
|
@ -208,6 +208,24 @@ Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &grou
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group) {
|
||||
OperatorName operator_name = REDUCE_SCATTER;
|
||||
ValuePtr attr0_value = MakeValue(reduce_op); // ReduceOP.SUM
|
||||
ValuePtr attr1_value = MakeValue(group); // group
|
||||
Attr attr0 = std::make_pair(OP, attr0_value);
|
||||
Attr attr1 = std::make_pair(GROUP, attr1_value);
|
||||
OperatorAttrs operator_attrs;
|
||||
operator_attrs.push_back(attr0);
|
||||
operator_attrs.push_back(attr1);
|
||||
|
||||
OperatorParams operator_param;
|
||||
OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
|
||||
|
||||
Operator op = std::make_pair(operator_name, operator_arg);
|
||||
MS_LOG(INFO) << "Create reduce scatter op success, the reduce_op is " << reduce_op << ", the group is " << group;
|
||||
return op;
|
||||
}
|
||||
|
||||
// use for get tensor slice
|
||||
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) {
|
||||
Shape tensor_map = tensor_layout.tensor_map().array();
|
||||
|
|
|
@ -263,6 +263,7 @@ Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy);
|
|||
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool);
|
||||
Operator CreateVirtualDivOp(int32_t div_num);
|
||||
Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group);
|
||||
Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group);
|
||||
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);
|
||||
OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num);
|
||||
int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map);
|
||||
|
|
|
@ -126,6 +126,7 @@ constexpr char MIRROR_OP[] = "mirror_op";
|
|||
constexpr char FORWARD_OP[] = "forward_op";
|
||||
constexpr char REDISTRIBUTION_OP[] = "redistribution_op";
|
||||
constexpr char DARA_PARALLEL[] = "data_parallel";
|
||||
constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter";
|
||||
|
||||
// Operator
|
||||
constexpr char VIRTUAL_DIV[] = "_VirtualDiv";
|
||||
|
|
|
@ -121,3 +121,53 @@ def test_two_matmul_repeated_calculation2():
|
|||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_forward_reduce_scatter():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul().set_strategy(strategy1)
|
||||
self.matmul.add_prim_attr("forward_reduce_scatter", True)
|
||||
self.mul = P.Mul().set_strategy(strategy2)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.matmul(x, y)
|
||||
out = self.mul(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
context.set_context(save_graphs=True)
|
||||
strategy1 = ((2, 2), (2, 2))
|
||||
strategy2 = ((4, 2), (4, 2))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
||||
|
||||
|
||||
def test_matmul_forward_reduce_scatter_transpose():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul(transpose_b=True).set_strategy(strategy1)
|
||||
self.matmul.add_prim_attr("forward_reduce_scatter", True)
|
||||
self.mul = P.Mul().set_strategy(strategy2)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.matmul(x, y)
|
||||
out = self.mul(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
context.set_context(save_graphs=True)
|
||||
strategy1 = ((2, 4), (2, 4))
|
||||
strategy2 = ((8, 2), (8, 2))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
|
||||
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
compile_net(net, x, y, b)
|
Loading…
Reference in New Issue