!20490 update check strategy for conv2d

Merge pull request !20490 from yangzhenzhang/update-check-strategy-for-conv2d
This commit is contained in:
i-robot 2021-07-20 01:23:08 +00:00 committed by Gitee
commit 6061194083
3 changed files with 44 additions and 7 deletions

View File

@ -124,7 +124,29 @@ Status Conv2DInfo::GetAttrsBase() {
Status Conv2DInfo::GetAttrs() { return GetAttrsBase(); }
Status Conv2DInfo::CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy) {
if (outputs_shape_[0][2] % h_strategy != 0) {
MS_LOG(ERROR) << name_
<< ": Do not support to split h dimension when out_shape of h dimension is not divisible by strategy "
"of h dimension";
return FAILED;
}
if (outputs_shape_[0][3] % w_strategy != 0) {
MS_LOG(ERROR) << name_
<< ": Do not support to split w dimension when out_shape of w dimension is not divisible by strategy "
"of w dimension";
return FAILED;
}
return SUCCESS;
}
Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) {
return FAILED;
}
if (pad_mode_ == 0) { // 'pad' mode
MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W";
return FAILED;
@ -642,6 +664,10 @@ Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
}
Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) {
return FAILED;
}
if (pad_mode_ != 1) { // only support same mode
MS_LOG(ERROR) << name_ << ": Do not support the pad mode " << pad_mode_ << " when split H or W dimension";
return FAILED;
@ -649,18 +675,18 @@ Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_st
if (h_strategy > 1) {
if (inputs_shape_[0][2] * stride_[2] != outputs_shape_[0][2]) {
MS_LOG(ERROR) << name_ << ": Do not support split h dimension when in_shape * stride != out_shape";
MS_LOG(ERROR) << name_ << ": Do not support to split h dimension when in_shape * stride != out_shape";
return FAILED;
}
if (kernel_size_[0] > stride_[2]) {
MS_LOG(ERROR) << name_ << ": Do not support split h dimension when kernel size larger than stride";
MS_LOG(ERROR) << name_ << ": Do not support to split h dimension when kernel size larger than stride";
return FAILED;
}
}
if (w_strategy > 1 && inputs_shape_[0][3] * stride_[3] != outputs_shape_[0][3]) {
MS_LOG(ERROR) << name_ << ": Do not support split w dimension when in_shape * stride != out_shape";
MS_LOG(ERROR) << name_ << ": Do not support to split w dimension when in_shape * stride != out_shape";
return FAILED;
}

View File

@ -46,6 +46,7 @@ class Conv2DInfo : public OperatorInfo {
Status GetAttrsBase();
Status GetAttrs() override;
Status CheckStrategyBase(const StrategyPtr &strategy);
Status CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy);
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override;
Status InferDevMatrixShape() override;
@ -117,10 +118,10 @@ class Conv2DBackpropInputInfo : public Conv2DInfo {
Status InferTensorMap() override;
Status InferMirrorOps() override; // can not use OperatorInfo::InferMirrorOps(), since the 'out_shape' is not tensor
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
void InferNewPadList();
int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);
int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias);
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) override;
void InferNewPadList() override;
int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) override;
int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias) override;
private:
Shape out_shape_;

View File

@ -13,6 +13,7 @@
# limitations under the License.
import numpy as np
import pytest
import mindspore as ms
from mindspore import context, Tensor, Parameter
@ -72,3 +73,12 @@ def test_conv2d_model_parallel2():
strategy2 = ((32, 1, 1, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_conv2d_output_can_not_divisible_by_strategy():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)