forked from mindspore-Ecosystem/mindspore
!35231 Tensor getitem & setitem support dynamic shape on ascend platform
Merge pull request !35231 from huoxinyou/0526tensorslice
This commit is contained in:
commit
d740fe74c2
|
@ -14,11 +14,12 @@
|
|||
# ============================================================================
|
||||
|
||||
"""constexpr util"""
|
||||
import numpy as np
|
||||
from . import _constexpr_utils as const_utils
|
||||
from ... import functional as F
|
||||
from ... import operations as P
|
||||
from ...composite import base
|
||||
from ...operations._inner_ops import TensorCopySlices, SliceGetItem
|
||||
from ...operations._inner_ops import TensorCopySlices, SliceGetItem, DynamicBroadcastTo
|
||||
from ....common import dtype as mstype
|
||||
from ....common._register_for_tensor import tensor_operator_registry
|
||||
from ....common.tensor import Tensor, CSRTensor
|
||||
|
@ -27,6 +28,7 @@ slice_get_item = SliceGetItem()
|
|||
hyper_map = base.HyperMap()
|
||||
stack = P.Stack(axis=-1)
|
||||
copy_slice = TensorCopySlices()
|
||||
dynamic_broadcast_to = DynamicBroadcastTo()
|
||||
|
||||
|
||||
def _tensor_getitem(self, index):
|
||||
|
@ -295,8 +297,9 @@ def tensor_index_by_slice(data, slice_index):
|
|||
or isinstance(slice_get_item(slice_index, "stop"), Tensor)
|
||||
or isinstance(slice_get_item(slice_index, "step"), Tensor))
|
||||
if is_dynamic:
|
||||
return tensor_index_by_dyn_slice(data, slice_index)
|
||||
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(data_shape, slice_index)
|
||||
begin_strides, end_strides, step_strides = get_stride_info_from_slice(data, slice_index)
|
||||
else:
|
||||
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(data_shape, slice_index)
|
||||
begin_mask = 1 if slice_get_item(slice_index, "start") is None else 0
|
||||
end_mask = 1 if slice_get_item(slice_index, "stop") is None else 0
|
||||
for i in range(1, len(data_shape)):
|
||||
|
@ -307,26 +310,18 @@ def tensor_index_by_slice(data, slice_index):
|
|||
return F.strided_slice(data, begin_strides, end_strides, step_strides)
|
||||
|
||||
|
||||
def tensor_index_by_dyn_slice(data, slice_index):
|
||||
"""Tensor getitem by a slice."""
|
||||
min_data_dim, max_data_dim = 1, 8
|
||||
data_dims = data.ndim
|
||||
const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
|
||||
def get_stride_info_from_slice(data, slice_index):
|
||||
"""get the stride info from slice index"""
|
||||
data_shape = F.dyn_shape(data)
|
||||
begin_strides, end_strides, step_strides = [], [], []
|
||||
start, stop, step = get_slice_stride(slice_index, data_shape[0])
|
||||
begin_strides.append(start)
|
||||
end_strides.append(stop)
|
||||
step_strides.append(step)
|
||||
|
||||
for index in range(1, data_dims):
|
||||
begin_strides.append(const_utils.scalar_to_tensor(0))
|
||||
end_strides.append(data_shape[index])
|
||||
step_strides.append(const_utils.scalar_to_tensor(1))
|
||||
begin_tensor = stack(begin_strides)
|
||||
end_tensor = stack(end_strides)
|
||||
step_tensor = stack(step_strides)
|
||||
return F.strided_slice(data, begin_tensor, end_tensor, step_tensor)
|
||||
return begin_tensor, end_tensor, step_tensor
|
||||
|
||||
|
||||
def tensor_index_by_number(data, number_index):
|
||||
|
@ -349,14 +344,38 @@ def _tensor_index_by_bool(data, bool_value):
|
|||
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
|
||||
|
||||
|
||||
def check_range(x, dim_size):
|
||||
tensor_x = const_utils.scalar_to_tensor(x)
|
||||
if tensor_x >= dim_size or tensor_x < -dim_size:
|
||||
return tensor_x
|
||||
tensor_x = tensor_x % dim_size
|
||||
return tensor_x
|
||||
|
||||
|
||||
def get_stride_info_from_integer(tensor_int):
|
||||
begin_strides = [tensor_int]
|
||||
end_strides = [tensor_int + const_utils.scalar_to_tensor(1)]
|
||||
step_strides = [const_utils.scalar_to_tensor(1)]
|
||||
begin_tensor = stack(begin_strides)
|
||||
end_tensor = stack(end_strides)
|
||||
step_tensor = stack(step_strides)
|
||||
return begin_tensor, end_tensor, step_tensor
|
||||
|
||||
|
||||
def _tensor_index_by_integer(data, int_index):
|
||||
"""Tensor getitem by a single integer number"""
|
||||
if data.ndim < 1 or data.ndim > 8:
|
||||
const_utils.raise_value_error("Expect Tensor to have dimension between 1 and 8.")
|
||||
|
||||
data_shape = F.shape(data)
|
||||
transformed_number = const_utils.check_range(int_index, data_shape[0])
|
||||
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(data_shape, transformed_number)
|
||||
if -1 in F.shape(data):
|
||||
data_shape = F.dyn_shape(data)
|
||||
transformed_tensor = check_range(int_index, data_shape[0])
|
||||
begin_strides, end_strides, step_strides = get_stride_info_from_integer(transformed_tensor)
|
||||
else:
|
||||
data_shape = F.shape(data)
|
||||
transformed_number = const_utils.check_range(int_index, data_shape[0])
|
||||
begin_strides, end_strides, step_strides = \
|
||||
const_utils.get_stride_info_from_integer(data_shape, transformed_number)
|
||||
shrink_axis_mask = 1
|
||||
begin_mask = 0
|
||||
end_mask = 0
|
||||
|
@ -386,9 +405,9 @@ def tensor_index_by_list(data, list_index):
|
|||
data_shape = F.shape(data)
|
||||
indexes_types = hyper_map(F.typeof, list_index)
|
||||
if const_utils.judge_indexes_types(indexes_types, mstype.int_type + (mstype.bool_,)):
|
||||
if -1 in data_shape:
|
||||
if data_shape[0] == -1 and all(isinstance(i, bool) for i in list_index):
|
||||
const_utils.raise_unimplemented_error(
|
||||
"Not supported to take the subscript of dynamic shape tensor using integer or Boolean type")
|
||||
"Not supported to take the subscript of dynamic shape tensor using Boolean type")
|
||||
tensor_index = const_utils.sequence_to_index(list_index, data_shape[0])
|
||||
if tensor_index is False:
|
||||
const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.")
|
||||
|
@ -483,23 +502,8 @@ def cal_tuple_slice_mask(data_shape, tuple_index):
|
|||
return begin_mask, end_mask
|
||||
|
||||
|
||||
def _tensor_getitem_by_tuple_slice(data, tuple_index):
|
||||
"""Tensor getitem by a tuple of slice"""
|
||||
data_shape = F.shape(data)
|
||||
is_dynamic = -1 in data_shape
|
||||
for item in tuple_index:
|
||||
if isinstance(item, slice):
|
||||
is_dynamic = is_dynamic or isinstance(slice_get_item(item, "start"), Tensor) \
|
||||
or isinstance(slice_get_item(item, "stop"), Tensor) \
|
||||
or isinstance(slice_get_item(item, "step"), Tensor)
|
||||
|
||||
if not is_dynamic:
|
||||
begin_strides, end_strides, step_strides, shrink_axis_mask = const_utils.get_stride_info_from_tuple(
|
||||
data_shape, tuple_index)
|
||||
begin_mask, end_mask = cal_tuple_slice_mask(data_shape, tuple_index)
|
||||
strided_slice_op = P.StridedSlice(begin_mask, end_mask, 0, 0, shrink_axis_mask)
|
||||
return strided_slice_op(data, begin_strides, end_strides, step_strides)
|
||||
|
||||
def _get_stride_info_from_tuple(data, tuple_index):
|
||||
"""get the stride info from tuple"""
|
||||
data_shape = F.dyn_shape(data)
|
||||
begin_strides, end_strides, step_strides = [], [], []
|
||||
tuple_index_len = len(tuple_index)
|
||||
|
@ -517,10 +521,11 @@ def _tensor_getitem_by_tuple_slice(data, tuple_index):
|
|||
step_strides.append(step)
|
||||
index_count = index_count + 1
|
||||
elif isinstance(index, int):
|
||||
begin_strides.append(const_utils.scalar_to_tensor(index))
|
||||
end_strides.append(const_utils.scalar_to_tensor(index + 1))
|
||||
int_tensor = check_range(index, dim_size)
|
||||
begin_strides.append(int_tensor)
|
||||
end_strides.append(int_tensor + const_utils.scalar_to_tensor(1))
|
||||
step_strides.append(const_utils.scalar_to_tensor(1))
|
||||
shrink_axis = shrink_axis + (1 << index_count)
|
||||
shrink_axis = shrink_axis + (2 ** index_count)
|
||||
index_count = index_count + 1
|
||||
elif index is ...:
|
||||
ellipsis_count = ellipsis_count + 1
|
||||
|
@ -536,28 +541,56 @@ def _tensor_getitem_by_tuple_slice(data, tuple_index):
|
|||
exp_msg = const_utils.gen_exception_msg("Not supported index data type, got {}, type is {}", index,
|
||||
type(index))
|
||||
const_utils.raise_index_error(exp_msg)
|
||||
for index in range(index_count, data_dim):
|
||||
begin_strides.append(const_utils.scalar_to_tensor(0))
|
||||
end_strides.append(data_shape[index])
|
||||
step_strides.append(const_utils.scalar_to_tensor(1))
|
||||
begin_tensor = stack(begin_strides)
|
||||
end_tensor = stack(end_strides)
|
||||
step_tensor = stack(step_strides)
|
||||
return P.StridedSlice(0, 0, 0, 0, shrink_axis)(data, begin_tensor, end_tensor, step_tensor)
|
||||
strides_v = {
|
||||
'begin': begin_tensor,
|
||||
'end': end_tensor,
|
||||
'step': step_tensor
|
||||
}
|
||||
return strides_v, shrink_axis
|
||||
|
||||
|
||||
def _tensor_getitem_by_tuple_slice(data, tuple_index):
|
||||
"""Tensor getitem by a tuple of slice"""
|
||||
data_shape = F.shape(data)
|
||||
is_dynamic = -1 in data_shape
|
||||
for item in tuple_index:
|
||||
if isinstance(item, slice):
|
||||
is_dynamic = is_dynamic or isinstance(slice_get_item(item, "start"), Tensor) \
|
||||
or isinstance(slice_get_item(item, "stop"), Tensor) \
|
||||
or isinstance(slice_get_item(item, "step"), Tensor)
|
||||
|
||||
strides_v = {}
|
||||
shrink_axis_mask = 0
|
||||
if not is_dynamic:
|
||||
strides_v, shrink_axis_mask = const_utils.get_stride_info_from_tuple(
|
||||
data_shape, tuple_index)
|
||||
else:
|
||||
strides_v, shrink_axis_mask = _get_stride_info_from_tuple(
|
||||
data, tuple_index)
|
||||
begin_mask, end_mask = cal_tuple_slice_mask(data_shape, tuple_index)
|
||||
begin_v = strides_v['begin']
|
||||
end_v = strides_v['end']
|
||||
step_v = strides_v['step']
|
||||
return P.StridedSlice(begin_mask, end_mask, 0, 0, shrink_axis_mask)(data, begin_v, end_v, step_v)
|
||||
|
||||
|
||||
def _tensor_getitem_by_tuple(data, tuple_index, op_name):
|
||||
"""Tensor getitem by a tuple of mixed tensor."""
|
||||
data_shape = F.shape(data)
|
||||
data_rank = len(data_shape)
|
||||
dyn_shape = F.dyn_shape(data)
|
||||
is_dynamic = -1 in data_shape
|
||||
data_rank = len(data_shape)
|
||||
slice_is_tensor = False
|
||||
for item in tuple_index:
|
||||
if isinstance(item, slice):
|
||||
is_dynamic = isinstance(slice_get_item(item, "start"), Tensor) \
|
||||
or isinstance(slice_get_item(item, "stop"), Tensor) \
|
||||
or isinstance(slice_get_item(item, "step"), Tensor)
|
||||
if is_dynamic:
|
||||
const_utils.raise_index_error("Not supported to get a dynamic shape tensor's or using a dynamic slice")
|
||||
slice_is_tensor = isinstance(slice_get_item(item, "start"), Tensor) \
|
||||
or isinstance(slice_get_item(item, "stop"), Tensor) \
|
||||
or isinstance(slice_get_item(item, "step"), Tensor)
|
||||
if slice_is_tensor:
|
||||
const_utils.raise_index_error("Not supported when slice has tensor")
|
||||
tuple_index_len = len(tuple_index)
|
||||
tensor_indexes, slice_indexes = [], []
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
|
@ -568,6 +601,9 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
|
|||
if i in int_positions:
|
||||
int_index = const_utils.check_range(index, dim_size)
|
||||
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
|
||||
if is_dynamic:
|
||||
tensor_index = check_range(index, dyn_shape[i])
|
||||
tensor_index = F.cast(tensor_index, mstype.int64)
|
||||
tuple_index_new += (tensor_index,)
|
||||
tensor_indexes.append(tensor_index)
|
||||
tensor_positions += (i,)
|
||||
|
@ -710,7 +746,7 @@ def sequence_to_tensor(value, dtype):
|
|||
else:
|
||||
new_value = ()
|
||||
for ele in value:
|
||||
ele = ele if isinstance(ele, Tensor) else const_utils.make_tensor(ele)
|
||||
ele = ele if isinstance(ele, Tensor) else const_utils.make_tensor(ele, dtype)
|
||||
new_value += (ele,)
|
||||
value = F.stack(new_value).astype(dtype)
|
||||
return value
|
||||
|
@ -727,7 +763,13 @@ def _generate_updates_from_sequence(data, index, value, op_type):
|
|||
def _generate_updates_from_tensor(data, index, value, op_type):
|
||||
"""Generate an updates tensor from a tensor."""
|
||||
value = value.astype(data.dtype)
|
||||
updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type)
|
||||
if -1 in F.shape(data):
|
||||
data_shape = F.dyn_shape(data)
|
||||
index_shape = F.dyn_shape(index)
|
||||
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type, True)
|
||||
updates = dynamic_broadcast_to(value, updates_shape)
|
||||
return updates
|
||||
updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type, False)
|
||||
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value.shape)
|
||||
if need_broadcast:
|
||||
return _broadcast(updates_shape, value)
|
||||
|
@ -811,6 +853,9 @@ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
|||
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
|
||||
if tensor_dtype == const_utils.INT_:
|
||||
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
|
||||
if -1 in F.shape(data):
|
||||
const_utils.raise_unimplemented_error(
|
||||
"Not supported to take the subscript of dynamic shape tensor using Boolean type")
|
||||
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
|
||||
|
||||
|
||||
|
@ -862,6 +907,9 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
|
|||
value = _broadcast(value_shape, value)
|
||||
return copy_slice(data, value.astype(data.dtype), (start,), (stop,), (step,))
|
||||
data_shape = F.shape(data)
|
||||
if -1 in data_shape:
|
||||
const_utils.raise_unimplemented_error(
|
||||
"Not supported to take the subscript of dynamic shape tensor slice setitem")
|
||||
indices = const_utils.slice2indices(input_slice, data_shape)
|
||||
if indices is False:
|
||||
return data
|
||||
|
@ -934,6 +982,9 @@ def tensor_setitem_by_number_with_sequence(data, index, value):
|
|||
|
||||
def tensor_setitem_by_number_with_tensor(data, index, value):
|
||||
"""Assigns the tensor by number with tensor value."""
|
||||
if -1 in F.shape(data):
|
||||
index = Tensor(np.array([index]), mstype.int32)
|
||||
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value)
|
||||
data_shape = F.shape(data)
|
||||
index = const_utils.int_to_index(index, data_shape)
|
||||
value_shape = const_utils.tuple_slice(F.shape(index), None, -1)
|
||||
|
@ -945,6 +996,9 @@ def tensor_setitem_by_ellipsis_with_number(data, value):
|
|||
"""Assigns the tensor by ellipsis with number value."""
|
||||
data_shape = F.shape(data)
|
||||
data_dtype = F.dtype(data)
|
||||
if -1 in data_shape:
|
||||
value = F.fill(F.dtype(data), (), value)
|
||||
return tensor_setitem_by_ellipsis_with_tensor(data, value)
|
||||
return F.fill(data_dtype, data_shape, value)
|
||||
|
||||
|
||||
|
@ -953,6 +1007,10 @@ def tensor_setitem_by_ellipsis_with_tensor(data, value):
|
|||
data_shape = F.shape(data)
|
||||
data_dtype = F.dtype(data)
|
||||
value = value.astype(data_dtype)
|
||||
if -1 in data_shape:
|
||||
data_shape = F.dyn_shape(data)
|
||||
data = dynamic_broadcast_to(value, data_shape)
|
||||
return data
|
||||
value_shape = F.shape(value)
|
||||
source_shape = const_utils.get_source_shape(data_shape, value_shape)
|
||||
value = F.reshape(value, source_shape)
|
||||
|
@ -978,6 +1036,10 @@ def tensor_setitem_by_bool(data, index, value):
|
|||
value = const_utils.make_tensor(value, mstype.int32)
|
||||
elif isinstance(value, float):
|
||||
value = const_utils.make_tensor(value, mstype.float32)
|
||||
if -1 in data_shape and index:
|
||||
data_shape = F.dyn_shape(data)
|
||||
data = dynamic_broadcast_to(value, data_shape)
|
||||
return data
|
||||
value_shape = F.shape(value)
|
||||
source_shape = const_utils.get_source_shape(data_shape, value_shape)
|
||||
if index:
|
||||
|
|
|
@ -28,6 +28,7 @@ from ....common.tensor import Tensor
|
|||
from ....common._register_for_tensor import tensor_operator_registry
|
||||
from ....ops import _utils as op_utils
|
||||
from ...._checkparam import Validator as validator
|
||||
from ... import operations as P
|
||||
|
||||
ALL_TENSOR = 0
|
||||
NO_TENSOR = 1
|
||||
|
@ -129,6 +130,8 @@ def _deep_tensor_to_nparray(array_like):
|
|||
|
||||
@constexpr
|
||||
def check_range(x, dim_size):
|
||||
if dim_size == -1:
|
||||
return x
|
||||
if isinstance(x, int) and not isinstance(x, bool):
|
||||
if x >= dim_size or x < -dim_size:
|
||||
raise IndexError(f'index {x} is out of bounds for dimension with size {dim_size}')
|
||||
|
@ -479,13 +482,18 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty
|
|||
return Tensor(np.full(updates_shape, value), dtype=data_dtype)
|
||||
|
||||
|
||||
@constexpr
|
||||
def generate_updates_shape(data_shape, index_shape, op_type):
|
||||
def generate_updates_shape(data_shape, index_shape, op_type, is_dynamic):
|
||||
"""Generate updates shape for 'tensor setitem'."""
|
||||
if op_type == SET_ITEM_BY_ONE_TENSOR:
|
||||
updates_shape = index_shape + data_shape[1:]
|
||||
if is_dynamic:
|
||||
updates_shape = P.Concat(-1)((index_shape, data_shape[1:]))
|
||||
else:
|
||||
updates_shape = index_shape + data_shape[1:]
|
||||
else:
|
||||
updates_shape = index_shape[:-1] + data_shape[index_shape[-1]:]
|
||||
if is_dynamic:
|
||||
updates_shape = P.Concat(-1)((index_shape[:-1], data_shape[index_shape[-1]:]))
|
||||
else:
|
||||
updates_shape = index_shape[:-1] + data_shape[index_shape[-1]:]
|
||||
return updates_shape
|
||||
|
||||
|
||||
|
@ -641,7 +649,12 @@ def get_stride_info_from_tuple(data_shape, tuple_index):
|
|||
begin_strides.append(0)
|
||||
end_strides.append(data_shape[index])
|
||||
step_strides.append(1)
|
||||
return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis
|
||||
strides_v = {
|
||||
'begin': tuple(begin_strides),
|
||||
'end': tuple(end_strides),
|
||||
'step': tuple(step_strides)
|
||||
}
|
||||
return strides_v, shrink_axis
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -665,6 +678,8 @@ def normalize_start(start, dim_size):
|
|||
"""
|
||||
if start is None:
|
||||
return 0
|
||||
if dim_size == -1:
|
||||
return start
|
||||
if start < 0:
|
||||
return 0 if start < -dim_size else start % dim_size
|
||||
return start if start < dim_size else dim_size
|
||||
|
@ -676,8 +691,12 @@ def normalize_stop(stop, dim_size):
|
|||
Normalize `stop` according to the number of dimensions (`dim_size`).
|
||||
If the number of dimensions is not given, return the original input directly.
|
||||
"""
|
||||
if stop is None and dim_size == -1:
|
||||
raise IndexError("Not Support stop is None when dim is dynamic")
|
||||
if stop is None:
|
||||
return dim_size
|
||||
if dim_size == -1:
|
||||
return stop
|
||||
if stop < 0:
|
||||
return 0 if stop < -dim_size else stop % dim_size
|
||||
return stop if stop < dim_size else dim_size
|
||||
|
@ -753,6 +772,8 @@ def sequence_to_index(sequence, dim_size):
|
|||
if not sequence:
|
||||
return False
|
||||
if all(isinstance(i, bool) for i in sequence):
|
||||
if dim_size == -1:
|
||||
raise IndexError("Not supported to take the subscript of dynamic shape tensor using Boolean type")
|
||||
seq_size = len(sequence)
|
||||
if seq_size != dim_size:
|
||||
raise IndexError(f'dimension is {dim_size} but corresponding boolean dimension is {seq_size}')
|
||||
|
|
|
@ -0,0 +1,312 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class NumpyGetItem():
|
||||
def __init__(self, index1, index2):
|
||||
super(NumpyGetItem, self).__init__()
|
||||
self.index1 = index1
|
||||
self.index2 = index2
|
||||
|
||||
def __call__(self, tensor1, tensor2):
|
||||
return tensor1[self.index1], tensor2[self.index2]
|
||||
|
||||
|
||||
class TensorGetItem(nn.Cell):
|
||||
def __init__(self, index1, index2):
|
||||
super(TensorGetItem, self).__init__()
|
||||
self.index1 = index1
|
||||
self.index2 = index2
|
||||
|
||||
def construct(self, tensor1, tensor2):
|
||||
return tensor1[self.index1], tensor2[self.index2]
|
||||
|
||||
|
||||
def common_func(ms_net, np_net):
|
||||
x = Tensor(shape=[8, None, 32], dtype=mindspore.float32)
|
||||
y = Tensor(shape=[None, 32, 32], dtype=mindspore.float32)
|
||||
ms_net.set_inputs(x, y)
|
||||
input_np1 = np.arange(8 * 16 * 32).reshape(8, 16, 32).astype(np.float32)
|
||||
input_np2 = np.arange(16 * 32 * 32).reshape(16, 32, 32).astype(np.float32)
|
||||
out0, out1 = ms_net(Tensor(input_np1), Tensor(input_np2))
|
||||
out_np0, out_np1 = np_net(input_np1, input_np2)
|
||||
assert np.all(out0.asnumpy() == out_np0)
|
||||
assert np.all(out1.asnumpy() == out_np1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_int_negative():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is negative int.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = -2
|
||||
index2 = -1
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_int():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is int.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = 2
|
||||
index2 = 1
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_tuple_basic():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is basic tuple.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = (1, slice(0, 1, 1), ...)
|
||||
index2 = (slice(2, None, None), 1, slice(3, 4, None))
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_tuple_basic_neg():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is basic tuple(int is negative).
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = (slice(0, 1, 1), ..., -1)
|
||||
index2 = (-2, slice(2, None, None), slice(3, 4, None))
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_tuple():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is tuple.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
tensor_index = Tensor(np.array([[1, 2, 1], [0, 3, 2]]), mindspore.int32)
|
||||
index1 = (slice(2, None, None), (0, 2, 1), tensor_index)
|
||||
index2 = (-1, slice(0, 1, None), tensor_index)
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
index3 = (slice(2, None, None), (0, 2, 1), tensor_index.asnumpy())
|
||||
index4 = (-1, slice(0, 1, None), tensor_index.asnumpy())
|
||||
np_net = NumpyGetItem(index3, index4)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_bool():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is bool.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index = True
|
||||
ms_net = TensorGetItem(index, index)
|
||||
np_net = NumpyGetItem(index, index)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_none():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is none.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index = None
|
||||
ms_net = TensorGetItem(index, index)
|
||||
np_net = NumpyGetItem(index, index)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_ellipsis():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is ellipsis.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index = ...
|
||||
ms_net = TensorGetItem(index, index)
|
||||
np_net = NumpyGetItem(index, index)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_slice():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is slice.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = slice(1, 5, 1)
|
||||
index2 = slice(1, None, None)
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_slice_neg():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is negative slice.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = slice(-3, -1, 1)
|
||||
index2 = slice(-1, None, None)
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_tensor():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is tensor.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = Tensor(np.array([[1, 2], [0, 3]]), mindspore.int32)
|
||||
index2 = Tensor(np.array([[1, 2]]), mindspore.int32)
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1.asnumpy(), index2.asnumpy())
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_list():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is list.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = [True, 2, True]
|
||||
index2 = [1, 2, 0]
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_getitem_slice_startoversize():
|
||||
"""
|
||||
Feature: Test Tensor slice for dynamic shape in feed mode.
|
||||
Description: The input shape is dynamic and the tensor index is slice and start is over size.
|
||||
Expectation: Assert the result is equal the numpy result.
|
||||
"""
|
||||
index1 = slice(8, None, 1)
|
||||
index2 = slice(30, None, None)
|
||||
ms_net = TensorGetItem(index1, index2)
|
||||
np_net = NumpyGetItem(index1, index2)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
common_func(ms_net, np_net)
|
Loading…
Reference in New Issue