!4356 Add validation for field split

Merge pull request !4356 from yangzhenzhang/update-field-split
This commit is contained in:
mindspore-ci-bot 2020-08-13 16:43:44 +08:00 committed by Gitee
commit 2db0290c49
7 changed files with 343 additions and 77 deletions

View File

@ -44,14 +44,15 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
auto device_arrangement = tensor_layout->device_arrangement().array(); auto device_arrangement = tensor_layout->device_arrangement().array();
auto tensor_map = tensor_layout->tensor_map().array(); auto tensor_map = tensor_layout->tensor_map().array();
auto slice_shape = tensor_layout->slice_shape().array(); auto slice_shape = tensor_layout->slice_shape().array();
int32_t _field_size = tensor_layout->get_field_size(); Shape field_size = {tensor_layout->get_field_size()};
Shape field_size; Shape uniform_split;
if (_field_size != 0) { if (tensor_layout->uniform_split()) {
field_size.push_back(_field_size); uniform_split.push_back(1);
} else { } else {
field_size = {0}; uniform_split.push_back(0);
} }
std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size};
std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size, uniform_split};
dict[py::str(name)] = layout; dict[py::str(name)] = layout;
MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString();
} }

View File

@ -27,6 +27,92 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status GatherV2PInfo::GetManualSplitWithoutOffsetAttr() {
auto manual_split_without_offset_iter = attrs_.find("manual_split");
if (manual_split_without_offset_iter != attrs_.end()) {
manual_split_ = true;
MS_EXCEPTION_IF_NULL(manual_split_without_offset_iter->second);
if (manual_split_without_offset_iter->second->cast<ValueTuplePtr>() == nullptr) {
MS_LOG(ERROR) << name_ << ": Manual split without offset strategy's format is wrong! Need ValueSequeue";
return FAILED;
}
std::vector<ValuePtr> value_vector = manual_split_without_offset_iter->second->cast<ValueTuplePtr>()->value();
MS_LOG(INFO) << name_ << ": manual split with offset is " << manual_split_without_offset_iter->second->ToString();
int64_t offset = 0;
for (auto &ele : value_vector) {
index_offsets_.push_back(offset);
if (!ele->isa<Int32Imm>()) {
MS_LOG(ERROR) << name_ << ": The element of manual split must be int";
return FAILED;
}
int64_t param_split_shape = static_cast<int64_t>(GetValue<int>(ele));
if (param_split_shape <= 0) {
MS_LOG(ERROR) << name_ << ": The value of manual split must be positive, but got " << param_split_shape;
return FAILED;
}
param_split_shapes_.push_back(param_split_shape);
offset += param_split_shape;
}
if (param_split_shapes_.empty()) {
MS_LOG(ERROR) << name_ << ": Failed to extract param split's split info";
return FAILED;
}
}
return SUCCESS;
}
Status GatherV2PInfo::GetManualSplitAttr() {
auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset");
if (manual_split_with_offset_iter != attrs_.end()) {
manual_split_ = true;
auto var = manual_split_with_offset_iter->second->cast<ValueTuplePtr>();
if (var == nullptr) {
MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequeue";
return FAILED;
}
MS_LOG(INFO) << name_ << ": manual split with offset strategy " << var->ToString();
for (auto &ele : var->value()) {
if (!ele->isa<ValueSequeue>()) {
MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequeue";
return FAILED;
}
std::vector<ValuePtr> value_vector = ele->cast<ValueTuplePtr>()->value();
if (value_vector.size() != 2) {
MS_LOG(ERROR) << name_ << ": Size of manual split with offset's element must be 2";
return FAILED;
}
int64_t param_split_row = static_cast<int64_t>(GetValue<int>(value_vector[0]));
int64_t offset = static_cast<int64_t>(GetValue<int>(value_vector[1]));
if ((param_split_row <= 0) || (offset < 0)) {
MS_LOG(ERROR) << name_
<< ": The value of param split shape must be positive, and the offset must larger or equal to 0";
return FAILED;
}
param_split_shapes_.push_back(param_split_row);
index_offsets_.push_back(offset);
}
if (param_split_shapes_.empty()) {
MS_LOG(ERROR) << name_ << ": Failed to extract param split with offset's split info";
return FAILED;
}
if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) {
MS_LOG(ERROR) << name_ << ": Index offset must not less than 0";
return FAILED;
}
return SUCCESS;
}
if (GetManualSplitWithoutOffsetAttr() != SUCCESS) {
return FAILED;
}
return SUCCESS;
}
Status GatherV2PInfo::GetAttrs() { Status GatherV2PInfo::GetAttrs() {
// get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis. // get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
if (target_ != CPU) { if (target_ != CPU) {
@ -53,58 +139,76 @@ Status GatherV2PInfo::GetAttrs() {
if (target_iter->second->isa<StringImm>()) { if (target_iter->second->isa<StringImm>()) {
target_ = target_iter->second->cast<StringImmPtr>()->value(); target_ = target_iter->second->cast<StringImmPtr>()->value();
} else { } else {
MS_LOG(ERROR) << name_ << " : The value of target is not a string."; MS_LOG(ERROR) << name_ << ": The value of target is not a string.";
}
}
auto manual_split_iter = attrs_.find("manual_split");
if (manual_split_iter != attrs_.end()) {
param_split_shapes_.clear();
manual_split_ = true;
auto var = manual_split_iter->second->cast<ValueTuplePtr>();
MS_LOG(DEBUG) << "Extract manual split strategy " << manual_split_iter->second->ToString();
if (var->size() > 0) {
std::vector<ValuePtr> elements = var->value();
for (auto &ele : elements) {
if (ele->isa<ValueSequeue>()) {
auto value_tuple = ele->cast<ValueTuplePtr>();
std::vector<ValuePtr> value_vector = value_tuple->value();
if (value_vector.size() != 2) {
MS_LOG(ERROR) << "Failure: Size of manual_split element must be 2.";
return FAILED;
}
param_split_shapes_.push_back(static_cast<int64_t>(GetValue<int>(value_vector[0])));
index_offsets_.push_back(static_cast<int64_t>(GetValue<int>(value_vector[1])));
} else {
MS_LOG(ERROR) << "Failure: Manual split strategy's format is wrong! Need ValueSequeue";
return FAILED;
}
}
if (param_split_shapes_.empty()) {
MS_LOG(ERROR) << "Failed to extract param split strategy.";
return FAILED;
}
} }
} }
if (GetManualSplitAttr() != SUCCESS) {
return FAILED;
}
if (manual_split_ && (axis_ != 0)) {
MS_LOG(ERROR) << name_ << ": The axis or offset must be 0 if manual split, bug got " << axis_;
return FAILED;
}
return SUCCESS; return SUCCESS;
} }
Status GatherV2PInfo::CheckManualSplit() { Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
auto param_shape = inputs_shape_.at(0); if (strategy.size() != 2) {
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size();
return FAILED;
}
Dimensions param_strategy = strategy[0];
Dimensions indices_strategy = strategy[1];
if (param_strategy.size() != 2 || indices_strategy.size() != 2) {
MS_LOG(ERROR) << name_ << ": The size of param strategy or indices strategy must be 2";
return FAILED;
}
if (indices_strategy[0] != 1) {
MS_LOG(ERROR) << name_ << ": The indices_strategy[0] must be 1, bug got " << indices_strategy[0];
return FAILED;
}
if (param_strategy[0] != indices_strategy[1]) {
MS_LOG(ERROR) << name_ << ": The param_strategy[0] must be equal to indices_strategy[1]";
return FAILED;
}
if (indices_strategy[1] != SizeToInt(param_split_shapes_.size())) {
MS_LOG(ERROR) << name_ << ": The indices_strategy[1] must be equal to manual split size";
return FAILED;
}
int64_t min_param_slice_row = inputs_shape_[1][1] / indices_strategy[1];
bool invalid = std::any_of(param_split_shapes_.begin(), param_split_shapes_.end(),
[&min_param_slice_row](int64_t v) { return v < min_param_slice_row; });
if (invalid) {
MS_LOG(ERROR) << name_ << ": The split value must be larger than or equal to indices slice's column num";
return FAILED;
}
if (inputs_shape_[0][0] < inputs_shape_[1][1]) {
MS_LOG(ERROR) << name_ << ": The param's row smaller than indices' column";
return FAILED;
}
// Don't support repeated calc
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
if (IntToSize(product_p) < dev_num) {
MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc";
return FAILED;
}
int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0, int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0,
[](int64_t s, int64_t shape) { return s + shape; }); [](int64_t s, int64_t shape) { return s + shape; });
if (split_shape_sum < param_shape.at(0)) { if (split_shape_sum != inputs_shape_[0][0]) {
MS_LOG(ERROR) << "Failure: Sum of splited shapes should not be smaller than param_shape."; MS_LOG(ERROR) << name_ << ": Sum of splited shapes must be equal to param_shape[0]";
return FAILED; return FAILED;
} }
if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) {
MS_LOG(ERROR) << "Failure: Index offset must not less than 0.";
return FAILED;
}
return SUCCESS; return SUCCESS;
} }
@ -147,7 +251,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
if (manual_split_) { if (manual_split_) {
if (CheckManualSplit() != SUCCESS) { if (CheckManualSplit(strategy->GetInputDim()) != SUCCESS) {
return FAILED; return FAILED;
} }
// when using manual_split, no need to check belowings. // when using manual_split, no need to check belowings.
@ -343,14 +447,15 @@ Status GatherV2PInfo::InferTensorInfo() {
SUCCESS)) { SUCCESS)) {
return FAILED; return FAILED;
} }
if (manual_split_) {
input_tensor_layout.set_uniform_split(false);
}
// infer tensor info // infer tensor info
TensorInfo input_tensor_info(input_tensor_layout); TensorInfo input_tensor_info(input_tensor_layout);
TensorInfo input_index_info(input_index_layout); TensorInfo input_index_info(input_index_layout);
TensorInfo output_tensor_info(output_tensor_layout); TensorInfo output_tensor_info(output_tensor_layout);
Shape slice_shape = input_tensor_info.slice_shape();
MS_LOG(DEBUG) << "The fake slice shape is: " << ShapeToString(slice_shape);
inputs_tensor_info_.push_back(input_tensor_info); inputs_tensor_info_.push_back(input_tensor_info);
inputs_tensor_info_.push_back(input_index_info); inputs_tensor_info_.push_back(input_index_info);
outputs_tensor_info_.push_back(output_tensor_info); outputs_tensor_info_.push_back(output_tensor_info);
@ -392,9 +497,17 @@ Status GatherV2PInfo::InferBias() {
Status GatherV2PInfo::InferOffset() { Status GatherV2PInfo::InferOffset() {
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
size_t rank = g_device_manager->global_rank(); size_t rank = g_device_manager->global_rank();
if (rank < index_offsets_.size()) {
index_offset_ = index_offsets_.at(rank); MS_EXCEPTION_IF_NULL(strategy_);
MS_LOG(DEBUG) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_; auto param_strategy = strategy_->GetInputDim()[0];
if (param_strategy.size() != 2) {
MS_LOG(ERROR) << "The size of param strategy must be 2";
return FAILED;
}
size_t index = rank / param_strategy[1];
if (index < index_offsets_.size()) {
index_offset_ = index_offsets_[index];
MS_LOG(INFO) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_;
return SUCCESS; return SUCCESS;
} }
@ -524,8 +637,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
if (manual_split_ && target_ != CPU) { if (manual_split_ && target_ != CPU) {
if (ComputeReplaceGraph(cnode) != SUCCESS) { if (ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
return nullptr;
} }
return replace_graph_; return replace_graph_;
} }
@ -536,8 +648,7 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
return nullptr; return nullptr;
} }
if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) { if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
return nullptr;
} }
return replace_graph_; return replace_graph_;
} }
@ -614,6 +725,13 @@ Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
} }
Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
if (GetAttrs() != SUCCESS) {
return FAILED;
}
if (manual_split_) {
MS_LOG(ERROR) << name_ << ": Manual split does not support to search strategy";
return FAILED;
}
is_auto_parallel_ = true; is_auto_parallel_ = true;
Shape input0_split(inputs_shape_[0].size(), 1); Shape input0_split(inputs_shape_[0].size(), 1);
Shape input1_split(inputs_shape_[1].size(), 1); Shape input1_split(inputs_shape_[1].size(), 1);
@ -621,14 +739,14 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
std::vector<StrategyPtr> sp_vector; std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; MS_LOG(ERROR) << name_ << ": Generate strategies for independent inputs() failed.";
return FAILED; return FAILED;
} }
size_t success = 0; size_t success = 0;
for (auto &sp : sp_vector) { for (auto &sp : sp_vector) {
if (SetCostUnderStrategy(sp) == SUCCESS) { if (SetCostUnderStrategy(sp) == SUCCESS) {
success++; success++;
MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy";
PrintStrategy(sp); PrintStrategy(sp);
} }
} }
@ -636,6 +754,12 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
} }
std::shared_ptr<Strategys> GatherV2PInfo::GenerateBatchStrategies() { std::shared_ptr<Strategys> GatherV2PInfo::GenerateBatchStrategies() {
if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
}
if (manual_split_) {
MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy";
}
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions param_strategy(inputs_shape_[0].size(), 1); Dimensions param_strategy(inputs_shape_[0].size(), 1);

View File

@ -59,7 +59,9 @@ class GatherV2PInfo : public OperatorInfo {
Status GetAttrs() override; Status GetAttrs() override;
Status ComputeReplaceGraph(const CNodePtr &cnode); Status ComputeReplaceGraph(const CNodePtr &cnode);
Status CheckManualSplit(); Status CheckManualSplit(const Strategys &strategy);
Status GetManualSplitAttr();
Status GetManualSplitWithoutOffsetAttr();
Status ComputeReplaceOp(); Status ComputeReplaceOp();
Status InferBias(); Status InferBias();
Status InferOffset(); Status InferOffset();

View File

@ -48,6 +48,10 @@ class TensorLayout {
void set_field_size(int32_t field_size) { field_size_ = field_size; } void set_field_size(int32_t field_size) { field_size_ = field_size; }
bool uniform_split() const { return uniform_split_; }
void set_uniform_split(bool flag) { uniform_split_ = flag; }
Arrangement device_arrangement() const { return device_arrangement_; } Arrangement device_arrangement() const { return device_arrangement_; }
Map tensor_map() const { return tensor_map_; } Map tensor_map() const { return tensor_map_; }
@ -104,6 +108,7 @@ class TensorLayout {
Arrangement tensor_shape_; Arrangement tensor_shape_;
bool skip_redistribution_ = false; bool skip_redistribution_ = false;
int32_t field_size_ = 0; int32_t field_size_ = 0;
bool uniform_split_ = true;
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

View File

@ -229,10 +229,13 @@ def _load_tensor_by_layout(tensor, layout):
""" """
if not isinstance(layout, list): if not isinstance(layout, list):
raise TypeError("The layout should be list! layout is {}".format(layout)) raise TypeError("The layout should be list! layout is {}".format(layout))
if len(layout) < 3: if len(layout) < 5:
raise ValueError("The length of layout must be larger than 3! layout is {}".format(layout)) raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout))
dev_mat = layout[0] dev_mat = layout[0]
tensor_map = layout[1] tensor_map = layout[1]
uniform_split = layout[4]
if uniform_split[0] == 0:
raise RuntimeError("The load tensor only support uniform split now")
if tensor.size() == 1: if tensor.size() == 1:
return tensor return tensor
return _load_tensor(tensor, dev_mat, tensor_map) return _load_tensor(tensor, dev_mat, tensor_map)

View File

@ -49,8 +49,8 @@ def test_get_parameter_layout():
net.set_auto_parallel() net.set_auto_parallel()
exe = me._executor exe = me._executor
exe.compile(net, x, phase='train', auto_parallel_mode=True) exe.compile(net, x, phase='train', auto_parallel_mode=True)
x_layout = [[2, 4], [1, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [1, -1] x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout = [[2, 4], [0, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [0, -1] weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1]
expect_dict = {'x': x_layout, 'w1': weight_layout} expect_dict = {'x': x_layout, 'w1': weight_layout}
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
assert net.parameter_layout_dict == expect_dict assert net.parameter_layout_dict == expect_dict

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
import numpy as np import numpy as np
import pytest
import mindspore as ms import mindspore as ms
from mindspore import context, Tensor, Parameter from mindspore import context, Tensor, Parameter
from mindspore.common.api import _executor from mindspore.common.api import _executor
@ -22,40 +23,170 @@ from mindspore.ops import operations as P
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
class Net(Cell): class Net(Cell):
def __init__(self, strategy1=None, strategy2=None, strategy3=None): def __init__(self,
strategy1=None,
strategy2=None,
strategy3=None,
axis=0,
init_flag=True,
split_tuple=(4, 4),
split_string="manual_split",
param_shape=(8, 8)):
super().__init__() super().__init__()
self.gatherv2 = P.GatherV2().set_strategy(strategy1) self.gatherv2 = P.GatherV2().set_strategy(strategy1)
self.gatherv2.add_prim_attr("manual_split", ((1, 0), (7, 1))) self.gatherv2.add_prim_attr(split_string, split_tuple)
self.mul = P.Mul().set_strategy(strategy2) self.mul = P.Mul().set_strategy(strategy2)
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.matmul = P.MatMul().set_strategy(strategy3) self.matmul = P.MatMul().set_strategy(strategy3)
self.matmul.add_prim_attr("forward_reduce_scatter", True) self.matmul.add_prim_attr("forward_reduce_scatter", True)
self.param = Parameter(initializer("ones", (8, 64), ms.float32), name="gatherv2_param") if init_flag:
self.mul_weight = Parameter(initializer("ones", (2, 4, 64), ms.float32), name="mul_weight") self.param = Parameter(initializer("ones", param_shape, ms.float32), name="gatherv2_param")
self.matmul_weight = Parameter(initializer("ones", (256, 16), ms.float32), name="matmul_weight") else:
self.param = Parameter(Tensor(np.ones(param_shape), dtype=ms.float32), name="gatherv2_param")
self.mul_weight = Parameter(initializer("ones", (8, 8, 8), ms.float32), name="mul_weight")
self.matmul_weight = Parameter(initializer("ones", (64, 16), ms.float32), name="matmul_weight")
self.axis = axis
def construct(self, x, b): def construct(self, x, b):
out = self.gatherv2(self.param, x, 0) out = self.gatherv2(self.param, x, self.axis)
out = self.mul(out, self.mul_weight) out = self.mul(out, self.mul_weight)
out = self.reshape(out, (2, 256)) out = self.reshape(out, (8, 64))
out = self.matmul(out, self.matmul_weight) out = self.matmul(out, self.matmul_weight)
return out return out
_x = Tensor(np.ones([2, 4]), dtype=ms.int32)
_x = Tensor(np.ones([8, 8]), dtype=ms.int32)
_b = Tensor(np.ones([64, 8]), dtype=ms.float32) _b = Tensor(np.ones([64, 8]), dtype=ms.float32)
def compile_net(net): def compile_net(net):
context.set_context(save_graphs=True)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer) train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel() train_net.set_auto_parallel()
_executor.compile(train_net, _x, _b) _executor.compile(train_net, _x, _b, auto_parallel_mode=True)
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
def test_neg_data_parallel():
context.set_context(save_graphs=True) def test_normal_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2)) strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1)) strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1)) strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3) net = Net(strategy1, strategy2, strategy3)
compile_net(net) compile_net(net)
def test_normal_split2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
strategy1 = ((4, 1), (1, 4))
strategy2 = ((1, 4, 1), (1, 4, 1))
strategy3 = ((1, 4), (4, 1))
net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8))
compile_net(net)
def test_normal_split3():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=17)
strategy1 = ((4, 8), (1, 4))
strategy2 = ((1, 4, 8), (1, 4, 8))
strategy3 = ((1, 32), (32, 1))
net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8))
compile_net(net)
def test_normal_split_with_offset():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3, split_string="manual_split_with_offset", split_tuple=((4, 0), (4, 4)))
compile_net(net)
def test_auto_parallel_error():
context.set_context(save_graphs=True)
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2, global_rank=0)
net = Net()
with pytest.raises(RuntimeError):
compile_net(net)
def test_axis_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3, axis=1)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((4, 1), (8, 1))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((4, 1), (1, 8))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error3():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error4():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 8), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_strategy_error5():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
strategy1 = ((4, 1), (1, 4))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
def test_split_tuple_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3, split_tuple=((5, 0), (5, 5)))
with pytest.raises(RuntimeError):
compile_net(net)
def test_parameter_use_tensor_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3, init_flag=False)
with pytest.raises(RuntimeError):
compile_net(net)