!19431 compute attrs of inserted operator for conv2d transpose

Merge pull request !19431 from yangzhenzhang/compute-attrs-of-inserted-operator-for-conv2d-transpose
This commit is contained in:
i-robot 2021-07-07 02:26:04 +00:00 committed by Gitee
commit 2b2925f0d0
3 changed files with 180 additions and 16 deletions

View File

@ -17,6 +17,7 @@
#include "frontend/parallel/ops_info/conv2d_info.h" #include "frontend/parallel/ops_info/conv2d_info.h"
#include <algorithm> #include <algorithm>
#include <cmath>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -262,11 +263,15 @@ Status Conv2DInfo::InferDevMatrixShape() {
} }
Status Conv2DInfo::InferRankBias() { Status Conv2DInfo::InferRankBias() {
// the Conv2D operator:
// the origin dev_matrix is [n, i, h, w, o] // the origin dev_matrix is [n, i, h, w, o]
// if repeated calculation and repeated num in the left of dev matrix, the dev_matrix is [repeated_num, n, i, h, w, o] // if repeated calculation and repeated num in the left of dev matrix, the dev_matrix is [repeated_num, n, i, h, w, o]
// if repeated calculation and repeated num in the right of dev matrix, the dev_matrix is [n, i, h, w, o, // if repeated calculation and repeated num in the right of dev matrix, the dev_matrix is [n, i, h, w, o,
// repeated_num] the rank_bias_ is the position of the current rank in the w dimension of the dev_matrix(have not // repeated_num]
// split h dimension) //
// the Conv2DBackpropInput's origin dev_matrix is [n, o, h, w, i], w dimension's relative position is the same as
// Conv2D, the rank_bias_ is the position of the current rank in the w dimension of the dev_matrix(have not split h
// dimension)
if (!need_exchange_overlap_) { if (!need_exchange_overlap_) {
MS_LOG(INFO) << name_ << ": No need to infer rank bias"; MS_LOG(INFO) << name_ << ": No need to infer rank bias";
return SUCCESS; return SUCCESS;
@ -424,32 +429,43 @@ Status Conv2DInfo::InferForwardCommunication() {
return SUCCESS; return SUCCESS;
} }
void Conv2DInfo::InferNewOperatorAttrs() { void Conv2DInfo::InferNewPadList() {
// send/recv flag and new_pad_list
new_pad_list_ = pad_list_; new_pad_list_ = pad_list_;
if (rank_bias_ == 0) { // the first rank
new_pad_list_[3] = 0; // no need the right pad
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
new_pad_list_[2] = 0; // no need the left pad
} else { // the middle rank
new_pad_list_[2] = 0; // no need the left pad
new_pad_list_[3] = 0; // no need the right pad
}
MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_;
}
void Conv2DInfo::InferNewOperatorAttrs() {
// new_pad_list
InferNewPadList();
// send/recv flag
if (rank_bias_ == 0) { // the first rank if (rank_bias_ == 0) { // the first rank
left_need_send_ = false; left_need_send_ = false;
left_need_recv_ = false; left_need_recv_ = false;
right_need_send_ = (right_rank_overlap_left_size_ > 0); right_need_send_ = (right_rank_overlap_left_size_ > 0);
right_need_recv_ = (overlap_right_size_ > 0); right_need_recv_ = (overlap_right_size_ > 0); // no need the right pad
new_pad_list_[3] = 0; // no need the right pad
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank } else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
left_need_send_ = (left_rank_overlap_right_size_ > 0); left_need_send_ = (left_rank_overlap_right_size_ > 0);
left_need_recv_ = (overlap_left_size_ > 0); left_need_recv_ = (overlap_left_size_ > 0);
right_need_send_ = false; right_need_send_ = false;
right_need_recv_ = false; right_need_recv_ = false;
new_pad_list_[2] = 0; // no need the left pad } else { // the middle rank
} else { // the middle rank
left_need_send_ = (left_rank_overlap_right_size_ > 0); left_need_send_ = (left_rank_overlap_right_size_ > 0);
left_need_recv_ = (overlap_left_size_ > 0); left_need_recv_ = (overlap_left_size_ > 0);
right_need_send_ = (right_rank_overlap_left_size_ > 0); right_need_send_ = (right_rank_overlap_left_size_ > 0);
right_need_recv_ = (overlap_right_size_ > 0); right_need_recv_ = (overlap_right_size_ > 0);
new_pad_list_[2] = 0; // no need the left pad
new_pad_list_[3] = 0; // no need the right pad
} }
MS_LOG(INFO) << name_ << ": The left need send is " << left_need_send_ << ", the left need recv is " MS_LOG(INFO) << name_ << ": The left need send is " << left_need_send_ << ", the left need recv is "
<< left_need_recv_ << ", the right need send is " << right_need_send_ << ", the right need recv is " << left_need_recv_ << ", the right need send is " << right_need_send_ << ", the right need recv is "
<< right_need_recv_ << ", the new pad list is " << new_pad_list_; << right_need_recv_;
// the exchange rank ids // the exchange rank ids
if (left_need_send_ || left_need_recv_) { if (left_need_send_ || left_need_recv_) {
@ -598,6 +614,7 @@ Status Conv2DBackpropInputInfo::GetAttrs() {
} }
Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) { Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
need_exchange_overlap_ = false;
if (CheckStrategyBase(strategy) != SUCCESS) { if (CheckStrategyBase(strategy) != SUCCESS) {
return FAILED; return FAILED;
} }
@ -617,6 +634,10 @@ Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
} }
// kernel size larger than stride and the w dimension is split, need to exchange overlap
if ((kernel_size_[1] > stride_[3]) && (input_strategy[3] > 1)) {
need_exchange_overlap_ = true;
}
return SUCCESS; return SUCCESS;
} }
@ -654,6 +675,8 @@ Status Conv2DBackpropInputInfo::InferDevMatrixShape() {
out_slice_shape_[i] = out_slice_shape_[i] / out_strategy[i]; out_slice_shape_[i] = out_slice_shape_[i] / out_strategy[i];
} }
w_dimension_shard_num_ = stra[0][3];
input_slice_shape_ = GetSliceShape(inputs_shape_[0], stra[0]);
MS_LOG(INFO) << name_ << ": The output slice shape is " << out_slice_shape_; MS_LOG(INFO) << name_ << ": The output slice shape is " << out_slice_shape_;
return SUCCESS; return SUCCESS;
} }
@ -736,5 +759,135 @@ void Conv2DBackpropInputInfo::UpdateOutShape(const CNodePtr &cnode) {
(void)manager->Replace(cnode->input(3), val); (void)manager->Replace(cnode->input(3), val);
MS_LOG(INFO) << name_ << ": Update the output shape " << out_slice_shape_; MS_LOG(INFO) << name_ << ": Update the output shape " << out_slice_shape_;
} }
int64_t Conv2DBackpropInputInfo::ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) {
// 1. the first rank: 0
// 2. the last rank:
// size of origin data required by current rank: a = ceil((o/n + k - o + w*s - s - x)/s)
// data size of the current rank: b = w/n
// return a - b = ceil((o/n + k - o + w*s - s - x)/s) - w/n
// 3. the middle rank:
// r*w/n - ceil((r*o/n - k + x + 1)/s)
if (rank_bias == 0) { // the first rank
return 0;
}
int64_t w_output_shape = outputs_shape_[0][3];
int64_t w_input_shape = inputs_shape_[0][3];
int64_t w_kernel_size = kernel_size_[1];
int64_t w_stride = stride_[3];
int64_t left_pad = pad_list_[2];
if (rank_bias == w_dimension_shard_num_ - 1) { // the last rank
return DoubleToLong(std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + w_kernel_size -
w_output_shape + w_input_shape * w_stride - w_stride - left_pad) /
LongToDouble(w_stride))) -
w_input_shape / w_dimension_shard_num_;
}
// the middle rank
return rank_bias * w_input_shape / w_dimension_shard_num_ -
DoubleToLong(
std::ceil(LongToDouble(rank_bias * w_output_shape / w_dimension_shard_num_ - w_kernel_size + left_pad + 1) /
LongToDouble(w_stride)));
}
int64_t Conv2DBackpropInputInfo::ComputeOverlapRightSizeByRankBias(int64_t rank_bias) {
// 1. the first rank: ceil((o/n + x)/s) - w/n
// 2. the last rank: 0
// 3. the middle rank: ceil((r*o/n + o/n + x)/s) - r*w/n - w/n
int64_t w_output_shape = outputs_shape_[0][3];
int64_t w_input_shape = inputs_shape_[0][3];
int64_t w_stride = stride_[3];
int64_t left_pad = pad_list_[2];
if (rank_bias == 0) { // the first rank
return DoubleToLong(
std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + left_pad) / LongToDouble(w_stride))) -
w_input_shape / w_dimension_shard_num_;
}
if (rank_bias == w_dimension_shard_num_ - 1) { // the last rank
return 0;
}
// the middle rank
return DoubleToLong(std::ceil(LongToDouble(rank_bias * w_output_shape / w_dimension_shard_num_ +
w_output_shape / w_dimension_shard_num_ + left_pad) /
LongToDouble(w_stride))) -
(rank_bias + 1) * w_input_shape / w_dimension_shard_num_;
}
void Conv2DBackpropInputInfo::InferNewPadList() {
// 1. compute the size of origin data required by current rank:
// 1) the first rank: ceil((o/n + x) / s)
// 2) the last rank: ceil((o/n + k - o + ws - s - x) / s)
// 3) the middle rank: ceil((r*o/n + o/n + x) / s) - ceil((r*o/n - k + x + 1) / s)
//
// 2. compute the real left pad
// 1) the first rank: k - x - 1
// 2) the last rank:
// if (o/n + k - o + ws - s - x) is divisible by s, real_left_pad = s - 1.
// otherwise, real_left_pad = (o/n + k - o + ws - s - x) % s - 1
// 3) the middle rank:
// if (r*on - k + x + 1) is divisible by s, real_left_pad = 0.
// otherwise, real_left_pad = s - (r*on - k + x + 1) % s
int64_t w_output_shape = outputs_shape_[0][3];
int64_t w_input_shape = inputs_shape_[0][3];
int64_t w_kernel_size = kernel_size_[1];
int64_t w_stride = stride_[3];
int64_t left_pad = pad_list_[2];
int64_t current_rank_required_size = 0;
int64_t real_left_pad = 0;
if (rank_bias_ == 0) { // the first rank
current_rank_required_size = DoubleToLong(
std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + left_pad) / LongToDouble(w_stride)));
real_left_pad = w_kernel_size - left_pad - 1;
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
current_rank_required_size =
DoubleToLong(std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + w_kernel_size - w_output_shape +
w_input_shape * w_stride - w_stride - left_pad) /
LongToDouble(w_stride)));
int64_t tmp = w_output_shape / w_dimension_shard_num_ + w_kernel_size - w_output_shape + w_input_shape * w_stride -
w_stride - left_pad;
if (tmp % w_stride == 0) {
real_left_pad = w_stride - 1;
} else {
real_left_pad = tmp % w_stride - 1;
}
} else { // the middle rank
current_rank_required_size =
DoubleToLong(std::ceil(LongToDouble(rank_bias_ * w_output_shape / w_dimension_shard_num_ +
w_output_shape / w_dimension_shard_num_ + left_pad) /
LongToDouble(w_stride))) -
DoubleToLong(
std::ceil(LongToDouble(rank_bias_ * w_output_shape / w_dimension_shard_num_ - w_kernel_size + left_pad + 1) /
LongToDouble(w_stride)));
int64_t tmp = rank_bias_ * w_output_shape / w_dimension_shard_num_ - w_kernel_size + left_pad + 1;
if (tmp % w_stride == 0) {
real_left_pad = 0;
} else {
real_left_pad = w_stride - tmp % w_stride;
}
}
// 3. compute the pad_add: (current_rank_required_size - 1) * s + k - o/n
int64_t pad_all =
(current_rank_required_size - 1) * w_stride + w_kernel_size - w_output_shape / w_dimension_shard_num_;
// 4. compute new left pad: k - real_left_pad - 1
new_pad_list_ = pad_list_;
new_pad_list_[2] = w_kernel_size - real_left_pad - 1;
// 5. compute new right pad: pad_all - new_left_pad
new_pad_list_[3] = pad_all - new_pad_list_[2];
MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_ << ", the required size of current rank is "
<< current_rank_required_size << ", new pad all is " << pad_all;
}
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

View File

@ -95,10 +95,10 @@ class Conv2DInfo : public OperatorInfo {
std::vector<int64_t> exchange_rank_ids_; std::vector<int64_t> exchange_rank_ids_;
Shapes recv_shapes_; Shapes recv_shapes_;
private: virtual Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy); virtual void InferNewPadList();
int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias); virtual int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);
int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias); virtual int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias);
}; };
class Conv2DBackpropInputInfo : public Conv2DInfo { class Conv2DBackpropInputInfo : public Conv2DInfo {
@ -117,8 +117,12 @@ class Conv2DBackpropInputInfo : public Conv2DInfo {
Status InferTensorMap() override; Status InferTensorMap() override;
Status InferMirrorOps() override; // can not use OperatorInfo::InferMirrorOps(), since the 'out_shape' is not tensor Status InferMirrorOps() override; // can not use OperatorInfo::InferMirrorOps(), since the 'out_shape' is not tensor
private:
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy); Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
void InferNewPadList();
int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);
int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias);
private:
Shape out_shape_; Shape out_shape_;
Shape out_slice_shape_; Shape out_slice_shape_;
}; };

View File

@ -80,6 +80,13 @@ inline int FloatToInt(float u) {
return static_cast<int>(u); return static_cast<int>(u);
} }
inline int64_t DoubleToLong(double u) {
if (u > static_cast<double>((std::numeric_limits<int64_t>::max)())) {
MS_LOG(EXCEPTION) << "The double value(" << u << ") exceeds the maximum value of int64_t.";
}
return static_cast<int64_t>(u);
}
inline float SizeToFloat(size_t v) { return static_cast<float>(v); } inline float SizeToFloat(size_t v) { return static_cast<float>(v); }
inline double LongToDouble(int64_t v) { return static_cast<double>(v); } inline double LongToDouble(int64_t v) { return static_cast<double>(v); }