!19431 compute attrs of inserted operator for conv2d transpose
Merge pull request !19431 from yangzhenzhang/compute-attrs-of-inserted-operator-for-conv2d-transpose
This commit is contained in:
commit
2b2925f0d0
|
@ -17,6 +17,7 @@
|
||||||
#include "frontend/parallel/ops_info/conv2d_info.h"
|
#include "frontend/parallel/ops_info/conv2d_info.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -262,11 +263,15 @@ Status Conv2DInfo::InferDevMatrixShape() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Conv2DInfo::InferRankBias() {
|
Status Conv2DInfo::InferRankBias() {
|
||||||
|
// the Conv2D operator:
|
||||||
// the origin dev_matrix is [n, i, h, w, o]
|
// the origin dev_matrix is [n, i, h, w, o]
|
||||||
// if repeated calculation and repeated num in the left of dev matrix, the dev_matrix is [repeated_num, n, i, h, w, o]
|
// if repeated calculation and repeated num in the left of dev matrix, the dev_matrix is [repeated_num, n, i, h, w, o]
|
||||||
// if repeated calculation and repeated num in the right of dev matrix, the dev_matrix is [n, i, h, w, o,
|
// if repeated calculation and repeated num in the right of dev matrix, the dev_matrix is [n, i, h, w, o,
|
||||||
// repeated_num] the rank_bias_ is the position of the current rank in the w dimension of the dev_matrix(have not
|
// repeated_num]
|
||||||
// split h dimension)
|
//
|
||||||
|
// the Conv2DBackpropInput's origin dev_matrix is [n, o, h, w, i], w dimension's relative position is the same as
|
||||||
|
// Conv2D, the rank_bias_ is the position of the current rank in the w dimension of the dev_matrix(have not split h
|
||||||
|
// dimension)
|
||||||
if (!need_exchange_overlap_) {
|
if (!need_exchange_overlap_) {
|
||||||
MS_LOG(INFO) << name_ << ": No need to infer rank bias";
|
MS_LOG(INFO) << name_ << ": No need to infer rank bias";
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
|
@ -424,32 +429,43 @@ Status Conv2DInfo::InferForwardCommunication() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Conv2DInfo::InferNewOperatorAttrs() {
|
void Conv2DInfo::InferNewPadList() {
|
||||||
// send/recv flag and new_pad_list
|
|
||||||
new_pad_list_ = pad_list_;
|
new_pad_list_ = pad_list_;
|
||||||
|
if (rank_bias_ == 0) { // the first rank
|
||||||
|
new_pad_list_[3] = 0; // no need the right pad
|
||||||
|
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
|
||||||
|
new_pad_list_[2] = 0; // no need the left pad
|
||||||
|
} else { // the middle rank
|
||||||
|
new_pad_list_[2] = 0; // no need the left pad
|
||||||
|
new_pad_list_[3] = 0; // no need the right pad
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Conv2DInfo::InferNewOperatorAttrs() {
|
||||||
|
// new_pad_list
|
||||||
|
InferNewPadList();
|
||||||
|
|
||||||
|
// send/recv flag
|
||||||
if (rank_bias_ == 0) { // the first rank
|
if (rank_bias_ == 0) { // the first rank
|
||||||
left_need_send_ = false;
|
left_need_send_ = false;
|
||||||
left_need_recv_ = false;
|
left_need_recv_ = false;
|
||||||
right_need_send_ = (right_rank_overlap_left_size_ > 0);
|
right_need_send_ = (right_rank_overlap_left_size_ > 0);
|
||||||
right_need_recv_ = (overlap_right_size_ > 0);
|
right_need_recv_ = (overlap_right_size_ > 0); // no need the right pad
|
||||||
new_pad_list_[3] = 0; // no need the right pad
|
|
||||||
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
|
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
|
||||||
left_need_send_ = (left_rank_overlap_right_size_ > 0);
|
left_need_send_ = (left_rank_overlap_right_size_ > 0);
|
||||||
left_need_recv_ = (overlap_left_size_ > 0);
|
left_need_recv_ = (overlap_left_size_ > 0);
|
||||||
right_need_send_ = false;
|
right_need_send_ = false;
|
||||||
right_need_recv_ = false;
|
right_need_recv_ = false;
|
||||||
new_pad_list_[2] = 0; // no need the left pad
|
} else { // the middle rank
|
||||||
} else { // the middle rank
|
|
||||||
left_need_send_ = (left_rank_overlap_right_size_ > 0);
|
left_need_send_ = (left_rank_overlap_right_size_ > 0);
|
||||||
left_need_recv_ = (overlap_left_size_ > 0);
|
left_need_recv_ = (overlap_left_size_ > 0);
|
||||||
right_need_send_ = (right_rank_overlap_left_size_ > 0);
|
right_need_send_ = (right_rank_overlap_left_size_ > 0);
|
||||||
right_need_recv_ = (overlap_right_size_ > 0);
|
right_need_recv_ = (overlap_right_size_ > 0);
|
||||||
new_pad_list_[2] = 0; // no need the left pad
|
|
||||||
new_pad_list_[3] = 0; // no need the right pad
|
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << name_ << ": The left need send is " << left_need_send_ << ", the left need recv is "
|
MS_LOG(INFO) << name_ << ": The left need send is " << left_need_send_ << ", the left need recv is "
|
||||||
<< left_need_recv_ << ", the right need send is " << right_need_send_ << ", the right need recv is "
|
<< left_need_recv_ << ", the right need send is " << right_need_send_ << ", the right need recv is "
|
||||||
<< right_need_recv_ << ", the new pad list is " << new_pad_list_;
|
<< right_need_recv_;
|
||||||
|
|
||||||
// the exchange rank ids
|
// the exchange rank ids
|
||||||
if (left_need_send_ || left_need_recv_) {
|
if (left_need_send_ || left_need_recv_) {
|
||||||
|
@ -598,6 +614,7 @@ Status Conv2DBackpropInputInfo::GetAttrs() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
|
Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
|
need_exchange_overlap_ = false;
|
||||||
if (CheckStrategyBase(strategy) != SUCCESS) {
|
if (CheckStrategyBase(strategy) != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
@ -617,6 +634,10 @@ Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// kernel size larger than stride and the w dimension is split, need to exchange overlap
|
||||||
|
if ((kernel_size_[1] > stride_[3]) && (input_strategy[3] > 1)) {
|
||||||
|
need_exchange_overlap_ = true;
|
||||||
|
}
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -654,6 +675,8 @@ Status Conv2DBackpropInputInfo::InferDevMatrixShape() {
|
||||||
out_slice_shape_[i] = out_slice_shape_[i] / out_strategy[i];
|
out_slice_shape_[i] = out_slice_shape_[i] / out_strategy[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
w_dimension_shard_num_ = stra[0][3];
|
||||||
|
input_slice_shape_ = GetSliceShape(inputs_shape_[0], stra[0]);
|
||||||
MS_LOG(INFO) << name_ << ": The output slice shape is " << out_slice_shape_;
|
MS_LOG(INFO) << name_ << ": The output slice shape is " << out_slice_shape_;
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
@ -736,5 +759,135 @@ void Conv2DBackpropInputInfo::UpdateOutShape(const CNodePtr &cnode) {
|
||||||
(void)manager->Replace(cnode->input(3), val);
|
(void)manager->Replace(cnode->input(3), val);
|
||||||
MS_LOG(INFO) << name_ << ": Update the output shape " << out_slice_shape_;
|
MS_LOG(INFO) << name_ << ": Update the output shape " << out_slice_shape_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64_t Conv2DBackpropInputInfo::ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) {
|
||||||
|
// 1. the first rank: 0
|
||||||
|
// 2. the last rank:
|
||||||
|
// size of origin data required by current rank: a = ceil((o/n + k - o + w*s - s - x)/s)
|
||||||
|
// data size of the current rank: b = w/n
|
||||||
|
// return a - b = ceil((o/n + k - o + w*s - s - x)/s) - w/n
|
||||||
|
// 3. the middle rank:
|
||||||
|
// r*w/n - ceil((r*o/n - k + x + 1)/s)
|
||||||
|
if (rank_bias == 0) { // the first rank
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t w_output_shape = outputs_shape_[0][3];
|
||||||
|
int64_t w_input_shape = inputs_shape_[0][3];
|
||||||
|
int64_t w_kernel_size = kernel_size_[1];
|
||||||
|
int64_t w_stride = stride_[3];
|
||||||
|
int64_t left_pad = pad_list_[2];
|
||||||
|
if (rank_bias == w_dimension_shard_num_ - 1) { // the last rank
|
||||||
|
return DoubleToLong(std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + w_kernel_size -
|
||||||
|
w_output_shape + w_input_shape * w_stride - w_stride - left_pad) /
|
||||||
|
LongToDouble(w_stride))) -
|
||||||
|
w_input_shape / w_dimension_shard_num_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// the middle rank
|
||||||
|
return rank_bias * w_input_shape / w_dimension_shard_num_ -
|
||||||
|
DoubleToLong(
|
||||||
|
std::ceil(LongToDouble(rank_bias * w_output_shape / w_dimension_shard_num_ - w_kernel_size + left_pad + 1) /
|
||||||
|
LongToDouble(w_stride)));
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Conv2DBackpropInputInfo::ComputeOverlapRightSizeByRankBias(int64_t rank_bias) {
|
||||||
|
// 1. the first rank: ceil((o/n + x)/s) - w/n
|
||||||
|
// 2. the last rank: 0
|
||||||
|
// 3. the middle rank: ceil((r*o/n + o/n + x)/s) - r*w/n - w/n
|
||||||
|
int64_t w_output_shape = outputs_shape_[0][3];
|
||||||
|
int64_t w_input_shape = inputs_shape_[0][3];
|
||||||
|
int64_t w_stride = stride_[3];
|
||||||
|
int64_t left_pad = pad_list_[2];
|
||||||
|
|
||||||
|
if (rank_bias == 0) { // the first rank
|
||||||
|
return DoubleToLong(
|
||||||
|
std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + left_pad) / LongToDouble(w_stride))) -
|
||||||
|
w_input_shape / w_dimension_shard_num_;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rank_bias == w_dimension_shard_num_ - 1) { // the last rank
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// the middle rank
|
||||||
|
return DoubleToLong(std::ceil(LongToDouble(rank_bias * w_output_shape / w_dimension_shard_num_ +
|
||||||
|
w_output_shape / w_dimension_shard_num_ + left_pad) /
|
||||||
|
LongToDouble(w_stride))) -
|
||||||
|
(rank_bias + 1) * w_input_shape / w_dimension_shard_num_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Conv2DBackpropInputInfo::InferNewPadList() {
|
||||||
|
// 1. compute the size of origin data required by current rank:
|
||||||
|
// 1) the first rank: ceil((o/n + x) / s)
|
||||||
|
// 2) the last rank: ceil((o/n + k - o + ws - s - x) / s)
|
||||||
|
// 3) the middle rank: ceil((r*o/n + o/n + x) / s) - ceil((r*o/n - k + x + 1) / s)
|
||||||
|
//
|
||||||
|
// 2. compute the real left pad
|
||||||
|
// 1) the first rank: k - x - 1
|
||||||
|
// 2) the last rank:
|
||||||
|
// if (o/n + k - o + ws - s - x) is divisible by s, real_left_pad = s - 1.
|
||||||
|
// otherwise, real_left_pad = (o/n + k - o + ws - s - x) % s - 1
|
||||||
|
// 3) the middle rank:
|
||||||
|
// if (r*on - k + x + 1) is divisible by s, real_left_pad = 0.
|
||||||
|
// otherwise, real_left_pad = s - (r*on - k + x + 1) % s
|
||||||
|
int64_t w_output_shape = outputs_shape_[0][3];
|
||||||
|
int64_t w_input_shape = inputs_shape_[0][3];
|
||||||
|
int64_t w_kernel_size = kernel_size_[1];
|
||||||
|
int64_t w_stride = stride_[3];
|
||||||
|
int64_t left_pad = pad_list_[2];
|
||||||
|
int64_t current_rank_required_size = 0;
|
||||||
|
int64_t real_left_pad = 0;
|
||||||
|
|
||||||
|
if (rank_bias_ == 0) { // the first rank
|
||||||
|
current_rank_required_size = DoubleToLong(
|
||||||
|
std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + left_pad) / LongToDouble(w_stride)));
|
||||||
|
|
||||||
|
real_left_pad = w_kernel_size - left_pad - 1;
|
||||||
|
|
||||||
|
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
|
||||||
|
current_rank_required_size =
|
||||||
|
DoubleToLong(std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + w_kernel_size - w_output_shape +
|
||||||
|
w_input_shape * w_stride - w_stride - left_pad) /
|
||||||
|
LongToDouble(w_stride)));
|
||||||
|
|
||||||
|
int64_t tmp = w_output_shape / w_dimension_shard_num_ + w_kernel_size - w_output_shape + w_input_shape * w_stride -
|
||||||
|
w_stride - left_pad;
|
||||||
|
if (tmp % w_stride == 0) {
|
||||||
|
real_left_pad = w_stride - 1;
|
||||||
|
} else {
|
||||||
|
real_left_pad = tmp % w_stride - 1;
|
||||||
|
}
|
||||||
|
} else { // the middle rank
|
||||||
|
current_rank_required_size =
|
||||||
|
DoubleToLong(std::ceil(LongToDouble(rank_bias_ * w_output_shape / w_dimension_shard_num_ +
|
||||||
|
w_output_shape / w_dimension_shard_num_ + left_pad) /
|
||||||
|
LongToDouble(w_stride))) -
|
||||||
|
DoubleToLong(
|
||||||
|
std::ceil(LongToDouble(rank_bias_ * w_output_shape / w_dimension_shard_num_ - w_kernel_size + left_pad + 1) /
|
||||||
|
LongToDouble(w_stride)));
|
||||||
|
|
||||||
|
int64_t tmp = rank_bias_ * w_output_shape / w_dimension_shard_num_ - w_kernel_size + left_pad + 1;
|
||||||
|
if (tmp % w_stride == 0) {
|
||||||
|
real_left_pad = 0;
|
||||||
|
} else {
|
||||||
|
real_left_pad = w_stride - tmp % w_stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. compute the pad_add: (current_rank_required_size - 1) * s + k - o/n
|
||||||
|
int64_t pad_all =
|
||||||
|
(current_rank_required_size - 1) * w_stride + w_kernel_size - w_output_shape / w_dimension_shard_num_;
|
||||||
|
|
||||||
|
// 4. compute new left pad: k - real_left_pad - 1
|
||||||
|
new_pad_list_ = pad_list_;
|
||||||
|
new_pad_list_[2] = w_kernel_size - real_left_pad - 1;
|
||||||
|
|
||||||
|
// 5. compute new right pad: pad_all - new_left_pad
|
||||||
|
new_pad_list_[3] = pad_all - new_pad_list_[2];
|
||||||
|
|
||||||
|
MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_ << ", the required size of current rank is "
|
||||||
|
<< current_rank_required_size << ", new pad all is " << pad_all;
|
||||||
|
}
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -95,10 +95,10 @@ class Conv2DInfo : public OperatorInfo {
|
||||||
std::vector<int64_t> exchange_rank_ids_;
|
std::vector<int64_t> exchange_rank_ids_;
|
||||||
Shapes recv_shapes_;
|
Shapes recv_shapes_;
|
||||||
|
|
||||||
private:
|
virtual Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
|
||||||
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
|
virtual void InferNewPadList();
|
||||||
int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);
|
virtual int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);
|
||||||
int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias);
|
virtual int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias);
|
||||||
};
|
};
|
||||||
|
|
||||||
class Conv2DBackpropInputInfo : public Conv2DInfo {
|
class Conv2DBackpropInputInfo : public Conv2DInfo {
|
||||||
|
@ -117,8 +117,12 @@ class Conv2DBackpropInputInfo : public Conv2DInfo {
|
||||||
Status InferTensorMap() override;
|
Status InferTensorMap() override;
|
||||||
Status InferMirrorOps() override; // can not use OperatorInfo::InferMirrorOps(), since the 'out_shape' is not tensor
|
Status InferMirrorOps() override; // can not use OperatorInfo::InferMirrorOps(), since the 'out_shape' is not tensor
|
||||||
|
|
||||||
private:
|
|
||||||
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
|
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
|
||||||
|
void InferNewPadList();
|
||||||
|
int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);
|
||||||
|
int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias);
|
||||||
|
|
||||||
|
private:
|
||||||
Shape out_shape_;
|
Shape out_shape_;
|
||||||
Shape out_slice_shape_;
|
Shape out_slice_shape_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -80,6 +80,13 @@ inline int FloatToInt(float u) {
|
||||||
return static_cast<int>(u);
|
return static_cast<int>(u);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline int64_t DoubleToLong(double u) {
|
||||||
|
if (u > static_cast<double>((std::numeric_limits<int64_t>::max)())) {
|
||||||
|
MS_LOG(EXCEPTION) << "The double value(" << u << ") exceeds the maximum value of int64_t.";
|
||||||
|
}
|
||||||
|
return static_cast<int64_t>(u);
|
||||||
|
}
|
||||||
|
|
||||||
inline float SizeToFloat(size_t v) { return static_cast<float>(v); }
|
inline float SizeToFloat(size_t v) { return static_cast<float>(v); }
|
||||||
|
|
||||||
inline double LongToDouble(int64_t v) { return static_cast<double>(v); }
|
inline double LongToDouble(int64_t v) { return static_cast<double>(v); }
|
||||||
|
|
Loading…
Reference in New Issue