!29453 support pad mode for conv2d parallel operator

Merge pull request !29453 from yangzhenzhang/support-pad-mode-for-conv2d
This commit is contained in:
i-robot 2022-02-07 06:58:00 +00:00 committed by Gitee
commit a55b1b5e05
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 163 additions and 66 deletions

View File

@ -144,53 +144,6 @@ Status Conv2DInfo::CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy) c
return SUCCESS;
}
Status Conv2DInfo::CheckHWStrategySameModeByDimension(int64_t strategy, const std::string &dimension) {
int64_t h_or_w_input_shape = 0, h_or_w_slice_shape = 0, h_or_w_kernel_size = 0, h_or_w_stride = 0;
if (dimension == H_DIMENSION) {
h_or_w_input_shape = inputs_shape_[0][2];
h_or_w_slice_shape = h_or_w_input_shape / strategy;
h_or_w_kernel_size = kernel_size_use_dilation_[0];
h_or_w_stride = stride_[2];
} else {
h_or_w_input_shape = inputs_shape_[0][3];
h_or_w_slice_shape = h_or_w_input_shape / strategy;
h_or_w_kernel_size = kernel_size_use_dilation_[1];
h_or_w_stride = stride_[3];
}
if (strategy > 1 && (h_or_w_kernel_size <= h_or_w_stride && h_or_w_slice_shape % h_or_w_stride != 0)) {
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split " << dimension
<< " when kernel_size_use_dilation_ <= stride but slice shape is not divisible by stride ";
return FAILED;
}
if (strategy > 1 && (h_or_w_kernel_size > h_or_w_stride)) {
if (h_or_w_input_shape % h_or_w_stride != 0) {
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split " << dimension
<< " when kernel_size_use_dilation_ > stride but input shape is not divisible by stride";
return FAILED;
}
if (h_or_w_slice_shape <= ((h_or_w_kernel_size - h_or_w_stride + 1) / 2)) {
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split " << dimension
<< " when kernel_size_use_dilation_ > stride but slice shape <= (k - s + 1) / 2";
return FAILED;
}
}
return SUCCESS;
}
Status Conv2DInfo::CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strategy) {
if (CheckHWStrategySameModeByDimension(h_strategy, H_DIMENSION) != SUCCESS) {
return FAILED;
}
if (CheckHWStrategySameModeByDimension(w_strategy, W_DIMENSION) != SUCCESS) {
return FAILED;
}
return SUCCESS;
}
Status Conv2DInfo::CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strategy) {
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
@ -221,18 +174,58 @@ Status Conv2DInfo::CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strate
return SUCCESS;
}
Status Conv2DInfo::CheckHWStrategyPadModeByDimension(int64_t strategy, const std::string &dimension) {
if (strategy == 1) {
return SUCCESS;
}
int64_t h_or_w_input_shape = 0, h_or_w_output_shape = 0, h_or_w_kernel_size = 0, h_or_w_stride = 0, pad_all = 0;
if (dimension == H_DIMENSION) {
h_or_w_input_shape = inputs_shape_[0][2];
h_or_w_output_shape = outputs_shape_[0][2];
h_or_w_kernel_size = kernel_size_use_dilation_[0];
h_or_w_stride = stride_[2];
pad_all = pad_list_[0] + pad_list_[1];
} else {
h_or_w_input_shape = inputs_shape_[0][3];
h_or_w_output_shape = outputs_shape_[0][3];
h_or_w_kernel_size = kernel_size_use_dilation_[1];
h_or_w_stride = stride_[3];
pad_all = pad_list_[2] + pad_list_[3];
}
if ((h_or_w_input_shape + pad_all - h_or_w_kernel_size) % h_or_w_stride != 0) {
MS_LOG(ERROR) << name_ << ": The 'pad' or 'same' mode do not support to split " << dimension
<< " when input_shape + pad_all - k is not divisible by stride ";
return FAILED;
}
if ((h_or_w_output_shape * h_or_w_stride - h_or_w_input_shape) % strategy != 0) {
MS_LOG(ERROR) << name_ << ": The 'pad' or 'same' mode do not support to split " << dimension
<< " when output_shape * s - input_shape is not divisible by stride ";
return FAILED;
}
return SUCCESS;
}
Status Conv2DInfo::CheckHWStrategyPadMode(int64_t h_strategy, int64_t w_strategy) {
if (CheckHWStrategyPadModeByDimension(h_strategy, H_DIMENSION) != SUCCESS) {
return FAILED;
}
if (CheckHWStrategyPadModeByDimension(w_strategy, W_DIMENSION) != SUCCESS) {
return FAILED;
}
return SUCCESS;
}
Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) {
return FAILED;
}
if (pad_mode_ == 0) { // 'pad' mode
MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W";
return FAILED;
}
if (pad_mode_ == 1) { // 'same' mode
return CheckHWStrategySameMode(h_strategy, w_strategy);
if (pad_mode_ == 0 || pad_mode_ == 1) { // 'pad' mode or 'same' mode
return CheckHWStrategyPadMode(h_strategy, w_strategy);
}
if (pad_mode_ == 2) { // 'valid' mode
@ -309,12 +302,12 @@ Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
}
}
// kernel size larger than stride and the h/w dimension is split, need to exchange overlap
if ((kernel_size_use_dilation_[0] > stride_[2]) && (input_strategy[2] > 1)) {
// if the h/w dimension is split, need to exchange overlap
if (input_strategy[2] > 1) {
h_dim_need_exchange_overlap_ = true;
}
if ((kernel_size_use_dilation_[1] > stride_[3]) && (input_strategy[3] > 1)) {
if (input_strategy[3] > 1) {
w_dim_need_exchange_overlap_ = true;
}
return SUCCESS;
@ -563,6 +556,54 @@ void Conv2DInfo::InferOverlapSizeForWDim() {
}
}
void Conv2DInfo::CheckOverlapSizeNonNegative() {
// check h dimension
int64_t h_first_rank_bottom_size = ComputeOverlapBottomSizeByRankBias(0);
if (h_first_rank_bottom_size < 0) {
MS_LOG(EXCEPTION) << name_ << ": The bottom overlap size of h dimension rank bias 0 must be positive, but it is "
<< h_first_rank_bottom_size;
}
for (int64_t h_rank_bias = 1; h_rank_bias < h_dimension_shard_num_ - 1; ++h_rank_bias) {
auto top_size = ComputeOverlapTopSizeByRankBias(h_rank_bias);
auto bottom_size = ComputeOverlapBottomSizeByRankBias(h_rank_bias);
if (top_size < 0 || bottom_size < 0) {
MS_LOG(EXCEPTION) << name_ << ": The overlap size of h dimension rank bias " << h_rank_bias
<< " must be positive, but top overlap size is " << top_size << ", bottom overlap size is "
<< bottom_size;
}
}
int64_t h_last_rank_top_size = ComputeOverlapTopSizeByRankBias(h_dimension_shard_num_ - 1);
if (h_last_rank_top_size < 0) {
MS_LOG(EXCEPTION) << name_ << ": The top overlap size of h dimension last rank bias must be positive, but it is "
<< h_last_rank_top_size;
}
// check w dimension
int64_t w_first_rank_right_size = ComputeOverlapRightSizeByRankBias(0);
if (w_first_rank_right_size < 0) {
MS_LOG(EXCEPTION) << name_ << ": The right overlap size of w dimension rank bias 0 must be positive, but it is "
<< w_first_rank_right_size;
}
for (int64_t w_rank_bias = 1; w_rank_bias < w_dimension_shard_num_ - 1; ++w_rank_bias) {
auto left_size = ComputeOverlapLeftSizeByRankBias(w_rank_bias);
auto right_size = ComputeOverlapRightSizeByRankBias(w_rank_bias);
if (left_size < 0 || right_size < 0) {
MS_LOG(EXCEPTION) << name_ << ": The overlap size of w dimension rank bias " << w_rank_bias
<< " must be positive, but left overlap size is " << left_size << ", right overlap size is "
<< right_size;
}
}
int64_t w_last_rank_left_size = ComputeOverlapLeftSizeByRankBias(w_dimension_shard_num_ - 1);
if (w_last_rank_left_size < 0) {
MS_LOG(EXCEPTION) << name_ << ": The left overlap size of w dimension last rank bias must be positive, but it is "
<< w_last_rank_left_size;
}
}
void Conv2DInfo::InferOverlapSize() {
InferOverlapSizeForHDim();
InferOverlapSizeForWDim();
@ -575,6 +616,8 @@ void Conv2DInfo::InferOverlapSize() {
<< ", the bottom overlap size of current rank is " << overlap_bottom_size_
<< ", the bottom overlap size of top rank is " << top_rank_overlap_bottom_size_
<< ", the top overlap size of bottom rank is " << bottom_rank_overlap_top_size_;
CheckOverlapSizeNonNegative();
}
Status Conv2DInfo::InferTensorMap() {
@ -710,6 +753,18 @@ void Conv2DInfo::InferCommunicationAttrs() {
MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the send lens is " << send_lens_
<< ", the recv rank ids is " << recv_rank_ids_ << ", the recv lens is " << recv_lens_;
for (auto &send_len : send_lens_) {
if (send_len < 0) {
MS_LOG(EXCEPTION) << name_ << ": Send len less than 0 is not supported, but it is " << send_len;
}
}
for (auto &recv_len : recv_lens_) {
if (recv_len < 0) {
MS_LOG(EXCEPTION) << name_ << ": Recv len less than 0 is not supported, but it is " << recv_len;
}
}
int64_t h_slice_shape = input_slice_shape_[2];
if (send_top_len > h_slice_shape || send_bottom_len > h_slice_shape || recv_top_len > h_slice_shape ||
recv_bottom_len > h_slice_shape) {
@ -834,6 +889,17 @@ ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) {
InferNewOperatorAttrs();
int64_t all_send_lens = std::accumulate(send_lens_.begin(), send_lens_.end(), 0);
int64_t all_recv_lens = std::accumulate(recv_lens_.begin(), recv_lens_.end(), 0);
if (all_send_lens + all_recv_lens == 0) {
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
prim->set_attr(OUT_CHANNEL, MakeValue(new_out_channel_));
prim->set_attr(PAD_MODE, MakeValue(PAD));
prim->set_attr(PAD, MakeValue(new_pad_list_));
MS_LOG(INFO) << name_ << ": the send lens and recv lens is 0, no need exchange data";
return nullptr;
}
ComputeReplaceGraph(cnode);
return replace_graph_;
}

View File

@ -53,6 +53,7 @@ class Conv2DInfo : public OperatorInfo {
void InferAdjacentRankInfo();
std::vector<int64_t> GetAdjacentRankIdsAndBiases(int64_t rank_id, const std::string &dimension);
void InferOverlapSize();
void CheckOverlapSizeNonNegative();
void InferOverlapSizeForHDim();
void InferOverlapSizeForWDim();
void InferNewOperatorAttrs();
@ -142,9 +143,9 @@ class Conv2DInfo : public OperatorInfo {
virtual int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias);
private:
Status CheckHWStrategySameModeByDimension(int64_t strategy, const std::string &dimension);
Status CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strategy);
Status CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strategy);
Status CheckHWStrategyPadModeByDimension(int64_t strategy, const std::string &dimension);
Status CheckHWStrategyPadMode(int64_t h_strategy, int64_t w_strategy);
};
class Conv2DBackpropInputInfo : public Conv2DInfo {

View File

@ -23,11 +23,11 @@ from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, dilation=1, group=1,
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, dilation=1, group=1, pad=0,
strategy1=None, strategy2=None):
super().__init__()
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride, dilation=dilation, group=group).shard(strategy1)
self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size, pad_mode=pad_mode, pad=pad,
stride=stride, dilation=dilation, group=group).shard(strategy1)
self.neg = P.Neg().shard(strategy2)
self.conv2d_weight = Parameter(conv2d_weight, "w1")
@ -39,11 +39,13 @@ class Net(Cell):
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
_x2 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32)
_x3 = Tensor(np.ones([32, 16, 16, 16]), dtype=ms.float32)
_w0 = Tensor(np.ones([8, 16, 1, 1]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
_w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32)
_w4 = Tensor(np.ones([8, 8, 2, 2]), dtype=ms.float32)
_w5 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32)
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
@ -69,11 +71,40 @@ def test_conv2d_data_parallel():
compile_net(net)
def test_conv2d_pad_mode_overlap_is_negative():
"""
Feature: test conv2d pad mode and overlap is negative
Description: shard h/w
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((1, 1, 4, 4), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 1),)
net = Net(_w5, out_channel=8, kernel_size=4, pad_mode="pad", stride=5, pad=(3, 0, 3, 0),
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net, _x3)
def test_conv2d_pad_mode():
"""
Feature: test conv2d pad mode and overlap is non-negative
Description: shard h/w
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 2, 4), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 1),)
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="pad", stride=1, pad=(3, 3, 3, 3),
strategy1=strategy1, strategy2=strategy2)
compile_net(net, _x3)
def test_conv2d_data_parallel_invalid_stride():
"""
Feature: test conv2d invalid stride
Description: the first two elements of stride must be 1, but set 2
Expectation: compile success
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
@ -416,14 +447,13 @@ def test_conv2d_same_mode_overlap_size_equal_to_slice_shape():
"""
Feature: same mode, slice shape is equal to overlap shape
Description: split w
Expectation: compile failed
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
strategy2 = ((2, 1, 1, 4),)
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
compile_net(net)
def test_kernel_size_larger_than_stride_and_left_pad_is_0():