compute attrs for conv2d

This commit is contained in:
yangzhenzhang 2021-06-28 16:39:04 +08:00
parent 1983ded03f
commit 03207aeeed
2 changed files with 171 additions and 16 deletions

View File

@ -257,6 +257,7 @@ Status Conv2DInfo::InferDevMatrixShape() {
dev_matrix_shape_ = stra[0];
dev_matrix_shape_.push_back(stra[1][0]);
w_dimension_shard_num_ = stra[0][3];
input_slice_shape_ = GetSliceShape(inputs_shape_[0], stra[0]);
return SUCCESS;
}
@ -290,34 +291,89 @@ Status Conv2DInfo::InferRankBias() {
return SUCCESS;
}
std::vector<int64_t>::iterator result = std::find(group_devices.begin(), group_devices.end(), rank);
rank_bias_ = std::distance(group_devices.begin(), result);
if (group_devices.size() != LongToSize(w_dimension_shard_num_)) {
MS_LOG(ERROR) << name_ << ": The devices' size of w dimension is " << group_devices.size()
<< ", but the shard num of w dimension is " << w_dimension_shard_num_;
return FAILED;
}
std::vector<int64_t>::iterator it = std::find(group_devices.begin(), group_devices.end(), rank);
if (it == group_devices.end()) {
MS_LOG(ERROR) << name_ << ": Can not find the current rank in device list of w dimension, the current rank is "
<< rank << ", the device list is " << group_devices;
return FAILED;
}
rank_bias_ = std::distance(group_devices.begin(), it);
if (it == group_devices.begin()) {
left_rank_bias_ = -1;
right_rank_bias_ = rank_bias_ + 1;
} else if (it == group_devices.end() - 1) {
left_rank_bias_ = rank_bias_ - 1;
right_rank_bias_ = -1;
} else {
left_rank_bias_ = rank_bias_ - 1;
right_rank_bias_ = rank_bias_ + 1;
}
MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the device list of w dimension is " << group_devices
<< ", the rank bias is " << rank_bias_;
<< ", the rank bias is " << rank_bias_ << ", the left rank bias is " << left_rank_bias_
<< ", the right rank bias is " << right_rank_bias_;
return SUCCESS;
}
int64_t Conv2DInfo::ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) {
int64_t left_pad = pad_list_[2];
int64_t w_dimension_input_shape = inputs_shape_[0][3];
int64_t w_dimension_output_shape = outputs_shape_[0][3];
int64_t w_stride = stride_[3];
return left_pad +
(w_dimension_input_shape - w_dimension_output_shape * w_stride) * rank_bias / w_dimension_shard_num_;
}
int64_t Conv2DInfo::ComputeOverlapRightSizeByRankBias(int64_t rank_bias) {
int64_t left_pad = pad_list_[2];
int64_t w_dimension_input_shape = inputs_shape_[0][3];
int64_t w_dimension_output_shape = outputs_shape_[0][3];
int64_t w_kernel_size = kernel_size_[1];
int64_t w_stride = stride_[3];
return (rank_bias + 1) * (w_dimension_output_shape * w_stride - w_dimension_input_shape) / w_dimension_shard_num_ +
w_kernel_size - w_stride - left_pad;
}
Status Conv2DInfo::InferOverlapSize() {
if (!need_exchange_overlap_) {
MS_LOG(INFO) << name_ << ": No need to infer overlap size";
return SUCCESS;
}
int64_t left_pad = pad_list_[2];
int64_t w_dimension_input_shape = inputs_shape_[0][3];
int64_t w_dimension_output_shape = outputs_shape_[0][3];
int64_t w_kernel_size = kernel_size_[1];
int64_t w_stride = stride_[3];
overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(rank_bias_);
overlap_right_size_ = ComputeOverlapRightSizeByRankBias(rank_bias_);
overlap_left_size_ =
left_pad + (w_dimension_input_shape - w_dimension_output_shape * w_stride) * rank_bias_ / w_dimension_shard_num_;
if (rank_bias_ == 0) { // it has not left rank
left_rank_overlap_left_size_ = 0;
left_rank_overlap_right_size_ = 0;
right_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(right_rank_bias_);
right_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(right_rank_bias_);
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // it has not right rank
left_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(left_rank_bias_);
left_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(left_rank_bias_);
right_rank_overlap_left_size_ = 0;
right_rank_overlap_right_size_ = 0;
} else { // it has left rank and right rank
left_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(left_rank_bias_);
left_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(left_rank_bias_);
right_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(right_rank_bias_);
right_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(right_rank_bias_);
}
overlap_right_size_ =
(rank_bias_ + 1) * (w_dimension_output_shape * w_stride - w_dimension_input_shape) / w_dimension_shard_num_ +
w_kernel_size - w_stride - left_pad;
MS_LOG(INFO) << name_ << ": the left overlap size is " << overlap_left_size_ << ", the right overlap size is "
<< overlap_right_size_;
MS_LOG(INFO) << name_ << ": the left overlap size of current rank is " << overlap_left_size_
<< ", the right overlap size of current rank is " << overlap_right_size_
<< ", the left overlap size of left rank is " << left_rank_overlap_left_size_
<< ", the right overlap size of left rank is " << left_rank_overlap_right_size_
<< ", the left overlap size of right rank is " << right_rank_overlap_left_size_
<< ", the right overlap size of right rank is " << right_rank_overlap_right_size_;
return SUCCESS;
}
@ -368,6 +424,77 @@ Status Conv2DInfo::InferForwardCommunication() {
return SUCCESS;
}
void Conv2DInfo::InferNewOperatorAttrs() {
// send/recv flag and new_pad_list
new_pad_list_ = pad_list_;
if (rank_bias_ == 0) { // the first rank
left_need_send_ = false;
left_need_recv_ = false;
right_need_send_ = (right_rank_overlap_left_size_ > 0);
right_need_recv_ = (overlap_right_size_ > 0);
new_pad_list_[3] = 0; // no need the right pad
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
left_need_send_ = (left_rank_overlap_right_size_ > 0);
left_need_recv_ = (overlap_left_size_ > 0);
right_need_send_ = false;
right_need_recv_ = false;
new_pad_list_[2] = 0; // no need the left pad
} else { // the middle rank
left_need_send_ = (left_rank_overlap_right_size_ > 0);
left_need_recv_ = (overlap_left_size_ > 0);
right_need_send_ = (right_rank_overlap_left_size_ > 0);
right_need_recv_ = (overlap_right_size_ > 0);
new_pad_list_[2] = 0; // no need the left pad
new_pad_list_[3] = 0; // no need the right pad
}
MS_LOG(INFO) << name_ << ": The left need send is " << left_need_send_ << ", the left need recv is "
<< left_need_recv_ << ", the right need send is " << right_need_send_ << ", the right need recv is "
<< right_need_recv_ << ", the new pad list is " << new_pad_list_;
// the exchange rank ids
if (left_need_send_ || left_need_recv_) {
exchange_rank_ids_.push_back(left_rank_bias_);
}
if (right_need_send_ || right_need_recv_) {
exchange_rank_ids_.push_back(right_rank_bias_);
}
MS_LOG(INFO) << name_ << ": The exchange rank ids is " << exchange_rank_ids_;
// the recv reshapes
if (left_need_recv_) {
Shape left_recv_shape = input_slice_shape_;
left_recv_shape[3] = overlap_left_size_;
recv_shapes_.push_back(left_recv_shape);
}
if (right_need_recv_) {
Shape right_recv_shape = input_slice_shape_;
right_recv_shape[3] = overlap_right_size_;
recv_shapes_.push_back(right_recv_shape);
}
MS_LOG(INFO) << name_ << ": the recv shapes is " << recv_shapes_;
// the begin, end and strides of StridedSlice
if (left_need_send_) {
left_strided_slice_begin_ = {0, 0, 0, 0};
left_strided_slice_end_ = input_slice_shape_;
left_strided_slice_end_[3] = left_rank_overlap_right_size_;
left_strided_slice_strides_ = {1, 1, 1, 1};
MS_LOG(INFO) << name_ << ": The left strided slice begin is " << left_strided_slice_begin_ << ", end is "
<< left_strided_slice_end_;
}
if (right_need_send_) {
right_strided_slice_begin_ = {0, 0, 0, 0};
right_strided_slice_begin_[3] = input_slice_shape_[3] - right_rank_overlap_left_size_;
right_strided_slice_end_ = input_slice_shape_;
right_strided_slice_strides_ = {1, 1, 1, 1};
MS_LOG(INFO) << name_ << ": The right strided slice begin is " << right_strided_slice_begin_ << ", end is "
<< right_strided_slice_end_;
}
}
ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) {
if (!need_exchange_overlap_) {
if (!out_channel_shard_) {
@ -386,6 +513,8 @@ ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) {
return nullptr;
}
InferNewOperatorAttrs();
return nullptr;
}

View File

@ -52,6 +52,7 @@ class Conv2DInfo : public OperatorInfo {
Status InferTensorMap() override;
Status InferRankBias();
Status InferOverlapSize();
void InferNewOperatorAttrs();
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
int64_t out_channel_ = 1;
@ -65,14 +66,39 @@ class Conv2DInfo : public OperatorInfo {
std::string format_;
bool out_channel_shard_ = false;
int64_t new_out_channel_ = 1;
std::vector<int64_t> new_pad_list_;
bool need_exchange_overlap_ = false;
int64_t rank_bias_ = 0;
int64_t left_rank_bias_ = -1;
int64_t right_rank_bias_ = -1;
int64_t overlap_left_size_ = 0;
int64_t overlap_right_size_ = 0;
int64_t left_rank_overlap_left_size_ = 0;
int64_t left_rank_overlap_right_size_ = 0;
int64_t right_rank_overlap_left_size_ = 0;
int64_t right_rank_overlap_right_size_ = 0;
int64_t w_dimension_shard_num_ = 1;
Shape input_slice_shape_;
bool left_need_send_ = false;
bool left_need_recv_ = false;
bool right_need_send_ = false;
bool right_need_recv_ = false;
Shape left_strided_slice_begin_;
Shape left_strided_slice_end_;
Shape left_strided_slice_strides_;
Shape right_strided_slice_begin_;
Shape right_strided_slice_end_;
Shape right_strided_slice_strides_;
std::vector<int64_t> exchange_rank_ids_;
Shapes recv_shapes_;
private:
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);
int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias);
};
class Conv2DBackpropInputInfo : public Conv2DInfo {