forked from mindspore-Ecosystem/mindspore
!31041 add check for conv2d
Merge pull request !31041 from yangzhenzhang/add-check-for-conv2d
This commit is contained in:
commit
216e7c6a92
|
@ -302,12 +302,12 @@ Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
}
|
||||
}
|
||||
|
||||
// if the h/w dimension is split, need to exchange overlap
|
||||
if (input_strategy[2] > 1) {
|
||||
// if the h/w dimension is split, and the pad mode is not "valid", need to exchange overlap
|
||||
if (input_strategy[2] > 1 && pad_mode_ != 2) {
|
||||
h_dim_need_exchange_overlap_ = true;
|
||||
}
|
||||
|
||||
if (input_strategy[3] > 1) {
|
||||
if (input_strategy[3] > 1 && pad_mode_ != 2) {
|
||||
w_dim_need_exchange_overlap_ = true;
|
||||
}
|
||||
return SUCCESS;
|
||||
|
@ -556,8 +556,12 @@ void Conv2DInfo::InferOverlapSizeForWDim() {
|
|||
}
|
||||
}
|
||||
|
||||
void Conv2DInfo::CheckOverlapSizeNonNegative() {
|
||||
// check h dimension
|
||||
void Conv2DInfo::CheckHDimensionOverlapSizeNonNegative() {
|
||||
if (h_dimension_shard_num_ == 1) {
|
||||
MS_LOG(INFO) << name_ << ": The h dimension is not shard";
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t h_first_rank_bottom_size = ComputeOverlapBottomSizeByRankBias(0);
|
||||
if (h_first_rank_bottom_size < 0) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The bottom overlap size of h dimension rank bias 0 must be positive, but it is "
|
||||
|
@ -579,8 +583,13 @@ void Conv2DInfo::CheckOverlapSizeNonNegative() {
|
|||
MS_LOG(EXCEPTION) << name_ << ": The top overlap size of h dimension last rank bias must be positive, but it is "
|
||||
<< h_last_rank_top_size;
|
||||
}
|
||||
}
|
||||
|
||||
// check w dimension
|
||||
void Conv2DInfo::CheckWDimensionOverlapSizeNonNegative() {
|
||||
if (w_dimension_shard_num_ == 1) {
|
||||
MS_LOG(INFO) << name_ << ": The w dimension is not shard";
|
||||
return;
|
||||
}
|
||||
int64_t w_first_rank_right_size = ComputeOverlapRightSizeByRankBias(0);
|
||||
if (w_first_rank_right_size < 0) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The right overlap size of w dimension rank bias 0 must be positive, but it is "
|
||||
|
@ -604,6 +613,11 @@ void Conv2DInfo::CheckOverlapSizeNonNegative() {
|
|||
}
|
||||
}
|
||||
|
||||
void Conv2DInfo::CheckOverlapSizeNonNegative() {
|
||||
CheckHDimensionOverlapSizeNonNegative();
|
||||
CheckWDimensionOverlapSizeNonNegative();
|
||||
}
|
||||
|
||||
void Conv2DInfo::InferOverlapSize() {
|
||||
InferOverlapSizeForHDim();
|
||||
InferOverlapSizeForWDim();
|
||||
|
|
|
@ -53,6 +53,8 @@ class Conv2DInfo : public OperatorInfo {
|
|||
void InferAdjacentRankInfo();
|
||||
std::vector<int64_t> GetAdjacentRankIdsAndBiases(int64_t rank_id, const std::string &dimension);
|
||||
void InferOverlapSize();
|
||||
void CheckHDimensionOverlapSizeNonNegative();
|
||||
void CheckWDimensionOverlapSizeNonNegative();
|
||||
void CheckOverlapSizeNonNegative();
|
||||
void InferOverlapSizeForHDim();
|
||||
void InferOverlapSizeForWDim();
|
||||
|
|
|
@ -100,6 +100,21 @@ def test_conv2d_pad_mode():
|
|||
compile_net(net, _x3)
|
||||
|
||||
|
||||
def test_conv2d_valid_mode_output_shape_cannot_div_by_strategy():
|
||||
"""
|
||||
Feature: test conv2d valid mode, and output shape can not div by strategy
|
||||
Description: shard w
|
||||
Expectation: compile failed
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
|
||||
strategy2 = ((1, 1, 1, 1),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="valid", stride=4,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, _x3)
|
||||
|
||||
|
||||
def test_conv2d_data_parallel_invalid_stride():
|
||||
"""
|
||||
Feature: test conv2d invalid stride
|
||||
|
|
Loading…
Reference in New Issue