Support mixed tensor for tensor getitem or setitem.

This commit is contained in:
buxue 2020-05-24 17:38:46 +08:00
parent b94949ea99
commit 7ae289a197
14 changed files with 1020 additions and 231 deletions

View File

@ -105,7 +105,7 @@ convert_object_map = {
T.ge: multitype_ops.greater_equal, T.ge: multitype_ops.greater_equal,
T.is_: F.is_, T.is_: F.is_,
T.is_not: F.is_not, T.is_not: F.is_not,
T.contains: F.in_dict, T.contains: multitype_ops.in_,
T.not_contains: F.not_in_dict, T.not_contains: F.not_in_dict,
# system function # system function

View File

@ -474,6 +474,8 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init()); (void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init());
(void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init()); (void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init()); (void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
(void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init());
(void)py::class_<Ellipsis, Type, std::shared_ptr<Ellipsis>>(m_sub, "Ellipsis").def(py::init());
})); }));
const TypePtr kTypeExternal = std::make_shared<External>(); const TypePtr kTypeExternal = std::make_shared<External>();

View File

@ -95,6 +95,8 @@ string = typing.String()
type_refkey = typing.RefKeyType() type_refkey = typing.RefKeyType()
tensor_type = typing.TensorType tensor_type = typing.TensorType
anything_type = typing.TypeAnything anything_type = typing.TypeAnything
slice_type = typing.Slice
ellipsis_type = typing.Ellipsis
number_type = (int8, number_type = (int8,
int16, int16,

View File

@ -37,6 +37,7 @@ from .logical_and_impl import logical_and
from .logical_or_impl import logical_or from .logical_or_impl import logical_or
from .logic_not_impl import logical_not from .logic_not_impl import logical_not
from .uadd_impl import uadd from .uadd_impl import uadd
from .in_impl import in_
__all__ = [ __all__ = [
'add', 'add',
'sub', 'sub',
@ -59,5 +60,6 @@ __all__ = [
'setitem', 'setitem',
'logical_and', 'logical_and',
'logical_or', 'logical_or',
'logical_not' 'logical_not',
'in_'
] ]

View File

@ -0,0 +1,154 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""constexpr util"""
from . import _constexpr_utils as const_utils
from ... import functional as F
from ... import operations as P
from ...composite import base
from ....common import dtype as mstype
hyper_map = base.HyperMap()
pack = P.Pack(axis=-1)
def broadcast(broadcast_shape, x):
"""Broadcast tensor to the required shape."""
if F.shape(x) == broadcast_shape:
return x
multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape)
if multiples:
return F.tile(x, multiples)
return x
def transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x):
"""Transform indexing tensor to the required."""
x = broadcast(broadcast_shape, x)
return broadcast(final_shape, F.reshape(x, new_shape))
def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor."""
indices = None
check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
if check_index_tensor_number:
dtype_tuple = hyper_map(F.dtype, tuple_index)
check_dtypes = const_utils.check_index_tensors_dtype(dtype_tuple, op_name)
if check_dtypes:
shape_tuple = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name)
broadcast_tensors = hyper_map(F.partial(broadcast, broadcast_shape), tuple_index)
indices = pack(broadcast_tensors)
return indices
def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
indexes_types = hyper_map(F.typeof, tuple_index)
int_positions = const_utils.get_pos_of_int_index(indexes_types)
for i in int_positions:
tuple_index = F.tuple_setitem(tuple_index, i, F.scalar_to_tensor(tuple_index[i], mstype.int32))
indexes_types = hyper_map(F.typeof, tuple_index)
tensor_positions, slice_positions, ellipsis_position = \
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
tensor_indexes = []
slice_indexes = []
for i in tensor_positions:
tensor_indexes.append(tuple_index[i])
for j in slice_positions:
slice_indexes.append(tuple_index[j])
data_shape = F.shape(data)
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \
const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape,
indexes_types,
tensor_indexes_shapes,
tensor_indexes_dtypes,
slice_indexes,
op_name)
slice_number = 0
final_index_tensors = []
tuple_index_size = len(tuple_index)
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_size):
if i in tensor_positions:
transform_tensor = transform_indexing_tensor(broadcast_shape,
final_shape,
index_tensor_new_shape,
tuple_index[i])
final_index_tensors.append(transform_tensor)
if i in slice_positions:
slice_tensor = const_utils.convert_slice_to_tensor(slice_number,
final_shape,
indexes_shapes_info,
op_name)
final_index_tensors.append(slice_tensor)
slice_number += 1
if i == ellipsis_position:
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(slice_number,
ellipsis_occupied_dims,
final_shape,
indexes_shapes_info,
op_name)
for ele in ellipsis_tensors:
final_index_tensors.append(ele)
slice_number += ellipsis_occupied_dims
indices = pack(final_index_tensors)
return indices
def generate_updates_from_scalar(data, indices, value, op_type):
"""Generate an updates tensor from a scalar."""
data_shape = F.shape(data)
indices_shape = F.shape(indices)
data_dtype = F.dtype(data)
return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type)
def generate_updates_from_tuple(data, index, value, op_type):
"""Generate an updates tensor from a tuple."""
value_types = hyper_map(F.typeof, value)
data_dtype = F.dtype(data)
value_elements_type = const_utils.check_value_elements(data_dtype, value_types)
if value_elements_type == const_utils.ALL_TENSOR:
value_shapes = hyper_map(F.shape, value)
shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM)
if shapes_same:
value = F.pack(value)
return generate_updates_from_tensor(data, index, value, op_type)
data_shape = F.shape(data)
index_shape = F.shape(index)
return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type)
def generate_updates_from_tensor(data, index, value, op_type):
"""Generate an updates tensor from a tensor."""
data_shape = F.shape(data)
index_shape = F.shape(index)
value_shape = F.shape(value)
data_dtype = F.dtype(data)
value_dtype = F.dtype(value)
updates_shape = value_shape
check_dtype_same = const_utils.check_tensors_dtype_same(data_dtype, value_dtype, const_utils.TENSOR_SETITEM)
if check_dtype_same:
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type)
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape)
if need_broadcast:
return broadcast(updates_shape, value)
return value

View File

@ -15,19 +15,15 @@
"""constexpr util""" """constexpr util"""
from functools import reduce from functools import reduce
import numpy as np
from ...primitive import constexpr
from ....common.tensor import Tensor
from ....common import dtype as mstype
from ...._extends.utils import Slice, Ellipsis_
from ....ops import _utils as op_utils
from ...composite import base
from .... import log as logger
from ... import functional as F
from ... import operations as P
hyper_map = base.HyperMap() import numpy as np
pack = P.Pack(axis=-1)
from ...primitive import constexpr
from .... import log as logger
from ...._extends.utils import Slice, Ellipsis_
from ....common import dtype as mstype
from ....common.tensor import Tensor
from ....ops import _utils as op_utils
ALL_TENSOR = 0 ALL_TENSOR = 0
NO_TENSOR = 1 NO_TENSOR = 1
@ -264,7 +260,7 @@ def tuple_index_elements_type(types, op_name):
return ALL_TENSOR return ALL_TENSOR
if tensors_number == 0: if tensors_number == 0:
return NO_TENSOR return NO_TENSOR
raise IndexError(f"For '{op_name}', the index does not support mixed tensor.") return CONTAIN_TENSOR
@constexpr @constexpr
@ -279,12 +275,12 @@ def check_value_elements(data_dtype, types):
tensors_number += 1 tensors_number += 1
else: else:
raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' " raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' "
f"in value tuple is not consistent with origin tensor data type '{data_dtype}'.") f"in value tuple is not consistent with assigned tensor data type '{data_dtype}'.")
elif mstype.issubclass_(ele, data_dtype): elif mstype.issubclass_(ele, data_dtype):
scalars_number += 1 scalars_number += 1
else: else:
raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in " raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in "
f"value tuple is not consistent with origin tensor data type '{data_dtype}'.") f"value tuple is not consistent with assigned tensor data type '{data_dtype}'.")
if tensors_number == len(types): if tensors_number == len(types):
return ALL_TENSOR return ALL_TENSOR
if scalars_number == len(types): if scalars_number == len(types):
@ -299,51 +295,46 @@ def get_index_tensor_dtype(dtype):
return INT_ return INT_
if dtype == mstype.bool_: if dtype == mstype.bool_:
return BOOL_ return BOOL_
raise TypeError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") raise IndexError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
@constexpr @constexpr
def check_index_tensors_dtype(dtypes, op_name): def check_index_tensors_dtype(dtypes, op_name):
"""Check a tuple of tensor data type.""" """Check a tuple of tensor data type."""
if op_name == TENSOR_GETITEM:
valid_dtypes = (mstype.int32, mstype.int64)
elif op_name == TENSOR_SETITEM:
valid_dtypes = (mstype.int32,)
else:
raise ValueError("Unsupported operation.")
for ele in dtypes: for ele in dtypes:
if ele in valid_dtypes and ele == dtypes[0]: if not ele == mstype.int32:
continue raise IndexError(f"For '{op_name}', the all index tensor "
raise TypeError(f"For '{op_name}', the index tensors data type must be same, " f"data types should be mstype.int32, but got {dtypes}.")
f"and should be one of the following: {valid_dtypes}, but got {dtypes}.")
return True return True
@constexpr @constexpr
def check_tensor_dtype_valid(dtype, valid_dtypes): def check_index_tensor_dtype(dtype, op_name):
"""Check a tensor data type.""" """Check a tensor data type."""
if dtype in valid_dtypes: if dtype == mstype.int32:
return True return True
raise TypeError(f"The index tensor data type must be one of " raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.")
f"the following: {valid_dtypes}, but got {dtype}.")
@constexpr @constexpr
def check_tensors_dtype_same(x_dtype, y_dtype, op_name): def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
"""Check tensors data type same.""" """Check tensors data type same."""
if x_dtype == y_dtype: if value_dtype == data_dtype:
return True return True
raise TypeError(f"For '{op_name}', the value data type '{y_dtype}' " raise TypeError(f"For '{op_name}', the value data type '{value_dtype}' "
f"is not consistent with origin tensor data type {x_dtype}.") f"is not consistent with assigned tensor data type {data_dtype}.")
@constexpr @constexpr
def broadcast_shapes(shapes, op_name): def generate_broadcast_shape(shapes, op_name):
"""Broadcasts a tuple of tensor.""" """Generate broadcast shape for a tuple of shape."""
broadcast_shape = shapes[0] broadcast_shape = shapes[0]
for i, shape in enumerate(shapes): for i, shape in enumerate(shapes):
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
try:
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name) broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
except ValueError as ex:
raise IndexError(ex)
return tuple(broadcast_shape) return tuple(broadcast_shape)
@ -366,14 +357,82 @@ def check_two_shapes_need_broadcast(shape_x, shape_y):
@constexpr @constexpr
def compute_multiples(origin_shape, broadcast_shape): def compute_multiples(origin_shape, broadcast_shape):
"""Compute multiples between broadcast_shape with origin_shape.""" """Compute multiples between origin shape with broadcast shape."""
len_gap = len(broadcast_shape) - len(origin_shape) len_gap = len(broadcast_shape) - len(origin_shape)
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape)) return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape))
def tile(broadcast_shape, x): @constexpr
multiples = compute_multiples(F.shape(x), broadcast_shape) def compute_new_shape(origin_shape, indexes_shapes_info):
return F.tile(x, multiples) """Compute new shape between origin shape with final shape."""
new_shape = []
for i in indexes_shapes_info:
if i == origin_shape:
new_shape.extend(origin_shape)
else:
new_shape.append(1)
return tuple(new_shape)
@constexpr
def convert_ellipsis_to_tensors(slice_number,
ellipsis_occupied_dims,
final_shape,
indexes_shapes_info,
op_name):
"""Convert an ellipsis to a list of tensor."""
tensor_list = []
dims_dealt_count = 0
while dims_dealt_count < ellipsis_occupied_dims:
shape = []
slice_count = 0
array = None
for ele in indexes_shapes_info:
if isinstance(ele, list):
if slice_count == slice_number:
array = np.array(ele, np.int32)
shape.append(len(ele))
else:
shape.append(1)
slice_count += 1
if isinstance(ele, tuple):
shape.extend([1] * len(ele))
if array is None:
raise ValueError(f"For '{op_name}', generate tensors from ellipsis failed.")
array = np.reshape(array, shape)
reps = compute_multiples(shape, final_shape)
tensor = Tensor(np.tile(array, reps))
tensor_list.append(tensor)
slice_number += 1
dims_dealt_count += 1
return tensor_list
@constexpr
def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name):
"""Convert a slice to a tensor."""
shape = []
count = 0
array = None
for ele in indexes_shapes_info:
if isinstance(ele, list):
if count == slice_number:
array = np.array(ele, np.int32)
shape.append(len(ele))
else:
# When the slice is not the slice looking for, the shape is filled with 1.
shape.append(1)
count += 1
elif isinstance(ele, tuple):
shape.extend([1] * len(ele))
else:
shape.append(1)
if array is None:
raise ValueError(f"For '{op_name}', generate tensor from 'slice' failed.")
array = np.reshape(array, shape)
reps = compute_multiples(shape, final_shape)
tensor = Tensor(np.tile(array, reps))
return tensor
@constexpr @constexpr
@ -381,8 +440,8 @@ def check_shapes_same(value_shapes, op_name):
"""Check if the shapes in the tuple are consistent.""" """Check if the shapes in the tuple are consistent."""
for i, shape in enumerate(value_shapes): for i, shape in enumerate(value_shapes):
if shape != value_shapes[0]: if shape != value_shapes[0]:
raise ValueError(f"For '{op_name}', the {i}th tensor shape in value tuple " raise ValueError(f"For '{op_name}', the {i}th tensor shape in "
f"is not same as the first tensor shape.") f"value tuple is not same as the first tensor shape.")
return True return True
@ -396,7 +455,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty
if isinstance(value, mstype.dtype_to_pytype(data_dtype)): if isinstance(value, mstype.dtype_to_pytype(data_dtype)):
return Tensor(np.full(updates_shape, value), dtype=data_dtype) return Tensor(np.full(updates_shape, value), dtype=data_dtype)
raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'" raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'"
f" is not consistent with tensor data type {data_dtype}.") f" is not consistent with the assigned tensor data type {data_dtype}.")
@constexpr @constexpr
@ -404,8 +463,8 @@ def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value
"""Convert a tuple of scalar to a tensor.""" """Convert a tuple of scalar to a tensor."""
updates_shape = generate_updates_shape(data_shape, index_shape, op_type) updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
if len(value) != updates_shape[-1]: if len(value) != updates_shape[-1]:
raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} in the updates tuple " raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} "
f"does not meet the requirements: {updates_shape[-1]}.") f"in the updates tuple does not meet the requirements: {updates_shape[-1]}.")
array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype)) array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype))
reps = compute_multiples(updates_shape[-1:], updates_shape) reps = compute_multiples(updates_shape[-1:], updates_shape)
return Tensor(np.tile(array, reps)) return Tensor(np.tile(array, reps))
@ -430,58 +489,145 @@ def check_number_of_index_tensor(data_shape, tuple_len, op_name):
f"is greater than the dimension {len(data_shape)} of the operated tensor.") f"is greater than the dimension {len(data_shape)} of the operated tensor.")
def generate_indeices_from_tuple_of_tensor(data, tuple_index, op_name): @constexpr
"""Generate an indices tensor from a tuple of tensor.""" def generate_index_info_from_tuple_of_mixed_tensors(data_shape,
indices = None indexes_types,
check_index_tensor_number = check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name) tensor_indexes_shapes,
if check_index_tensor_number: tensor_indexes_dtypes,
dtype_tuple = hyper_map(F.dtype, tuple_index) slice_indexes,
check_dtypes = check_index_tensors_dtype(dtype_tuple, op_name) op_name):
if check_dtypes: """
shape_tuple = hyper_map(F.shape, tuple_index) Generate index info which contain broadcast shape, final shape,
broadcast_shape = broadcast_shapes(shape_tuple, op_name) indexes shapes info, ellipsis size from a tuple of mixed tensors.
broadcast_tensors = hyper_map(F.partial(tile, broadcast_shape), tuple_index) """
indices = pack(broadcast_tensors) check_index_tensors_dtype(tensor_indexes_dtypes, op_name)
return indices data_rank = len(data_shape)
indexes_size = len(indexes_types)
if indexes_size > data_rank:
raise IndexError(f"For '{op_name}', the number {indexes_size} of index elements "
f"is greater than the dimension {len(data_shape)} of the operated tensor.")
indexes_info = {}
index_tensors_info = {}
ellipsis_num = 0
ellipsis_occupied_dims = 0
tensor_count = 0
slice_count = 0
for i, ele_type in enumerate(indexes_types):
if ellipsis_num == 0:
pos = i
else:
pos = i + ellipsis_occupied_dims - 1
if isinstance(ele_type, mstype.tensor_type):
indexes_info[pos] = tensor_indexes_shapes[tensor_count]
index_tensors_info[pos] = tensor_indexes_shapes[tensor_count]
tensor_count += 1
elif isinstance(ele_type, mstype.slice_type):
slice_obj = slice(slice_indexes[slice_count].start,
slice_indexes[slice_count].end,
slice_indexes[slice_count].step)
# Use list to represent slicing result.
indexes_info[pos] = list(range(data_shape[pos]))[slice_obj]
slice_count += 1
elif isinstance(ele_type, mstype.ellipsis_type):
if ellipsis_num != 0:
raise IndexError(f"For '{op_name}', the index could only contain one ellipsis.")
ellipsis_occupied_dims = data_rank - indexes_size + 1
for j in range(pos, pos + ellipsis_occupied_dims):
# Use list to represent slicing result.
indexes_info[j] = list(range(data_shape[j]))
ellipsis_num += 1
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {ele_type}.")
broadcast_shape, final_shape, indexes_shapes_info = \
_derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name)
return broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims
def generate_updates_from_scalar(data, indices, value, op_type): def _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key: list):
"""Generate an updates tensor from a scalar.""" """Determine whether the tensor in the index appears continuously."""
data_shape = F.shape(data) for i in range(len(index_tensor_info_key) - 1):
indices_shape = F.shape(indices) if index_tensor_info_key[i + 1] != index_tensor_info_key[i] + 1:
data_dtype = F.dtype(data) return False
return convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) return True
def generate_updates_from_tuple(data, index, value, op_type): def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name):
"""Generate an updates tensor from a tuple.""" """Derive the resulting shape information from the a tuple index of mixed tensors."""
value_types = hyper_map(F.typeof, value) index_tensor_info_key = list(index_tensors_info.keys())
data_dtype = F.dtype(data) index_tensor_info_value = list(index_tensors_info.values())
value_elements_type = check_value_elements(data_dtype, value_types) broadcast_shape = generate_broadcast_shape(index_tensor_info_value, op_name)
if value_elements_type == ALL_TENSOR: final_shape = []
value_shapes = hyper_map(F.shape, value) indexes_shapes_info = []
shapes_same = check_shapes_same(value_shapes, TENSOR_SETITEM) mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key)
if shapes_same: if mixed_tensors_continuous:
value = F.pack(value) tensor_shape_dealt = False
return generate_updates_from_tensor(data, index, value, op_type) for ele in indexes_info.values():
if isinstance(ele, list):
data_shape = F.shape(data) final_shape.append(len(ele))
index_shape = F.shape(index) indexes_shapes_info.append(ele)
return convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type) elif isinstance(ele, tuple):
if not tensor_shape_dealt:
final_shape.extend(broadcast_shape)
indexes_shapes_info.append(broadcast_shape)
tensor_shape_dealt = True
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.")
else:
final_shape.extend(broadcast_shape)
indexes_shapes_info.append(broadcast_shape)
for ele in indexes_info.values():
if isinstance(ele, list):
final_shape.append(len(ele))
indexes_shapes_info.append(ele)
elif isinstance(ele, tuple):
continue
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.")
return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info)
def generate_updates_from_tensor(data, index, value, op_type): @constexpr
"""Generate an updates tensor from a tensor.""" def get_pos_of_int_index(indexes_types):
data_shape = F.shape(data) """Get int index positions from the mixed tensors index which contains int, tensor, slice, and ellipsis."""
index_shape = F.shape(index) int_positions = []
value_shape = F.shape(value) for i, ele_type in enumerate(indexes_types):
data_dtype = F.dtype(data) if ele_type == mstype.int32:
value_dtype = F.dtype(value) int_positions.append(i)
updates_shape = value_shape return int_positions
check_dtype_same = check_tensors_dtype_same(data_dtype, value_dtype, TENSOR_SETITEM)
if check_dtype_same:
updates_shape = generate_updates_shape(data_shape, index_shape, op_type) @constexpr
need_broadcast = check_two_shapes_need_broadcast(updates_shape, value_shape) def separate_mixed_tensors_index(indexes_types, op_name):
if need_broadcast: """Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
return tile(updates_shape, value) tensor_positions = []
return value slice_positions = []
ellipsis_position = None
for i, ele_type in enumerate(indexes_types):
if isinstance(ele_type, mstype.tensor_type):
tensor_positions.append(i)
elif isinstance(ele_type, mstype.slice_type):
slice_positions.append(i)
elif isinstance(ele_type, mstype.ellipsis_type):
ellipsis_position = i
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'Slice', 'Ellipsis', but got {ele_type}.")
return tensor_positions, slice_positions, ellipsis_position
@constexpr
def scalar_in_sequence(x, y):
"""Determine whether the scalar in the sequence."""
if x is None:
raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, "
"but the scalar is not.")
if y is None:
raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, "
"but the sequence is not.")
if x in y:
return True
return False

View File

@ -14,11 +14,11 @@
# ============================================================================ # ============================================================================
"""Implementation for getitem.""" """Implementation for getitem."""
from . import _compile_utils as compile_utils
from . import _utils as multi_utils from . import _constexpr_utils as const_utils
from .. import base from .. import base
from ... import functional as F from ... import functional as F
from ....common import dtype as mstype
getitem = base.MultitypeFuncGraph('getitem') getitem = base.MultitypeFuncGraph('getitem')
""" """
@ -227,7 +227,8 @@ def _tensor_getitem_by_tensor(data, tensor_index):
Outputs: Outputs:
Tensor, element type is same as the element type of data. Tensor, element type is same as the element type of data.
""" """
check_dtypes = multi_utils.check_tensor_dtype_valid(F.dtype(tensor_index), (mstype.int32, mstype.int64)) check_dtypes = const_utils.check_index_tensor_dtype(F.dtype(tensor_index),
const_utils.TENSOR_GETITEM)
result = None result = None
if check_dtypes: if check_dtypes:
result = F.gather(data, tensor_index, 0) result = F.gather(data, tensor_index, 0)
@ -246,14 +247,13 @@ def _tensor_getitem_by_tuple(data, tuple_index):
Outputs: Outputs:
Tensor, element type is same as the element type of data. Tensor, element type is same as the element type of data.
""" """
index_types = multi_utils.hyper_map(F.typeof, tuple_index) indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_GETITEM) index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
result = None if index_elements_type == const_utils.NO_TENSOR:
if index_elements_type == multi_utils.NO_TENSOR: return _tensor_slice(data, tuple_index)
result = _tensor_slice(data, tuple_index) if index_elements_type == const_utils.ALL_TENSOR:
if index_elements_type == multi_utils.ALL_TENSOR: return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
result = _tensor_getitem_by_tuple_of_tensor(data, tuple_index) return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index)
return result
@getitem.register("Tensor", "Ellipsis") @getitem.register("Tensor", "Ellipsis")
@ -273,6 +273,17 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
"""Tensor getitem by a tuple of tensor.""" """Tensor getitem by a tuple of tensor."""
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_GETITEM) indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result
def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
"""Tensor getitem by a tuple of mixed tensor."""
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices) result = F.gather_nd(data, indices)
return result return result

View File

@ -0,0 +1,101 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""in_impl"""
from . import _constexpr_utils as const_utils
from ... import functional as F
from ...composite import base
in_ = base.MultitypeFuncGraph("in")
"""
in_ is a metafuncgraph object which will determine if a in b
using ".register" decorator
"""
@in_.register("Number", "Tuple")
def _number_in_tuple(x, y):
"""
Determine if a number in tuple.
Args:
x (Number): x
y (tuple): y
Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.scalar_in_sequence(x, y)
@in_.register("Number", "List")
def _number_in_list(x, y):
"""
Determine if a number in list.
Args:
x (Number): x
y (list): y
Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.scalar_in_sequence(x, y)
@in_.register("String", "Tuple")
def _string_in_tuple(x, y):
"""
Determine if a str in a tuple.
Args:
x (str): x
y (tuple): y
Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.scalar_in_sequence(x, y)
@in_.register("String", "List")
def _string_in_list(x, y):
"""
Determine if a str in a list.
Args:
x (str): x
y (list): y
Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.scalar_in_sequence(x, y)
@in_.register("String", "Dictionary")
def _str_in_dict(x, y):
"""
Determine if a str in dict.
Args:
x: str
y: dict
Returns:
bool, if x in y return true, x not in y return false.
"""
return F.in_dict(x, y)

View File

@ -15,10 +15,11 @@
"""Implementation for setitem.""" """Implementation for setitem."""
from . import _compile_utils as compile_utils
from . import _constexpr_utils as const_utils
from ... import functional as F
from ...composite import base from ...composite import base
from ....common import dtype as mstype from ....common import dtype as mstype
from ... import functional as F
from . import _utils as multi_utils
setitem = base.MultitypeFuncGraph('setitem') setitem = base.MultitypeFuncGraph('setitem')
@ -139,8 +140,8 @@ def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
Tensor, element type and shape is same as data. Tensor, element type and shape is same as data.
""" """
index_dtype = F.dtype(index) index_dtype = F.dtype(index)
tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype) tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == multi_utils.INT_: if tensor_dtype == const_utils.INT_:
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor) return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor) return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
@ -166,8 +167,8 @@ def _tensor_setitem_by_tensor_with_number(data, index, value):
Tensor, element type and shape is same as data. Tensor, element type and shape is same as data.
""" """
index_dtype = F.dtype(index) index_dtype = F.dtype(index)
tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype) tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == multi_utils.BOOL_: if tensor_dtype == const_utils.BOOL_:
return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value) return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value)
return _tensor_setitem_by_int_tensor_with_scalar(data, index, value) return _tensor_setitem_by_int_tensor_with_scalar(data, index, value)
@ -190,17 +191,24 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
Outputs: Outputs:
Tensor, element type and shape is same as data. Tensor, element type and shape is same as data.
""" """
index_types = multi_utils.hyper_map(F.typeof, tuple_index) indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM) index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
result = None
if index_elements_type == multi_utils.NO_TENSOR: if index_elements_type == const_utils.NO_TENSOR:
result = _tensor_assgin_number(data, tuple_index, value) return _tensor_assgin_number(data, tuple_index, value)
if index_elements_type == multi_utils.ALL_TENSOR: if index_elements_type == const_utils.ALL_TENSOR:
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
updates = multi_utils.generate_updates_from_scalar(data, indices, value, tuple_index,
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) const_utils.TENSOR_SETITEM)
result = F.scatter_nd_update(data, indices, updates) else:
return result indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_scalar(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)
@setitem.register("Tensor", "Tuple", "Tensor") @setitem.register("Tensor", "Tuple", "Tensor")
@ -221,17 +229,24 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
Outputs: Outputs:
Tensor, element type and shape is same as data. Tensor, element type and shape is same as data.
""" """
index_types = multi_utils.hyper_map(F.typeof, tuple_index) indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM) index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
result = None
if index_elements_type == multi_utils.NO_TENSOR: if index_elements_type == const_utils.NO_TENSOR:
result = _tensor_assgin_tensor(data, tuple_index, value) return _tensor_assgin_tensor(data, tuple_index, value)
if index_elements_type == multi_utils.ALL_TENSOR: if index_elements_type == const_utils.ALL_TENSOR:
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
updates = multi_utils.generate_updates_from_tensor(data, indices, value, tuple_index,
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) const_utils.TENSOR_SETITEM)
result = F.scatter_nd_update(data, indices, updates) else:
return result indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_tensor(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)
@setitem.register("Tensor", "Tuple", "Tuple") @setitem.register("Tensor", "Tuple", "Tuple")
@ -253,15 +268,22 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
Outputs: Outputs:
Tensor, element type and shape is same as data. Tensor, element type and shape is same as data.
""" """
index_types = multi_utils.hyper_map(F.typeof, tuple_index) indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM) index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
result = None
if index_elements_type == multi_utils.ALL_TENSOR: if index_elements_type == const_utils.ALL_TENSOR:
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
updates = multi_utils.generate_updates_from_tuple(data, indices, value, tuple_index,
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) const_utils.TENSOR_SETITEM)
result = F.scatter_nd_update(data, indices, updates) else:
return result indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_tuple(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)
@setitem.register("Tensor", "Tensor", "Tuple") @setitem.register("Tensor", "Tensor", "Tuple")
@ -278,7 +300,7 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
Tensor, element type and shape is same as data. Tensor, element type and shape is same as data.
""" """
index_dtype = F.dtype(index) index_dtype = F.dtype(index)
check_dtype = multi_utils.check_tensor_dtype_valid(index_dtype, (mstype.int32, mstype.int64)) check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM)
result = None result = None
if check_dtype: if check_dtype:
result = _tensor_setitem_by_tensor_with_tuple(data, index, value) result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
@ -331,14 +353,14 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value):
def _tensor_assgin_number(data, input_slice, value): def _tensor_assgin_number(data, input_slice, value):
"""Givens a scalar assign to tensor by slice""" """Givens a scalar assign to tensor by slice"""
check_result = multi_utils.check_tensor_setitem_index(input_slice) check_result = const_utils.check_tensor_setitem_index(input_slice)
result = None result = None
if check_result: if check_result:
data_shape = F.shape(data) data_shape = F.shape(data)
indices = multi_utils.slice2indices(input_slice, data_shape) indices = const_utils.slice2indices(input_slice, data_shape)
is_tuple_int = multi_utils.tuple_element_is_int(input_slice) is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int: if is_tuple_int:
indices = multi_utils.integer_to_indices(input_slice, data_shape) indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_number(data, data_shape, input_slice, indices, value) result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
return result return result
@ -347,7 +369,7 @@ def _tensor_assgin_number(data, input_slice, value):
def _tensor_setitem_with_int_v1(data, index, value): def _tensor_setitem_with_int_v1(data, index, value):
"""Syntax: A[1] = 3""" """Syntax: A[1] = 3"""
data_shape = F.shape(data) data_shape = F.shape(data)
indices = multi_utils.integer_to_indices(index, data_shape) indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_number(data, data_shape, index, indices, value) return _tensor_indices_number(data, data_shape, index, indices, value)
@ -355,7 +377,7 @@ def _tensor_setitem_with_int_v1(data, index, value):
def _tensor_setitem_with_int_v2(data, index, value): def _tensor_setitem_with_int_v2(data, index, value):
"""Syntax: A[1] = Tensor""" """Syntax: A[1] = Tensor"""
data_shape = F.shape(data) data_shape = F.shape(data)
indices = multi_utils.integer_to_indices(index, data_shape) indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value) return _tensor_indices_tensor(data, data_shape, index, indices, value)
@ -376,7 +398,7 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
data_size = F.size(data) data_size = F.size(data)
value_shape = F.shape(value) value_shape = F.shape(value)
value_size = F.size(value) value_size = F.size(value)
check_result = multi_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
if check_result: if check_result:
if data_size == value_size: if data_size == value_size:
result = F.reshape(value, data_shape) result = F.reshape(value, data_shape)
@ -391,13 +413,13 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
def _tensor_assgin_tensor(data, input_slice, value): def _tensor_assgin_tensor(data, input_slice, value):
"""Assigns a tensor value to the tensor by slice.""" """Assigns a tensor value to the tensor by slice."""
result = None result = None
check_result = multi_utils.check_tensor_setitem_index(input_slice) check_result = const_utils.check_tensor_setitem_index(input_slice)
if check_result: if check_result:
data_shape = F.shape(data) data_shape = F.shape(data)
indices = multi_utils.slice2indices(input_slice, data_shape) indices = const_utils.slice2indices(input_slice, data_shape)
is_tuple_int = multi_utils.tuple_element_is_int(input_slice) is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int: if is_tuple_int:
indices = multi_utils.integer_to_indices(input_slice, data_shape) indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
return result return result
@ -407,7 +429,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value):
data_size = F.size(data) data_size = F.size(data)
data_dtype = F.dtype(data) data_dtype = F.dtype(data)
indices_size = F.size(indices) indices_size = F.size(indices)
indices_size = multi_utils.check_indices(indices_size, index) indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1) update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,)) condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape) condition = F.reshape(condition_1d, data_shape)
@ -415,7 +437,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value):
value_fill = None value_fill = None
value_size = F.size(value) value_size = F.size(value)
value_size = multi_utils.check_indices_value_size(indices_size, value_size) value_size = const_utils.check_indices_value_size(indices_size, value_size)
if value_size == 1: if value_size == 1:
value_fill = F.fill(data_dtype, (indices_size,), 1) value_fill = F.fill(data_dtype, (indices_size,), 1)
value = F.cast(value, data_dtype) value = F.cast(value, data_dtype)
@ -432,7 +454,7 @@ def _tensor_indices_number(data, data_shape, index, indices, value):
data_size = F.size(data) data_size = F.size(data)
data_dtype = F.dtype(data) data_dtype = F.dtype(data)
indices_size = F.size(indices) indices_size = F.size(indices)
indices_size = multi_utils.check_indices(indices_size, index) indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1) update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,)) condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape) condition = F.reshape(condition_1d, data_shape)
@ -445,16 +467,16 @@ def _tensor_indices_number(data, data_shape, index, indices, value):
def _tensor_setitem_by_tensor_with_tuple(data, index, value): def _tensor_setitem_by_tensor_with_tuple(data, index, value):
"""Set a tensor item by a tensor with a tuple.""" """Set a tensor item by a tensor with a tuple."""
updates = multi_utils.generate_updates_from_tuple(data, index, value, updates = compile_utils.generate_updates_from_tuple(data, index, value,
multi_utils.SET_ITEM_BY_ONE_TENSOR) const_utils.SET_ITEM_BY_ONE_TENSOR)
result = F.scatter_update(data, index, updates) result = F.scatter_update(data, index, updates)
return result return result
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
"""Set a tensor item by a int tensor with a scalar.""" """Set a tensor item by a int tensor with a scalar."""
updates = multi_utils.generate_updates_from_scalar(data, index, value, updates = compile_utils.generate_updates_from_scalar(data, index, value,
multi_utils.SET_ITEM_BY_ONE_TENSOR) const_utils.SET_ITEM_BY_ONE_TENSOR)
return F.scatter_update(data, index, updates) return F.scatter_update(data, index, updates)
@ -462,7 +484,7 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
"""Set a tensor item by a bool tensor with a scalar.""" """Set a tensor item by a bool tensor with a scalar."""
index_shape = F.shape(index) index_shape = F.shape(index)
shape = F.shape(data) shape = F.shape(data)
shape = multi_utils.check_equal( shape = const_utils.check_equal(
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
dtype = F.dtype(data) dtype = F.dtype(data)
u = F.fill(dtype, shape, value) u = F.fill(dtype, shape, value)
@ -471,8 +493,8 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
"""Set a tensor item by a int tensor with a tensor.""" """Set a tensor item by a int tensor with a tensor."""
updates = multi_utils.generate_updates_from_tensor(data, index, value, updates = compile_utils.generate_updates_from_tensor(data, index, value,
multi_utils.SET_ITEM_BY_ONE_TENSOR) const_utils.SET_ITEM_BY_ONE_TENSOR)
return F.scatter_update(data, index, updates) return F.scatter_update(data, index, updates)
@ -480,10 +502,10 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
"""Set a tensor item by a bool tensor with a tensor.""" """Set a tensor item by a bool tensor with a tensor."""
index_shape = F.shape(index) index_shape = F.shape(index)
data_shape = F.shape(data) data_shape = F.shape(data)
data_shape = multi_utils.check_equal(data_shape, index_shape, data_shape = const_utils.check_equal(data_shape, index_shape,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.") "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
size = F.size(value) size = F.size(value)
size = multi_utils.check_equal(1, size, size = const_utils.check_equal(1, size,
"When assign value is a tensor, its size should be {}, but current size is {}.") "When assign value is a tensor, its size should be {}, but current size is {}.")
dtype = F.dtype(data) dtype = F.dtype(data)
u_cast = F.cast(value, dtype) u_cast = F.cast(value, dtype)

View File

@ -1425,7 +1425,6 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name):
validator.check_value_type("shape", x_shape, [tuple, list], prim_name) validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
validator.check_integer("len of input_x", len(x_shape), 1, Rel.GT, prim_name) validator.check_integer("len of input_x", len(x_shape), 1, Rel.GT, prim_name)
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name) validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name)
validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT, prim_name)
rank_base = len(x_shape[0]) rank_base = len(x_shape[0])
N = len(x_shape) N = len(x_shape)
out_shape = x_shape[0] out_shape = x_shape[0]

View File

@ -33,9 +33,4 @@ class IdentityEC(IExectorComponent):
keyword.desc_inputs: self.inputs[keyword.desc_inputs], keyword.desc_inputs: self.inputs[keyword.desc_inputs],
keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs]) keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs])
} }
print("buxue------------------------------------------------")
print("inputs")
print(ret[keyword.desc_inputs])
print("outputs")
print(ret[keyword.result])
return ret return ret

View File

@ -19,9 +19,9 @@ import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from ..ut_filter import non_graph_engine from tests.ut.python.ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test from tests.mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \ from tests.mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
@ -152,6 +152,20 @@ class ListOperate(nn.Cell):
return x return x
class InListNet(nn.Cell):
def __init__(self, ):
super(InListNet, self).__init__()
self.list_ = [1, 2, 3, 4, 5, "ok"]
def construct(self, x):
ret = x
if 2 in self.list_:
ret = x + x
if "ok" in self.list_:
ret = x - x
return ret
class AxisListNet(nn.Cell): class AxisListNet(nn.Cell):
def __init__(self): def __init__(self):
super(AxisListNet, self).__init__() super(AxisListNet, self).__init__()
@ -204,10 +218,15 @@ test_case_ops = [
('AxisListDefault', { ('AxisListDefault', {
'block': AxisListDefaultNet(), 'block': AxisListDefaultNet(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
('InList', {
'block': InListNet(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
] ]
test_case_lists = [test_case_ops] test_case_lists = [test_case_ops]
test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists) test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
# use -k to select certain testcast # use -k to select certain testcast
# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm # pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm

View File

@ -19,9 +19,9 @@ import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore import dtype as mstype from mindspore import dtype as mstype
from ..ut_filter import non_graph_engine from tests.ut.python.ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test from tests.mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \ from tests.mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
context.set_context(mode=context.GRAPH_MODE, save_graphs=True) context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
@ -52,6 +52,20 @@ class NestTupleGraphNet(nn.Cell):
return self.layers[0][1](x) return self.layers[0][1](x)
class InTupleNet(nn.Cell):
def __init__(self, ):
super(InTupleNet, self).__init__()
self.tuple_ = (1, 2, 3, 4, 5, "ok")
def construct(self, x):
ret = x
if 2 in self.tuple_:
ret = x + x
if "ok" in self.tuple_:
ret = x - x
return ret
test_case_ops = [ test_case_ops = [
('TupleGraph', { ('TupleGraph', {
'block': TupleGraphNet(), 'block': TupleGraphNet(),
@ -59,6 +73,9 @@ test_case_ops = [
('NestTupleGraph', { ('NestTupleGraph', {
'block': NestTupleGraphNet(), 'block': NestTupleGraphNet(),
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
('InTuple', {
'block': InTupleNet(),
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]})
] ]
test_case_lists = [test_case_ops] test_case_lists = [test_case_ops]

View File

@ -176,12 +176,134 @@ class TensorGetItemByThreeTensors(Cell):
return ret return ret
class TensorGetItemByMixedTensors(Cell): class TensorGetItemByMixedTensors_0(Cell):
def __init__(self): def __init__(self):
super(TensorGetItemByMixedTensors, self).__init__() super(TensorGetItemByMixedTensors_0, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32))
def construct(self, tensor, index_0, index_1):
ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const
return ret
class TensorGetItemByMixedTensors_1(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_1, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32))
def construct(self, tensor, index_0, index_1):
ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const
return ret
class TensorGetItemByMixedTensors_2(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_2, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32))
def construct(self, tensor, index_0, index_1):
ret = tensor[0, index_0, index_1, ..., 3] + self.const
return ret
class TensorGetItemByMixedTensors_3(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_3, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32))
def construct(self, tensor, index_0, index_1):
ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const
return ret
class TensorGetItemByMixedTensors_4(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_4, self).__init__()
self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32))
def construct(self, tensor, index_0, index_1, index_2):
ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const
return ret
class TensorGetItemByMixedTensors_5(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_5, self).__init__()
self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32))
def construct(self, tensor, index_0, index_1, index_2):
ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const
return ret
class TensorGetItemByMixedTensors_6(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_6, self).__init__()
self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))
def construct(self, tensor, index_0, index_1, index_2):
ret = tensor[..., index_0, index_1, index_2, 3] + self.const
return ret
class TensorSetItemByMixedTensors_0(Cell):
def __init__(self, value):
super(TensorSetItemByMixedTensors_0, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)),
mstype.float32),
name="x")
self.value = value
def construct(self, index_0, index_1, index_2):
self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value
ret = self.param + self.const
return ret
class TensorSetItemByMixedTensors_1(Cell):
def __init__(self, value):
super(TensorSetItemByMixedTensors_1, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
name="x")
self.value = value
def construct(self, index_0, index_1, index_2):
self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value
ret = self.param + self.const
return ret
class TensorSetItemByMixedTensors_2(Cell):
def __init__(self, value):
super(TensorSetItemByMixedTensors_2, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
name="x")
self.value = value
def construct(self, index_0, index_1, index_2):
self.param[..., index_0, index_1, index_2, 3] = self.value
ret = self.param + self.const
return ret
class TensorGetItemByMixedTensorsTypeError(Cell):
def __init__(self):
super(TensorGetItemByMixedTensorsTypeError, self).__init__()
def construct(self, x, index_0, index_1): def construct(self, x, index_0, index_1):
ret = x[index_0, index_1, 0:6] ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]]
return ret
class TensorGetItemByMixedTensorsNumberError(Cell):
def __init__(self):
super(TensorGetItemByMixedTensorsNumberError, self).__init__()
def construct(self, x, index_0, index_1):
ret = x[index_0, index_1, 0:3, ..., index_1, index_0]
return ret return ret
@ -538,11 +660,11 @@ def test_tensor_assign_bool_index():
net1(Ta, Tb, Tc, u_tensor) net1(Ta, Tb, Tc, u_tensor)
with pytest.raises(ValueError): with pytest.raises(ValueError):
net1(Ta, Td, Tc, u_tensor) net1(Ta, Td, Tc, u_tensor)
with pytest.raises(TypeError): with pytest.raises(IndexError):
net1(Ta, u_tensor, Tc, u_tensor) net1(Ta, u_tensor, Tc, u_tensor)
with pytest.raises(ValueError): with pytest.raises(ValueError):
net1(Ta, Tb, Td, u_tensor) net1(Ta, Tb, Td, u_tensor)
with pytest.raises(TypeError): with pytest.raises(IndexError):
net1(Ta, Tb, Ta, u_tensor) net1(Ta, Tb, Ta, u_tensor)
with pytest.raises(ValueError): with pytest.raises(ValueError):
net1(Ta, Tb, Tc, u_tensor_error) net1(Ta, Tb, Tc, u_tensor_error)
@ -636,6 +758,51 @@ test_cases = [
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}), }),
('TensorGetItemByMixedTensors_0', {
'block': TensorGetItemByMixedTensors_0(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_1', {
'block': TensorGetItemByMixedTensors_1(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_2', {
'block': TensorGetItemByMixedTensors_2(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_3', {
'block': TensorGetItemByMixedTensors_3(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_4', {
'block': TensorGetItemByMixedTensors_4(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_5', {
'block': TensorGetItemByMixedTensors_5(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_6', {
'block': TensorGetItemByMixedTensors_6(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithNumber', { ('TensorSetItemByOneTensorWithNumber', {
'block': TensorSetItemByOneTensorWithNumber(value=0.0), 'block': TensorSetItemByOneTensorWithNumber(value=0.0),
'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)], 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
@ -683,46 +850,143 @@ test_cases = [
Tensor(np.zeros((4, 5)), mstype.float32), Tensor(np.zeros((4, 5)), mstype.float32),
Tensor(np.ones((4, 5)), mstype.float32), Tensor(np.ones((4, 5)), mstype.float32),
Tensor(np.ones((4, 5)) * 2, mstype.float32)], Tensor(np.ones((4, 5)) * 2, mstype.float32)],
}) }),
('TensorSetItemByMixedTensorsWithNumber_0', {
'block': TensorSetItemByMixedTensors_0(value=88.0),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithTensor_0', {
'block': TensorSetItemByMixedTensors_0(value=Tensor(np.ones((4, 5, 3, 9), np.float32))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfNumber_0', {
'block': TensorSetItemByMixedTensors_0(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfTensor_0', {
'block': TensorSetItemByMixedTensors_0(value=(Tensor(np.ones((4, 5, 3, 9), np.float32)),
Tensor(np.zeros((4, 5, 3, 9), np.float32)),
Tensor(np.ones((4, 5, 3, 9), np.float32)))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithNumber_1', {
'block': TensorSetItemByMixedTensors_1(value=88.0),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithTensor_1', {
'block': TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfNumber_1', {
'block': TensorSetItemByMixedTensors_1(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfTensor_1', {
'block': TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.zeros((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.float32)))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithNumber_2', {
'block': TensorSetItemByMixedTensors_2(value=88.0),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithTensor_2', {
'block': TensorSetItemByMixedTensors_2(value=Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfNumber_2', {
'block': TensorSetItemByMixedTensors_2(value=(1.0, 2.0, 3.0, 4.0, 5.0)),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfTensor_2', {
'block': TensorSetItemByMixedTensors_2(value=(Tensor(np.ones((4, 5), np.float32)),
Tensor(np.zeros((4, 5), np.float32)),
Tensor(np.ones((4, 5), np.float32)))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
] ]
raise_error_set = [ raise_error_set = [
('TensorGetItemByOneTensorDtypeError', { ('TensorGetItemByOneTensorDtypeError', {
'block': (TensorGetItemByOneTensor(), {'exception': TypeError}), 'block': (TensorGetItemByOneTensor(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)], Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)],
}), }),
('TensorGetItemByTwoTensorsShapeError', { ('TensorGetItemByTwoTensorsShapeError', {
'block': (TensorGetItemByTwoTensors(), {'exception': ValueError}), 'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)], Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)],
}), }),
('TensorGetItemByTwoTensorsDtypeError', { ('TensorGetItemByTwoTensorsDtypeError', {
'block': (TensorGetItemByTwoTensors(), {'exception': TypeError}), 'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)], Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)],
}), }),
('TensorGetItemByThreeTensorsShapeError', { ('TensorGetItemByThreeTensorsShapeError', {
'block': (TensorGetItemByThreeTensors(), {'exception': ValueError}), 'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)], Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)],
}), }),
('TensorGetItemByThreeTensorsDtypeError', { ('TensorGetItemByThreeTensorsDtypeError', {
'block': (TensorGetItemByThreeTensors(), {'exception': TypeError}), 'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64), Tensor(np.random.randint(7, size=(4, 5)), mstype.int64),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}), }),
('TensorGetItemByMixedTensors', { ('TensorGetItemByMixedTensorsNumberError', {
'block': (TensorGetItemByMixedTensors(), {'exception': IndexError}), 'block': (TensorGetItemByMixedTensorsNumberError(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64)], Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsTypeError', {
'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsDtypeError', {
'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.float32)],
}),
('TensorGetItemByMixedTensorsShapeError', {
'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(2, 4, 5)), mstype.int32)],
}), }),
('TensorSetItemByOneTensorWithNumberTypeError', { ('TensorSetItemByOneTensorWithNumberTypeError', {
'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}), 'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}),
@ -782,7 +1046,7 @@ raise_error_set = [
Tensor(np.zeros((2, 5)), mstype.float32)], Tensor(np.zeros((2, 5)), mstype.float32)],
}), }),
('TensorSetItemByTensorsWithTupleOfNumberTypeError', { ('TensorSetItemByTensorsWithTupleOfNumberTypeError', {
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0, 1, 2, 3, 4)), {'exception': TypeError}), 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1, 2, 3, 4)), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
@ -810,10 +1074,65 @@ raise_error_set = [
Tensor(np.ones((4, 5)), mstype.int32), Tensor(np.ones((4, 5)), mstype.int32),
Tensor(np.ones((4, 5)) * 2, mstype.int32)], Tensor(np.ones((4, 5)) * 2, mstype.int32)],
}), }),
('TensorSetItemByMixedTensors', { ('TensorSetItemByMixedTensorsWithNumberValueTypeError', {
'block': (TensorSetItemByMixedTensors(), {'exception': IndexError}), 'block': (TensorSetItemByMixedTensors_1(value=88), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)], Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithNumberIndexTypeError', {
'block': (TensorSetItemByMixedTensors_1(value=88.0), {'exception': IndexError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.float32)],
}),
('TensorSetItemByMixedTensorsWithTensorValueDtypeError', {
'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.int32))),
{'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithTensorValueShapeError', {
'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((3, 2, 6), np.float32))),
{'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithTensorIndexDtypeError', {
'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))),
{'exception': IndexError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.float32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfNumberValueTypeError', {
'block': (TensorSetItemByMixedTensors_1(value=(1.0, 2, 3.0, 4.0, 5.0, 6.0)),
{'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfTensorValueDtypeError', {
'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.zeros((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.int32)))),
{'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfTensorIndexDtypeError', {
'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.zeros((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.int32)))),
{'exception': IndexError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.float32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}) })
] ]