!31041 add check for conv2d

Merge pull request !31041 from yangzhenzhang/add-check-for-conv2d
This commit is contained in:
i-robot 2022-03-11 01:20:52 +00:00 committed by Gitee
commit 216e7c6a92
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 37 additions and 6 deletions

View File

@ -302,12 +302,12 @@ Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
}
}
// if the h/w dimension is split, need to exchange overlap
if (input_strategy[2] > 1) {
// if the h/w dimension is split, and the pad mode is not "valid", need to exchange overlap
if (input_strategy[2] > 1 && pad_mode_ != 2) {
h_dim_need_exchange_overlap_ = true;
}
if (input_strategy[3] > 1) {
if (input_strategy[3] > 1 && pad_mode_ != 2) {
w_dim_need_exchange_overlap_ = true;
}
return SUCCESS;
@ -556,8 +556,12 @@ void Conv2DInfo::InferOverlapSizeForWDim() {
}
}
void Conv2DInfo::CheckOverlapSizeNonNegative() {
// check h dimension
void Conv2DInfo::CheckHDimensionOverlapSizeNonNegative() {
if (h_dimension_shard_num_ == 1) {
MS_LOG(INFO) << name_ << ": The h dimension is not shard";
return;
}
int64_t h_first_rank_bottom_size = ComputeOverlapBottomSizeByRankBias(0);
if (h_first_rank_bottom_size < 0) {
MS_LOG(EXCEPTION) << name_ << ": The bottom overlap size of h dimension rank bias 0 must be positive, but it is "
@ -579,8 +583,13 @@ void Conv2DInfo::CheckOverlapSizeNonNegative() {
MS_LOG(EXCEPTION) << name_ << ": The top overlap size of h dimension last rank bias must be positive, but it is "
<< h_last_rank_top_size;
}
}
// check w dimension
void Conv2DInfo::CheckWDimensionOverlapSizeNonNegative() {
if (w_dimension_shard_num_ == 1) {
MS_LOG(INFO) << name_ << ": The w dimension is not shard";
return;
}
int64_t w_first_rank_right_size = ComputeOverlapRightSizeByRankBias(0);
if (w_first_rank_right_size < 0) {
MS_LOG(EXCEPTION) << name_ << ": The right overlap size of w dimension rank bias 0 must be positive, but it is "
@ -604,6 +613,11 @@ void Conv2DInfo::CheckOverlapSizeNonNegative() {
}
}
void Conv2DInfo::CheckOverlapSizeNonNegative() {
CheckHDimensionOverlapSizeNonNegative();
CheckWDimensionOverlapSizeNonNegative();
}
void Conv2DInfo::InferOverlapSize() {
InferOverlapSizeForHDim();
InferOverlapSizeForWDim();

View File

@ -53,6 +53,8 @@ class Conv2DInfo : public OperatorInfo {
void InferAdjacentRankInfo();
std::vector<int64_t> GetAdjacentRankIdsAndBiases(int64_t rank_id, const std::string &dimension);
void InferOverlapSize();
void CheckHDimensionOverlapSizeNonNegative();
void CheckWDimensionOverlapSizeNonNegative();
void CheckOverlapSizeNonNegative();
void InferOverlapSizeForHDim();
void InferOverlapSizeForWDim();

View File

@ -100,6 +100,21 @@ def test_conv2d_pad_mode():
compile_net(net, _x3)
def test_conv2d_valid_mode_output_shape_cannot_div_by_strategy():
"""
Feature: test conv2d valid mode, and output shape can not div by strategy
Description: shard w
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="valid", stride=4,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net, _x3)
def test_conv2d_data_parallel_invalid_stride():
"""
Feature: test conv2d invalid stride