forked from mindspore-Ecosystem/mindspore
!31941 modify Strided_slice for master
Merge pull request !31941 from lilei/modify_stridedslice_for_master
This commit is contained in:
commit
72414ff8d9
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "frontend/parallel/ops_info/strided_slice_info.h"
|
||||
|
||||
#include <bitset>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
@ -51,6 +52,70 @@ Status StridedSliceInfo::GetMask(const std::string &mask_name, int64_t *mask_val
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
constexpr auto kStridedSliceMaxDims = 8;
|
||||
static std::vector<bool> Dec2Bin(int64_t mask) {
|
||||
auto mask_str = std::bitset<kStridedSliceMaxDims>(mask).to_string();
|
||||
int64_t dim_idx = 0;
|
||||
std::vector<bool> result(kStridedSliceMaxDims, false);
|
||||
for (int64_t i = mask_str.size() - 1; i >= 0; --i) {
|
||||
if (mask_str[i] == '1') {
|
||||
result[dim_idx] = true;
|
||||
}
|
||||
dim_idx++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void StridedSliceInfo::ComputeBeginMask(int64_t begin_mask_) {
|
||||
auto begin_mask = Dec2Bin(begin_mask_);
|
||||
for (size_t i = 0; i < begin_mask.size(); ++i) {
|
||||
if (i < kStridedSliceMaxDims && begin_mask[i]) {
|
||||
begin_[i] = strides_[i] < 0 ? SizeToLong(inputs_shape_[0][i]) - 1 : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void StridedSliceInfo::ComputeEndMask(int64_t end_mask_) {
|
||||
auto end_mask = Dec2Bin(end_mask_);
|
||||
for (size_t j = 0; j < end_mask.size(); ++j) {
|
||||
if (j < kStridedSliceMaxDims && end_mask[j]) {
|
||||
end_[j] = strides_[j] < 0 ? -1 : SizeToLong(inputs_shape_[0][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void StridedSliceInfo::ComputeEllipsisMask(int64_t ellipsis_mask_) {
|
||||
auto ellipsis_mask = Dec2Bin(ellipsis_mask_);
|
||||
for (size_t k = 0; k < ellipsis_mask.size(); ++k) {
|
||||
if (k < kStridedSliceMaxDims && ellipsis_mask[k]) {
|
||||
begin_[k] = 0;
|
||||
end_[k] = SizeToLong(inputs_shape_[0][k]);
|
||||
strides_[k] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void StridedSliceInfo::ComputeNewAxisMask(int64_t new_axis_mask_) {
|
||||
auto new_axis_mask = Dec2Bin(new_axis_mask_);
|
||||
for (size_t l = 0; l < new_axis_mask.size(); ++l) {
|
||||
if (l < kStridedSliceMaxDims && new_axis_mask[l]) {
|
||||
begin_[l] = 0;
|
||||
end_[l] = SizeToLong(inputs_shape_[0][l]);
|
||||
strides_[l] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void StridedSliceInfo::ComputShrinkAxisMask(int64_t shrink_axis_mask_) {
|
||||
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_);
|
||||
for (size_t m = 0; m < shrink_axis_mask.size(); ++m) {
|
||||
if (m < kStridedSliceMaxDims && shrink_axis_mask[m]) {
|
||||
end_[m] = end_[m] > begin_[m] ? begin_[m] + 1 : begin_[m] - 1;
|
||||
strides_[m] = end_[m] > begin_[m] ? 1 : -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status StridedSliceInfo::GetAttrs() {
|
||||
if (attrs_.size() < STRIDED_SLICE_ATTRS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of attrs small than " << STRIDED_SLICE_ATTRS_SIZE;
|
||||
|
@ -62,7 +127,6 @@ Status StridedSliceInfo::GetAttrs() {
|
|||
(GetMask(SHRINK_AXIS_MASK, &shrink_axis_mask_) != SUCCESS)) {
|
||||
return FAILED;
|
||||
}
|
||||
has_mask_ = ((ellipsis_mask_ != 0) || (new_axis_mask_ != 0) || (shrink_axis_mask_ != 0));
|
||||
|
||||
if (input_value_.size() != STRIDED_SLICE_INPUTS_SIZE) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of input value must be " << STRIDED_SLICE_INPUTS_SIZE << ", but got "
|
||||
|
@ -75,7 +139,11 @@ Status StridedSliceInfo::GetAttrs() {
|
|||
(TransValueSequeueToVector(input_value_[STRIDED_SLICE_STRIDES_INDEX], &strides_) != SUCCESS)) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
ComputeBeginMask(begin_mask_);
|
||||
ComputeEndMask(end_mask_);
|
||||
ComputeEllipsisMask(ellipsis_mask_);
|
||||
ComputeNewAxisMask(new_axis_mask_);
|
||||
ComputShrinkAxisMask(shrink_axis_mask_);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -93,12 +161,6 @@ Status StridedSliceInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
}
|
||||
|
||||
Dimensions strategy_value = stra[0];
|
||||
bool has_split = std::any_of(strategy_value.begin(), strategy_value.end(), [](int64_t v) { return v > 1; });
|
||||
if (has_split && has_mask_) {
|
||||
MS_LOG(ERROR) << name_ << ": When there is a mask, the input is not supported to be split";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (strategy_value.size() < strides_.size()) {
|
||||
MS_LOG(ERROR) << name_ << ": The size of strategy must be larger or equal to the size of strides";
|
||||
return FAILED;
|
||||
|
@ -153,6 +215,8 @@ Status StridedSliceInfo::InferTensorMap() {
|
|||
}
|
||||
|
||||
inputs_tensor_map_.push_back(tensor_map);
|
||||
if (new_axis_mask_ != 0) tensor_map.insert(tensor_map.begin() + (new_axis_mask_ - 1), -1);
|
||||
if (shrink_axis_mask_ != 0) tensor_map.erase(tensor_map.begin() + (shrink_axis_mask_ - 1));
|
||||
outputs_tensor_map_.push_back(tensor_map);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
@ -196,17 +260,11 @@ Status StridedSliceInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
|||
|
||||
std::vector<StrategyPtr> StridedSliceInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
Shape input_split(inputs_shape_[0].size(), 1);
|
||||
if (has_mask_) {
|
||||
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
|
||||
for (size_t i = 0; i < begin_.size(); ++i) {
|
||||
bool no_fully_fetch = ((begin_[i] != 0) || (end_[i] < inputs_shape_[0][i]));
|
||||
if (no_fully_fetch || (strides_[i] != 1)) {
|
||||
input_split[i] = 0;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < begin_.size(); ++i) {
|
||||
bool no_fully_fetch = ((begin_[i] != 0) || (end_[i] < inputs_shape_[0][i]));
|
||||
if (no_fully_fetch || (strides_[i] != 1)) {
|
||||
input_split[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
Shapes splittable_inputs = {input_split};
|
||||
|
||||
|
|
|
@ -40,6 +40,11 @@ class StridedSliceInfo : public OperatorInfo {
|
|||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
|
||||
void ComputeBeginMask(int64_t begin_mask_);
|
||||
void ComputeEndMask(int64_t end_mask_);
|
||||
void ComputeEllipsisMask(int64_t ellipsis_mask_);
|
||||
void ComputeNewAxisMask(int64_t new_axis_mask_);
|
||||
void ComputShrinkAxisMask(int64_t shrink_axis_mask_);
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override;
|
||||
|
@ -59,7 +64,6 @@ class StridedSliceInfo : public OperatorInfo {
|
|||
int64_t ellipsis_mask_ = 0;
|
||||
int64_t new_axis_mask_ = 0;
|
||||
int64_t shrink_axis_mask_ = 0;
|
||||
bool has_mask_ = false;
|
||||
};
|
||||
|
||||
using StridedSliceInfoPtr = std::shared_ptr<StridedSliceInfo>;
|
||||
|
|
|
@ -24,7 +24,6 @@ from collections import Counter
|
|||
import numpy as np
|
||||
|
||||
from mindspore import log as logger
|
||||
from mindspore import context
|
||||
from mindspore.common.initializer import Zero
|
||||
from .. import signature as sig
|
||||
from .._utils import get_broadcast_shape, is_shape_unknown
|
||||
|
@ -3505,10 +3504,6 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
"""Initialize StridedSlice"""
|
||||
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
|
||||
|
||||
# auto parallel haven't support begin_mask and end_mask
|
||||
if context.get_auto_parallel_context("parallel_mode") in ["semi_auto_parallel", "auto_parallel"]:
|
||||
begin_mask = 0
|
||||
end_mask = 0
|
||||
validator.check_non_negative_int(begin_mask, 'begin_mask', self.name)
|
||||
validator.check_non_negative_int(end_mask, 'end_mask', self.name)
|
||||
validator.check_non_negative_int(ellipsis_mask, 'ellipsis_mask', self.name)
|
||||
|
|
|
@ -20,13 +20,18 @@ from mindspore import context, Tensor, Parameter
|
|||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from parallel.utils.utils import ParallelValidator
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, weight, w2, begin, end, strides, strategy1=None, strategy2=None, is_parameter=True, mask=0):
|
||||
def __init__(self, weight, w2, begin, end, strides, strategy1=None, strategy2=None, is_parameter=True,
|
||||
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().shard(strategy1)
|
||||
self.strided_slice = P.StridedSlice(begin_mask=mask).shard(strategy2)
|
||||
self.strided_slice = P.StridedSlice(begin_mask=begin_mask,
|
||||
end_mask=end_mask,
|
||||
ellipsis_mask=ellipsis_mask, new_axis_mask=new_axis_mask,
|
||||
shrink_axis_mask=shrink_axis_mask).shard(strategy2)
|
||||
if is_parameter:
|
||||
self.weight = Parameter(weight, "w1")
|
||||
else:
|
||||
|
@ -45,10 +50,14 @@ class Net(Cell):
|
|||
|
||||
|
||||
class Net2(Cell):
|
||||
def __init__(self, weight2, begin, end, strides, strategy1=None, strategy2=None):
|
||||
def __init__(self, weight2, begin, end, strides, strategy1=None, strategy2=None,
|
||||
begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().shard(strategy1)
|
||||
self.strided_slice = P.StridedSlice().shard(strategy2)
|
||||
self.strided_slice = P.StridedSlice(begin_mask=begin_mask,
|
||||
end_mask=end_mask,
|
||||
ellipsis_mask=ellipsis_mask, new_axis_mask=new_axis_mask,
|
||||
shrink_axis_mask=shrink_axis_mask).shard(strategy2)
|
||||
self.weight2 = Parameter(weight2, "w2")
|
||||
self.begin = begin
|
||||
self.end = end
|
||||
|
@ -60,105 +69,423 @@ class Net2(Cell):
|
|||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([128, 64, 1]), dtype=ms.float32)
|
||||
_x1 = Tensor(np.ones([128, 64, 1]), dtype=ms.float32)
|
||||
_x2 = Tensor(np.ones([1, 64, 32, 32]), dtype=ms.float32)
|
||||
_x3 = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([256, 64, 32]), dtype=ms.float32)
|
||||
_w2 = Tensor(np.ones([128, 64, 1]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
_w3 = Tensor(np.ones([1, 64, 32, 32]), dtype=ms.float32)
|
||||
_b1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
_b2 = Tensor(np.ones([1, 64, 32, 32]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
def compile_net(net, _x1, _b1):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
train_net.set_train()
|
||||
_cell_graph_executor.compile(train_net, _x, _b)
|
||||
_cell_graph_executor.compile(train_net, _x1, _b1)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def compile_net_utils(net: Cell, *inputs):
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
phase, _ = _cell_graph_executor.compile(net, *inputs, auto_parallel_mode=True)
|
||||
context.reset_auto_parallel_context()
|
||||
return phase
|
||||
|
||||
|
||||
def test_stridedslice_no_fully_fetch_split_error():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||
strategy2 = ((2, 2, 2),)
|
||||
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_strides_no_1_split_error():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with strides no 1 split in semi auto parallel.
|
||||
Expectation: compile error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||
strategy2 = ((1, 2, 2),)
|
||||
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 2), strategy1, strategy2, is_parameter=True)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_stridedslice_mask_no_0_split_error():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 2, 2), (2, 2, 2))
|
||||
strategy2 = ((1, 2, 2),)
|
||||
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True, mask=1)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_begin_size_smaller():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with begin size is smaller in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 1), (1, 4, 2))
|
||||
strategy2 = ((1, 4, 2),)
|
||||
net = Net(_w1, _w2, (0, 0), (128, 64), (1, 1), strategy1, strategy2, is_parameter=True)
|
||||
compile_net(net)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_parameter():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice of parameter in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 1), (1, 4, 2))
|
||||
strategy2 = ((1, 4, 2),)
|
||||
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True)
|
||||
compile_net(net)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_begin_mask_no_0_split_parameter():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with begin mask no 0 split in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 1), (1, 4, 2))
|
||||
strategy2 = ((1, 4, 2),)
|
||||
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True, begin_mask=1)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_end_mask_no_0_parameter():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with end mask no 0 in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 1), (1, 4, 2))
|
||||
strategy2 = ((1, 4, 2),)
|
||||
net = Net(_w1, _w2, (127, 0, 0), (128, 63, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True,
|
||||
begin_mask=1, end_mask=2)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_ellipsis_mask_no_0_parameter():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with ellipsis mask no 0 in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 1), (1, 4, 2))
|
||||
strategy2 = ((1, 4, 2),)
|
||||
net = Net(_w1, _w2, (127, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True,
|
||||
begin_mask=1, end_mask=2, ellipsis_mask=4)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_new_axis_mask_no_0_parameter():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with new axis mask no 0 in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 2, 1), (1, 4, 2, 1))
|
||||
strategy2 = ((1, 1, 4),)
|
||||
net = Net(_w1, _w3, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True,
|
||||
new_axis_mask=1)
|
||||
compile_net(net, _x2, _b2)
|
||||
|
||||
|
||||
def test_stridedslice_shrink_axis_mask_no_0_parameter():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with shrink axis mask no 0 in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 2), (1, 2))
|
||||
strategy2 = ((1, 4, 1),)
|
||||
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True,
|
||||
shrink_axis_mask=1)
|
||||
compile_net(net, _x3, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_tensor():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice of tensor in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 1), (1, 4, 2))
|
||||
strategy2 = ((1, 4, 2),)
|
||||
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False)
|
||||
compile_net(net)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_begin_mask_no_0_tensor():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with begin mask no 0 in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 1), (1, 4, 2))
|
||||
strategy2 = ((1, 4, 2),)
|
||||
net = Net(_w1, _w2, (127, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False, begin_mask=1)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_end_mask_no_0_tensor():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with end mask no 0 in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 1), (1, 4, 2))
|
||||
strategy2 = ((1, 4, 2),)
|
||||
net = Net(_w1, _w2, (0, 0, 0), (128, 63, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False, end_mask=2)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_ellipsis_mask_no_0_tensor():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with ellipsis mask no 0 in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 1), (1, 4, 2))
|
||||
strategy2 = ((1, 4, 2),)
|
||||
net = Net(_w1, _w2, (127, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False,
|
||||
begin_mask=1, end_mask=2, ellipsis_mask=4)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_new_axis_mask_no_0_tensor():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with new axis mask no 0 in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 2, 1), (1, 4, 2, 1))
|
||||
strategy2 = ((1, 1, 4),)
|
||||
net = Net(_w1, _w3, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False,
|
||||
new_axis_mask=1)
|
||||
compile_net(net, _x2, _b2)
|
||||
|
||||
|
||||
def test_stridedslice_shrink_axis_mask_no_0_tensor():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with shrink axis mask no 0 in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 2), (1, 2))
|
||||
strategy2 = ((1, 4, 1),)
|
||||
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False,
|
||||
shrink_axis_mask=1)
|
||||
compile_net(net, _x3, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_parameter_no_full_split():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with no full split in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 1), (1, 4, 2))
|
||||
strategy2 = ((1, 2, 2),)
|
||||
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True)
|
||||
compile_net(net)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_output():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice of output in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 8, 1), (1, 8, 1))
|
||||
strategy2 = ((1, 8, 1),)
|
||||
net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2)
|
||||
compile_net(net)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_begin_mask_no_0_output():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with begin mask no 0 in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 8, 1), (1, 8, 1))
|
||||
strategy2 = ((1, 8, 1),)
|
||||
net = Net2(_w2, (61, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2, begin_mask=1)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_end_mask_no_0_output():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with end mask no 0 in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 8, 1), (1, 8, 1))
|
||||
strategy2 = ((1, 8, 1),)
|
||||
net = Net2(_w2, (0, 0, 0), (64, 63, 1), (1, 1, 1), strategy1, strategy2, end_mask=2)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_ellipsis_mask_no_0_output():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with ellipsis mask no 0 in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 8, 1), (1, 8, 1))
|
||||
strategy2 = ((1, 8, 1),)
|
||||
net = Net2(_w2, (63, 0, 0), (64, 63, 1), (1, 1, 1), strategy1, strategy2,
|
||||
begin_mask=1, end_mask=2, ellipsis_mask=4)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_new_axis_mask_no_0_output():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with new axis mask no 0 in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 8, 1), (1, 8, 1))
|
||||
strategy2 = ((8, 1, 1),)
|
||||
net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2, new_axis_mask=1)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_shrink_axis_mask_no_0_output():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with shrink axis mask no 0 in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 8, 1), (1, 8, 1))
|
||||
strategy2 = ((1, 8, 1),)
|
||||
net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2, shrink_axis_mask=1)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_output_no_full_split():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with no full split in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 8, 1), (1, 8, 1))
|
||||
strategy2 = ((1, 4, 1),)
|
||||
net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2)
|
||||
compile_net(net)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_no_strategy():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with no strategy in semi auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 8, 1), (1, 8, 1))
|
||||
strategy2 = None
|
||||
net = Net2(_w2, (0, 0, 0), (128, 64, 1), (1, 1, 1), strategy1, strategy2)
|
||||
compile_net(net)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_begin_mask_no_0_no_strategy():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with begin mask no 0 in auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 8, 1), (1, 8, 1))
|
||||
strategy2 = None
|
||||
net = Net2(_w2, (127, 0, 0), (128, 64, 1), (1, 1, 1), strategy1, strategy2, begin_mask=1)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_auto_parallel():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice in auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net2(_w2, (0, 0, 0), (32, 64, 1), (1, 1, 1))
|
||||
compile_net(net)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
|
||||
def test_stridedslice_begin_mask_no_0_auto_parallel():
|
||||
"""
|
||||
Feature: distribute operator stridedslice in auto parallel mode.
|
||||
Description: test stridedslice with begin mask no 0 in auto parallel.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net2(_w2, (29, 0, 0), (32, 64, 1), (1, 1, 1), begin_mask=1)
|
||||
compile_net(net, _x1, _b1)
|
||||
|
||||
|
||||
def test_stridedslice_layout():
|
||||
"""
|
||||
Features: StridedSlice
|
||||
Description: validate layout and structure
|
||||
Expectation: No raise RuntimeError
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 1), (1, 4, 2))
|
||||
strategy2 = ((1, 4, 2),)
|
||||
net = Net(_w1, _w2, (127, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True,
|
||||
begin_mask=1, end_mask=2, ellipsis_mask=4)
|
||||
phase = compile_net_utils(net, _x1, _b1)
|
||||
validator = ParallelValidator(net, phase)
|
||||
|
||||
# check layout
|
||||
features_expect_layout = ([4, 2], [-1, 1, 0], [256, 16, 16], 0, True, '')
|
||||
assert validator.check_parameter_layout('w1', features_expect_layout)
|
||||
|
||||
# check attrs
|
||||
roi_expect_attrs = {'begin_mask': 1, 'end_mask': 2, 'ellipsis_mask': 4}
|
||||
assert validator.check_node_attrs('StridedSlice-1', roi_expect_attrs)
|
||||
|
||||
# check inputs
|
||||
roi_expect_inputs = ['Load-0', 'out((127, 0, 0))', 'out((128, 64, 32))', 'out((1, 1, 1))']
|
||||
assert validator.check_node_inputs('StridedSlice-1', roi_expect_inputs)
|
||||
|
||||
# check sub_graph
|
||||
sub_graph = {
|
||||
'StridedSlice-1': ['Load-0', 'out((127, 0, 0))', 'out((128, 64, 32))', 'out((1, 1, 1))'],
|
||||
'Mul-0': ['Reshape-1', 'StridedSlice-1'],
|
||||
'AllGather-2': ['Reshape-2'],
|
||||
'Split-1': ['AllGather-2'],
|
||||
'TupleGetItem-3': ['Split-1', 0],
|
||||
'TupleGetItem-4': ['Split-1', 1],
|
||||
'TupleGetItem-5': ['Split-1', 2],
|
||||
'TupleGetItem-6': ['Split-1', 3],
|
||||
'MakeTuple-2': ['TupleGetItem-3', 'TupleGetItem-4', 'TupleGetItem-5', 'TupleGetItem-6'],
|
||||
'Concat-1': ['MakeTuple-2']
|
||||
}
|
||||
assert validator.check_graph_structure(sub_graph)
|
||||
|
|
Loading…
Reference in New Issue