compute overlap size for conv2d

This commit is contained in:
yangzhenzhang 2021-06-26 17:14:14 +08:00
parent 96dc8f8eee
commit 651a9ffadf
2 changed files with 89 additions and 3 deletions

View File

@ -216,6 +216,7 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) {
}
Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
need_exchange_overlap_ = false;
if (CheckStrategyBase(strategy) != SUCCESS) {
return FAILED;
}
@ -235,6 +236,11 @@ Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
}
}
// kernel size larger than stride and the w dimension is split, need to exchange overlap
if ((kernel_size_[1] > stride_[3]) && (input_strategy[3] > 1)) {
need_exchange_overlap_ = true;
}
return SUCCESS;
}
@ -250,6 +256,68 @@ Status Conv2DInfo::InferDevMatrixShape() {
dev_matrix_shape_ = stra[0];
dev_matrix_shape_.push_back(stra[1][0]);
w_dimension_shard_num_ = stra[0][3];
return SUCCESS;
}
Status Conv2DInfo::InferRankBias() {
// the origin dev_matrix is [n, i, h, w, o]
// if repeated calculation and repeated num in the left of dev matrix, the dev_matrix is [repeated_num, n, i, h, w, o]
// if repeated calculation and repeated num in the right of dev matrix, the dev_matrix is [n, i, h, w, o,
// repeated_num] the rank_bias_ is the position of the current rank in the w dimension of the dev_matrix(have not
// split h dimension)
if (!need_exchange_overlap_) {
MS_LOG(INFO) << name_ << ": No need to infer rank bias";
return SUCCESS;
}
uint64_t w_index_in_dev_matrix = 3;
if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) {
w_index_in_dev_matrix += 1;
}
CheckGlobalDeviceManager();
int64_t rank = g_device_manager->global_rank();
DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
RankList group_devices;
if (dev_matrix.GetDevicesAlongDim(w_index_in_dev_matrix, &group_devices) != SUCCESS) {
return FAILED;
}
if (group_devices.size() <= 1) {
MS_LOG(INFO) << name_ << ": The devices' size of w dimension is " << group_devices.size()
<< ", no need to infer rank bias";
return SUCCESS;
}
std::vector<int64_t>::iterator result = std::find(group_devices.begin(), group_devices.end(), rank);
rank_bias_ = std::distance(group_devices.begin(), result);
MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the device list of w dimension is " << group_devices
<< ", the rank bias is " << rank_bias_;
return SUCCESS;
}
Status Conv2DInfo::InferOverlapSize() {
if (!need_exchange_overlap_) {
MS_LOG(INFO) << name_ << ": No need to infer overlap size";
return SUCCESS;
}
int64_t left_pad = pad_list_[2];
int64_t w_dimension_input_shape = inputs_shape_[0][3];
int64_t w_dimension_output_shape = outputs_shape_[0][3];
int64_t w_kernel_size = kernel_size_[1];
int64_t w_stride = stride_[3];
overlap_left_size_ =
left_pad + (w_dimension_input_shape - w_dimension_output_shape * w_stride) * rank_bias_ / w_dimension_shard_num_;
overlap_right_size_ =
(rank_bias_ + 1) * (w_dimension_output_shape * w_stride - w_dimension_input_shape) / w_dimension_shard_num_ +
w_kernel_size - w_stride - left_pad;
MS_LOG(INFO) << name_ << ": the left overlap size is " << overlap_left_size_ << ", the right overlap size is "
<< overlap_right_size_;
return SUCCESS;
}
@ -301,12 +369,23 @@ Status Conv2DInfo::InferForwardCommunication() {
}
ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) {
if (!out_channel_shard_) {
if (!need_exchange_overlap_) {
if (!out_channel_shard_) {
return nullptr;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
prim->set_attr(OUT_CHANNEL, MakeValue(new_out_channel_));
return nullptr;
}
if (InferRankBias() != SUCCESS) {
return nullptr;
}
if (InferOverlapSize() != SUCCESS) {
return nullptr;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
prim->set_attr(OUT_CHANNEL, MakeValue(new_out_channel_));
return nullptr;
}

View File

@ -50,6 +50,8 @@ class Conv2DInfo : public OperatorInfo {
Status InferForwardCommunication() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferRankBias();
Status InferOverlapSize();
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
int64_t out_channel_ = 1;
@ -63,6 +65,11 @@ class Conv2DInfo : public OperatorInfo {
std::string format_;
bool out_channel_shard_ = false;
int64_t new_out_channel_ = 1;
bool need_exchange_overlap_ = false;
int64_t rank_bias_ = 0;
int64_t overlap_left_size_ = 0;
int64_t overlap_right_size_ = 0;
int64_t w_dimension_shard_num_ = 1;
private:
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);