Support mixed tensor for tensor getitem or setitem.

This commit is contained in:
buxue 2020-05-24 17:38:46 +08:00
parent b94949ea99
commit 7ae289a197
14 changed files with 1020 additions and 231 deletions

View File

@ -105,7 +105,7 @@ convert_object_map = {
T.ge: multitype_ops.greater_equal,
T.is_: F.is_,
T.is_not: F.is_not,
T.contains: F.in_dict,
T.contains: multitype_ops.in_,
T.not_contains: F.not_in_dict,
# system function

View File

@ -474,6 +474,8 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init());
(void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
(void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init());
(void)py::class_<Ellipsis, Type, std::shared_ptr<Ellipsis>>(m_sub, "Ellipsis").def(py::init());
}));
const TypePtr kTypeExternal = std::make_shared<External>();

View File

@ -95,6 +95,8 @@ string = typing.String()
type_refkey = typing.RefKeyType()
tensor_type = typing.TensorType
anything_type = typing.TypeAnything
slice_type = typing.Slice
ellipsis_type = typing.Ellipsis
number_type = (int8,
int16,

View File

@ -37,6 +37,7 @@ from .logical_and_impl import logical_and
from .logical_or_impl import logical_or
from .logic_not_impl import logical_not
from .uadd_impl import uadd
from .in_impl import in_
__all__ = [
'add',
'sub',
@ -59,5 +60,6 @@ __all__ = [
'setitem',
'logical_and',
'logical_or',
'logical_not'
'logical_not',
'in_'
]

View File

@ -0,0 +1,154 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""constexpr util"""
from . import _constexpr_utils as const_utils
from ... import functional as F
from ... import operations as P
from ...composite import base
from ....common import dtype as mstype
hyper_map = base.HyperMap()
pack = P.Pack(axis=-1)
def broadcast(broadcast_shape, x):
"""Broadcast tensor to the required shape."""
if F.shape(x) == broadcast_shape:
return x
multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape)
if multiples:
return F.tile(x, multiples)
return x
def transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x):
"""Transform indexing tensor to the required."""
x = broadcast(broadcast_shape, x)
return broadcast(final_shape, F.reshape(x, new_shape))
def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor."""
indices = None
check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
if check_index_tensor_number:
dtype_tuple = hyper_map(F.dtype, tuple_index)
check_dtypes = const_utils.check_index_tensors_dtype(dtype_tuple, op_name)
if check_dtypes:
shape_tuple = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name)
broadcast_tensors = hyper_map(F.partial(broadcast, broadcast_shape), tuple_index)
indices = pack(broadcast_tensors)
return indices
def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
indexes_types = hyper_map(F.typeof, tuple_index)
int_positions = const_utils.get_pos_of_int_index(indexes_types)
for i in int_positions:
tuple_index = F.tuple_setitem(tuple_index, i, F.scalar_to_tensor(tuple_index[i], mstype.int32))
indexes_types = hyper_map(F.typeof, tuple_index)
tensor_positions, slice_positions, ellipsis_position = \
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
tensor_indexes = []
slice_indexes = []
for i in tensor_positions:
tensor_indexes.append(tuple_index[i])
for j in slice_positions:
slice_indexes.append(tuple_index[j])
data_shape = F.shape(data)
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \
const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape,
indexes_types,
tensor_indexes_shapes,
tensor_indexes_dtypes,
slice_indexes,
op_name)
slice_number = 0
final_index_tensors = []
tuple_index_size = len(tuple_index)
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_size):
if i in tensor_positions:
transform_tensor = transform_indexing_tensor(broadcast_shape,
final_shape,
index_tensor_new_shape,
tuple_index[i])
final_index_tensors.append(transform_tensor)
if i in slice_positions:
slice_tensor = const_utils.convert_slice_to_tensor(slice_number,
final_shape,
indexes_shapes_info,
op_name)
final_index_tensors.append(slice_tensor)
slice_number += 1
if i == ellipsis_position:
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(slice_number,
ellipsis_occupied_dims,
final_shape,
indexes_shapes_info,
op_name)
for ele in ellipsis_tensors:
final_index_tensors.append(ele)
slice_number += ellipsis_occupied_dims
indices = pack(final_index_tensors)
return indices
def generate_updates_from_scalar(data, indices, value, op_type):
"""Generate an updates tensor from a scalar."""
data_shape = F.shape(data)
indices_shape = F.shape(indices)
data_dtype = F.dtype(data)
return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type)
def generate_updates_from_tuple(data, index, value, op_type):
"""Generate an updates tensor from a tuple."""
value_types = hyper_map(F.typeof, value)
data_dtype = F.dtype(data)
value_elements_type = const_utils.check_value_elements(data_dtype, value_types)
if value_elements_type == const_utils.ALL_TENSOR:
value_shapes = hyper_map(F.shape, value)
shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM)
if shapes_same:
value = F.pack(value)
return generate_updates_from_tensor(data, index, value, op_type)
data_shape = F.shape(data)
index_shape = F.shape(index)
return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type)
def generate_updates_from_tensor(data, index, value, op_type):
"""Generate an updates tensor from a tensor."""
data_shape = F.shape(data)
index_shape = F.shape(index)
value_shape = F.shape(value)
data_dtype = F.dtype(data)
value_dtype = F.dtype(value)
updates_shape = value_shape
check_dtype_same = const_utils.check_tensors_dtype_same(data_dtype, value_dtype, const_utils.TENSOR_SETITEM)
if check_dtype_same:
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type)
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape)
if need_broadcast:
return broadcast(updates_shape, value)
return value

View File

@ -15,19 +15,15 @@
"""constexpr util"""
from functools import reduce
import numpy as np
from ...primitive import constexpr
from ....common.tensor import Tensor
from ....common import dtype as mstype
from ...._extends.utils import Slice, Ellipsis_
from ....ops import _utils as op_utils
from ...composite import base
from .... import log as logger
from ... import functional as F
from ... import operations as P
hyper_map = base.HyperMap()
pack = P.Pack(axis=-1)
import numpy as np
from ...primitive import constexpr
from .... import log as logger
from ...._extends.utils import Slice, Ellipsis_
from ....common import dtype as mstype
from ....common.tensor import Tensor
from ....ops import _utils as op_utils
ALL_TENSOR = 0
NO_TENSOR = 1
@ -264,7 +260,7 @@ def tuple_index_elements_type(types, op_name):
return ALL_TENSOR
if tensors_number == 0:
return NO_TENSOR
raise IndexError(f"For '{op_name}', the index does not support mixed tensor.")
return CONTAIN_TENSOR
@constexpr
@ -279,12 +275,12 @@ def check_value_elements(data_dtype, types):
tensors_number += 1
else:
raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' "
f"in value tuple is not consistent with origin tensor data type '{data_dtype}'.")
f"in value tuple is not consistent with assigned tensor data type '{data_dtype}'.")
elif mstype.issubclass_(ele, data_dtype):
scalars_number += 1
else:
raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in "
f"value tuple is not consistent with origin tensor data type '{data_dtype}'.")
f"value tuple is not consistent with assigned tensor data type '{data_dtype}'.")
if tensors_number == len(types):
return ALL_TENSOR
if scalars_number == len(types):
@ -299,51 +295,46 @@ def get_index_tensor_dtype(dtype):
return INT_
if dtype == mstype.bool_:
return BOOL_
raise TypeError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
raise IndexError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
@constexpr
def check_index_tensors_dtype(dtypes, op_name):
"""Check a tuple of tensor data type."""
if op_name == TENSOR_GETITEM:
valid_dtypes = (mstype.int32, mstype.int64)
elif op_name == TENSOR_SETITEM:
valid_dtypes = (mstype.int32,)
else:
raise ValueError("Unsupported operation.")
for ele in dtypes:
if ele in valid_dtypes and ele == dtypes[0]:
continue
raise TypeError(f"For '{op_name}', the index tensors data type must be same, "
f"and should be one of the following: {valid_dtypes}, but got {dtypes}.")
if not ele == mstype.int32:
raise IndexError(f"For '{op_name}', the all index tensor "
f"data types should be mstype.int32, but got {dtypes}.")
return True
@constexpr
def check_tensor_dtype_valid(dtype, valid_dtypes):
def check_index_tensor_dtype(dtype, op_name):
"""Check a tensor data type."""
if dtype in valid_dtypes:
if dtype == mstype.int32:
return True
raise TypeError(f"The index tensor data type must be one of "
f"the following: {valid_dtypes}, but got {dtype}.")
raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.")
@constexpr
def check_tensors_dtype_same(x_dtype, y_dtype, op_name):
def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
"""Check tensors data type same."""
if x_dtype == y_dtype:
if value_dtype == data_dtype:
return True
raise TypeError(f"For '{op_name}', the value data type '{y_dtype}' "
f"is not consistent with origin tensor data type {x_dtype}.")
raise TypeError(f"For '{op_name}', the value data type '{value_dtype}' "
f"is not consistent with assigned tensor data type {data_dtype}.")
@constexpr
def broadcast_shapes(shapes, op_name):
"""Broadcasts a tuple of tensor."""
def generate_broadcast_shape(shapes, op_name):
"""Generate broadcast shape for a tuple of shape."""
broadcast_shape = shapes[0]
for i, shape in enumerate(shapes):
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
try:
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
except ValueError as ex:
raise IndexError(ex)
return tuple(broadcast_shape)
@ -366,14 +357,82 @@ def check_two_shapes_need_broadcast(shape_x, shape_y):
@constexpr
def compute_multiples(origin_shape, broadcast_shape):
"""Compute multiples between broadcast_shape with origin_shape."""
"""Compute multiples between origin shape with broadcast shape."""
len_gap = len(broadcast_shape) - len(origin_shape)
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape))
def tile(broadcast_shape, x):
multiples = compute_multiples(F.shape(x), broadcast_shape)
return F.tile(x, multiples)
@constexpr
def compute_new_shape(origin_shape, indexes_shapes_info):
"""Compute new shape between origin shape with final shape."""
new_shape = []
for i in indexes_shapes_info:
if i == origin_shape:
new_shape.extend(origin_shape)
else:
new_shape.append(1)
return tuple(new_shape)
@constexpr
def convert_ellipsis_to_tensors(slice_number,
ellipsis_occupied_dims,
final_shape,
indexes_shapes_info,
op_name):
"""Convert an ellipsis to a list of tensor."""
tensor_list = []
dims_dealt_count = 0
while dims_dealt_count < ellipsis_occupied_dims:
shape = []
slice_count = 0
array = None
for ele in indexes_shapes_info:
if isinstance(ele, list):
if slice_count == slice_number:
array = np.array(ele, np.int32)
shape.append(len(ele))
else:
shape.append(1)
slice_count += 1
if isinstance(ele, tuple):
shape.extend([1] * len(ele))
if array is None:
raise ValueError(f"For '{op_name}', generate tensors from ellipsis failed.")
array = np.reshape(array, shape)
reps = compute_multiples(shape, final_shape)
tensor = Tensor(np.tile(array, reps))
tensor_list.append(tensor)
slice_number += 1
dims_dealt_count += 1
return tensor_list
@constexpr
def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name):
"""Convert a slice to a tensor."""
shape = []
count = 0
array = None
for ele in indexes_shapes_info:
if isinstance(ele, list):
if count == slice_number:
array = np.array(ele, np.int32)
shape.append(len(ele))
else:
# When the slice is not the slice looking for, the shape is filled with 1.
shape.append(1)
count += 1
elif isinstance(ele, tuple):
shape.extend([1] * len(ele))
else:
shape.append(1)
if array is None:
raise ValueError(f"For '{op_name}', generate tensor from 'slice' failed.")
array = np.reshape(array, shape)
reps = compute_multiples(shape, final_shape)
tensor = Tensor(np.tile(array, reps))
return tensor
@constexpr
@ -381,8 +440,8 @@ def check_shapes_same(value_shapes, op_name):
"""Check if the shapes in the tuple are consistent."""
for i, shape in enumerate(value_shapes):
if shape != value_shapes[0]:
raise ValueError(f"For '{op_name}', the {i}th tensor shape in value tuple "
f"is not same as the first tensor shape.")
raise ValueError(f"For '{op_name}', the {i}th tensor shape in "
f"value tuple is not same as the first tensor shape.")
return True
@ -396,7 +455,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty
if isinstance(value, mstype.dtype_to_pytype(data_dtype)):
return Tensor(np.full(updates_shape, value), dtype=data_dtype)
raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'"
f" is not consistent with tensor data type {data_dtype}.")
f" is not consistent with the assigned tensor data type {data_dtype}.")
@constexpr
@ -404,8 +463,8 @@ def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value
"""Convert a tuple of scalar to a tensor."""
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
if len(value) != updates_shape[-1]:
raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} in the updates tuple "
f"does not meet the requirements: {updates_shape[-1]}.")
raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} "
f"in the updates tuple does not meet the requirements: {updates_shape[-1]}.")
array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype))
reps = compute_multiples(updates_shape[-1:], updates_shape)
return Tensor(np.tile(array, reps))
@ -430,58 +489,145 @@ def check_number_of_index_tensor(data_shape, tuple_len, op_name):
f"is greater than the dimension {len(data_shape)} of the operated tensor.")
def generate_indeices_from_tuple_of_tensor(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor."""
indices = None
check_index_tensor_number = check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
if check_index_tensor_number:
dtype_tuple = hyper_map(F.dtype, tuple_index)
check_dtypes = check_index_tensors_dtype(dtype_tuple, op_name)
if check_dtypes:
shape_tuple = hyper_map(F.shape, tuple_index)
broadcast_shape = broadcast_shapes(shape_tuple, op_name)
broadcast_tensors = hyper_map(F.partial(tile, broadcast_shape), tuple_index)
indices = pack(broadcast_tensors)
return indices
@constexpr
def generate_index_info_from_tuple_of_mixed_tensors(data_shape,
indexes_types,
tensor_indexes_shapes,
tensor_indexes_dtypes,
slice_indexes,
op_name):
"""
Generate index info which contain broadcast shape, final shape,
indexes shapes info, ellipsis size from a tuple of mixed tensors.
"""
check_index_tensors_dtype(tensor_indexes_dtypes, op_name)
data_rank = len(data_shape)
indexes_size = len(indexes_types)
if indexes_size > data_rank:
raise IndexError(f"For '{op_name}', the number {indexes_size} of index elements "
f"is greater than the dimension {len(data_shape)} of the operated tensor.")
indexes_info = {}
index_tensors_info = {}
ellipsis_num = 0
ellipsis_occupied_dims = 0
tensor_count = 0
slice_count = 0
for i, ele_type in enumerate(indexes_types):
if ellipsis_num == 0:
pos = i
else:
pos = i + ellipsis_occupied_dims - 1
if isinstance(ele_type, mstype.tensor_type):
indexes_info[pos] = tensor_indexes_shapes[tensor_count]
index_tensors_info[pos] = tensor_indexes_shapes[tensor_count]
tensor_count += 1
elif isinstance(ele_type, mstype.slice_type):
slice_obj = slice(slice_indexes[slice_count].start,
slice_indexes[slice_count].end,
slice_indexes[slice_count].step)
# Use list to represent slicing result.
indexes_info[pos] = list(range(data_shape[pos]))[slice_obj]
slice_count += 1
elif isinstance(ele_type, mstype.ellipsis_type):
if ellipsis_num != 0:
raise IndexError(f"For '{op_name}', the index could only contain one ellipsis.")
ellipsis_occupied_dims = data_rank - indexes_size + 1
for j in range(pos, pos + ellipsis_occupied_dims):
# Use list to represent slicing result.
indexes_info[j] = list(range(data_shape[j]))
ellipsis_num += 1
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {ele_type}.")
broadcast_shape, final_shape, indexes_shapes_info = \
_derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name)
return broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims
def generate_updates_from_scalar(data, indices, value, op_type):
"""Generate an updates tensor from a scalar."""
data_shape = F.shape(data)
indices_shape = F.shape(indices)
data_dtype = F.dtype(data)
return convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type)
def _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key: list):
"""Determine whether the tensor in the index appears continuously."""
for i in range(len(index_tensor_info_key) - 1):
if index_tensor_info_key[i + 1] != index_tensor_info_key[i] + 1:
return False
return True
def generate_updates_from_tuple(data, index, value, op_type):
"""Generate an updates tensor from a tuple."""
value_types = hyper_map(F.typeof, value)
data_dtype = F.dtype(data)
value_elements_type = check_value_elements(data_dtype, value_types)
if value_elements_type == ALL_TENSOR:
value_shapes = hyper_map(F.shape, value)
shapes_same = check_shapes_same(value_shapes, TENSOR_SETITEM)
if shapes_same:
value = F.pack(value)
return generate_updates_from_tensor(data, index, value, op_type)
data_shape = F.shape(data)
index_shape = F.shape(index)
return convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type)
def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name):
"""Derive the resulting shape information from the a tuple index of mixed tensors."""
index_tensor_info_key = list(index_tensors_info.keys())
index_tensor_info_value = list(index_tensors_info.values())
broadcast_shape = generate_broadcast_shape(index_tensor_info_value, op_name)
final_shape = []
indexes_shapes_info = []
mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key)
if mixed_tensors_continuous:
tensor_shape_dealt = False
for ele in indexes_info.values():
if isinstance(ele, list):
final_shape.append(len(ele))
indexes_shapes_info.append(ele)
elif isinstance(ele, tuple):
if not tensor_shape_dealt:
final_shape.extend(broadcast_shape)
indexes_shapes_info.append(broadcast_shape)
tensor_shape_dealt = True
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.")
else:
final_shape.extend(broadcast_shape)
indexes_shapes_info.append(broadcast_shape)
for ele in indexes_info.values():
if isinstance(ele, list):
final_shape.append(len(ele))
indexes_shapes_info.append(ele)
elif isinstance(ele, tuple):
continue
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.")
return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info)
def generate_updates_from_tensor(data, index, value, op_type):
"""Generate an updates tensor from a tensor."""
data_shape = F.shape(data)
index_shape = F.shape(index)
value_shape = F.shape(value)
data_dtype = F.dtype(data)
value_dtype = F.dtype(value)
updates_shape = value_shape
check_dtype_same = check_tensors_dtype_same(data_dtype, value_dtype, TENSOR_SETITEM)
if check_dtype_same:
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
need_broadcast = check_two_shapes_need_broadcast(updates_shape, value_shape)
if need_broadcast:
return tile(updates_shape, value)
return value
@constexpr
def get_pos_of_int_index(indexes_types):
"""Get int index positions from the mixed tensors index which contains int, tensor, slice, and ellipsis."""
int_positions = []
for i, ele_type in enumerate(indexes_types):
if ele_type == mstype.int32:
int_positions.append(i)
return int_positions
@constexpr
def separate_mixed_tensors_index(indexes_types, op_name):
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
tensor_positions = []
slice_positions = []
ellipsis_position = None
for i, ele_type in enumerate(indexes_types):
if isinstance(ele_type, mstype.tensor_type):
tensor_positions.append(i)
elif isinstance(ele_type, mstype.slice_type):
slice_positions.append(i)
elif isinstance(ele_type, mstype.ellipsis_type):
ellipsis_position = i
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'Slice', 'Ellipsis', but got {ele_type}.")
return tensor_positions, slice_positions, ellipsis_position
@constexpr
def scalar_in_sequence(x, y):
"""Determine whether the scalar in the sequence."""
if x is None:
raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, "
"but the scalar is not.")
if y is None:
raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, "
"but the sequence is not.")
if x in y:
return True
return False

View File

@ -14,11 +14,11 @@
# ============================================================================
"""Implementation for getitem."""
from . import _utils as multi_utils
from ..import base
from . import _compile_utils as compile_utils
from . import _constexpr_utils as const_utils
from .. import base
from ... import functional as F
from ....common import dtype as mstype
getitem = base.MultitypeFuncGraph('getitem')
"""
@ -227,7 +227,8 @@ def _tensor_getitem_by_tensor(data, tensor_index):
Outputs:
Tensor, element type is same as the element type of data.
"""
check_dtypes = multi_utils.check_tensor_dtype_valid(F.dtype(tensor_index), (mstype.int32, mstype.int64))
check_dtypes = const_utils.check_index_tensor_dtype(F.dtype(tensor_index),
const_utils.TENSOR_GETITEM)
result = None
if check_dtypes:
result = F.gather(data, tensor_index, 0)
@ -246,14 +247,13 @@ def _tensor_getitem_by_tuple(data, tuple_index):
Outputs:
Tensor, element type is same as the element type of data.
"""
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_GETITEM)
result = None
if index_elements_type == multi_utils.NO_TENSOR:
result = _tensor_slice(data, tuple_index)
if index_elements_type == multi_utils.ALL_TENSOR:
result = _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
return result
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return _tensor_slice(data, tuple_index)
if index_elements_type == const_utils.ALL_TENSOR:
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index)
@getitem.register("Tensor", "Ellipsis")
@ -273,6 +273,17 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
"""Tensor getitem by a tuple of tensor."""
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_GETITEM)
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result
def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
"""Tensor getitem by a tuple of mixed tensor."""
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result

View File

@ -0,0 +1,101 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""in_impl"""
from . import _constexpr_utils as const_utils
from ... import functional as F
from ...composite import base
in_ = base.MultitypeFuncGraph("in")
"""
in_ is a metafuncgraph object which will determine if a in b
using ".register" decorator
"""
@in_.register("Number", "Tuple")
def _number_in_tuple(x, y):
"""
Determine if a number in tuple.
Args:
x (Number): x
y (tuple): y
Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.scalar_in_sequence(x, y)
@in_.register("Number", "List")
def _number_in_list(x, y):
"""
Determine if a number in list.
Args:
x (Number): x
y (list): y
Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.scalar_in_sequence(x, y)
@in_.register("String", "Tuple")
def _string_in_tuple(x, y):
"""
Determine if a str in a tuple.
Args:
x (str): x
y (tuple): y
Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.scalar_in_sequence(x, y)
@in_.register("String", "List")
def _string_in_list(x, y):
"""
Determine if a str in a list.
Args:
x (str): x
y (list): y
Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.scalar_in_sequence(x, y)
@in_.register("String", "Dictionary")
def _str_in_dict(x, y):
"""
Determine if a str in dict.
Args:
x: str
y: dict
Returns:
bool, if x in y return true, x not in y return false.
"""
return F.in_dict(x, y)

View File

@ -15,10 +15,11 @@
"""Implementation for setitem."""
from . import _compile_utils as compile_utils
from . import _constexpr_utils as const_utils
from ... import functional as F
from ...composite import base
from ....common import dtype as mstype
from ... import functional as F
from . import _utils as multi_utils
setitem = base.MultitypeFuncGraph('setitem')
@ -139,8 +140,8 @@ def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
Tensor, element type and shape is same as data.
"""
index_dtype = F.dtype(index)
tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == multi_utils.INT_:
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == const_utils.INT_:
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
@ -166,8 +167,8 @@ def _tensor_setitem_by_tensor_with_number(data, index, value):
Tensor, element type and shape is same as data.
"""
index_dtype = F.dtype(index)
tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == multi_utils.BOOL_:
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == const_utils.BOOL_:
return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value)
return _tensor_setitem_by_int_tensor_with_scalar(data, index, value)
@ -190,17 +191,24 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM)
result = None
if index_elements_type == multi_utils.NO_TENSOR:
result = _tensor_assgin_number(data, tuple_index, value)
if index_elements_type == multi_utils.ALL_TENSOR:
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
updates = multi_utils.generate_updates_from_scalar(data, indices, value,
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
result = F.scatter_nd_update(data, indices, updates)
return result
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return _tensor_assgin_number(data, tuple_index, value)
if index_elements_type == const_utils.ALL_TENSOR:
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_scalar(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)
@setitem.register("Tensor", "Tuple", "Tensor")
@ -221,17 +229,24 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM)
result = None
if index_elements_type == multi_utils.NO_TENSOR:
result = _tensor_assgin_tensor(data, tuple_index, value)
if index_elements_type == multi_utils.ALL_TENSOR:
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
updates = multi_utils.generate_updates_from_tensor(data, indices, value,
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
result = F.scatter_nd_update(data, indices, updates)
return result
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return _tensor_assgin_tensor(data, tuple_index, value)
if index_elements_type == const_utils.ALL_TENSOR:
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_tensor(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)
@setitem.register("Tensor", "Tuple", "Tuple")
@ -253,15 +268,22 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM)
result = None
if index_elements_type == multi_utils.ALL_TENSOR:
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
updates = multi_utils.generate_updates_from_tuple(data, indices, value,
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
result = F.scatter_nd_update(data, indices, updates)
return result
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
if index_elements_type == const_utils.ALL_TENSOR:
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_tuple(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)
@setitem.register("Tensor", "Tensor", "Tuple")
@ -278,7 +300,7 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
Tensor, element type and shape is same as data.
"""
index_dtype = F.dtype(index)
check_dtype = multi_utils.check_tensor_dtype_valid(index_dtype, (mstype.int32, mstype.int64))
check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM)
result = None
if check_dtype:
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
@ -331,14 +353,14 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value):
def _tensor_assgin_number(data, input_slice, value):
"""Givens a scalar assign to tensor by slice"""
check_result = multi_utils.check_tensor_setitem_index(input_slice)
check_result = const_utils.check_tensor_setitem_index(input_slice)
result = None
if check_result:
data_shape = F.shape(data)
indices = multi_utils.slice2indices(input_slice, data_shape)
is_tuple_int = multi_utils.tuple_element_is_int(input_slice)
indices = const_utils.slice2indices(input_slice, data_shape)
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = multi_utils.integer_to_indices(input_slice, data_shape)
indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
return result
@ -347,7 +369,7 @@ def _tensor_assgin_number(data, input_slice, value):
def _tensor_setitem_with_int_v1(data, index, value):
"""Syntax: A[1] = 3"""
data_shape = F.shape(data)
indices = multi_utils.integer_to_indices(index, data_shape)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_number(data, data_shape, index, indices, value)
@ -355,7 +377,7 @@ def _tensor_setitem_with_int_v1(data, index, value):
def _tensor_setitem_with_int_v2(data, index, value):
"""Syntax: A[1] = Tensor"""
data_shape = F.shape(data)
indices = multi_utils.integer_to_indices(index, data_shape)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value)
@ -376,7 +398,7 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
data_size = F.size(data)
value_shape = F.shape(value)
value_size = F.size(value)
check_result = multi_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
if check_result:
if data_size == value_size:
result = F.reshape(value, data_shape)
@ -391,13 +413,13 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
def _tensor_assgin_tensor(data, input_slice, value):
"""Assigns a tensor value to the tensor by slice."""
result = None
check_result = multi_utils.check_tensor_setitem_index(input_slice)
check_result = const_utils.check_tensor_setitem_index(input_slice)
if check_result:
data_shape = F.shape(data)
indices = multi_utils.slice2indices(input_slice, data_shape)
is_tuple_int = multi_utils.tuple_element_is_int(input_slice)
indices = const_utils.slice2indices(input_slice, data_shape)
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = multi_utils.integer_to_indices(input_slice, data_shape)
indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
return result
@ -407,7 +429,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value):
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = multi_utils.check_indices(indices_size, index)
indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
@ -415,7 +437,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value):
value_fill = None
value_size = F.size(value)
value_size = multi_utils.check_indices_value_size(indices_size, value_size)
value_size = const_utils.check_indices_value_size(indices_size, value_size)
if value_size == 1:
value_fill = F.fill(data_dtype, (indices_size,), 1)
value = F.cast(value, data_dtype)
@ -432,7 +454,7 @@ def _tensor_indices_number(data, data_shape, index, indices, value):
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = multi_utils.check_indices(indices_size, index)
indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
@ -445,16 +467,16 @@ def _tensor_indices_number(data, data_shape, index, indices, value):
def _tensor_setitem_by_tensor_with_tuple(data, index, value):
"""Set a tensor item by a tensor with a tuple."""
updates = multi_utils.generate_updates_from_tuple(data, index, value,
multi_utils.SET_ITEM_BY_ONE_TENSOR)
updates = compile_utils.generate_updates_from_tuple(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
result = F.scatter_update(data, index, updates)
return result
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
"""Set a tensor item by a int tensor with a scalar."""
updates = multi_utils.generate_updates_from_scalar(data, index, value,
multi_utils.SET_ITEM_BY_ONE_TENSOR)
updates = compile_utils.generate_updates_from_scalar(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
return F.scatter_update(data, index, updates)
@ -462,7 +484,7 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
"""Set a tensor item by a bool tensor with a scalar."""
index_shape = F.shape(index)
shape = F.shape(data)
shape = multi_utils.check_equal(
shape = const_utils.check_equal(
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
dtype = F.dtype(data)
u = F.fill(dtype, shape, value)
@ -471,8 +493,8 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
"""Set a tensor item by a int tensor with a tensor."""
updates = multi_utils.generate_updates_from_tensor(data, index, value,
multi_utils.SET_ITEM_BY_ONE_TENSOR)
updates = compile_utils.generate_updates_from_tensor(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
return F.scatter_update(data, index, updates)
@ -480,10 +502,10 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
"""Set a tensor item by a bool tensor with a tensor."""
index_shape = F.shape(index)
data_shape = F.shape(data)
data_shape = multi_utils.check_equal(data_shape, index_shape,
data_shape = const_utils.check_equal(data_shape, index_shape,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
size = F.size(value)
size = multi_utils.check_equal(1, size,
size = const_utils.check_equal(1, size,
"When assign value is a tensor, its size should be {}, but current size is {}.")
dtype = F.dtype(data)
u_cast = F.cast(value, dtype)

View File

@ -1425,7 +1425,6 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name):
validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
validator.check_integer("len of input_x", len(x_shape), 1, Rel.GT, prim_name)
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name)
validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT, prim_name)
rank_base = len(x_shape[0])
N = len(x_shape)
out_shape = x_shape[0]

View File

@ -33,9 +33,4 @@ class IdentityEC(IExectorComponent):
keyword.desc_inputs: self.inputs[keyword.desc_inputs],
keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs])
}
print("buxue------------------------------------------------")
print("inputs")
print(ret[keyword.desc_inputs])
print("outputs")
print(ret[keyword.result])
return ret

View File

@ -19,9 +19,9 @@ import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
from tests.ut.python.ut_filter import non_graph_engine
from tests.mindspore_test_framework.mindspore_test import mindspore_test
from tests.mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
@ -133,7 +133,7 @@ def test_list_append_2():
class ListOperate(nn.Cell):
def __init__(self,):
def __init__(self, ):
super(ListOperate, self).__init__()
def construct(self, t, l):
@ -152,6 +152,20 @@ class ListOperate(nn.Cell):
return x
class InListNet(nn.Cell):
def __init__(self, ):
super(InListNet, self).__init__()
self.list_ = [1, 2, 3, 4, 5, "ok"]
def construct(self, x):
ret = x
if 2 in self.list_:
ret = x + x
if "ok" in self.list_:
ret = x - x
return ret
class AxisListNet(nn.Cell):
def __init__(self):
super(AxisListNet, self).__init__()
@ -204,10 +218,15 @@ test_case_ops = [
('AxisListDefault', {
'block': AxisListDefaultNet(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
('InList', {
'block': InListNet(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
]
test_case_lists = [test_case_ops]
test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
# use -k to select certain testcast
# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm

View File

@ -19,9 +19,9 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import dtype as mstype
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
from tests.ut.python.ut_filter import non_graph_engine
from tests.mindspore_test_framework.mindspore_test import mindspore_test
from tests.mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
@ -52,6 +52,20 @@ class NestTupleGraphNet(nn.Cell):
return self.layers[0][1](x)
class InTupleNet(nn.Cell):
def __init__(self, ):
super(InTupleNet, self).__init__()
self.tuple_ = (1, 2, 3, 4, 5, "ok")
def construct(self, x):
ret = x
if 2 in self.tuple_:
ret = x + x
if "ok" in self.tuple_:
ret = x - x
return ret
test_case_ops = [
('TupleGraph', {
'block': TupleGraphNet(),
@ -59,6 +73,9 @@ test_case_ops = [
('NestTupleGraph', {
'block': NestTupleGraphNet(),
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
('InTuple', {
'block': InTupleNet(),
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]})
]
test_case_lists = [test_case_ops]

View File

@ -176,12 +176,134 @@ class TensorGetItemByThreeTensors(Cell):
return ret
class TensorGetItemByMixedTensors(Cell):
class TensorGetItemByMixedTensors_0(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors, self).__init__()
super(TensorGetItemByMixedTensors_0, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32))
def construct(self, tensor, index_0, index_1):
ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const
return ret
class TensorGetItemByMixedTensors_1(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_1, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32))
def construct(self, tensor, index_0, index_1):
ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const
return ret
class TensorGetItemByMixedTensors_2(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_2, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32))
def construct(self, tensor, index_0, index_1):
ret = tensor[0, index_0, index_1, ..., 3] + self.const
return ret
class TensorGetItemByMixedTensors_3(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_3, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32))
def construct(self, tensor, index_0, index_1):
ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const
return ret
class TensorGetItemByMixedTensors_4(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_4, self).__init__()
self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32))
def construct(self, tensor, index_0, index_1, index_2):
ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const
return ret
class TensorGetItemByMixedTensors_5(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_5, self).__init__()
self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32))
def construct(self, tensor, index_0, index_1, index_2):
ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const
return ret
class TensorGetItemByMixedTensors_6(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_6, self).__init__()
self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))
def construct(self, tensor, index_0, index_1, index_2):
ret = tensor[..., index_0, index_1, index_2, 3] + self.const
return ret
class TensorSetItemByMixedTensors_0(Cell):
def __init__(self, value):
super(TensorSetItemByMixedTensors_0, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)),
mstype.float32),
name="x")
self.value = value
def construct(self, index_0, index_1, index_2):
self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value
ret = self.param + self.const
return ret
class TensorSetItemByMixedTensors_1(Cell):
def __init__(self, value):
super(TensorSetItemByMixedTensors_1, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
name="x")
self.value = value
def construct(self, index_0, index_1, index_2):
self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value
ret = self.param + self.const
return ret
class TensorSetItemByMixedTensors_2(Cell):
def __init__(self, value):
super(TensorSetItemByMixedTensors_2, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
name="x")
self.value = value
def construct(self, index_0, index_1, index_2):
self.param[..., index_0, index_1, index_2, 3] = self.value
ret = self.param + self.const
return ret
class TensorGetItemByMixedTensorsTypeError(Cell):
def __init__(self):
super(TensorGetItemByMixedTensorsTypeError, self).__init__()
def construct(self, x, index_0, index_1):
ret = x[index_0, index_1, 0:6]
ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]]
return ret
class TensorGetItemByMixedTensorsNumberError(Cell):
def __init__(self):
super(TensorGetItemByMixedTensorsNumberError, self).__init__()
def construct(self, x, index_0, index_1):
ret = x[index_0, index_1, 0:3, ..., index_1, index_0]
return ret
@ -189,7 +311,7 @@ class TensorSetItemByOneTensorWithNumber(Cell):
def __init__(self, value):
super(TensorSetItemByOneTensorWithNumber, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = value
def construct(self, index):
@ -202,7 +324,7 @@ class TensorSetItemByOneTensorWithTensor(Cell):
def __init__(self):
super(TensorSetItemByOneTensorWithTensor, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
def construct(self, index, value):
self.param[index] = value
@ -214,7 +336,7 @@ class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
def __init__(self, value):
super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = value
def construct(self, index):
@ -227,7 +349,7 @@ class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
def __init__(self):
super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__()
self.const = Tensor(np.ones((6, 3, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*3*8).reshape((6, 3, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 3 * 8).reshape((6, 3, 8)), mstype.float32), name="x")
def construct(self, index, value_0, value_1, value_2):
self.param[index] = (value_0, value_1, value_2)
@ -239,7 +361,7 @@ class TensorSetItemByTensorsWithNumber(Cell):
def __init__(self, value):
super(TensorSetItemByTensorsWithNumber, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = value
def construct(self, index_0, index_1, index_2):
@ -252,7 +374,7 @@ class TensorSetItemByTensorsWithTensor(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTensor, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
def construct(self, index_0, index_1, index_2, value):
self.param[index_0, index_1, index_2] = value
@ -264,7 +386,7 @@ class TensorSetItemByTensorsWithTensorNumberError(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTensorNumberError, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
def construct(self, index_0, index_1, index_2, index_3, value):
self.param[index_0, index_1, index_2, index_3] = value
@ -276,7 +398,7 @@ class TensorSetItemByTensorsWithTupleOfNumber(Cell):
def __init__(self, value):
super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = value
def construct(self, index_0, index_1, index_2):
@ -289,7 +411,7 @@ class TensorSetItemByTensorsWithTupleOfTensor(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
def construct(self, index_0, index_1, index_2, value_0, value_1, value_2):
self.param[index_0, index_1, index_2] = (value_0, value_1, value_2)
@ -301,7 +423,7 @@ class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
def construct(self, index_0, index_1, index_2, value_0, value_1):
self.param[index_0, index_1, index_2] = (value_0, value_1)
@ -313,7 +435,7 @@ class TensorSetItemByMixedTensors(Cell):
def __init__(self):
super(TensorSetItemByMixedTensors, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = 99.0
def construct(self, index_0, index_1):
@ -538,11 +660,11 @@ def test_tensor_assign_bool_index():
net1(Ta, Tb, Tc, u_tensor)
with pytest.raises(ValueError):
net1(Ta, Td, Tc, u_tensor)
with pytest.raises(TypeError):
with pytest.raises(IndexError):
net1(Ta, u_tensor, Tc, u_tensor)
with pytest.raises(ValueError):
net1(Ta, Tb, Td, u_tensor)
with pytest.raises(TypeError):
with pytest.raises(IndexError):
net1(Ta, Tb, Ta, u_tensor)
with pytest.raises(ValueError):
net1(Ta, Tb, Tc, u_tensor_error)
@ -620,22 +742,67 @@ test_cases = [
}),
('TensorGetItemByOneTensor', {
'block': TensorGetItemByOneTensor(),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)],
}),
('TensorGetItemByTwoTensors', {
'block': TensorGetItemByTwoTensors(),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)],
}),
('TensorGetItemByThreeTensors', {
'block': TensorGetItemByThreeTensors(),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_0', {
'block': TensorGetItemByMixedTensors_0(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_1', {
'block': TensorGetItemByMixedTensors_1(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_2', {
'block': TensorGetItemByMixedTensors_2(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_3', {
'block': TensorGetItemByMixedTensors_3(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_4', {
'block': TensorGetItemByMixedTensors_4(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_5', {
'block': TensorGetItemByMixedTensors_5(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_6', {
'block': TensorGetItemByMixedTensors_6(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithNumber', {
'block': TensorSetItemByOneTensorWithNumber(value=0.0),
'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
@ -683,46 +850,143 @@ test_cases = [
Tensor(np.zeros((4, 5)), mstype.float32),
Tensor(np.ones((4, 5)), mstype.float32),
Tensor(np.ones((4, 5)) * 2, mstype.float32)],
})
}),
('TensorSetItemByMixedTensorsWithNumber_0', {
'block': TensorSetItemByMixedTensors_0(value=88.0),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithTensor_0', {
'block': TensorSetItemByMixedTensors_0(value=Tensor(np.ones((4, 5, 3, 9), np.float32))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfNumber_0', {
'block': TensorSetItemByMixedTensors_0(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfTensor_0', {
'block': TensorSetItemByMixedTensors_0(value=(Tensor(np.ones((4, 5, 3, 9), np.float32)),
Tensor(np.zeros((4, 5, 3, 9), np.float32)),
Tensor(np.ones((4, 5, 3, 9), np.float32)))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithNumber_1', {
'block': TensorSetItemByMixedTensors_1(value=88.0),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithTensor_1', {
'block': TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfNumber_1', {
'block': TensorSetItemByMixedTensors_1(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfTensor_1', {
'block': TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.zeros((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.float32)))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithNumber_2', {
'block': TensorSetItemByMixedTensors_2(value=88.0),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithTensor_2', {
'block': TensorSetItemByMixedTensors_2(value=Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfNumber_2', {
'block': TensorSetItemByMixedTensors_2(value=(1.0, 2.0, 3.0, 4.0, 5.0)),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfTensor_2', {
'block': TensorSetItemByMixedTensors_2(value=(Tensor(np.ones((4, 5), np.float32)),
Tensor(np.zeros((4, 5), np.float32)),
Tensor(np.ones((4, 5), np.float32)))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
]
raise_error_set = [
('TensorGetItemByOneTensorDtypeError', {
'block': (TensorGetItemByOneTensor(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
'block': (TensorGetItemByOneTensor(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)],
}),
('TensorGetItemByTwoTensorsShapeError', {
'block': (TensorGetItemByTwoTensors(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)],
}),
('TensorGetItemByTwoTensorsDtypeError', {
'block': (TensorGetItemByTwoTensors(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)],
}),
('TensorGetItemByThreeTensorsShapeError', {
'block': (TensorGetItemByThreeTensors(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)],
}),
('TensorGetItemByThreeTensorsDtypeError', {
'block': (TensorGetItemByThreeTensors(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int64),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors', {
'block': (TensorGetItemByMixedTensors(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
('TensorGetItemByMixedTensorsNumberError', {
'block': (TensorGetItemByMixedTensorsNumberError(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64)],
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsTypeError', {
'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsDtypeError', {
'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.float32)],
}),
('TensorGetItemByMixedTensorsShapeError', {
'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(2, 4, 5)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithNumberTypeError', {
'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}),
@ -760,21 +1024,21 @@ raise_error_set = [
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTensorShapeError', {
'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}),
'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((2, 5)), mstype.float32)],
}),
('TensorSetItemByTensorsWithTensorTypeError', {
'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}),
'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTensorNumberError', {
'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}),
'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
@ -782,19 +1046,19 @@ raise_error_set = [
Tensor(np.zeros((2, 5)), mstype.float32)],
}),
('TensorSetItemByTensorsWithTupleOfNumberTypeError', {
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0, 1, 2, 3, 4)), {'exception': TypeError}),
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1, 2, 3, 4)), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTupleOfNumberNumberError', {
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}),
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTupleOfTensorNumberError', {
'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}),
'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
@ -802,7 +1066,7 @@ raise_error_set = [
Tensor(np.ones((4, 5)), mstype.float32)],
}),
('TensorSetItemByTensorsWithTupleOfTensorTypeError', {
'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}),
'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
@ -810,10 +1074,65 @@ raise_error_set = [
Tensor(np.ones((4, 5)), mstype.int32),
Tensor(np.ones((4, 5)) * 2, mstype.int32)],
}),
('TensorSetItemByMixedTensors', {
'block': (TensorSetItemByMixedTensors(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)],
('TensorSetItemByMixedTensorsWithNumberValueTypeError', {
'block': (TensorSetItemByMixedTensors_1(value=88), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithNumberIndexTypeError', {
'block': (TensorSetItemByMixedTensors_1(value=88.0), {'exception': IndexError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.float32)],
}),
('TensorSetItemByMixedTensorsWithTensorValueDtypeError', {
'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.int32))),
{'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithTensorValueShapeError', {
'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((3, 2, 6), np.float32))),
{'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithTensorIndexDtypeError', {
'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))),
{'exception': IndexError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.float32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfNumberValueTypeError', {
'block': (TensorSetItemByMixedTensors_1(value=(1.0, 2, 3.0, 4.0, 5.0, 6.0)),
{'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfTensorValueDtypeError', {
'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.zeros((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.int32)))),
{'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfTensorIndexDtypeError', {
'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.zeros((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.int32)))),
{'exception': IndexError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.float32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
})
]