!20490 update check strategy for conv2d
Merge pull request !20490 from yangzhenzhang/update-check-strategy-for-conv2d
This commit is contained in:
commit
6061194083
|
@ -124,7 +124,29 @@ Status Conv2DInfo::GetAttrsBase() {
|
|||
|
||||
Status Conv2DInfo::GetAttrs() { return GetAttrsBase(); }
|
||||
|
||||
Status Conv2DInfo::CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy) {
|
||||
if (outputs_shape_[0][2] % h_strategy != 0) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": Do not support to split h dimension when out_shape of h dimension is not divisible by strategy "
|
||||
"of h dimension";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (outputs_shape_[0][3] % w_strategy != 0) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": Do not support to split w dimension when out_shape of w dimension is not divisible by strategy "
|
||||
"of w dimension";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
|
||||
if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (pad_mode_ == 0) { // 'pad' mode
|
||||
MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W";
|
||||
return FAILED;
|
||||
|
@ -642,6 +664,10 @@ Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
}
|
||||
|
||||
Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
|
||||
if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (pad_mode_ != 1) { // only support same mode
|
||||
MS_LOG(ERROR) << name_ << ": Do not support the pad mode " << pad_mode_ << " when split H or W dimension";
|
||||
return FAILED;
|
||||
|
@ -649,18 +675,18 @@ Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_st
|
|||
|
||||
if (h_strategy > 1) {
|
||||
if (inputs_shape_[0][2] * stride_[2] != outputs_shape_[0][2]) {
|
||||
MS_LOG(ERROR) << name_ << ": Do not support split h dimension when in_shape * stride != out_shape";
|
||||
MS_LOG(ERROR) << name_ << ": Do not support to split h dimension when in_shape * stride != out_shape";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (kernel_size_[0] > stride_[2]) {
|
||||
MS_LOG(ERROR) << name_ << ": Do not support split h dimension when kernel size larger than stride";
|
||||
MS_LOG(ERROR) << name_ << ": Do not support to split h dimension when kernel size larger than stride";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
if (w_strategy > 1 && inputs_shape_[0][3] * stride_[3] != outputs_shape_[0][3]) {
|
||||
MS_LOG(ERROR) << name_ << ": Do not support split w dimension when in_shape * stride != out_shape";
|
||||
MS_LOG(ERROR) << name_ << ": Do not support to split w dimension when in_shape * stride != out_shape";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
|
|
@ -46,6 +46,7 @@ class Conv2DInfo : public OperatorInfo {
|
|||
Status GetAttrsBase();
|
||||
Status GetAttrs() override;
|
||||
Status CheckStrategyBase(const StrategyPtr &strategy);
|
||||
Status CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy);
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferForwardCommunication() override;
|
||||
Status InferDevMatrixShape() override;
|
||||
|
@ -117,10 +118,10 @@ class Conv2DBackpropInputInfo : public Conv2DInfo {
|
|||
Status InferTensorMap() override;
|
||||
Status InferMirrorOps() override; // can not use OperatorInfo::InferMirrorOps(), since the 'out_shape' is not tensor
|
||||
|
||||
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
|
||||
void InferNewPadList();
|
||||
int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);
|
||||
int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias);
|
||||
Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) override;
|
||||
void InferNewPadList() override;
|
||||
int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) override;
|
||||
int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias) override;
|
||||
|
||||
private:
|
||||
Shape out_shape_;
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
|
@ -72,3 +73,12 @@ def test_conv2d_model_parallel2():
|
|||
strategy2 = ((32, 1, 1, 1),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_output_can_not_divisible_by_strategy():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
|
||||
strategy2 = ((1, 1, 1, 8),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
|
Loading…
Reference in New Issue