compute attrs for conv2d
This commit is contained in:
parent
1983ded03f
commit
03207aeeed
|
@ -257,6 +257,7 @@ Status Conv2DInfo::InferDevMatrixShape() {
|
|||
dev_matrix_shape_ = stra[0];
|
||||
dev_matrix_shape_.push_back(stra[1][0]);
|
||||
w_dimension_shard_num_ = stra[0][3];
|
||||
input_slice_shape_ = GetSliceShape(inputs_shape_[0], stra[0]);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -290,34 +291,89 @@ Status Conv2DInfo::InferRankBias() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<int64_t>::iterator result = std::find(group_devices.begin(), group_devices.end(), rank);
|
||||
rank_bias_ = std::distance(group_devices.begin(), result);
|
||||
if (group_devices.size() != LongToSize(w_dimension_shard_num_)) {
|
||||
MS_LOG(ERROR) << name_ << ": The devices' size of w dimension is " << group_devices.size()
|
||||
<< ", but the shard num of w dimension is " << w_dimension_shard_num_;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::vector<int64_t>::iterator it = std::find(group_devices.begin(), group_devices.end(), rank);
|
||||
if (it == group_devices.end()) {
|
||||
MS_LOG(ERROR) << name_ << ": Can not find the current rank in device list of w dimension, the current rank is "
|
||||
<< rank << ", the device list is " << group_devices;
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
rank_bias_ = std::distance(group_devices.begin(), it);
|
||||
if (it == group_devices.begin()) {
|
||||
left_rank_bias_ = -1;
|
||||
right_rank_bias_ = rank_bias_ + 1;
|
||||
} else if (it == group_devices.end() - 1) {
|
||||
left_rank_bias_ = rank_bias_ - 1;
|
||||
right_rank_bias_ = -1;
|
||||
} else {
|
||||
left_rank_bias_ = rank_bias_ - 1;
|
||||
right_rank_bias_ = rank_bias_ + 1;
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the device list of w dimension is " << group_devices
|
||||
<< ", the rank bias is " << rank_bias_;
|
||||
<< ", the rank bias is " << rank_bias_ << ", the left rank bias is " << left_rank_bias_
|
||||
<< ", the right rank bias is " << right_rank_bias_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
int64_t Conv2DInfo::ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) {
|
||||
int64_t left_pad = pad_list_[2];
|
||||
int64_t w_dimension_input_shape = inputs_shape_[0][3];
|
||||
int64_t w_dimension_output_shape = outputs_shape_[0][3];
|
||||
int64_t w_stride = stride_[3];
|
||||
|
||||
return left_pad +
|
||||
(w_dimension_input_shape - w_dimension_output_shape * w_stride) * rank_bias / w_dimension_shard_num_;
|
||||
}
|
||||
|
||||
int64_t Conv2DInfo::ComputeOverlapRightSizeByRankBias(int64_t rank_bias) {
|
||||
int64_t left_pad = pad_list_[2];
|
||||
int64_t w_dimension_input_shape = inputs_shape_[0][3];
|
||||
int64_t w_dimension_output_shape = outputs_shape_[0][3];
|
||||
int64_t w_kernel_size = kernel_size_[1];
|
||||
int64_t w_stride = stride_[3];
|
||||
|
||||
return (rank_bias + 1) * (w_dimension_output_shape * w_stride - w_dimension_input_shape) / w_dimension_shard_num_ +
|
||||
w_kernel_size - w_stride - left_pad;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::InferOverlapSize() {
|
||||
if (!need_exchange_overlap_) {
|
||||
MS_LOG(INFO) << name_ << ": No need to infer overlap size";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
int64_t left_pad = pad_list_[2];
|
||||
int64_t w_dimension_input_shape = inputs_shape_[0][3];
|
||||
int64_t w_dimension_output_shape = outputs_shape_[0][3];
|
||||
int64_t w_kernel_size = kernel_size_[1];
|
||||
int64_t w_stride = stride_[3];
|
||||
overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(rank_bias_);
|
||||
overlap_right_size_ = ComputeOverlapRightSizeByRankBias(rank_bias_);
|
||||
|
||||
overlap_left_size_ =
|
||||
left_pad + (w_dimension_input_shape - w_dimension_output_shape * w_stride) * rank_bias_ / w_dimension_shard_num_;
|
||||
if (rank_bias_ == 0) { // it has not left rank
|
||||
left_rank_overlap_left_size_ = 0;
|
||||
left_rank_overlap_right_size_ = 0;
|
||||
right_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(right_rank_bias_);
|
||||
right_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(right_rank_bias_);
|
||||
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // it has not right rank
|
||||
left_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(left_rank_bias_);
|
||||
left_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(left_rank_bias_);
|
||||
right_rank_overlap_left_size_ = 0;
|
||||
right_rank_overlap_right_size_ = 0;
|
||||
} else { // it has left rank and right rank
|
||||
left_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(left_rank_bias_);
|
||||
left_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(left_rank_bias_);
|
||||
right_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(right_rank_bias_);
|
||||
right_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(right_rank_bias_);
|
||||
}
|
||||
|
||||
overlap_right_size_ =
|
||||
(rank_bias_ + 1) * (w_dimension_output_shape * w_stride - w_dimension_input_shape) / w_dimension_shard_num_ +
|
||||
w_kernel_size - w_stride - left_pad;
|
||||
|
||||
MS_LOG(INFO) << name_ << ": the left overlap size is " << overlap_left_size_ << ", the right overlap size is "
|
||||
<< overlap_right_size_;
|
||||
MS_LOG(INFO) << name_ << ": the left overlap size of current rank is " << overlap_left_size_
|
||||
<< ", the right overlap size of current rank is " << overlap_right_size_
|
||||
<< ", the left overlap size of left rank is " << left_rank_overlap_left_size_
|
||||
<< ", the right overlap size of left rank is " << left_rank_overlap_right_size_
|
||||
<< ", the left overlap size of right rank is " << right_rank_overlap_left_size_
|
||||
<< ", the right overlap size of right rank is " << right_rank_overlap_right_size_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -368,6 +424,77 @@ Status Conv2DInfo::InferForwardCommunication() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
void Conv2DInfo::InferNewOperatorAttrs() {
|
||||
// send/recv flag and new_pad_list
|
||||
new_pad_list_ = pad_list_;
|
||||
if (rank_bias_ == 0) { // the first rank
|
||||
left_need_send_ = false;
|
||||
left_need_recv_ = false;
|
||||
right_need_send_ = (right_rank_overlap_left_size_ > 0);
|
||||
right_need_recv_ = (overlap_right_size_ > 0);
|
||||
new_pad_list_[3] = 0; // no need the right pad
|
||||
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
|
||||
left_need_send_ = (left_rank_overlap_right_size_ > 0);
|
||||
left_need_recv_ = (overlap_left_size_ > 0);
|
||||
right_need_send_ = false;
|
||||
right_need_recv_ = false;
|
||||
new_pad_list_[2] = 0; // no need the left pad
|
||||
} else { // the middle rank
|
||||
left_need_send_ = (left_rank_overlap_right_size_ > 0);
|
||||
left_need_recv_ = (overlap_left_size_ > 0);
|
||||
right_need_send_ = (right_rank_overlap_left_size_ > 0);
|
||||
right_need_recv_ = (overlap_right_size_ > 0);
|
||||
new_pad_list_[2] = 0; // no need the left pad
|
||||
new_pad_list_[3] = 0; // no need the right pad
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": The left need send is " << left_need_send_ << ", the left need recv is "
|
||||
<< left_need_recv_ << ", the right need send is " << right_need_send_ << ", the right need recv is "
|
||||
<< right_need_recv_ << ", the new pad list is " << new_pad_list_;
|
||||
|
||||
// the exchange rank ids
|
||||
if (left_need_send_ || left_need_recv_) {
|
||||
exchange_rank_ids_.push_back(left_rank_bias_);
|
||||
}
|
||||
|
||||
if (right_need_send_ || right_need_recv_) {
|
||||
exchange_rank_ids_.push_back(right_rank_bias_);
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": The exchange rank ids is " << exchange_rank_ids_;
|
||||
|
||||
// the recv reshapes
|
||||
if (left_need_recv_) {
|
||||
Shape left_recv_shape = input_slice_shape_;
|
||||
left_recv_shape[3] = overlap_left_size_;
|
||||
recv_shapes_.push_back(left_recv_shape);
|
||||
}
|
||||
|
||||
if (right_need_recv_) {
|
||||
Shape right_recv_shape = input_slice_shape_;
|
||||
right_recv_shape[3] = overlap_right_size_;
|
||||
recv_shapes_.push_back(right_recv_shape);
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": the recv shapes is " << recv_shapes_;
|
||||
|
||||
// the begin, end and strides of StridedSlice
|
||||
if (left_need_send_) {
|
||||
left_strided_slice_begin_ = {0, 0, 0, 0};
|
||||
left_strided_slice_end_ = input_slice_shape_;
|
||||
left_strided_slice_end_[3] = left_rank_overlap_right_size_;
|
||||
left_strided_slice_strides_ = {1, 1, 1, 1};
|
||||
MS_LOG(INFO) << name_ << ": The left strided slice begin is " << left_strided_slice_begin_ << ", end is "
|
||||
<< left_strided_slice_end_;
|
||||
}
|
||||
|
||||
if (right_need_send_) {
|
||||
right_strided_slice_begin_ = {0, 0, 0, 0};
|
||||
right_strided_slice_begin_[3] = input_slice_shape_[3] - right_rank_overlap_left_size_;
|
||||
right_strided_slice_end_ = input_slice_shape_;
|
||||
right_strided_slice_strides_ = {1, 1, 1, 1};
|
||||
MS_LOG(INFO) << name_ << ": The right strided slice begin is " << right_strided_slice_begin_ << ", end is "
|
||||
<< right_strided_slice_end_;
|
||||
}
|
||||
}
|
||||
|
||||
ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) {
|
||||
if (!need_exchange_overlap_) {
|
||||
if (!out_channel_shard_) {
|
||||
|
@ -386,6 +513,8 @@ ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
InferNewOperatorAttrs();
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -52,6 +52,7 @@ class Conv2DInfo : public OperatorInfo {
|
|||
Status InferTensorMap() override;
|
||||
Status InferRankBias();
|
||||
Status InferOverlapSize();
|
||||
void InferNewOperatorAttrs();
|
||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||
|
||||
int64_t out_channel_ = 1;
|
||||
|
@ -65,14 +66,39 @@ class Conv2DInfo : public OperatorInfo {
|
|||
std::string format_;
|
||||
bool out_channel_shard_ = false;
|
||||
int64_t new_out_channel_ = 1;
|
||||
std::vector<int64_t> new_pad_list_;
|
||||
|
||||
bool need_exchange_overlap_ = false;
|
||||
int64_t rank_bias_ = 0;
|
||||
int64_t left_rank_bias_ = -1;
|
||||
int64_t right_rank_bias_ = -1;
|
||||
int64_t overlap_left_size_ = 0;
|
||||
int64_t overlap_right_size_ = 0;
|
||||
int64_t left_rank_overlap_left_size_ = 0;
|
||||
int64_t left_rank_overlap_right_size_ = 0;
|
||||
int64_t right_rank_overlap_left_size_ = 0;
|
||||
int64_t right_rank_overlap_right_size_ = 0;
|
||||
int64_t w_dimension_shard_num_ = 1;
|
||||
Shape input_slice_shape_;
|
||||
|
||||
bool left_need_send_ = false;
|
||||
bool left_need_recv_ = false;
|
||||
bool right_need_send_ = false;
|
||||
bool right_need_recv_ = false;
|
||||
Shape left_strided_slice_begin_;
|
||||
Shape left_strided_slice_end_;
|
||||
Shape left_strided_slice_strides_;
|
||||
Shape right_strided_slice_begin_;
|
||||
Shape right_strided_slice_end_;
|
||||
Shape right_strided_slice_strides_;
|
||||
|
||||
std::vector<int64_t> exchange_rank_ids_;
|
||||
Shapes recv_shapes_;
|
||||
|
||||
private:
|
||||
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
|
||||
int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);
|
||||
int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias);
|
||||
};
|
||||
|
||||
class Conv2DBackpropInputInfo : public Conv2DInfo {
|
||||
|
|
Loading…
Reference in New Issue