forked from mindspore-Ecosystem/mindspore
compute overlap size for conv2d
This commit is contained in:
parent
96dc8f8eee
commit
651a9ffadf
|
@ -216,6 +216,7 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) {
|
|||
}
|
||||
|
||||
Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
need_exchange_overlap_ = false;
|
||||
if (CheckStrategyBase(strategy) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -235,6 +236,11 @@ Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
}
|
||||
}
|
||||
|
||||
// kernel size larger than stride and the w dimension is split, need to exchange overlap
|
||||
if ((kernel_size_[1] > stride_[3]) && (input_strategy[3] > 1)) {
|
||||
need_exchange_overlap_ = true;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -250,6 +256,68 @@ Status Conv2DInfo::InferDevMatrixShape() {
|
|||
|
||||
dev_matrix_shape_ = stra[0];
|
||||
dev_matrix_shape_.push_back(stra[1][0]);
|
||||
w_dimension_shard_num_ = stra[0][3];
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::InferRankBias() {
|
||||
// the origin dev_matrix is [n, i, h, w, o]
|
||||
// if repeated calculation and repeated num in the left of dev matrix, the dev_matrix is [repeated_num, n, i, h, w, o]
|
||||
// if repeated calculation and repeated num in the right of dev matrix, the dev_matrix is [n, i, h, w, o,
|
||||
// repeated_num] the rank_bias_ is the position of the current rank in the w dimension of the dev_matrix(have not
|
||||
// split h dimension)
|
||||
if (!need_exchange_overlap_) {
|
||||
MS_LOG(INFO) << name_ << ": No need to infer rank bias";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
uint64_t w_index_in_dev_matrix = 3;
|
||||
if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) {
|
||||
w_index_in_dev_matrix += 1;
|
||||
}
|
||||
|
||||
CheckGlobalDeviceManager();
|
||||
int64_t rank = g_device_manager->global_rank();
|
||||
DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
|
||||
RankList group_devices;
|
||||
if (dev_matrix.GetDevicesAlongDim(w_index_in_dev_matrix, &group_devices) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (group_devices.size() <= 1) {
|
||||
MS_LOG(INFO) << name_ << ": The devices' size of w dimension is " << group_devices.size()
|
||||
<< ", no need to infer rank bias";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<int64_t>::iterator result = std::find(group_devices.begin(), group_devices.end(), rank);
|
||||
rank_bias_ = std::distance(group_devices.begin(), result);
|
||||
MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the device list of w dimension is " << group_devices
|
||||
<< ", the rank bias is " << rank_bias_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::InferOverlapSize() {
|
||||
if (!need_exchange_overlap_) {
|
||||
MS_LOG(INFO) << name_ << ": No need to infer overlap size";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
int64_t left_pad = pad_list_[2];
|
||||
int64_t w_dimension_input_shape = inputs_shape_[0][3];
|
||||
int64_t w_dimension_output_shape = outputs_shape_[0][3];
|
||||
int64_t w_kernel_size = kernel_size_[1];
|
||||
int64_t w_stride = stride_[3];
|
||||
|
||||
overlap_left_size_ =
|
||||
left_pad + (w_dimension_input_shape - w_dimension_output_shape * w_stride) * rank_bias_ / w_dimension_shard_num_;
|
||||
|
||||
overlap_right_size_ =
|
||||
(rank_bias_ + 1) * (w_dimension_output_shape * w_stride - w_dimension_input_shape) / w_dimension_shard_num_ +
|
||||
w_kernel_size - w_stride - left_pad;
|
||||
|
||||
MS_LOG(INFO) << name_ << ": the left overlap size is " << overlap_left_size_ << ", the right overlap size is "
|
||||
<< overlap_right_size_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -301,12 +369,23 @@ Status Conv2DInfo::InferForwardCommunication() {
|
|||
}
|
||||
|
||||
ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) {
|
||||
if (!out_channel_shard_) {
|
||||
if (!need_exchange_overlap_) {
|
||||
if (!out_channel_shard_) {
|
||||
return nullptr;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
prim->set_attr(OUT_CHANNEL, MakeValue(new_out_channel_));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (InferRankBias() != SUCCESS) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (InferOverlapSize() != SUCCESS) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
prim->set_attr(OUT_CHANNEL, MakeValue(new_out_channel_));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -50,6 +50,8 @@ class Conv2DInfo : public OperatorInfo {
|
|||
Status InferForwardCommunication() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status InferRankBias();
|
||||
Status InferOverlapSize();
|
||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||
|
||||
int64_t out_channel_ = 1;
|
||||
|
@ -63,6 +65,11 @@ class Conv2DInfo : public OperatorInfo {
|
|||
std::string format_;
|
||||
bool out_channel_shard_ = false;
|
||||
int64_t new_out_channel_ = 1;
|
||||
bool need_exchange_overlap_ = false;
|
||||
int64_t rank_bias_ = 0;
|
||||
int64_t overlap_left_size_ = 0;
|
||||
int64_t overlap_right_size_ = 0;
|
||||
int64_t w_dimension_shard_num_ = 1;
|
||||
|
||||
private:
|
||||
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
|
||||
|
|
Loading…
Reference in New Issue