!31941 modify Strided_slice for master

Merge pull request !31941 from lilei/modify_stridedslice_for_master
This commit is contained in:
i-robot 2022-03-31 09:11:24 +00:00 committed by Gitee
commit 72414ff8d9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 434 additions and 50 deletions

View File

@ -16,6 +16,7 @@
#include "frontend/parallel/ops_info/strided_slice_info.h"
#include <bitset>
#include <algorithm>
#include <memory>
#include <utility>
@ -51,6 +52,70 @@ Status StridedSliceInfo::GetMask(const std::string &mask_name, int64_t *mask_val
return SUCCESS;
}
constexpr auto kStridedSliceMaxDims = 8;
static std::vector<bool> Dec2Bin(int64_t mask) {
auto mask_str = std::bitset<kStridedSliceMaxDims>(mask).to_string();
int64_t dim_idx = 0;
std::vector<bool> result(kStridedSliceMaxDims, false);
for (int64_t i = mask_str.size() - 1; i >= 0; --i) {
if (mask_str[i] == '1') {
result[dim_idx] = true;
}
dim_idx++;
}
return result;
}
void StridedSliceInfo::ComputeBeginMask(int64_t begin_mask_) {
auto begin_mask = Dec2Bin(begin_mask_);
for (size_t i = 0; i < begin_mask.size(); ++i) {
if (i < kStridedSliceMaxDims && begin_mask[i]) {
begin_[i] = strides_[i] < 0 ? SizeToLong(inputs_shape_[0][i]) - 1 : 0;
}
}
}
void StridedSliceInfo::ComputeEndMask(int64_t end_mask_) {
auto end_mask = Dec2Bin(end_mask_);
for (size_t j = 0; j < end_mask.size(); ++j) {
if (j < kStridedSliceMaxDims && end_mask[j]) {
end_[j] = strides_[j] < 0 ? -1 : SizeToLong(inputs_shape_[0][j]);
}
}
}
void StridedSliceInfo::ComputeEllipsisMask(int64_t ellipsis_mask_) {
auto ellipsis_mask = Dec2Bin(ellipsis_mask_);
for (size_t k = 0; k < ellipsis_mask.size(); ++k) {
if (k < kStridedSliceMaxDims && ellipsis_mask[k]) {
begin_[k] = 0;
end_[k] = SizeToLong(inputs_shape_[0][k]);
strides_[k] = 1;
}
}
}
void StridedSliceInfo::ComputeNewAxisMask(int64_t new_axis_mask_) {
auto new_axis_mask = Dec2Bin(new_axis_mask_);
for (size_t l = 0; l < new_axis_mask.size(); ++l) {
if (l < kStridedSliceMaxDims && new_axis_mask[l]) {
begin_[l] = 0;
end_[l] = SizeToLong(inputs_shape_[0][l]);
strides_[l] = 1;
}
}
}
void StridedSliceInfo::ComputShrinkAxisMask(int64_t shrink_axis_mask_) {
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_);
for (size_t m = 0; m < shrink_axis_mask.size(); ++m) {
if (m < kStridedSliceMaxDims && shrink_axis_mask[m]) {
end_[m] = end_[m] > begin_[m] ? begin_[m] + 1 : begin_[m] - 1;
strides_[m] = end_[m] > begin_[m] ? 1 : -1;
}
}
}
Status StridedSliceInfo::GetAttrs() {
if (attrs_.size() < STRIDED_SLICE_ATTRS_SIZE) {
MS_LOG(ERROR) << name_ << ": The size of attrs small than " << STRIDED_SLICE_ATTRS_SIZE;
@ -62,7 +127,6 @@ Status StridedSliceInfo::GetAttrs() {
(GetMask(SHRINK_AXIS_MASK, &shrink_axis_mask_) != SUCCESS)) {
return FAILED;
}
has_mask_ = ((ellipsis_mask_ != 0) || (new_axis_mask_ != 0) || (shrink_axis_mask_ != 0));
if (input_value_.size() != STRIDED_SLICE_INPUTS_SIZE) {
MS_LOG(ERROR) << name_ << ": The size of input value must be " << STRIDED_SLICE_INPUTS_SIZE << ", but got "
@ -75,7 +139,11 @@ Status StridedSliceInfo::GetAttrs() {
(TransValueSequeueToVector(input_value_[STRIDED_SLICE_STRIDES_INDEX], &strides_) != SUCCESS)) {
return FAILED;
}
ComputeBeginMask(begin_mask_);
ComputeEndMask(end_mask_);
ComputeEllipsisMask(ellipsis_mask_);
ComputeNewAxisMask(new_axis_mask_);
ComputShrinkAxisMask(shrink_axis_mask_);
return SUCCESS;
}
@ -93,12 +161,6 @@ Status StridedSliceInfo::CheckStrategy(const StrategyPtr &strategy) {
}
Dimensions strategy_value = stra[0];
bool has_split = std::any_of(strategy_value.begin(), strategy_value.end(), [](int64_t v) { return v > 1; });
if (has_split && has_mask_) {
MS_LOG(ERROR) << name_ << ": When there is a mask, the input is not supported to be split";
return FAILED;
}
if (strategy_value.size() < strides_.size()) {
MS_LOG(ERROR) << name_ << ": The size of strategy must be larger or equal to the size of strides";
return FAILED;
@ -153,6 +215,8 @@ Status StridedSliceInfo::InferTensorMap() {
}
inputs_tensor_map_.push_back(tensor_map);
if (new_axis_mask_ != 0) tensor_map.insert(tensor_map.begin() + (new_axis_mask_ - 1), -1);
if (shrink_axis_mask_ != 0) tensor_map.erase(tensor_map.begin() + (shrink_axis_mask_ - 1));
outputs_tensor_map_.push_back(tensor_map);
return SUCCESS;
}
@ -196,17 +260,11 @@ Status StridedSliceInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
std::vector<StrategyPtr> StridedSliceInfo::GenerateOpStrategies(int64_t stage_id) {
Shape input_split(inputs_shape_[0].size(), 1);
if (has_mask_) {
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
for (size_t i = 0; i < begin_.size(); ++i) {
bool no_fully_fetch = ((begin_[i] != 0) || (end_[i] < inputs_shape_[0][i]));
if (no_fully_fetch || (strides_[i] != 1)) {
input_split[i] = 0;
}
} else {
for (size_t i = 0; i < begin_.size(); ++i) {
bool no_fully_fetch = ((begin_[i] != 0) || (end_[i] < inputs_shape_[0][i]));
if (no_fully_fetch || (strides_[i] != 1)) {
input_split[i] = 0;
}
}
}
Shapes splittable_inputs = {input_split};

View File

@ -40,6 +40,11 @@ class StridedSliceInfo : public OperatorInfo {
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
void ComputeBeginMask(int64_t begin_mask_);
void ComputeEndMask(int64_t end_mask_);
void ComputeEllipsisMask(int64_t ellipsis_mask_);
void ComputeNewAxisMask(int64_t new_axis_mask_);
void ComputShrinkAxisMask(int64_t shrink_axis_mask_);
protected:
Status GetAttrs() override;
@ -59,7 +64,6 @@ class StridedSliceInfo : public OperatorInfo {
int64_t ellipsis_mask_ = 0;
int64_t new_axis_mask_ = 0;
int64_t shrink_axis_mask_ = 0;
bool has_mask_ = false;
};
using StridedSliceInfoPtr = std::shared_ptr<StridedSliceInfo>;

View File

@ -24,7 +24,6 @@ from collections import Counter
import numpy as np
from mindspore import log as logger
from mindspore import context
from mindspore.common.initializer import Zero
from .. import signature as sig
from .._utils import get_broadcast_shape, is_shape_unknown
@ -3505,10 +3504,6 @@ class StridedSlice(PrimitiveWithInfer):
"""Initialize StridedSlice"""
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
# auto parallel haven't support begin_mask and end_mask
if context.get_auto_parallel_context("parallel_mode") in ["semi_auto_parallel", "auto_parallel"]:
begin_mask = 0
end_mask = 0
validator.check_non_negative_int(begin_mask, 'begin_mask', self.name)
validator.check_non_negative_int(end_mask, 'end_mask', self.name)
validator.check_non_negative_int(ellipsis_mask, 'ellipsis_mask', self.name)

View File

@ -20,13 +20,18 @@ from mindspore import context, Tensor, Parameter
from mindspore.common.api import _cell_graph_executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P
from parallel.utils.utils import ParallelValidator
class Net(Cell):
def __init__(self, weight, w2, begin, end, strides, strategy1=None, strategy2=None, is_parameter=True, mask=0):
def __init__(self, weight, w2, begin, end, strides, strategy1=None, strategy2=None, is_parameter=True,
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0):
super().__init__()
self.mul = P.Mul().shard(strategy1)
self.strided_slice = P.StridedSlice(begin_mask=mask).shard(strategy2)
self.strided_slice = P.StridedSlice(begin_mask=begin_mask,
end_mask=end_mask,
ellipsis_mask=ellipsis_mask, new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask).shard(strategy2)
if is_parameter:
self.weight = Parameter(weight, "w1")
else:
@ -45,10 +50,14 @@ class Net(Cell):
class Net2(Cell):
def __init__(self, weight2, begin, end, strides, strategy1=None, strategy2=None):
def __init__(self, weight2, begin, end, strides, strategy1=None, strategy2=None,
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0):
super().__init__()
self.mul = P.Mul().shard(strategy1)
self.strided_slice = P.StridedSlice().shard(strategy2)
self.strided_slice = P.StridedSlice(begin_mask=begin_mask,
end_mask=end_mask,
ellipsis_mask=ellipsis_mask, new_axis_mask=new_axis_mask,
shrink_axis_mask=shrink_axis_mask).shard(strategy2)
self.weight2 = Parameter(weight2, "w2")
self.begin = begin
self.end = end
@ -60,105 +69,423 @@ class Net2(Cell):
return out
_x = Tensor(np.ones([128, 64, 1]), dtype=ms.float32)
_x1 = Tensor(np.ones([128, 64, 1]), dtype=ms.float32)
_x2 = Tensor(np.ones([1, 64, 32, 32]), dtype=ms.float32)
_x3 = Tensor(np.ones([64, 32]), dtype=ms.float32)
_w1 = Tensor(np.ones([256, 64, 32]), dtype=ms.float32)
_w2 = Tensor(np.ones([128, 64, 1]), dtype=ms.float32)
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_w3 = Tensor(np.ones([1, 64, 32, 32]), dtype=ms.float32)
_b1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_b2 = Tensor(np.ones([1, 64, 32, 32]), dtype=ms.float32)
def compile_net(net):
def compile_net(net, _x1, _b1):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_cell_graph_executor.compile(train_net, _x, _b)
_cell_graph_executor.compile(train_net, _x1, _b1)
context.reset_auto_parallel_context()
def compile_net_utils(net: Cell, *inputs):
net.set_auto_parallel()
net.set_train()
phase, _ = _cell_graph_executor.compile(net, *inputs, auto_parallel_mode=True)
context.reset_auto_parallel_context()
return phase
def test_stridedslice_no_fully_fetch_split_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 2), (2, 2, 2))
strategy2 = ((2, 2, 2),)
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True)
with pytest.raises(RuntimeError):
compile_net(net)
compile_net(net, _x1, _b1)
def test_stridedslice_strides_no_1_split_error():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with strides no 1 split in semi auto parallel.
Expectation: compile error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 2), (2, 2, 2))
strategy2 = ((1, 2, 2),)
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 2), strategy1, strategy2, is_parameter=True)
with pytest.raises(RuntimeError):
compile_net(net)
def test_stridedslice_mask_no_0_split_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 2), (2, 2, 2))
strategy2 = ((1, 2, 2),)
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True, mask=1)
with pytest.raises(RuntimeError):
compile_net(net)
compile_net(net, _x1, _b1)
def test_stridedslice_begin_size_smaller():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with begin size is smaller in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4, 1), (1, 4, 2))
strategy2 = ((1, 4, 2),)
net = Net(_w1, _w2, (0, 0), (128, 64), (1, 1), strategy1, strategy2, is_parameter=True)
compile_net(net)
compile_net(net, _x1, _b1)
def test_stridedslice_parameter():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice of parameter in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4, 1), (1, 4, 2))
strategy2 = ((1, 4, 2),)
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True)
compile_net(net)
compile_net(net, _x1, _b1)
def test_stridedslice_begin_mask_no_0_split_parameter():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with begin mask no 0 split in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4, 1), (1, 4, 2))
strategy2 = ((1, 4, 2),)
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True, begin_mask=1)
compile_net(net, _x1, _b1)
def test_stridedslice_end_mask_no_0_parameter():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with end mask no 0 in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4, 1), (1, 4, 2))
strategy2 = ((1, 4, 2),)
net = Net(_w1, _w2, (127, 0, 0), (128, 63, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True,
begin_mask=1, end_mask=2)
compile_net(net, _x1, _b1)
def test_stridedslice_ellipsis_mask_no_0_parameter():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with ellipsis mask no 0 in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4, 1), (1, 4, 2))
strategy2 = ((1, 4, 2),)
net = Net(_w1, _w2, (127, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True,
begin_mask=1, end_mask=2, ellipsis_mask=4)
compile_net(net, _x1, _b1)
def test_stridedslice_new_axis_mask_no_0_parameter():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with new axis mask no 0 in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4, 2, 1), (1, 4, 2, 1))
strategy2 = ((1, 1, 4),)
net = Net(_w1, _w3, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True,
new_axis_mask=1)
compile_net(net, _x2, _b2)
def test_stridedslice_shrink_axis_mask_no_0_parameter():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with shrink axis mask no 0 in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 2), (1, 2))
strategy2 = ((1, 4, 1),)
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True,
shrink_axis_mask=1)
compile_net(net, _x3, _b1)
def test_stridedslice_tensor():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice of tensor in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4, 1), (1, 4, 2))
strategy2 = ((1, 4, 2),)
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False)
compile_net(net)
compile_net(net, _x1, _b1)
def test_stridedslice_begin_mask_no_0_tensor():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with begin mask no 0 in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4, 1), (1, 4, 2))
strategy2 = ((1, 4, 2),)
net = Net(_w1, _w2, (127, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False, begin_mask=1)
compile_net(net, _x1, _b1)
def test_stridedslice_end_mask_no_0_tensor():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with end mask no 0 in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4, 1), (1, 4, 2))
strategy2 = ((1, 4, 2),)
net = Net(_w1, _w2, (0, 0, 0), (128, 63, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False, end_mask=2)
compile_net(net, _x1, _b1)
def test_stridedslice_ellipsis_mask_no_0_tensor():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with ellipsis mask no 0 in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4, 1), (1, 4, 2))
strategy2 = ((1, 4, 2),)
net = Net(_w1, _w2, (127, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False,
begin_mask=1, end_mask=2, ellipsis_mask=4)
compile_net(net, _x1, _b1)
def test_stridedslice_new_axis_mask_no_0_tensor():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with new axis mask no 0 in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4, 2, 1), (1, 4, 2, 1))
strategy2 = ((1, 1, 4),)
net = Net(_w1, _w3, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False,
new_axis_mask=1)
compile_net(net, _x2, _b2)
def test_stridedslice_shrink_axis_mask_no_0_tensor():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with shrink axis mask no 0 in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 2), (1, 2))
strategy2 = ((1, 4, 1),)
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False,
shrink_axis_mask=1)
compile_net(net, _x3, _b1)
def test_stridedslice_parameter_no_full_split():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with no full split in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4, 1), (1, 4, 2))
strategy2 = ((1, 2, 2),)
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True)
compile_net(net)
compile_net(net, _x1, _b1)
def test_stridedslice_output():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice of output in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8, 1), (1, 8, 1))
strategy2 = ((1, 8, 1),)
net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2)
compile_net(net)
compile_net(net, _x1, _b1)
def test_stridedslice_begin_mask_no_0_output():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with begin mask no 0 in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8, 1), (1, 8, 1))
strategy2 = ((1, 8, 1),)
net = Net2(_w2, (61, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2, begin_mask=1)
compile_net(net, _x1, _b1)
def test_stridedslice_end_mask_no_0_output():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with end mask no 0 in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8, 1), (1, 8, 1))
strategy2 = ((1, 8, 1),)
net = Net2(_w2, (0, 0, 0), (64, 63, 1), (1, 1, 1), strategy1, strategy2, end_mask=2)
compile_net(net, _x1, _b1)
def test_stridedslice_ellipsis_mask_no_0_output():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with ellipsis mask no 0 in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8, 1), (1, 8, 1))
strategy2 = ((1, 8, 1),)
net = Net2(_w2, (63, 0, 0), (64, 63, 1), (1, 1, 1), strategy1, strategy2,
begin_mask=1, end_mask=2, ellipsis_mask=4)
compile_net(net, _x1, _b1)
def test_stridedslice_new_axis_mask_no_0_output():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with new axis mask no 0 in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8, 1), (1, 8, 1))
strategy2 = ((8, 1, 1),)
net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2, new_axis_mask=1)
compile_net(net, _x1, _b1)
def test_stridedslice_shrink_axis_mask_no_0_output():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with shrink axis mask no 0 in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8, 1), (1, 8, 1))
strategy2 = ((1, 8, 1),)
net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2, shrink_axis_mask=1)
compile_net(net, _x1, _b1)
def test_stridedslice_output_no_full_split():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with no full split in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8, 1), (1, 8, 1))
strategy2 = ((1, 4, 1),)
net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2)
compile_net(net)
compile_net(net, _x1, _b1)
def test_stridedslice_no_strategy():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with no strategy in semi auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8, 1), (1, 8, 1))
strategy2 = None
net = Net2(_w2, (0, 0, 0), (128, 64, 1), (1, 1, 1), strategy1, strategy2)
compile_net(net)
compile_net(net, _x1, _b1)
def test_stridedslice_begin_mask_no_0_no_strategy():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with begin mask no 0 in auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 8, 1), (1, 8, 1))
strategy2 = None
net = Net2(_w2, (127, 0, 0), (128, 64, 1), (1, 1, 1), strategy1, strategy2, begin_mask=1)
compile_net(net, _x1, _b1)
def test_stridedslice_auto_parallel():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice in auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net2(_w2, (0, 0, 0), (32, 64, 1), (1, 1, 1))
compile_net(net)
compile_net(net, _x1, _b1)
def test_stridedslice_begin_mask_no_0_auto_parallel():
"""
Feature: distribute operator stridedslice in auto parallel mode.
Description: test stridedslice with begin mask no 0 in auto parallel.
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net2(_w2, (29, 0, 0), (32, 64, 1), (1, 1, 1), begin_mask=1)
compile_net(net, _x1, _b1)
def test_stridedslice_layout():
"""
Features: StridedSlice
Description: validate layout and structure
Expectation: No raise RuntimeError
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4, 1), (1, 4, 2))
strategy2 = ((1, 4, 2),)
net = Net(_w1, _w2, (127, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True,
begin_mask=1, end_mask=2, ellipsis_mask=4)
phase = compile_net_utils(net, _x1, _b1)
validator = ParallelValidator(net, phase)
# check layout
features_expect_layout = ([4, 2], [-1, 1, 0], [256, 16, 16], 0, True, '')
assert validator.check_parameter_layout('w1', features_expect_layout)
# check attrs
roi_expect_attrs = {'begin_mask': 1, 'end_mask': 2, 'ellipsis_mask': 4}
assert validator.check_node_attrs('StridedSlice-1', roi_expect_attrs)
# check inputs
roi_expect_inputs = ['Load-0', 'out((127, 0, 0))', 'out((128, 64, 32))', 'out((1, 1, 1))']
assert validator.check_node_inputs('StridedSlice-1', roi_expect_inputs)
# check sub_graph
sub_graph = {
'StridedSlice-1': ['Load-0', 'out((127, 0, 0))', 'out((128, 64, 32))', 'out((1, 1, 1))'],
'Mul-0': ['Reshape-1', 'StridedSlice-1'],
'AllGather-2': ['Reshape-2'],
'Split-1': ['AllGather-2'],
'TupleGetItem-3': ['Split-1', 0],
'TupleGetItem-4': ['Split-1', 1],
'TupleGetItem-5': ['Split-1', 2],
'TupleGetItem-6': ['Split-1', 3],
'MakeTuple-2': ['TupleGetItem-3', 'TupleGetItem-4', 'TupleGetItem-5', 'TupleGetItem-6'],
'Concat-1': ['MakeTuple-2']
}
assert validator.check_graph_structure(sub_graph)