fix bug for conv3d gen strategy

This commit is contained in:
yangzhenzhang 2023-03-02 14:51:58 +08:00
parent ea4bdf86dd
commit 4b9b8435ea
4 changed files with 27 additions and 5 deletions

View File

@ -1015,8 +1015,14 @@ std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
Shape tmp_shape = inputs_shape_[0];
if (name_.find(CONV2D_INFO) != std::string::npos) { // conv2d: ((N, C-in, H, W), (C-out, C-in, k1, k2))
tmp_shape.push_back(inputs_shape_[1][0]); // the tmp shape is (N, C-in, H, W, C-out)
} else { // conv2d-transpose: ((N, C-out, H, W), (C-out, C-in, k1, k2))
} else if (name_.find(CONV2D_TRANSPOSE) !=
std::string::npos) { // conv2d-transpose: ((N, C-out, H, W), (C-out, C-in, k1, k2))
tmp_shape.push_back(inputs_shape_[1][1]); // the tmp shape is (N, C-out, H, W, C-in)
} else if (name_.find(CONV3D_INFO) != std::string::npos) { // conv3d
tmp_shape.pop_back();
tmp_shape.push_back(inputs_shape_[1][0]);
} else {
MS_LOG(EXCEPTION) << name_ << ": It does not support to generate strategies";
}
Shapes tmp_inputs_shape = {tmp_shape};
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {

View File

@ -938,7 +938,7 @@ std::vector<StrategyPtr> MaxPool3DInfo::GenerateOpStrategies(int64_t stage_id) {
return sp_vector;
}
Shapes splittable_input = {{1, 1, 1, 1, 1}};
Shapes splittable_input = {{1, 1, 1, 1, 0}};
Shapes tmp_inputs_shape = inputs_shape_;
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed";

View File

@ -790,7 +790,7 @@ bool IsSplittableOperator(const std::string &op_name) {
L2_LOSS, LERP, ADDN, CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE,
CHECK_VALID, INVERT, SCATTER_ADD, SCATTER_DIV, SCATTER_MUL, SCATTER_MAX, SCATTER_MIN, SCATTER_SUB, UNIQUE_WITH_PAD,
POPULATION_COUNT, IDENTITY, BESSELI0, BESSELI1, BESSELJ0, BESSELJ1, CUM_MAX, CUM_MIN, HYPOT, IGAMMA, IGAMMAC,
LEFT_SHIFT, RIGHT_SHIFT, NEXT_AFTER, ZETA, REVERSEV2, LGAMMA, TRUNC, BETAINC, GCD, CHOLESKY, MAXPOOL_3D,
LEFT_SHIFT, RIGHT_SHIFT, NEXT_AFTER, ZETA, REVERSEV2, LGAMMA, TRUNC, BETAINC, GCD, CHOLESKY, CONV3D, MAXPOOL_3D,
AVGPOOL_3D};
// clang-format on

View File

@ -1,4 +1,4 @@
# Copyright 2022 Huawei Technologies Co., Ltd
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -144,3 +144,19 @@ def test_conv3d_valid_mode_output_shape_cannot_div_by_strategy():
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net, _x3, _b)
def test_conv3d_pad_mode_unet_3d_auto_rank0():
"""
Feature: test pad mode unet 3d
Description: sharding propagation
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0,
search_mode="sharding_propagation")
strategy2 = ((1, 1, 2, 4, 1),)
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="pad", stride=2, pad=1, strategy2=strategy2)
phase = compile_net(net, _x4, _b)
validator = ParallelValidator(net, phase)
assert validator.check_node_attrs('NeighborExchangeV2-0', {'send_lens': '[0, 1, 0, 1]'})
assert validator.check_node_attrs('NeighborExchangeV2-0', {'recv_lens': '[0, 0, 0, 0]'})