forked from mindspore-Ecosystem/mindspore
fix bug for conv3d gen strategy
This commit is contained in:
parent
ea4bdf86dd
commit
4b9b8435ea
|
@ -1015,8 +1015,14 @@ std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
|
|||
Shape tmp_shape = inputs_shape_[0];
|
||||
if (name_.find(CONV2D_INFO) != std::string::npos) { // conv2d: ((N, C-in, H, W), (C-out, C-in, k1, k2))
|
||||
tmp_shape.push_back(inputs_shape_[1][0]); // the tmp shape is (N, C-in, H, W, C-out)
|
||||
} else { // conv2d-transpose: ((N, C-out, H, W), (C-out, C-in, k1, k2))
|
||||
} else if (name_.find(CONV2D_TRANSPOSE) !=
|
||||
std::string::npos) { // conv2d-transpose: ((N, C-out, H, W), (C-out, C-in, k1, k2))
|
||||
tmp_shape.push_back(inputs_shape_[1][1]); // the tmp shape is (N, C-out, H, W, C-in)
|
||||
} else if (name_.find(CONV3D_INFO) != std::string::npos) { // conv3d
|
||||
tmp_shape.pop_back();
|
||||
tmp_shape.push_back(inputs_shape_[1][0]);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << name_ << ": It does not support to generate strategies";
|
||||
}
|
||||
Shapes tmp_inputs_shape = {tmp_shape};
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {
|
||||
|
|
|
@ -938,7 +938,7 @@ std::vector<StrategyPtr> MaxPool3DInfo::GenerateOpStrategies(int64_t stage_id) {
|
|||
return sp_vector;
|
||||
}
|
||||
|
||||
Shapes splittable_input = {{1, 1, 1, 1, 1}};
|
||||
Shapes splittable_input = {{1, 1, 1, 1, 0}};
|
||||
Shapes tmp_inputs_shape = inputs_shape_;
|
||||
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Generate strategies failed";
|
||||
|
|
|
@ -790,7 +790,7 @@ bool IsSplittableOperator(const std::string &op_name) {
|
|||
L2_LOSS, LERP, ADDN, CDIST, SQUARED_DIFFERENCE, ERFINV, MASKED_FILL, SPLITV, GAMMA, KLDIV_LOSS, LIN_SPACE,
|
||||
CHECK_VALID, INVERT, SCATTER_ADD, SCATTER_DIV, SCATTER_MUL, SCATTER_MAX, SCATTER_MIN, SCATTER_SUB, UNIQUE_WITH_PAD,
|
||||
POPULATION_COUNT, IDENTITY, BESSELI0, BESSELI1, BESSELJ0, BESSELJ1, CUM_MAX, CUM_MIN, HYPOT, IGAMMA, IGAMMAC,
|
||||
LEFT_SHIFT, RIGHT_SHIFT, NEXT_AFTER, ZETA, REVERSEV2, LGAMMA, TRUNC, BETAINC, GCD, CHOLESKY, MAXPOOL_3D,
|
||||
LEFT_SHIFT, RIGHT_SHIFT, NEXT_AFTER, ZETA, REVERSEV2, LGAMMA, TRUNC, BETAINC, GCD, CHOLESKY, CONV3D, MAXPOOL_3D,
|
||||
AVGPOOL_3D};
|
||||
// clang-format on
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -144,3 +144,19 @@ def test_conv3d_valid_mode_output_shape_cannot_div_by_strategy():
|
|||
strategy1=strategy1, strategy2=strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, _x3, _b)
|
||||
|
||||
|
||||
def test_conv3d_pad_mode_unet_3d_auto_rank0():
|
||||
"""
|
||||
Feature: test pad mode unet 3d
|
||||
Description: sharding propagation
|
||||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0,
|
||||
search_mode="sharding_propagation")
|
||||
strategy2 = ((1, 1, 2, 4, 1),)
|
||||
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="pad", stride=2, pad=1, strategy2=strategy2)
|
||||
phase = compile_net(net, _x4, _b)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_node_attrs('NeighborExchangeV2-0', {'send_lens': '[0, 1, 0, 1]'})
|
||||
assert validator.check_node_attrs('NeighborExchangeV2-0', {'recv_lens': '[0, 0, 0, 0]'})
|
||||
|
|
Loading…
Reference in New Issue