update field split

This commit is contained in:
yangzhenzhang 2020-08-12 19:16:18 +08:00
parent a7556d874d
commit 4a0e6ff7fc
7 changed files with 343 additions and 77 deletions

View File

@ -44,14 +44,15 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
auto device_arrangement = tensor_layout->device_arrangement().array();
auto tensor_map = tensor_layout->tensor_map().array();
auto slice_shape = tensor_layout->slice_shape().array();
int32_t _field_size = tensor_layout->get_field_size();
Shape field_size;
if (_field_size != 0) {
field_size.push_back(_field_size);
Shape field_size = {tensor_layout->get_field_size()};
Shape uniform_split;
if (tensor_layout->uniform_split()) {
uniform_split.push_back(1);
} else {
field_size = {0};
uniform_split.push_back(0);
}
std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size};
std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size, uniform_split};
dict[py::str(name)] = layout;
MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString();
}

View File

@ -27,6 +27,92 @@
namespace mindspore {
namespace parallel {
Status GatherV2PInfo::GetManualSplitWithoutOffsetAttr() {
auto manual_split_without_offset_iter = attrs_.find("manual_split");
if (manual_split_without_offset_iter != attrs_.end()) {
manual_split_ = true;
MS_EXCEPTION_IF_NULL(manual_split_without_offset_iter->second);
if (manual_split_without_offset_iter->second->cast<ValueTuplePtr>() == nullptr) {
MS_LOG(ERROR) << name_ << ": Manual split without offset strategy's format is wrong! Need ValueSequeue";
return FAILED;
}
std::vector<ValuePtr> value_vector = manual_split_without_offset_iter->second->cast<ValueTuplePtr>()->value();
MS_LOG(INFO) << name_ << ": manual split with offset is " << manual_split_without_offset_iter->second->ToString();
int64_t offset = 0;
for (auto &ele : value_vector) {
index_offsets_.push_back(offset);
if (!ele->isa<Int32Imm>()) {
MS_LOG(ERROR) << name_ << ": The element of manual split must be int";
return FAILED;
}
int64_t param_split_shape = static_cast<int64_t>(GetValue<int>(ele));
if (param_split_shape <= 0) {
MS_LOG(ERROR) << name_ << ": The value of manual split must be positive, but got " << param_split_shape;
return FAILED;
}
param_split_shapes_.push_back(param_split_shape);
offset += param_split_shape;
}
if (param_split_shapes_.empty()) {
MS_LOG(ERROR) << name_ << ": Failed to extract param split's split info";
return FAILED;
}
}
return SUCCESS;
}
Status GatherV2PInfo::GetManualSplitAttr() {
auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset");
if (manual_split_with_offset_iter != attrs_.end()) {
manual_split_ = true;
auto var = manual_split_with_offset_iter->second->cast<ValueTuplePtr>();
if (var == nullptr) {
MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequeue";
return FAILED;
}
MS_LOG(INFO) << name_ << ": manual split with offset strategy " << var->ToString();
for (auto &ele : var->value()) {
if (!ele->isa<ValueSequeue>()) {
MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequeue";
return FAILED;
}
std::vector<ValuePtr> value_vector = ele->cast<ValueTuplePtr>()->value();
if (value_vector.size() != 2) {
MS_LOG(ERROR) << name_ << ": Size of manual split with offset's element must be 2";
return FAILED;
}
int64_t param_split_row = static_cast<int64_t>(GetValue<int>(value_vector[0]));
int64_t offset = static_cast<int64_t>(GetValue<int>(value_vector[1]));
if ((param_split_row <= 0) || (offset < 0)) {
MS_LOG(ERROR) << name_
<< ": The value of param split shape must be positive, and the offset must larger or equal to 0";
return FAILED;
}
param_split_shapes_.push_back(param_split_row);
index_offsets_.push_back(offset);
}
if (param_split_shapes_.empty()) {
MS_LOG(ERROR) << name_ << ": Failed to extract param split with offset's split info";
return FAILED;
}
if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) {
MS_LOG(ERROR) << name_ << ": Index offset must not less than 0";
return FAILED;
}
return SUCCESS;
}
if (GetManualSplitWithoutOffsetAttr() != SUCCESS) {
return FAILED;
}
return SUCCESS;
}
Status GatherV2PInfo::GetAttrs() {
// get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
if (target_ != CPU) {
@ -53,58 +139,76 @@ Status GatherV2PInfo::GetAttrs() {
if (target_iter->second->isa<StringImm>()) {
target_ = target_iter->second->cast<StringImmPtr>()->value();
} else {
MS_LOG(ERROR) << name_ << " : The value of target is not a string.";
}
}
auto manual_split_iter = attrs_.find("manual_split");
if (manual_split_iter != attrs_.end()) {
param_split_shapes_.clear();
manual_split_ = true;
auto var = manual_split_iter->second->cast<ValueTuplePtr>();
MS_LOG(DEBUG) << "Extract manual split strategy " << manual_split_iter->second->ToString();
if (var->size() > 0) {
std::vector<ValuePtr> elements = var->value();
for (auto &ele : elements) {
if (ele->isa<ValueSequeue>()) {
auto value_tuple = ele->cast<ValueTuplePtr>();
std::vector<ValuePtr> value_vector = value_tuple->value();
if (value_vector.size() != 2) {
MS_LOG(ERROR) << "Failure: Size of manual_split element must be 2.";
return FAILED;
}
param_split_shapes_.push_back(static_cast<int64_t>(GetValue<int>(value_vector[0])));
index_offsets_.push_back(static_cast<int64_t>(GetValue<int>(value_vector[1])));
} else {
MS_LOG(ERROR) << "Failure: Manual split strategy's format is wrong! Need ValueSequeue";
return FAILED;
MS_LOG(ERROR) << name_ << ": The value of target is not a string.";
}
}
if (param_split_shapes_.empty()) {
MS_LOG(ERROR) << "Failed to extract param split strategy.";
if (GetManualSplitAttr() != SUCCESS) {
return FAILED;
}
}
}
if (manual_split_ && (axis_ != 0)) {
MS_LOG(ERROR) << name_ << ": The axis or offset must be 0 if manual split, bug got " << axis_;
return FAILED;
}
return SUCCESS;
}
Status GatherV2PInfo::CheckManualSplit() {
auto param_shape = inputs_shape_.at(0);
Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
if (strategy.size() != 2) {
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size();
return FAILED;
}
Dimensions param_strategy = strategy[0];
Dimensions indices_strategy = strategy[1];
if (param_strategy.size() != 2 || indices_strategy.size() != 2) {
MS_LOG(ERROR) << name_ << ": The size of param strategy or indices strategy must be 2";
return FAILED;
}
if (indices_strategy[0] != 1) {
MS_LOG(ERROR) << name_ << ": The indices_strategy[0] must be 1, bug got " << indices_strategy[0];
return FAILED;
}
if (param_strategy[0] != indices_strategy[1]) {
MS_LOG(ERROR) << name_ << ": The param_strategy[0] must be equal to indices_strategy[1]";
return FAILED;
}
if (indices_strategy[1] != SizeToInt(param_split_shapes_.size())) {
MS_LOG(ERROR) << name_ << ": The indices_strategy[1] must be equal to manual split size";
return FAILED;
}
int64_t min_param_slice_row = inputs_shape_[1][1] / indices_strategy[1];
bool invalid = std::any_of(param_split_shapes_.begin(), param_split_shapes_.end(),
[&min_param_slice_row](int64_t v) { return v < min_param_slice_row; });
if (invalid) {
MS_LOG(ERROR) << name_ << ": The split value must be larger than or equal to indices slice's column num";
return FAILED;
}
if (inputs_shape_[0][0] < inputs_shape_[1][1]) {
MS_LOG(ERROR) << name_ << ": The param's row smaller than indices' column";
return FAILED;
}
// Don't support repeated calc
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
if (IntToSize(product_p) < dev_num) {
MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc";
return FAILED;
}
int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0,
[](int64_t s, int64_t shape) { return s + shape; });
if (split_shape_sum < param_shape.at(0)) {
MS_LOG(ERROR) << "Failure: Sum of splited shapes should not be smaller than param_shape.";
if (split_shape_sum != inputs_shape_[0][0]) {
MS_LOG(ERROR) << name_ << ": Sum of splited shapes must be equal to param_shape[0]";
return FAILED;
}
if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) {
MS_LOG(ERROR) << "Failure: Index offset must not less than 0.";
return FAILED;
}
return SUCCESS;
}
@ -147,7 +251,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
}
if (manual_split_) {
if (CheckManualSplit() != SUCCESS) {
if (CheckManualSplit(strategy->GetInputDim()) != SUCCESS) {
return FAILED;
}
// when using manual_split, no need to check belowings.
@ -343,14 +447,15 @@ Status GatherV2PInfo::InferTensorInfo() {
SUCCESS)) {
return FAILED;
}
if (manual_split_) {
input_tensor_layout.set_uniform_split(false);
}
// infer tensor info
TensorInfo input_tensor_info(input_tensor_layout);
TensorInfo input_index_info(input_index_layout);
TensorInfo output_tensor_info(output_tensor_layout);
Shape slice_shape = input_tensor_info.slice_shape();
MS_LOG(DEBUG) << "The fake slice shape is: " << ShapeToString(slice_shape);
inputs_tensor_info_.push_back(input_tensor_info);
inputs_tensor_info_.push_back(input_index_info);
outputs_tensor_info_.push_back(output_tensor_info);
@ -392,9 +497,17 @@ Status GatherV2PInfo::InferBias() {
Status GatherV2PInfo::InferOffset() {
CheckGlobalDeviceManager();
size_t rank = g_device_manager->global_rank();
if (rank < index_offsets_.size()) {
index_offset_ = index_offsets_.at(rank);
MS_LOG(DEBUG) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_;
MS_EXCEPTION_IF_NULL(strategy_);
auto param_strategy = strategy_->GetInputDim()[0];
if (param_strategy.size() != 2) {
MS_LOG(ERROR) << "The size of param strategy must be 2";
return FAILED;
}
size_t index = rank / param_strategy[1];
if (index < index_offsets_.size()) {
index_offset_ = index_offsets_[index];
MS_LOG(INFO) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_;
return SUCCESS;
}
@ -524,8 +637,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
if (manual_split_ && target_ != CPU) {
if (ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
return nullptr;
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
}
return replace_graph_;
}
@ -536,8 +648,7 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
return nullptr;
}
if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
return nullptr;
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
}
return replace_graph_;
}
@ -614,6 +725,13 @@ Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
}
Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
if (GetAttrs() != SUCCESS) {
return FAILED;
}
if (manual_split_) {
MS_LOG(ERROR) << name_ << ": Manual split does not support to search strategy";
return FAILED;
}
is_auto_parallel_ = true;
Shape input0_split(inputs_shape_[0].size(), 1);
Shape input1_split(inputs_shape_[1].size(), 1);
@ -621,14 +739,14 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed.";
MS_LOG(ERROR) << name_ << ": Generate strategies for independent inputs() failed.";
return FAILED;
}
size_t success = 0;
for (auto &sp : sp_vector) {
if (SetCostUnderStrategy(sp) == SUCCESS) {
success++;
MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy";
MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy";
PrintStrategy(sp);
}
}
@ -636,6 +754,12 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
}
std::shared_ptr<Strategys> GatherV2PInfo::GenerateBatchStrategies() {
if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
}
if (manual_split_) {
MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy";
}
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions param_strategy(inputs_shape_[0].size(), 1);

View File

@ -59,7 +59,9 @@ class GatherV2PInfo : public OperatorInfo {
Status GetAttrs() override;
Status ComputeReplaceGraph(const CNodePtr &cnode);
Status CheckManualSplit();
Status CheckManualSplit(const Strategys &strategy);
Status GetManualSplitAttr();
Status GetManualSplitWithoutOffsetAttr();
Status ComputeReplaceOp();
Status InferBias();
Status InferOffset();

View File

@ -48,6 +48,10 @@ class TensorLayout {
void set_field_size(int32_t field_size) { field_size_ = field_size; }
bool uniform_split() const { return uniform_split_; }
void set_uniform_split(bool flag) { uniform_split_ = flag; }
Arrangement device_arrangement() const { return device_arrangement_; }
Map tensor_map() const { return tensor_map_; }
@ -104,6 +108,7 @@ class TensorLayout {
Arrangement tensor_shape_;
bool skip_redistribution_ = false;
int32_t field_size_ = 0;
bool uniform_split_ = true;
};
} // namespace parallel
} // namespace mindspore

View File

@ -229,10 +229,13 @@ def _load_tensor_by_layout(tensor, layout):
"""
if not isinstance(layout, list):
raise TypeError("The layout should be list! layout is {}".format(layout))
if len(layout) < 3:
raise ValueError("The length of layout must be larger than 3! layout is {}".format(layout))
if len(layout) < 5:
raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout))
dev_mat = layout[0]
tensor_map = layout[1]
uniform_split = layout[4]
if uniform_split[0] == 0:
raise RuntimeError("The load tensor only support uniform split now")
if tensor.size() == 1:
return tensor
return _load_tensor(tensor, dev_mat, tensor_map)

View File

@ -49,8 +49,8 @@ def test_get_parameter_layout():
net.set_auto_parallel()
exe = me._executor
exe.compile(net, x, phase='train', auto_parallel_mode=True)
x_layout = [[2, 4], [1, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout = [[2, 4], [0, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [0, -1]
x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1]
expect_dict = {'x': x_layout, 'w1': weight_layout}
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
assert net.parameter_layout_dict == expect_dict

View File

@ -14,6 +14,7 @@
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _executor
@ -22,40 +23,170 @@ from mindspore.ops import operations as P
from mindspore.common.initializer import initializer
class Net(Cell):
def __init__(self, strategy1=None, strategy2=None, strategy3=None):
def __init__(self,
strategy1=None,
strategy2=None,
strategy3=None,
axis=0,
init_flag=True,
split_tuple=(4, 4),
split_string="manual_split",
param_shape=(8, 8)):
super().__init__()
self.gatherv2 = P.GatherV2().set_strategy(strategy1)
self.gatherv2.add_prim_attr("manual_split", ((1, 0), (7, 1)))
self.gatherv2.add_prim_attr(split_string, split_tuple)
self.mul = P.Mul().set_strategy(strategy2)
self.reshape = P.Reshape()
self.matmul = P.MatMul().set_strategy(strategy3)
self.matmul.add_prim_attr("forward_reduce_scatter", True)
self.param = Parameter(initializer("ones", (8, 64), ms.float32), name="gatherv2_param")
self.mul_weight = Parameter(initializer("ones", (2, 4, 64), ms.float32), name="mul_weight")
self.matmul_weight = Parameter(initializer("ones", (256, 16), ms.float32), name="matmul_weight")
if init_flag:
self.param = Parameter(initializer("ones", param_shape, ms.float32), name="gatherv2_param")
else:
self.param = Parameter(Tensor(np.ones(param_shape), dtype=ms.float32), name="gatherv2_param")
self.mul_weight = Parameter(initializer("ones", (8, 8, 8), ms.float32), name="mul_weight")
self.matmul_weight = Parameter(initializer("ones", (64, 16), ms.float32), name="matmul_weight")
self.axis = axis
def construct(self, x, b):
out = self.gatherv2(self.param, x, 0)
out = self.gatherv2(self.param, x, self.axis)
out = self.mul(out, self.mul_weight)
out = self.reshape(out, (2, 256))
out = self.reshape(out, (8, 64))
out = self.matmul(out, self.matmul_weight)
return out
_x = Tensor(np.ones([2, 4]), dtype=ms.int32)
_x = Tensor(np.ones([8, 8]), dtype=ms.int32)
_b = Tensor(np.ones([64, 8]), dtype=ms.float32)
def compile_net(net):
context.set_context(save_graphs=True)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
_executor.compile(train_net, _x, _b)
_executor.compile(train_net, _x, _b, auto_parallel_mode=True)
context.reset_auto_parallel_context()
def test_neg_data_parallel():
context.set_context(save_graphs=True)
def test_normal_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
compile_net(net)
def test_normal_split2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
strategy1 = ((4, 1), (1, 4))
strategy2 = ((1, 4, 1), (1, 4, 1))
strategy3 = ((1, 4), (4, 1))
net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8))
compile_net(net)
def test_normal_split3():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=17)
strategy1 = ((4, 8), (1, 4))
strategy2 = ((1, 4, 8), (1, 4, 8))
strategy3 = ((1, 32), (32, 1))
net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8))
compile_net(net)
def test_normal_split_with_offset():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3, split_string="manual_split_with_offset", split_tuple=((4, 0), (4, 4)))
compile_net(net)
def test_auto_parallel_error():
context.set_context(save_graphs=True)
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2, global_rank=0)
net = Net()
with pytest.raises(RuntimeError):
compile_net(net)
def test_axis_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3, axis=1)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((4, 1), (8, 1))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((4, 1), (1, 8))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error3():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error4():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 8), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error5():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
strategy1 = ((4, 1), (1, 4))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_split_tuple_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3, split_tuple=((5, 0), (5, 5)))
with pytest.raises(RuntimeError):
compile_net(net)
def test_parameter_use_tensor_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3, init_flag=False)
with pytest.raises(RuntimeError):
compile_net(net)