forked from mindspore-Ecosystem/mindspore
!4356 Add validation for field split
Merge pull request !4356 from yangzhenzhang/update-field-split
This commit is contained in:
commit
2db0290c49
|
@ -44,14 +44,15 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
|
|||
auto device_arrangement = tensor_layout->device_arrangement().array();
|
||||
auto tensor_map = tensor_layout->tensor_map().array();
|
||||
auto slice_shape = tensor_layout->slice_shape().array();
|
||||
int32_t _field_size = tensor_layout->get_field_size();
|
||||
Shape field_size;
|
||||
if (_field_size != 0) {
|
||||
field_size.push_back(_field_size);
|
||||
Shape field_size = {tensor_layout->get_field_size()};
|
||||
Shape uniform_split;
|
||||
if (tensor_layout->uniform_split()) {
|
||||
uniform_split.push_back(1);
|
||||
} else {
|
||||
field_size = {0};
|
||||
uniform_split.push_back(0);
|
||||
}
|
||||
std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size};
|
||||
|
||||
std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size, uniform_split};
|
||||
dict[py::str(name)] = layout;
|
||||
MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString();
|
||||
}
|
||||
|
|
|
@ -27,6 +27,92 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
Status GatherV2PInfo::GetManualSplitWithoutOffsetAttr() {
|
||||
auto manual_split_without_offset_iter = attrs_.find("manual_split");
|
||||
if (manual_split_without_offset_iter != attrs_.end()) {
|
||||
manual_split_ = true;
|
||||
MS_EXCEPTION_IF_NULL(manual_split_without_offset_iter->second);
|
||||
if (manual_split_without_offset_iter->second->cast<ValueTuplePtr>() == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << ": Manual split without offset strategy's format is wrong! Need ValueSequeue";
|
||||
return FAILED;
|
||||
}
|
||||
std::vector<ValuePtr> value_vector = manual_split_without_offset_iter->second->cast<ValueTuplePtr>()->value();
|
||||
MS_LOG(INFO) << name_ << ": manual split with offset is " << manual_split_without_offset_iter->second->ToString();
|
||||
|
||||
int64_t offset = 0;
|
||||
for (auto &ele : value_vector) {
|
||||
index_offsets_.push_back(offset);
|
||||
if (!ele->isa<Int32Imm>()) {
|
||||
MS_LOG(ERROR) << name_ << ": The element of manual split must be int";
|
||||
return FAILED;
|
||||
}
|
||||
int64_t param_split_shape = static_cast<int64_t>(GetValue<int>(ele));
|
||||
if (param_split_shape <= 0) {
|
||||
MS_LOG(ERROR) << name_ << ": The value of manual split must be positive, but got " << param_split_shape;
|
||||
return FAILED;
|
||||
}
|
||||
param_split_shapes_.push_back(param_split_shape);
|
||||
offset += param_split_shape;
|
||||
}
|
||||
if (param_split_shapes_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": Failed to extract param split's split info";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::GetManualSplitAttr() {
|
||||
auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset");
|
||||
if (manual_split_with_offset_iter != attrs_.end()) {
|
||||
manual_split_ = true;
|
||||
auto var = manual_split_with_offset_iter->second->cast<ValueTuplePtr>();
|
||||
if (var == nullptr) {
|
||||
MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequeue";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << name_ << ": manual split with offset strategy " << var->ToString();
|
||||
for (auto &ele : var->value()) {
|
||||
if (!ele->isa<ValueSequeue>()) {
|
||||
MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequeue";
|
||||
return FAILED;
|
||||
}
|
||||
std::vector<ValuePtr> value_vector = ele->cast<ValueTuplePtr>()->value();
|
||||
if (value_vector.size() != 2) {
|
||||
MS_LOG(ERROR) << name_ << ": Size of manual split with offset's element must be 2";
|
||||
return FAILED;
|
||||
}
|
||||
int64_t param_split_row = static_cast<int64_t>(GetValue<int>(value_vector[0]));
|
||||
int64_t offset = static_cast<int64_t>(GetValue<int>(value_vector[1]));
|
||||
if ((param_split_row <= 0) || (offset < 0)) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": The value of param split shape must be positive, and the offset must larger or equal to 0";
|
||||
return FAILED;
|
||||
}
|
||||
param_split_shapes_.push_back(param_split_row);
|
||||
index_offsets_.push_back(offset);
|
||||
}
|
||||
|
||||
if (param_split_shapes_.empty()) {
|
||||
MS_LOG(ERROR) << name_ << ": Failed to extract param split with offset's split info";
|
||||
return FAILED;
|
||||
}
|
||||
if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) {
|
||||
MS_LOG(ERROR) << name_ << ": Index offset must not less than 0";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
if (GetManualSplitWithoutOffsetAttr() != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::GetAttrs() {
|
||||
// get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
|
||||
if (target_ != CPU) {
|
||||
|
@ -56,55 +142,73 @@ Status GatherV2PInfo::GetAttrs() {
|
|||
MS_LOG(ERROR) << name_ << ": The value of target is not a string.";
|
||||
}
|
||||
}
|
||||
auto manual_split_iter = attrs_.find("manual_split");
|
||||
if (manual_split_iter != attrs_.end()) {
|
||||
param_split_shapes_.clear();
|
||||
manual_split_ = true;
|
||||
auto var = manual_split_iter->second->cast<ValueTuplePtr>();
|
||||
MS_LOG(DEBUG) << "Extract manual split strategy " << manual_split_iter->second->ToString();
|
||||
|
||||
if (var->size() > 0) {
|
||||
std::vector<ValuePtr> elements = var->value();
|
||||
for (auto &ele : elements) {
|
||||
if (ele->isa<ValueSequeue>()) {
|
||||
auto value_tuple = ele->cast<ValueTuplePtr>();
|
||||
std::vector<ValuePtr> value_vector = value_tuple->value();
|
||||
if (value_vector.size() != 2) {
|
||||
MS_LOG(ERROR) << "Failure: Size of manual_split element must be 2.";
|
||||
if (GetManualSplitAttr() != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
param_split_shapes_.push_back(static_cast<int64_t>(GetValue<int>(value_vector[0])));
|
||||
index_offsets_.push_back(static_cast<int64_t>(GetValue<int>(value_vector[1])));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Failure: Manual split strategy's format is wrong! Need ValueSequeue";
|
||||
|
||||
if (manual_split_ && (axis_ != 0)) {
|
||||
MS_LOG(ERROR) << name_ << ": The axis or offset must be 0 if manual split, bug got " << axis_;
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
if (param_split_shapes_.empty()) {
|
||||
MS_LOG(ERROR) << "Failed to extract param split strategy.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::CheckManualSplit() {
|
||||
auto param_shape = inputs_shape_.at(0);
|
||||
Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
|
||||
if (strategy.size() != 2) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size();
|
||||
return FAILED;
|
||||
}
|
||||
Dimensions param_strategy = strategy[0];
|
||||
Dimensions indices_strategy = strategy[1];
|
||||
if (param_strategy.size() != 2 || indices_strategy.size() != 2) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of param strategy or indices strategy must be 2";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (indices_strategy[0] != 1) {
|
||||
MS_LOG(ERROR) << name_ << ": The indices_strategy[0] must be 1, bug got " << indices_strategy[0];
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (param_strategy[0] != indices_strategy[1]) {
|
||||
MS_LOG(ERROR) << name_ << ": The param_strategy[0] must be equal to indices_strategy[1]";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (indices_strategy[1] != SizeToInt(param_split_shapes_.size())) {
|
||||
MS_LOG(ERROR) << name_ << ": The indices_strategy[1] must be equal to manual split size";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
int64_t min_param_slice_row = inputs_shape_[1][1] / indices_strategy[1];
|
||||
bool invalid = std::any_of(param_split_shapes_.begin(), param_split_shapes_.end(),
|
||||
[&min_param_slice_row](int64_t v) { return v < min_param_slice_row; });
|
||||
if (invalid) {
|
||||
MS_LOG(ERROR) << name_ << ": The split value must be larger than or equal to indices slice's column num";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (inputs_shape_[0][0] < inputs_shape_[1][1]) {
|
||||
MS_LOG(ERROR) << name_ << ": The param's row smaller than indices' column";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// Don't support repeated calc
|
||||
CheckGlobalDeviceManager();
|
||||
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
||||
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
|
||||
if (IntToSize(product_p) < dev_num) {
|
||||
MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0,
|
||||
[](int64_t s, int64_t shape) { return s + shape; });
|
||||
if (split_shape_sum < param_shape.at(0)) {
|
||||
MS_LOG(ERROR) << "Failure: Sum of splited shapes should not be smaller than param_shape.";
|
||||
if (split_shape_sum != inputs_shape_[0][0]) {
|
||||
MS_LOG(ERROR) << name_ << ": Sum of splited shapes must be equal to param_shape[0]";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) {
|
||||
MS_LOG(ERROR) << "Failure: Index offset must not less than 0.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -147,7 +251,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
}
|
||||
|
||||
if (manual_split_) {
|
||||
if (CheckManualSplit() != SUCCESS) {
|
||||
if (CheckManualSplit(strategy->GetInputDim()) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
// when using manual_split, no need to check belowings.
|
||||
|
@ -343,14 +447,15 @@ Status GatherV2PInfo::InferTensorInfo() {
|
|||
SUCCESS)) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (manual_split_) {
|
||||
input_tensor_layout.set_uniform_split(false);
|
||||
}
|
||||
// infer tensor info
|
||||
TensorInfo input_tensor_info(input_tensor_layout);
|
||||
TensorInfo input_index_info(input_index_layout);
|
||||
TensorInfo output_tensor_info(output_tensor_layout);
|
||||
|
||||
Shape slice_shape = input_tensor_info.slice_shape();
|
||||
MS_LOG(DEBUG) << "The fake slice shape is: " << ShapeToString(slice_shape);
|
||||
|
||||
inputs_tensor_info_.push_back(input_tensor_info);
|
||||
inputs_tensor_info_.push_back(input_index_info);
|
||||
outputs_tensor_info_.push_back(output_tensor_info);
|
||||
|
@ -392,9 +497,17 @@ Status GatherV2PInfo::InferBias() {
|
|||
Status GatherV2PInfo::InferOffset() {
|
||||
CheckGlobalDeviceManager();
|
||||
size_t rank = g_device_manager->global_rank();
|
||||
if (rank < index_offsets_.size()) {
|
||||
index_offset_ = index_offsets_.at(rank);
|
||||
MS_LOG(DEBUG) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_;
|
||||
|
||||
MS_EXCEPTION_IF_NULL(strategy_);
|
||||
auto param_strategy = strategy_->GetInputDim()[0];
|
||||
if (param_strategy.size() != 2) {
|
||||
MS_LOG(ERROR) << "The size of param strategy must be 2";
|
||||
return FAILED;
|
||||
}
|
||||
size_t index = rank / param_strategy[1];
|
||||
if (index < index_offsets_.size()) {
|
||||
index_offset_ = index_offsets_[index];
|
||||
MS_LOG(INFO) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -524,8 +637,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
||||
if (manual_split_ && target_ != CPU) {
|
||||
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
|
||||
return nullptr;
|
||||
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
|
||||
}
|
||||
return replace_graph_;
|
||||
}
|
||||
|
@ -536,8 +648,7 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
|||
return nullptr;
|
||||
}
|
||||
if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
|
||||
return nullptr;
|
||||
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
|
||||
}
|
||||
return replace_graph_;
|
||||
}
|
||||
|
@ -614,6 +725,13 @@ Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
|||
}
|
||||
|
||||
Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
|
||||
if (GetAttrs() != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
if (manual_split_) {
|
||||
MS_LOG(ERROR) << name_ << ": Manual split does not support to search strategy";
|
||||
return FAILED;
|
||||
}
|
||||
is_auto_parallel_ = true;
|
||||
Shape input0_split(inputs_shape_[0].size(), 1);
|
||||
Shape input1_split(inputs_shape_[1].size(), 1);
|
||||
|
@ -636,6 +754,12 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
|
|||
}
|
||||
|
||||
std::shared_ptr<Strategys> GatherV2PInfo::GenerateBatchStrategies() {
|
||||
if (GetAttrs() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
|
||||
}
|
||||
if (manual_split_) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy";
|
||||
}
|
||||
CheckGlobalDeviceManager();
|
||||
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
||||
Dimensions param_strategy(inputs_shape_[0].size(), 1);
|
||||
|
|
|
@ -59,7 +59,9 @@ class GatherV2PInfo : public OperatorInfo {
|
|||
Status GetAttrs() override;
|
||||
|
||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||
Status CheckManualSplit();
|
||||
Status CheckManualSplit(const Strategys &strategy);
|
||||
Status GetManualSplitAttr();
|
||||
Status GetManualSplitWithoutOffsetAttr();
|
||||
Status ComputeReplaceOp();
|
||||
Status InferBias();
|
||||
Status InferOffset();
|
||||
|
|
|
@ -48,6 +48,10 @@ class TensorLayout {
|
|||
|
||||
void set_field_size(int32_t field_size) { field_size_ = field_size; }
|
||||
|
||||
bool uniform_split() const { return uniform_split_; }
|
||||
|
||||
void set_uniform_split(bool flag) { uniform_split_ = flag; }
|
||||
|
||||
Arrangement device_arrangement() const { return device_arrangement_; }
|
||||
|
||||
Map tensor_map() const { return tensor_map_; }
|
||||
|
@ -104,6 +108,7 @@ class TensorLayout {
|
|||
Arrangement tensor_shape_;
|
||||
bool skip_redistribution_ = false;
|
||||
int32_t field_size_ = 0;
|
||||
bool uniform_split_ = true;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -229,10 +229,13 @@ def _load_tensor_by_layout(tensor, layout):
|
|||
"""
|
||||
if not isinstance(layout, list):
|
||||
raise TypeError("The layout should be list! layout is {}".format(layout))
|
||||
if len(layout) < 3:
|
||||
raise ValueError("The length of layout must be larger than 3! layout is {}".format(layout))
|
||||
if len(layout) < 5:
|
||||
raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout))
|
||||
dev_mat = layout[0]
|
||||
tensor_map = layout[1]
|
||||
uniform_split = layout[4]
|
||||
if uniform_split[0] == 0:
|
||||
raise RuntimeError("The load tensor only support uniform split now")
|
||||
if tensor.size() == 1:
|
||||
return tensor
|
||||
return _load_tensor(tensor, dev_mat, tensor_map)
|
||||
|
|
|
@ -49,8 +49,8 @@ def test_get_parameter_layout():
|
|||
net.set_auto_parallel()
|
||||
exe = me._executor
|
||||
exe.compile(net, x, phase='train', auto_parallel_mode=True)
|
||||
x_layout = [[2, 4], [1, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||
weight_layout = [[2, 4], [0, -1], [16, 32], [0]] # device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||
x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||
weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1]
|
||||
expect_dict = {'x': x_layout, 'w1': weight_layout}
|
||||
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
|
||||
assert net.parameter_layout_dict == expect_dict
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.common.api import _executor
|
||||
|
@ -22,40 +23,170 @@ from mindspore.ops import operations as P
|
|||
from mindspore.common.initializer import initializer
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, strategy1=None, strategy2=None, strategy3=None):
|
||||
def __init__(self,
|
||||
strategy1=None,
|
||||
strategy2=None,
|
||||
strategy3=None,
|
||||
axis=0,
|
||||
init_flag=True,
|
||||
split_tuple=(4, 4),
|
||||
split_string="manual_split",
|
||||
param_shape=(8, 8)):
|
||||
super().__init__()
|
||||
self.gatherv2 = P.GatherV2().set_strategy(strategy1)
|
||||
self.gatherv2.add_prim_attr("manual_split", ((1, 0), (7, 1)))
|
||||
self.gatherv2.add_prim_attr(split_string, split_tuple)
|
||||
self.mul = P.Mul().set_strategy(strategy2)
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.MatMul().set_strategy(strategy3)
|
||||
self.matmul.add_prim_attr("forward_reduce_scatter", True)
|
||||
self.param = Parameter(initializer("ones", (8, 64), ms.float32), name="gatherv2_param")
|
||||
self.mul_weight = Parameter(initializer("ones", (2, 4, 64), ms.float32), name="mul_weight")
|
||||
self.matmul_weight = Parameter(initializer("ones", (256, 16), ms.float32), name="matmul_weight")
|
||||
if init_flag:
|
||||
self.param = Parameter(initializer("ones", param_shape, ms.float32), name="gatherv2_param")
|
||||
else:
|
||||
self.param = Parameter(Tensor(np.ones(param_shape), dtype=ms.float32), name="gatherv2_param")
|
||||
self.mul_weight = Parameter(initializer("ones", (8, 8, 8), ms.float32), name="mul_weight")
|
||||
self.matmul_weight = Parameter(initializer("ones", (64, 16), ms.float32), name="matmul_weight")
|
||||
self.axis = axis
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.gatherv2(self.param, x, 0)
|
||||
out = self.gatherv2(self.param, x, self.axis)
|
||||
out = self.mul(out, self.mul_weight)
|
||||
out = self.reshape(out, (2, 256))
|
||||
out = self.reshape(out, (8, 64))
|
||||
out = self.matmul(out, self.matmul_weight)
|
||||
return out
|
||||
|
||||
_x = Tensor(np.ones([2, 4]), dtype=ms.int32)
|
||||
|
||||
_x = Tensor(np.ones([8, 8]), dtype=ms.int32)
|
||||
_b = Tensor(np.ones([64, 8]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
context.set_context(save_graphs=True)
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
_executor.compile(train_net, _x, _b)
|
||||
_executor.compile(train_net, _x, _b, auto_parallel_mode=True)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
def test_neg_data_parallel():
|
||||
context.set_context(save_graphs=True)
|
||||
|
||||
def test_normal_split():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
|
||||
strategy1 = ((2, 1), (1, 2))
|
||||
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||
strategy3 = ((1, 2), (2, 1))
|
||||
net = Net(strategy1, strategy2, strategy3)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_normal_split2():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
|
||||
strategy1 = ((4, 1), (1, 4))
|
||||
strategy2 = ((1, 4, 1), (1, 4, 1))
|
||||
strategy3 = ((1, 4), (4, 1))
|
||||
net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8))
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_normal_split3():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=17)
|
||||
strategy1 = ((4, 8), (1, 4))
|
||||
strategy2 = ((1, 4, 8), (1, 4, 8))
|
||||
strategy3 = ((1, 32), (32, 1))
|
||||
net = Net(strategy1, strategy2, strategy3, split_tuple=(10, 20, 30, 4), param_shape=(64, 8))
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_normal_split_with_offset():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
|
||||
strategy1 = ((2, 1), (1, 2))
|
||||
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||
strategy3 = ((1, 2), (2, 1))
|
||||
net = Net(strategy1, strategy2, strategy3, split_string="manual_split_with_offset", split_tuple=((4, 0), (4, 4)))
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_auto_parallel_error():
|
||||
context.set_context(save_graphs=True)
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2, global_rank=0)
|
||||
net = Net()
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_axis_error():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
|
||||
strategy1 = ((2, 1), (1, 2))
|
||||
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||
strategy3 = ((1, 2), (2, 1))
|
||||
net = Net(strategy1, strategy2, strategy3, axis=1)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_strategy_error():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((4, 1), (8, 1))
|
||||
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||
strategy3 = ((1, 2), (2, 1))
|
||||
net = Net(strategy1, strategy2, strategy3)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_strategy_error2():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((4, 1), (1, 8))
|
||||
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||
strategy3 = ((1, 2), (2, 1))
|
||||
net = Net(strategy1, strategy2, strategy3)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_strategy_error3():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 1), (1, 2))
|
||||
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||
strategy3 = ((1, 2), (2, 1))
|
||||
net = Net(strategy1, strategy2, strategy3)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_strategy_error4():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
|
||||
strategy1 = ((2, 8), (1, 2))
|
||||
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||
strategy3 = ((1, 2), (2, 1))
|
||||
net = Net(strategy1, strategy2, strategy3)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_strategy_error5():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
|
||||
strategy1 = ((4, 1), (1, 4))
|
||||
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||
strategy3 = ((1, 2), (2, 1))
|
||||
net = Net(strategy1, strategy2, strategy3)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_split_tuple_error():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
|
||||
strategy1 = ((2, 1), (1, 2))
|
||||
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||
strategy3 = ((1, 2), (2, 1))
|
||||
net = Net(strategy1, strategy2, strategy3, split_tuple=((5, 0), (5, 5)))
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_parameter_use_tensor_error():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
|
||||
strategy1 = ((2, 1), (1, 2))
|
||||
strategy2 = ((1, 2, 1), (1, 2, 1))
|
||||
strategy3 = ((1, 2), (2, 1))
|
||||
net = Net(strategy1, strategy2, strategy3, init_flag=False)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
|
Loading…
Reference in New Issue