forked from mindspore-Ecosystem/mindspore
update repeated calculation
This commit is contained in:
parent
359543d663
commit
eb6f4e3ce8
|
@ -205,6 +205,7 @@ Status MatMulBase::InferDevMatrixShape() {
|
|||
Dimensions mat_b_strategy = stra.at(1);
|
||||
|
||||
SetDevMatrixShape(mat_a_strategy, mat_b_strategy, transpose_b_, &dev_matrix_shape_);
|
||||
origin_dev_matrix_shape_ = dev_matrix_shape_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -236,10 +237,11 @@ Status MatMulBase::InferMirrorOps() {
|
|||
|
||||
Status MatMulBase::InferForwardCommunication() {
|
||||
forward_op_.clear();
|
||||
size_t dimension = dev_matrix_shape_.size();
|
||||
size_t dimension = origin_dev_matrix_shape_.size();
|
||||
size_t relevant_dimension_index = SECOND_FROM_END(dimension);
|
||||
// Relevant dimension is not split and all reduce is not required
|
||||
if (dev_matrix_shape_.at(relevant_dimension_index) == MIN_SLICE_NUM) {
|
||||
// Relevant dimension is not split and all reduce is not required,
|
||||
// need to use origin_dev_matrix_shape_ here, since the dev_matrix_shape_ will be changed if repeated calculation.
|
||||
if (origin_dev_matrix_shape_.at(relevant_dimension_index) == MIN_SLICE_NUM) {
|
||||
MS_LOG(INFO) << name_ << " : Forward all reduce is not required.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
|
|
@ -65,6 +65,7 @@ class MatMulBase : public OperatorInfo {
|
|||
int32_t field_size_ = 0;
|
||||
size_t mat_a_dimension_ = 0;
|
||||
size_t mat_b_dimension_ = 0;
|
||||
Shape origin_dev_matrix_shape_;
|
||||
};
|
||||
|
||||
class MatMul : public MatMulBase {
|
||||
|
|
|
@ -74,7 +74,7 @@ Status OneHotInfo::InferDevMatrixShape() {
|
|||
dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable
|
||||
}
|
||||
old_dev_matrix_back_ = dev_matrix_shape_.back();
|
||||
|
||||
repeated_num_in_dev_matrix_right_ = false;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -164,21 +164,24 @@ Status OperatorInfo::InferRepeatedCalcInfo() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
// If repeated calculation, need to set the repeated_calc_num as the last dimension of dev-matrix,
|
||||
// only use for infer tensor layout. Because if the previous shard is (a, b), and the next shard is
|
||||
// (a, 1), adding the repeated_calc_num to the last dimension of dev-matrix, there is no need to redistribution.
|
||||
// If repeated calculation, set the repeated_calc_num as the last dimension of dev-matrix in default,
|
||||
// because if the previous shard is (a, b), and the next shard is (a, 1), adding the repeated_calc_num
|
||||
// to the last dimension of dev-matrix, there is no need to redistribution.
|
||||
void OperatorInfo::SetRepeatedCalcDevMatrix() {
|
||||
if (repeated_calc_num_ <= 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
(void)dev_matrix_shape_.push_back(repeated_calc_num_);
|
||||
if (repeated_num_in_dev_matrix_right_) {
|
||||
dev_matrix_shape_.push_back(repeated_calc_num_);
|
||||
} else {
|
||||
(void)dev_matrix_shape_.insert(dev_matrix_shape_.begin(), repeated_calc_num_);
|
||||
}
|
||||
}
|
||||
|
||||
// If repeated calculation, since the repeated_calc_num is added to the last dimension of the dev-matrix,
|
||||
// If repeated calculation, and the repeated_calc_num is inserted to the last dimension of the dev-matrix,
|
||||
// the index value of tensor map needs to be increased by 1.
|
||||
void OperatorInfo::ResetTensorMapIfRepeatedCalc() {
|
||||
if (repeated_calc_num_ <= 1) {
|
||||
if ((repeated_calc_num_ <= 1) || !repeated_num_in_dev_matrix_right_) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -214,7 +214,7 @@ class OperatorInfo {
|
|||
StrategyPtr strategy_;
|
||||
std::vector<TensorInfo> inputs_tensor_info_;
|
||||
std::vector<TensorInfo> outputs_tensor_info_;
|
||||
Shape dev_matrix_shape_; // if repeated calculation, it contains the repeated_calc_num as the first dimension
|
||||
Shape dev_matrix_shape_; // if repeated calculation, it contains the repeated_calc_num_
|
||||
int32_t repeated_calc_num_ = 1;
|
||||
int32_t as_loss_divisor_ = 1;
|
||||
TensorMaps inputs_tensor_map_;
|
||||
|
@ -263,6 +263,8 @@ class OperatorInfo {
|
|||
std::string refkey_parameter_name_;
|
||||
CNodePtr cnode_;
|
||||
int32_t used_devices_ = -1;
|
||||
// the repeated_calc_num_ will be inserted to the last dimension of dev matrix in default
|
||||
bool repeated_num_in_dev_matrix_right_ = true;
|
||||
|
||||
private:
|
||||
OperatorCostPtr operator_cost_;
|
||||
|
|
|
@ -158,7 +158,10 @@ Status ReduceMethod::InferForwardCommunication() {
|
|||
size_t size = stra.size();
|
||||
// judge if the reduce dim is partitioned.
|
||||
Shape group_creat_map;
|
||||
if (dev_matrix_shape_.size() > size) {
|
||||
|
||||
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
|
||||
// it need to handle the first dimention of map.
|
||||
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
|
||||
group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
|
||||
}
|
||||
for (size_t index = 0; index < size; ++index) {
|
||||
|
@ -169,6 +172,18 @@ Status ReduceMethod::InferForwardCommunication() {
|
|||
}
|
||||
group_creat_map.push_back(SizeToInt(size) - SizeToInt(index) - 1);
|
||||
}
|
||||
|
||||
// if repeated calculation and the repeated_calc_num_ insert to the last dimension of dev matrix,
|
||||
// it need to handle the group_creat_map and insert the 0 to the last dimension of the group_creat_map.
|
||||
if (repeated_num_in_dev_matrix_right_ && (repeated_calc_num_ > 1)) {
|
||||
for (auto &ele : group_creat_map) {
|
||||
if (ele == MAP_NONE) {
|
||||
continue;
|
||||
}
|
||||
ele += 1;
|
||||
}
|
||||
group_creat_map.push_back(0);
|
||||
}
|
||||
std::vector<Group> forward_group;
|
||||
if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed.";
|
||||
|
@ -220,9 +235,13 @@ Status ReduceMeanInfo::InferForwardCommunication() {
|
|||
size_t size = stra.size();
|
||||
// judge if the reduce dim is partitioned.
|
||||
Shape group_creat_map;
|
||||
if (dev_matrix_shape_.size() > size) {
|
||||
|
||||
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
|
||||
// it need to handle the first dimention of map.
|
||||
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
|
||||
group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
|
||||
}
|
||||
|
||||
for (size_t index = 0; index < size; ++index) {
|
||||
auto pos =
|
||||
std::find_if(dim_list.begin(), dim_list.end(), [index](const int32_t &dim) { return SizeToInt(index) == dim; });
|
||||
|
@ -231,6 +250,19 @@ Status ReduceMeanInfo::InferForwardCommunication() {
|
|||
}
|
||||
group_creat_map.push_back(SizeToInt(size) - SizeToInt(index) - 1);
|
||||
}
|
||||
|
||||
// if repeated calculation and the repeated_calc_num_ insert to the last dimension of dev matrix,
|
||||
// it need to handle the group_creat_map and insert the 0 to the last dimension of the group_creat_map.
|
||||
if (repeated_num_in_dev_matrix_right_ && (repeated_calc_num_ > 1)) {
|
||||
for (auto &ele : group_creat_map) {
|
||||
if (ele == MAP_NONE) {
|
||||
continue;
|
||||
}
|
||||
ele += 1;
|
||||
}
|
||||
group_creat_map.push_back(0);
|
||||
}
|
||||
|
||||
std::vector<Group> forward_group;
|
||||
if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed.";
|
||||
|
|
|
@ -83,7 +83,7 @@ TEST_F(TestOneHotInfo, InferDevMatrixShape2) {
|
|||
ASSERT_EQ(status, SUCCESS);
|
||||
Shape dev_matrix_shape = onehot_info->dev_matrix_shape();
|
||||
|
||||
Shape expect = {4, 1, 2};
|
||||
Shape expect = {2, 4, 1};
|
||||
ASSERT_EQ(dev_matrix_shape, expect);
|
||||
}
|
||||
|
||||
|
|
|
@ -83,7 +83,7 @@ TEST_F(TestOneHotInfo2, InferDevMatrixShape2) {
|
|||
ASSERT_EQ(status, SUCCESS);
|
||||
Shape dev_matrix_shape = onehot_info2->dev_matrix_shape();
|
||||
|
||||
Shape expect = {4, 1, 2};
|
||||
Shape expect = {2, 4, 1};
|
||||
ASSERT_EQ(dev_matrix_shape, expect);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue