update repeated calculation

This commit is contained in:
yangzhenzhang 2020-10-21 14:46:05 +08:00
parent 359543d663
commit eb6f4e3ce8
8 changed files with 56 additions and 16 deletions

View File

@ -205,6 +205,7 @@ Status MatMulBase::InferDevMatrixShape() {
Dimensions mat_b_strategy = stra.at(1);
SetDevMatrixShape(mat_a_strategy, mat_b_strategy, transpose_b_, &dev_matrix_shape_);
origin_dev_matrix_shape_ = dev_matrix_shape_;
return SUCCESS;
}
@ -236,10 +237,11 @@ Status MatMulBase::InferMirrorOps() {
Status MatMulBase::InferForwardCommunication() {
forward_op_.clear();
size_t dimension = dev_matrix_shape_.size();
size_t dimension = origin_dev_matrix_shape_.size();
size_t relevant_dimension_index = SECOND_FROM_END(dimension);
// Relevant dimension is not split and all reduce is not required
if (dev_matrix_shape_.at(relevant_dimension_index) == MIN_SLICE_NUM) {
// Relevant dimension is not split and all reduce is not required,
// need to use origin_dev_matrix_shape_ here, since the dev_matrix_shape_ will be changed if repeated calculation.
if (origin_dev_matrix_shape_.at(relevant_dimension_index) == MIN_SLICE_NUM) {
MS_LOG(INFO) << name_ << " : Forward all reduce is not required.";
return SUCCESS;
}

View File

@ -65,6 +65,7 @@ class MatMulBase : public OperatorInfo {
int32_t field_size_ = 0;
size_t mat_a_dimension_ = 0;
size_t mat_b_dimension_ = 0;
Shape origin_dev_matrix_shape_;
};
class MatMul : public MatMulBase {

View File

@ -74,7 +74,7 @@ Status OneHotInfo::InferDevMatrixShape() {
dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable
}
old_dev_matrix_back_ = dev_matrix_shape_.back();
repeated_num_in_dev_matrix_right_ = false;
return SUCCESS;
}

View File

@ -164,21 +164,24 @@ Status OperatorInfo::InferRepeatedCalcInfo() {
return SUCCESS;
}
// If repeated calculation, need to set the repeated_calc_num as the last dimension of dev-matrix,
// only use for infer tensor layout. Because if the previous shard is (a, b), and the next shard is
// (a, 1), adding the repeated_calc_num to the last dimension of dev-matrix, there is no need to redistribution.
// If repeated calculation, set the repeated_calc_num as the last dimension of dev-matrix in default,
// because if the previous shard is (a, b), and the next shard is (a, 1), adding the repeated_calc_num
// to the last dimension of dev-matrix, there is no need to redistribution.
void OperatorInfo::SetRepeatedCalcDevMatrix() {
if (repeated_calc_num_ <= 1) {
return;
}
(void)dev_matrix_shape_.push_back(repeated_calc_num_);
if (repeated_num_in_dev_matrix_right_) {
dev_matrix_shape_.push_back(repeated_calc_num_);
} else {
(void)dev_matrix_shape_.insert(dev_matrix_shape_.begin(), repeated_calc_num_);
}
}
// If repeated calculation, since the repeated_calc_num is added to the last dimension of the dev-matrix,
// If repeated calculation, and the repeated_calc_num is inserted to the last dimension of the dev-matrix,
// the index value of tensor map needs to be increased by 1.
void OperatorInfo::ResetTensorMapIfRepeatedCalc() {
if (repeated_calc_num_ <= 1) {
if ((repeated_calc_num_ <= 1) || !repeated_num_in_dev_matrix_right_) {
return;
}

View File

@ -214,7 +214,7 @@ class OperatorInfo {
StrategyPtr strategy_;
std::vector<TensorInfo> inputs_tensor_info_;
std::vector<TensorInfo> outputs_tensor_info_;
Shape dev_matrix_shape_; // if repeated calculation, it contains the repeated_calc_num as the first dimension
Shape dev_matrix_shape_; // if repeated calculation, it contains the repeated_calc_num_
int32_t repeated_calc_num_ = 1;
int32_t as_loss_divisor_ = 1;
TensorMaps inputs_tensor_map_;
@ -263,6 +263,8 @@ class OperatorInfo {
std::string refkey_parameter_name_;
CNodePtr cnode_;
int32_t used_devices_ = -1;
// the repeated_calc_num_ will be inserted to the last dimension of dev matrix in default
bool repeated_num_in_dev_matrix_right_ = true;
private:
OperatorCostPtr operator_cost_;

View File

@ -158,7 +158,10 @@ Status ReduceMethod::InferForwardCommunication() {
size_t size = stra.size();
// judge if the reduce dim is partitioned.
Shape group_creat_map;
if (dev_matrix_shape_.size() > size) {
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
// it need to handle the first dimention of map.
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
}
for (size_t index = 0; index < size; ++index) {
@ -169,6 +172,18 @@ Status ReduceMethod::InferForwardCommunication() {
}
group_creat_map.push_back(SizeToInt(size) - SizeToInt(index) - 1);
}
// if repeated calculation and the repeated_calc_num_ insert to the last dimension of dev matrix,
// it need to handle the group_creat_map and insert the 0 to the last dimension of the group_creat_map.
if (repeated_num_in_dev_matrix_right_ && (repeated_calc_num_ > 1)) {
for (auto &ele : group_creat_map) {
if (ele == MAP_NONE) {
continue;
}
ele += 1;
}
group_creat_map.push_back(0);
}
std::vector<Group> forward_group;
if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed.";
@ -220,9 +235,13 @@ Status ReduceMeanInfo::InferForwardCommunication() {
size_t size = stra.size();
// judge if the reduce dim is partitioned.
Shape group_creat_map;
if (dev_matrix_shape_.size() > size) {
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
// it need to handle the first dimention of map.
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
}
for (size_t index = 0; index < size; ++index) {
auto pos =
std::find_if(dim_list.begin(), dim_list.end(), [index](const int32_t &dim) { return SizeToInt(index) == dim; });
@ -231,6 +250,19 @@ Status ReduceMeanInfo::InferForwardCommunication() {
}
group_creat_map.push_back(SizeToInt(size) - SizeToInt(index) - 1);
}
// if repeated calculation and the repeated_calc_num_ insert to the last dimension of dev matrix,
// it need to handle the group_creat_map and insert the 0 to the last dimension of the group_creat_map.
if (repeated_num_in_dev_matrix_right_ && (repeated_calc_num_ > 1)) {
for (auto &ele : group_creat_map) {
if (ele == MAP_NONE) {
continue;
}
ele += 1;
}
group_creat_map.push_back(0);
}
std::vector<Group> forward_group;
if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed.";

View File

@ -83,7 +83,7 @@ TEST_F(TestOneHotInfo, InferDevMatrixShape2) {
ASSERT_EQ(status, SUCCESS);
Shape dev_matrix_shape = onehot_info->dev_matrix_shape();
Shape expect = {4, 1, 2};
Shape expect = {2, 4, 1};
ASSERT_EQ(dev_matrix_shape, expect);
}

View File

@ -83,7 +83,7 @@ TEST_F(TestOneHotInfo2, InferDevMatrixShape2) {
ASSERT_EQ(status, SUCCESS);
Shape dev_matrix_shape = onehot_info2->dev_matrix_shape();
Shape expect = {4, 1, 2};
Shape expect = {2, 4, 1};
ASSERT_EQ(dev_matrix_shape, expect);
}