diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc index fba348229c..0f9a9d4f8c 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc @@ -44,14 +44,15 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { auto device_arrangement = tensor_layout->device_arrangement().array(); auto tensor_map = tensor_layout->tensor_map().array(); auto slice_shape = tensor_layout->slice_shape().array(); - int32_t _field_size = tensor_layout->get_field_size(); - Shape field_size; - if (_field_size != 0) { - field_size.push_back(_field_size); + Shape field_size = {tensor_layout->get_field_size()}; + Shape uniform_split; + if (tensor_layout->uniform_split()) { + uniform_split.push_back(1); } else { - field_size = {0}; + uniform_split.push_back(0); } - std::vector layout = {device_arrangement, tensor_map, slice_shape, field_size}; + + std::vector layout = {device_arrangement, tensor_map, slice_shape, field_size, uniform_split}; dict[py::str(name)] = layout; MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index b7e57dd1aa..02313e9402 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -27,6 +27,92 @@ namespace mindspore { namespace parallel { +Status GatherV2PInfo::GetManualSplitWithoutOffsetAttr() { + auto manual_split_without_offset_iter = attrs_.find("manual_split"); + if (manual_split_without_offset_iter != attrs_.end()) { + manual_split_ = true; + MS_EXCEPTION_IF_NULL(manual_split_without_offset_iter->second); + if (manual_split_without_offset_iter->second->cast() == nullptr) { + MS_LOG(ERROR) << name_ << ": Manual split without offset strategy's format is wrong! Need ValueSequeue"; + return FAILED; + } + std::vector value_vector = manual_split_without_offset_iter->second->cast()->value(); + MS_LOG(INFO) << name_ << ": manual split with offset is " << manual_split_without_offset_iter->second->ToString(); + + int64_t offset = 0; + for (auto &ele : value_vector) { + index_offsets_.push_back(offset); + if (!ele->isa()) { + MS_LOG(ERROR) << name_ << ": The element of manual split must be int"; + return FAILED; + } + int64_t param_split_shape = static_cast(GetValue(ele)); + if (param_split_shape <= 0) { + MS_LOG(ERROR) << name_ << ": The value of manual split must be positive, but got " << param_split_shape; + return FAILED; + } + param_split_shapes_.push_back(param_split_shape); + offset += param_split_shape; + } + if (param_split_shapes_.empty()) { + MS_LOG(ERROR) << name_ << ": Failed to extract param split's split info"; + return FAILED; + } + } + + return SUCCESS; +} + +Status GatherV2PInfo::GetManualSplitAttr() { + auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset"); + if (manual_split_with_offset_iter != attrs_.end()) { + manual_split_ = true; + auto var = manual_split_with_offset_iter->second->cast(); + if (var == nullptr) { + MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequeue"; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": manual split with offset strategy " << var->ToString(); + for (auto &ele : var->value()) { + if (!ele->isa()) { + MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequeue"; + return FAILED; + } + std::vector value_vector = ele->cast()->value(); + if (value_vector.size() != 2) { + MS_LOG(ERROR) << name_ << ": Size of manual split with offset's element must be 2"; + return FAILED; + } + int64_t param_split_row = static_cast(GetValue(value_vector[0])); + int64_t offset = static_cast(GetValue(value_vector[1])); + if ((param_split_row <= 0) || (offset < 0)) { + MS_LOG(ERROR) << name_ + << ": The value of param split shape must be positive, and the offset must larger or equal to 0"; + return FAILED; + } + param_split_shapes_.push_back(param_split_row); + index_offsets_.push_back(offset); + } + + if (param_split_shapes_.empty()) { + MS_LOG(ERROR) << name_ << ": Failed to extract param split with offset's split info"; + return FAILED; + } + if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) { + MS_LOG(ERROR) << name_ << ": Index offset must not less than 0"; + return FAILED; + } + return SUCCESS; + } + + if (GetManualSplitWithoutOffsetAttr() != SUCCESS) { + return FAILED; + } + + return SUCCESS; +} + Status GatherV2PInfo::GetAttrs() { // get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis. if (target_ != CPU) { @@ -53,58 +139,76 @@ Status GatherV2PInfo::GetAttrs() { if (target_iter->second->isa()) { target_ = target_iter->second->cast()->value(); } else { - MS_LOG(ERROR) << name_ << " : The value of target is not a string."; - } - } - auto manual_split_iter = attrs_.find("manual_split"); - if (manual_split_iter != attrs_.end()) { - param_split_shapes_.clear(); - manual_split_ = true; - auto var = manual_split_iter->second->cast(); - MS_LOG(DEBUG) << "Extract manual split strategy " << manual_split_iter->second->ToString(); - - if (var->size() > 0) { - std::vector elements = var->value(); - for (auto &ele : elements) { - if (ele->isa()) { - auto value_tuple = ele->cast(); - std::vector value_vector = value_tuple->value(); - if (value_vector.size() != 2) { - MS_LOG(ERROR) << "Failure: Size of manual_split element must be 2."; - return FAILED; - } - param_split_shapes_.push_back(static_cast(GetValue(value_vector[0]))); - index_offsets_.push_back(static_cast(GetValue(value_vector[1]))); - } else { - MS_LOG(ERROR) << "Failure: Manual split strategy's format is wrong! Need ValueSequeue"; - return FAILED; - } - } - - if (param_split_shapes_.empty()) { - MS_LOG(ERROR) << "Failed to extract param split strategy."; - return FAILED; - } + MS_LOG(ERROR) << name_ << ": The value of target is not a string."; } } + if (GetManualSplitAttr() != SUCCESS) { + return FAILED; + } + + if (manual_split_ && (axis_ != 0)) { + MS_LOG(ERROR) << name_ << ": The axis or offset must be 0 if manual split, bug got " << axis_; + return FAILED; + } return SUCCESS; } -Status GatherV2PInfo::CheckManualSplit() { - auto param_shape = inputs_shape_.at(0); +Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) { + if (strategy.size() != 2) { + MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size(); + return FAILED; + } + Dimensions param_strategy = strategy[0]; + Dimensions indices_strategy = strategy[1]; + if (param_strategy.size() != 2 || indices_strategy.size() != 2) { + MS_LOG(ERROR) << name_ << ": The size of param strategy or indices strategy must be 2"; + return FAILED; + } + + if (indices_strategy[0] != 1) { + MS_LOG(ERROR) << name_ << ": The indices_strategy[0] must be 1, bug got " << indices_strategy[0]; + return FAILED; + } + + if (param_strategy[0] != indices_strategy[1]) { + MS_LOG(ERROR) << name_ << ": The param_strategy[0] must be equal to indices_strategy[1]"; + return FAILED; + } + + if (indices_strategy[1] != SizeToInt(param_split_shapes_.size())) { + MS_LOG(ERROR) << name_ << ": The indices_strategy[1] must be equal to manual split size"; + return FAILED; + } + + int64_t min_param_slice_row = inputs_shape_[1][1] / indices_strategy[1]; + bool invalid = std::any_of(param_split_shapes_.begin(), param_split_shapes_.end(), + [&min_param_slice_row](int64_t v) { return v < min_param_slice_row; }); + if (invalid) { + MS_LOG(ERROR) << name_ << ": The split value must be larger than or equal to indices slice's column num"; + return FAILED; + } + + if (inputs_shape_[0][0] < inputs_shape_[1][1]) { + MS_LOG(ERROR) << name_ << ": The param's row smaller than indices' column"; + return FAILED; + } + + // Don't support repeated calc + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); + if (IntToSize(product_p) < dev_num) { + MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc"; + return FAILED; + } + int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0, [](int64_t s, int64_t shape) { return s + shape; }); - if (split_shape_sum < param_shape.at(0)) { - MS_LOG(ERROR) << "Failure: Sum of splited shapes should not be smaller than param_shape."; + if (split_shape_sum != inputs_shape_[0][0]) { + MS_LOG(ERROR) << name_ << ": Sum of splited shapes must be equal to param_shape[0]"; return FAILED; } - - if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) { - MS_LOG(ERROR) << "Failure: Index offset must not less than 0."; - return FAILED; - } - return SUCCESS; } @@ -147,7 +251,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { } if (manual_split_) { - if (CheckManualSplit() != SUCCESS) { + if (CheckManualSplit(strategy->GetInputDim()) != SUCCESS) { return FAILED; } // when using manual_split, no need to check belowings. @@ -343,14 +447,15 @@ Status GatherV2PInfo::InferTensorInfo() { SUCCESS)) { return FAILED; } + + if (manual_split_) { + input_tensor_layout.set_uniform_split(false); + } // infer tensor info TensorInfo input_tensor_info(input_tensor_layout); TensorInfo input_index_info(input_index_layout); TensorInfo output_tensor_info(output_tensor_layout); - Shape slice_shape = input_tensor_info.slice_shape(); - MS_LOG(DEBUG) << "The fake slice shape is: " << ShapeToString(slice_shape); - inputs_tensor_info_.push_back(input_tensor_info); inputs_tensor_info_.push_back(input_index_info); outputs_tensor_info_.push_back(output_tensor_info); @@ -392,9 +497,17 @@ Status GatherV2PInfo::InferBias() { Status GatherV2PInfo::InferOffset() { CheckGlobalDeviceManager(); size_t rank = g_device_manager->global_rank(); - if (rank < index_offsets_.size()) { - index_offset_ = index_offsets_.at(rank); - MS_LOG(DEBUG) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_; + + MS_EXCEPTION_IF_NULL(strategy_); + auto param_strategy = strategy_->GetInputDim()[0]; + if (param_strategy.size() != 2) { + MS_LOG(ERROR) << "The size of param strategy must be 2"; + return FAILED; + } + size_t index = rank / param_strategy[1]; + if (index < index_offsets_.size()) { + index_offset_ = index_offsets_[index]; + MS_LOG(INFO) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_; return SUCCESS; } @@ -524,8 +637,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { if (manual_split_ && target_ != CPU) { if (ComputeReplaceGraph(cnode) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; - return nullptr; + MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; } return replace_graph_; } @@ -536,8 +648,7 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { return nullptr; } if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; - return nullptr; + MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; } return replace_graph_; } @@ -614,6 +725,13 @@ Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { } Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { + if (GetAttrs() != SUCCESS) { + return FAILED; + } + if (manual_split_) { + MS_LOG(ERROR) << name_ << ": Manual split does not support to search strategy"; + return FAILED; + } is_auto_parallel_ = true; Shape input0_split(inputs_shape_[0].size(), 1); Shape input1_split(inputs_shape_[1].size(), 1); @@ -621,14 +739,14 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { std::vector sp_vector; if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; + MS_LOG(ERROR) << name_ << ": Generate strategies for independent inputs() failed."; return FAILED; } size_t success = 0; for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy"; PrintStrategy(sp); } } @@ -636,6 +754,12 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { } std::shared_ptr GatherV2PInfo::GenerateBatchStrategies() { + if (GetAttrs() != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; + } + if (manual_split_) { + MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy"; + } CheckGlobalDeviceManager(); size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); Dimensions param_strategy(inputs_shape_[0].size(), 1); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h index 42c8fee9e4..bfbb6b092c 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h @@ -59,7 +59,9 @@ class GatherV2PInfo : public OperatorInfo { Status GetAttrs() override; Status ComputeReplaceGraph(const CNodePtr &cnode); - Status CheckManualSplit(); + Status CheckManualSplit(const Strategys &strategy); + Status GetManualSplitAttr(); + Status GetManualSplitWithoutOffsetAttr(); Status ComputeReplaceOp(); Status InferBias(); Status InferOffset(); diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h index 9832964005..f8207b2d26 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h @@ -48,6 +48,10 @@ class TensorLayout { void set_field_size(int32_t field_size) { field_size_ = field_size; } + bool uniform_split() const { return uniform_split_; } + + void set_uniform_split(bool flag) { uniform_split_ = flag; } + Arrangement device_arrangement() const { return device_arrangement_; } Map tensor_map() const { return tensor_map_; } @@ -104,6 +108,7 @@ class TensorLayout { Arrangement tensor_shape_; bool skip_redistribution_ = false; int32_t field_size_ = 0; + bool uniform_split_ = true; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/parallel/_tensor.py b/mindspore/parallel/_tensor.py index 598046f66a..3d65521b43 100644 --- a/mindspore/parallel/_tensor.py +++ b/mindspore/parallel/_tensor.py @@ -229,10 +229,13 @@ def _load_tensor_by_layout(tensor, layout): """ if not isinstance(layout, list): raise TypeError("The layout should be list! layout is {}".format(layout)) - if len(layout) < 3: - raise ValueError("The length of layout must be larger than 3! layout is {}".format(layout)) + if len(layout) < 5: + raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout)) dev_mat = layout[0] tensor_map = layout[1] + uniform_split = layout[4] + if uniform_split[0] == 0: + raise RuntimeError("The load tensor only support uniform split now") if tensor.size() == 1: return tensor return _load_tensor(tensor, dev_mat, tensor_map) diff --git a/tests/ut/python/parallel/test_get_parameter_layout.py b/tests/ut/python/parallel/test_get_parameter_layout.py index 23649b5f0c..01248d13ae 100644 --- a/tests/ut/python/parallel/test_get_parameter_layout.py +++ b/tests/ut/python/parallel/test_get_parameter_layout.py @@ -49,8 +49,8 @@ def test_get_parameter_layout(): net.set_auto_parallel() exe = me._executor exe.compile(net, x, phase='train', auto_parallel_mode=True) - x_layout = [[2, 4], [1, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [1, -1] - weight_layout = [[2, 4], [0, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [0, -1] + x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1] + weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1] expect_dict = {'x': x_layout, 'w1': weight_layout} # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut assert net.parameter_layout_dict == expect_dict diff --git a/tests/ut/python/parallel/test_manual_gatherv2.py b/tests/ut/python/parallel/test_manual_gatherv2.py index 21d25ae720..5e41109c8c 100644 --- a/tests/ut/python/parallel/test_manual_gatherv2.py +++ b/tests/ut/python/parallel/test_manual_gatherv2.py @@ -14,6 +14,7 @@ # ============================================================================ import numpy as np +import pytest import mindspore as ms from mindspore import context, Tensor, Parameter from mindspore.common.api import _executor @@ -22,40 +23,170 @@ from mindspore.ops import operations as P from mindspore.common.initializer import initializer class Net(Cell): - def __init__(self, strategy1=None, strategy2=None, strategy3=None): + def __init__(self, + strategy1=None, + strategy2=None, + strategy3=None, + axis=0, + init_flag=True, + split_tuple=(4, 4), + split_string="manual_split", + param_shape=(8, 8)): super().__init__() self.gatherv2 = P.GatherV2().set_strategy(strategy1) - self.gatherv2.add_prim_attr("manual_split", ((1, 0), (7, 1))) + self.gatherv2.add_prim_attr(split_string, split_tuple) self.mul = P.Mul().set_strategy(strategy2) self.reshape = P.Reshape() self.matmul = P.MatMul().set_strategy(strategy3) self.matmul.add_prim_attr("forward_reduce_scatter", True) - self.param = Parameter(initializer("ones", (8, 64), ms.float32), name="gatherv2_param") - self.mul_weight = Parameter(initializer("ones", (2, 4, 64), ms.float32), name="mul_weight") - self.matmul_weight = Parameter(initializer("ones", (256, 16), ms.float32), name="matmul_weight") + if init_flag: + self.param = Parameter(initializer("ones", param_shape, ms.float32), name="gatherv2_param") + else: + self.param = Parameter(Tensor(np.ones(param_shape), dtype=ms.float32), name="gatherv2_param") + self.mul_weight = Parameter(initializer("ones", (8, 8, 8), ms.float32), name="mul_weight") + self.matmul_weight = Parameter(initializer("ones", (64, 16), ms.float32), name="matmul_weight") + self.axis = axis def construct(self, x, b): - out = self.gatherv2(self.param, x, 0) + out = self.gatherv2(self.param, x, self.axis) out = self.mul(out, self.mul_weight) - out = self.reshape(out, (2, 256)) + out = self.reshape(out, (8, 64)) out = self.matmul(out, self.matmul_weight) return out -_x = Tensor(np.ones([2, 4]), dtype=ms.int32) + +_x = Tensor(np.ones([8, 8]), dtype=ms.int32) _b = Tensor(np.ones([64, 8]), dtype=ms.float32) + def compile_net(net): + context.set_context(save_graphs=True) optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) train_net = TrainOneStepCell(net, optimizer) train_net.set_auto_parallel() - _executor.compile(train_net, _x, _b) + _executor.compile(train_net, _x, _b, auto_parallel_mode=True) context.reset_auto_parallel_context() -def test_neg_data_parallel(): - context.set_context(save_graphs=True) + +def test_normal_split(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) strategy1 = ((2, 1), (1, 2)) strategy2 = ((1, 2, 1), (1, 2, 1)) strategy3 = ((1, 2), (2, 1)) net = Net(strategy1, strategy2, strategy3) compile_net(net) + + +def test_normal_split2(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0) + strategy1 = ((4, 1), (1, 4)) + strategy2 = ((1, 4, 1), (1, 4, 1)) + strategy3 = ((1, 4), (4, 1)) + net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8)) + compile_net(net) + + +def test_normal_split3(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=17) + strategy1 = ((4, 8), (1, 4)) + strategy2 = ((1, 4, 8), (1, 4, 8)) + strategy3 = ((1, 32), (32, 1)) + net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8)) + compile_net(net) + + +def test_normal_split_with_offset(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) + strategy1 = ((2, 1), (1, 2)) + strategy2 = ((1, 2, 1), (1, 2, 1)) + strategy3 = ((1, 2), (2, 1)) + net = Net(strategy1, strategy2, strategy3, split_string="manual_split_with_offset", split_tuple=((4, 0), (4, 4))) + compile_net(net) + + +def test_auto_parallel_error(): + context.set_context(save_graphs=True) + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2, global_rank=0) + net = Net() + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_axis_error(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) + strategy1 = ((2, 1), (1, 2)) + strategy2 = ((1, 2, 1), (1, 2, 1)) + strategy3 = ((1, 2), (2, 1)) + net = Net(strategy1, strategy2, strategy3, axis=1) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_strategy_error(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((4, 1), (8, 1)) + strategy2 = ((1, 2, 1), (1, 2, 1)) + strategy3 = ((1, 2), (2, 1)) + net = Net(strategy1, strategy2, strategy3) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_strategy_error2(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((4, 1), (1, 8)) + strategy2 = ((1, 2, 1), (1, 2, 1)) + strategy3 = ((1, 2), (2, 1)) + net = Net(strategy1, strategy2, strategy3) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_strategy_error3(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 1), (1, 2)) + strategy2 = ((1, 2, 1), (1, 2, 1)) + strategy3 = ((1, 2), (2, 1)) + net = Net(strategy1, strategy2, strategy3) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_strategy_error4(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) + strategy1 = ((2, 8), (1, 2)) + strategy2 = ((1, 2, 1), (1, 2, 1)) + strategy3 = ((1, 2), (2, 1)) + net = Net(strategy1, strategy2, strategy3) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_strategy_error5(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0) + strategy1 = ((4, 1), (1, 4)) + strategy2 = ((1, 2, 1), (1, 2, 1)) + strategy3 = ((1, 2), (2, 1)) + net = Net(strategy1, strategy2, strategy3) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_split_tuple_error(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) + strategy1 = ((2, 1), (1, 2)) + strategy2 = ((1, 2, 1), (1, 2, 1)) + strategy3 = ((1, 2), (2, 1)) + net = Net(strategy1, strategy2, strategy3, split_tuple=((5, 0), (5, 5))) + with pytest.raises(RuntimeError): + compile_net(net) + + +def test_parameter_use_tensor_error(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) + strategy1 = ((2, 1), (1, 2)) + strategy2 = ((1, 2, 1), (1, 2, 1)) + strategy3 = ((1, 2), (2, 1)) + net = Net(strategy1, strategy2, strategy3, init_flag=False) + with pytest.raises(RuntimeError): + compile_net(net)