From 7ae289a1973cf4e2c25ad63b13257b105ad2fe27 Mon Sep 17 00:00:00 2001 From: buxue Date: Sun, 24 May 2020 17:38:46 +0800 Subject: [PATCH] Support mixed tensor for tensor getitem or setitem. --- mindspore/_extends/parse/resources.py | 2 +- mindspore/ccsrc/ir/dtype_extends.cc | 2 + mindspore/common/dtype.py | 2 + .../ops/composite/multitype_ops/__init__.py | 4 +- .../composite/multitype_ops/_compile_utils.py | 154 +++++++ .../{_utils.py => _constexpr_utils.py} | 336 ++++++++++---- .../composite/multitype_ops/getitem_impl.py | 39 +- .../ops/composite/multitype_ops/in_impl.py | 101 +++++ .../composite/multitype_ops/setitem_impl.py | 144 +++--- mindspore/ops/operations/array_ops.py | 1 - .../components/executor/exec_forward.py | 5 - tests/ut/python/{ops => dtype}/test_list.py | 27 +- tests/ut/python/{ops => dtype}/test_tuple.py | 23 +- tests/ut/python/ops/test_tensor_slice.py | 411 ++++++++++++++++-- 14 files changed, 1020 insertions(+), 231 deletions(-) create mode 100644 mindspore/ops/composite/multitype_ops/_compile_utils.py rename mindspore/ops/composite/multitype_ops/{_utils.py => _constexpr_utils.py} (54%) create mode 100644 mindspore/ops/composite/multitype_ops/in_impl.py rename tests/ut/python/{ops => dtype}/test_list.py (90%) rename tests/ut/python/{ops => dtype}/test_tuple.py (79%) diff --git a/mindspore/_extends/parse/resources.py b/mindspore/_extends/parse/resources.py index 7178cd26349..60847c43384 100644 --- a/mindspore/_extends/parse/resources.py +++ b/mindspore/_extends/parse/resources.py @@ -105,7 +105,7 @@ convert_object_map = { T.ge: multitype_ops.greater_equal, T.is_: F.is_, T.is_not: F.is_not, - T.contains: F.in_dict, + T.contains: multitype_ops.in_, T.not_contains: F.not_in_dict, # system function diff --git a/mindspore/ccsrc/ir/dtype_extends.cc b/mindspore/ccsrc/ir/dtype_extends.cc index 20c3c401e19..4e5adeaf4bc 100644 --- a/mindspore/ccsrc/ir/dtype_extends.cc +++ b/mindspore/ccsrc/ir/dtype_extends.cc @@ -474,6 +474,8 @@ REGISTER_PYBIND_DEFINE( (void)py::class_>(m_sub, "RefKeyType").def(py::init()); (void)py::class_>(m_sub, "RefType").def(py::init()); (void)py::class_>(m_sub, "TypeAnything").def(py::init()); + (void)py::class_>(m_sub, "Slice").def(py::init()); + (void)py::class_>(m_sub, "Ellipsis").def(py::init()); })); const TypePtr kTypeExternal = std::make_shared(); diff --git a/mindspore/common/dtype.py b/mindspore/common/dtype.py index e6b9779f39d..ae2b111eb89 100644 --- a/mindspore/common/dtype.py +++ b/mindspore/common/dtype.py @@ -95,6 +95,8 @@ string = typing.String() type_refkey = typing.RefKeyType() tensor_type = typing.TensorType anything_type = typing.TypeAnything +slice_type = typing.Slice +ellipsis_type = typing.Ellipsis number_type = (int8, int16, diff --git a/mindspore/ops/composite/multitype_ops/__init__.py b/mindspore/ops/composite/multitype_ops/__init__.py index b7f4f671b8d..7bbebbbba3d 100644 --- a/mindspore/ops/composite/multitype_ops/__init__.py +++ b/mindspore/ops/composite/multitype_ops/__init__.py @@ -37,6 +37,7 @@ from .logical_and_impl import logical_and from .logical_or_impl import logical_or from .logic_not_impl import logical_not from .uadd_impl import uadd +from .in_impl import in_ __all__ = [ 'add', 'sub', @@ -59,5 +60,6 @@ __all__ = [ 'setitem', 'logical_and', 'logical_or', - 'logical_not' + 'logical_not', + 'in_' ] diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py new file mode 100644 index 00000000000..8954470b76c --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -0,0 +1,154 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""constexpr util""" +from . import _constexpr_utils as const_utils +from ... import functional as F +from ... import operations as P +from ...composite import base +from ....common import dtype as mstype + +hyper_map = base.HyperMap() +pack = P.Pack(axis=-1) + + +def broadcast(broadcast_shape, x): + """Broadcast tensor to the required shape.""" + if F.shape(x) == broadcast_shape: + return x + multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape) + if multiples: + return F.tile(x, multiples) + return x + + +def transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x): + """Transform indexing tensor to the required.""" + x = broadcast(broadcast_shape, x) + return broadcast(final_shape, F.reshape(x, new_shape)) + + +def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): + """Generate an indices tensor from a tuple of tensor.""" + indices = None + check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name) + if check_index_tensor_number: + dtype_tuple = hyper_map(F.dtype, tuple_index) + check_dtypes = const_utils.check_index_tensors_dtype(dtype_tuple, op_name) + if check_dtypes: + shape_tuple = hyper_map(F.shape, tuple_index) + broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name) + broadcast_tensors = hyper_map(F.partial(broadcast, broadcast_shape), tuple_index) + indices = pack(broadcast_tensors) + return indices + + +def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): + """Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" + indexes_types = hyper_map(F.typeof, tuple_index) + int_positions = const_utils.get_pos_of_int_index(indexes_types) + for i in int_positions: + tuple_index = F.tuple_setitem(tuple_index, i, F.scalar_to_tensor(tuple_index[i], mstype.int32)) + indexes_types = hyper_map(F.typeof, tuple_index) + tensor_positions, slice_positions, ellipsis_position = \ + const_utils.separate_mixed_tensors_index(indexes_types, op_name) + tensor_indexes = [] + slice_indexes = [] + for i in tensor_positions: + tensor_indexes.append(tuple_index[i]) + for j in slice_positions: + slice_indexes.append(tuple_index[j]) + data_shape = F.shape(data) + tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) + tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) + broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \ + const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape, + indexes_types, + tensor_indexes_shapes, + tensor_indexes_dtypes, + slice_indexes, + op_name) + + slice_number = 0 + final_index_tensors = [] + tuple_index_size = len(tuple_index) + index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) + for i in range(tuple_index_size): + if i in tensor_positions: + transform_tensor = transform_indexing_tensor(broadcast_shape, + final_shape, + index_tensor_new_shape, + tuple_index[i]) + final_index_tensors.append(transform_tensor) + if i in slice_positions: + slice_tensor = const_utils.convert_slice_to_tensor(slice_number, + final_shape, + indexes_shapes_info, + op_name) + final_index_tensors.append(slice_tensor) + slice_number += 1 + if i == ellipsis_position: + ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(slice_number, + ellipsis_occupied_dims, + final_shape, + indexes_shapes_info, + op_name) + for ele in ellipsis_tensors: + final_index_tensors.append(ele) + slice_number += ellipsis_occupied_dims + indices = pack(final_index_tensors) + return indices + + +def generate_updates_from_scalar(data, indices, value, op_type): + """Generate an updates tensor from a scalar.""" + data_shape = F.shape(data) + indices_shape = F.shape(indices) + data_dtype = F.dtype(data) + return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) + + +def generate_updates_from_tuple(data, index, value, op_type): + """Generate an updates tensor from a tuple.""" + value_types = hyper_map(F.typeof, value) + data_dtype = F.dtype(data) + value_elements_type = const_utils.check_value_elements(data_dtype, value_types) + if value_elements_type == const_utils.ALL_TENSOR: + value_shapes = hyper_map(F.shape, value) + shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM) + if shapes_same: + value = F.pack(value) + return generate_updates_from_tensor(data, index, value, op_type) + + data_shape = F.shape(data) + index_shape = F.shape(index) + return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type) + + +def generate_updates_from_tensor(data, index, value, op_type): + """Generate an updates tensor from a tensor.""" + data_shape = F.shape(data) + index_shape = F.shape(index) + value_shape = F.shape(value) + data_dtype = F.dtype(data) + value_dtype = F.dtype(value) + updates_shape = value_shape + check_dtype_same = const_utils.check_tensors_dtype_same(data_dtype, value_dtype, const_utils.TENSOR_SETITEM) + if check_dtype_same: + updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type) + need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape) + if need_broadcast: + return broadcast(updates_shape, value) + return value diff --git a/mindspore/ops/composite/multitype_ops/_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py similarity index 54% rename from mindspore/ops/composite/multitype_ops/_utils.py rename to mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 9717ddc15e6..7a33c9ceed0 100644 --- a/mindspore/ops/composite/multitype_ops/_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -15,19 +15,15 @@ """constexpr util""" from functools import reduce -import numpy as np -from ...primitive import constexpr -from ....common.tensor import Tensor -from ....common import dtype as mstype -from ...._extends.utils import Slice, Ellipsis_ -from ....ops import _utils as op_utils -from ...composite import base -from .... import log as logger -from ... import functional as F -from ... import operations as P -hyper_map = base.HyperMap() -pack = P.Pack(axis=-1) +import numpy as np + +from ...primitive import constexpr +from .... import log as logger +from ...._extends.utils import Slice, Ellipsis_ +from ....common import dtype as mstype +from ....common.tensor import Tensor +from ....ops import _utils as op_utils ALL_TENSOR = 0 NO_TENSOR = 1 @@ -264,7 +260,7 @@ def tuple_index_elements_type(types, op_name): return ALL_TENSOR if tensors_number == 0: return NO_TENSOR - raise IndexError(f"For '{op_name}', the index does not support mixed tensor.") + return CONTAIN_TENSOR @constexpr @@ -279,12 +275,12 @@ def check_value_elements(data_dtype, types): tensors_number += 1 else: raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' " - f"in value tuple is not consistent with origin tensor data type '{data_dtype}'.") + f"in value tuple is not consistent with assigned tensor data type '{data_dtype}'.") elif mstype.issubclass_(ele, data_dtype): scalars_number += 1 else: raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in " - f"value tuple is not consistent with origin tensor data type '{data_dtype}'.") + f"value tuple is not consistent with assigned tensor data type '{data_dtype}'.") if tensors_number == len(types): return ALL_TENSOR if scalars_number == len(types): @@ -299,51 +295,46 @@ def get_index_tensor_dtype(dtype): return INT_ if dtype == mstype.bool_: return BOOL_ - raise TypeError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") + raise IndexError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") @constexpr def check_index_tensors_dtype(dtypes, op_name): """Check a tuple of tensor data type.""" - if op_name == TENSOR_GETITEM: - valid_dtypes = (mstype.int32, mstype.int64) - elif op_name == TENSOR_SETITEM: - valid_dtypes = (mstype.int32,) - else: - raise ValueError("Unsupported operation.") for ele in dtypes: - if ele in valid_dtypes and ele == dtypes[0]: - continue - raise TypeError(f"For '{op_name}', the index tensors data type must be same, " - f"and should be one of the following: {valid_dtypes}, but got {dtypes}.") + if not ele == mstype.int32: + raise IndexError(f"For '{op_name}', the all index tensor " + f"data types should be mstype.int32, but got {dtypes}.") return True @constexpr -def check_tensor_dtype_valid(dtype, valid_dtypes): +def check_index_tensor_dtype(dtype, op_name): """Check a tensor data type.""" - if dtype in valid_dtypes: + if dtype == mstype.int32: return True - raise TypeError(f"The index tensor data type must be one of " - f"the following: {valid_dtypes}, but got {dtype}.") + raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.") @constexpr -def check_tensors_dtype_same(x_dtype, y_dtype, op_name): +def check_tensors_dtype_same(data_dtype, value_dtype, op_name): """Check tensors data type same.""" - if x_dtype == y_dtype: + if value_dtype == data_dtype: return True - raise TypeError(f"For '{op_name}', the value data type '{y_dtype}' " - f"is not consistent with origin tensor data type {x_dtype}.") + raise TypeError(f"For '{op_name}', the value data type '{value_dtype}' " + f"is not consistent with assigned tensor data type {data_dtype}.") @constexpr -def broadcast_shapes(shapes, op_name): - """Broadcasts a tuple of tensor.""" +def generate_broadcast_shape(shapes, op_name): + """Generate broadcast shape for a tuple of shape.""" broadcast_shape = shapes[0] for i, shape in enumerate(shapes): logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") - broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name) + try: + broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name) + except ValueError as ex: + raise IndexError(ex) return tuple(broadcast_shape) @@ -366,14 +357,82 @@ def check_two_shapes_need_broadcast(shape_x, shape_y): @constexpr def compute_multiples(origin_shape, broadcast_shape): - """Compute multiples between broadcast_shape with origin_shape.""" + """Compute multiples between origin shape with broadcast shape.""" len_gap = len(broadcast_shape) - len(origin_shape) return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape)) -def tile(broadcast_shape, x): - multiples = compute_multiples(F.shape(x), broadcast_shape) - return F.tile(x, multiples) +@constexpr +def compute_new_shape(origin_shape, indexes_shapes_info): + """Compute new shape between origin shape with final shape.""" + new_shape = [] + for i in indexes_shapes_info: + if i == origin_shape: + new_shape.extend(origin_shape) + else: + new_shape.append(1) + return tuple(new_shape) + + +@constexpr +def convert_ellipsis_to_tensors(slice_number, + ellipsis_occupied_dims, + final_shape, + indexes_shapes_info, + op_name): + """Convert an ellipsis to a list of tensor.""" + tensor_list = [] + dims_dealt_count = 0 + while dims_dealt_count < ellipsis_occupied_dims: + shape = [] + slice_count = 0 + array = None + for ele in indexes_shapes_info: + if isinstance(ele, list): + if slice_count == slice_number: + array = np.array(ele, np.int32) + shape.append(len(ele)) + else: + shape.append(1) + slice_count += 1 + if isinstance(ele, tuple): + shape.extend([1] * len(ele)) + if array is None: + raise ValueError(f"For '{op_name}', generate tensors from ellipsis failed.") + array = np.reshape(array, shape) + reps = compute_multiples(shape, final_shape) + tensor = Tensor(np.tile(array, reps)) + tensor_list.append(tensor) + slice_number += 1 + dims_dealt_count += 1 + return tensor_list + + +@constexpr +def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name): + """Convert a slice to a tensor.""" + shape = [] + count = 0 + array = None + for ele in indexes_shapes_info: + if isinstance(ele, list): + if count == slice_number: + array = np.array(ele, np.int32) + shape.append(len(ele)) + else: + # When the slice is not the slice looking for, the shape is filled with 1. + shape.append(1) + count += 1 + elif isinstance(ele, tuple): + shape.extend([1] * len(ele)) + else: + shape.append(1) + if array is None: + raise ValueError(f"For '{op_name}', generate tensor from 'slice' failed.") + array = np.reshape(array, shape) + reps = compute_multiples(shape, final_shape) + tensor = Tensor(np.tile(array, reps)) + return tensor @constexpr @@ -381,8 +440,8 @@ def check_shapes_same(value_shapes, op_name): """Check if the shapes in the tuple are consistent.""" for i, shape in enumerate(value_shapes): if shape != value_shapes[0]: - raise ValueError(f"For '{op_name}', the {i}th tensor shape in value tuple " - f"is not same as the first tensor shape.") + raise ValueError(f"For '{op_name}', the {i}th tensor shape in " + f"value tuple is not same as the first tensor shape.") return True @@ -396,7 +455,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty if isinstance(value, mstype.dtype_to_pytype(data_dtype)): return Tensor(np.full(updates_shape, value), dtype=data_dtype) raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'" - f" is not consistent with tensor data type {data_dtype}.") + f" is not consistent with the assigned tensor data type {data_dtype}.") @constexpr @@ -404,8 +463,8 @@ def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value """Convert a tuple of scalar to a tensor.""" updates_shape = generate_updates_shape(data_shape, index_shape, op_type) if len(value) != updates_shape[-1]: - raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} in the updates tuple " - f"does not meet the requirements: {updates_shape[-1]}.") + raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} " + f"in the updates tuple does not meet the requirements: {updates_shape[-1]}.") array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype)) reps = compute_multiples(updates_shape[-1:], updates_shape) return Tensor(np.tile(array, reps)) @@ -430,58 +489,145 @@ def check_number_of_index_tensor(data_shape, tuple_len, op_name): f"is greater than the dimension {len(data_shape)} of the operated tensor.") -def generate_indeices_from_tuple_of_tensor(data, tuple_index, op_name): - """Generate an indices tensor from a tuple of tensor.""" - indices = None - check_index_tensor_number = check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name) - if check_index_tensor_number: - dtype_tuple = hyper_map(F.dtype, tuple_index) - check_dtypes = check_index_tensors_dtype(dtype_tuple, op_name) - if check_dtypes: - shape_tuple = hyper_map(F.shape, tuple_index) - broadcast_shape = broadcast_shapes(shape_tuple, op_name) - broadcast_tensors = hyper_map(F.partial(tile, broadcast_shape), tuple_index) - indices = pack(broadcast_tensors) - return indices +@constexpr +def generate_index_info_from_tuple_of_mixed_tensors(data_shape, + indexes_types, + tensor_indexes_shapes, + tensor_indexes_dtypes, + slice_indexes, + op_name): + """ + Generate index info which contain broadcast shape, final shape, + indexes shapes info, ellipsis size from a tuple of mixed tensors. + """ + check_index_tensors_dtype(tensor_indexes_dtypes, op_name) + data_rank = len(data_shape) + indexes_size = len(indexes_types) + if indexes_size > data_rank: + raise IndexError(f"For '{op_name}', the number {indexes_size} of index elements " + f"is greater than the dimension {len(data_shape)} of the operated tensor.") + indexes_info = {} + index_tensors_info = {} + ellipsis_num = 0 + ellipsis_occupied_dims = 0 + tensor_count = 0 + slice_count = 0 + for i, ele_type in enumerate(indexes_types): + if ellipsis_num == 0: + pos = i + else: + pos = i + ellipsis_occupied_dims - 1 + if isinstance(ele_type, mstype.tensor_type): + indexes_info[pos] = tensor_indexes_shapes[tensor_count] + index_tensors_info[pos] = tensor_indexes_shapes[tensor_count] + tensor_count += 1 + elif isinstance(ele_type, mstype.slice_type): + slice_obj = slice(slice_indexes[slice_count].start, + slice_indexes[slice_count].end, + slice_indexes[slice_count].step) + # Use list to represent slicing result. + indexes_info[pos] = list(range(data_shape[pos]))[slice_obj] + slice_count += 1 + elif isinstance(ele_type, mstype.ellipsis_type): + if ellipsis_num != 0: + raise IndexError(f"For '{op_name}', the index could only contain one ellipsis.") + ellipsis_occupied_dims = data_rank - indexes_size + 1 + for j in range(pos, pos + ellipsis_occupied_dims): + # Use list to represent slicing result. + indexes_info[j] = list(range(data_shape[j])) + ellipsis_num += 1 + else: + raise IndexError(f"For '{op_name}', the index elements only support " + f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {ele_type}.") + broadcast_shape, final_shape, indexes_shapes_info = \ + _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name) + return broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims -def generate_updates_from_scalar(data, indices, value, op_type): - """Generate an updates tensor from a scalar.""" - data_shape = F.shape(data) - indices_shape = F.shape(indices) - data_dtype = F.dtype(data) - return convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) +def _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key: list): + """Determine whether the tensor in the index appears continuously.""" + for i in range(len(index_tensor_info_key) - 1): + if index_tensor_info_key[i + 1] != index_tensor_info_key[i] + 1: + return False + return True -def generate_updates_from_tuple(data, index, value, op_type): - """Generate an updates tensor from a tuple.""" - value_types = hyper_map(F.typeof, value) - data_dtype = F.dtype(data) - value_elements_type = check_value_elements(data_dtype, value_types) - if value_elements_type == ALL_TENSOR: - value_shapes = hyper_map(F.shape, value) - shapes_same = check_shapes_same(value_shapes, TENSOR_SETITEM) - if shapes_same: - value = F.pack(value) - return generate_updates_from_tensor(data, index, value, op_type) - - data_shape = F.shape(data) - index_shape = F.shape(index) - return convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type) +def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name): + """Derive the resulting shape information from the a tuple index of mixed tensors.""" + index_tensor_info_key = list(index_tensors_info.keys()) + index_tensor_info_value = list(index_tensors_info.values()) + broadcast_shape = generate_broadcast_shape(index_tensor_info_value, op_name) + final_shape = [] + indexes_shapes_info = [] + mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key) + if mixed_tensors_continuous: + tensor_shape_dealt = False + for ele in indexes_info.values(): + if isinstance(ele, list): + final_shape.append(len(ele)) + indexes_shapes_info.append(ele) + elif isinstance(ele, tuple): + if not tensor_shape_dealt: + final_shape.extend(broadcast_shape) + indexes_shapes_info.append(broadcast_shape) + tensor_shape_dealt = True + else: + raise IndexError(f"For '{op_name}', the index elements only support " + f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.") + else: + final_shape.extend(broadcast_shape) + indexes_shapes_info.append(broadcast_shape) + for ele in indexes_info.values(): + if isinstance(ele, list): + final_shape.append(len(ele)) + indexes_shapes_info.append(ele) + elif isinstance(ele, tuple): + continue + else: + raise IndexError(f"For '{op_name}', the index elements only support " + f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.") + return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info) -def generate_updates_from_tensor(data, index, value, op_type): - """Generate an updates tensor from a tensor.""" - data_shape = F.shape(data) - index_shape = F.shape(index) - value_shape = F.shape(value) - data_dtype = F.dtype(data) - value_dtype = F.dtype(value) - updates_shape = value_shape - check_dtype_same = check_tensors_dtype_same(data_dtype, value_dtype, TENSOR_SETITEM) - if check_dtype_same: - updates_shape = generate_updates_shape(data_shape, index_shape, op_type) - need_broadcast = check_two_shapes_need_broadcast(updates_shape, value_shape) - if need_broadcast: - return tile(updates_shape, value) - return value +@constexpr +def get_pos_of_int_index(indexes_types): + """Get int index positions from the mixed tensors index which contains int, tensor, slice, and ellipsis.""" + int_positions = [] + for i, ele_type in enumerate(indexes_types): + if ele_type == mstype.int32: + int_positions.append(i) + return int_positions + + +@constexpr +def separate_mixed_tensors_index(indexes_types, op_name): + """Separate the position information of tensor and slice and ellipsis from the mixed tensors index.""" + tensor_positions = [] + slice_positions = [] + ellipsis_position = None + for i, ele_type in enumerate(indexes_types): + if isinstance(ele_type, mstype.tensor_type): + tensor_positions.append(i) + elif isinstance(ele_type, mstype.slice_type): + slice_positions.append(i) + elif isinstance(ele_type, mstype.ellipsis_type): + ellipsis_position = i + else: + raise IndexError(f"For '{op_name}', the index elements only support " + f"'Tensor', 'int32', 'Slice', 'Ellipsis', but got {ele_type}.") + + return tensor_positions, slice_positions, ellipsis_position + + +@constexpr +def scalar_in_sequence(x, y): + """Determine whether the scalar in the sequence.""" + if x is None: + raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, " + "but the scalar is not.") + if y is None: + raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, " + "but the sequence is not.") + if x in y: + return True + return False diff --git a/mindspore/ops/composite/multitype_ops/getitem_impl.py b/mindspore/ops/composite/multitype_ops/getitem_impl.py index affbb192906..1295aba87e2 100644 --- a/mindspore/ops/composite/multitype_ops/getitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/getitem_impl.py @@ -14,11 +14,11 @@ # ============================================================================ """Implementation for getitem.""" - -from . import _utils as multi_utils -from ..import base +from . import _compile_utils as compile_utils +from . import _constexpr_utils as const_utils +from .. import base from ... import functional as F -from ....common import dtype as mstype + getitem = base.MultitypeFuncGraph('getitem') """ @@ -227,7 +227,8 @@ def _tensor_getitem_by_tensor(data, tensor_index): Outputs: Tensor, element type is same as the element type of data. """ - check_dtypes = multi_utils.check_tensor_dtype_valid(F.dtype(tensor_index), (mstype.int32, mstype.int64)) + check_dtypes = const_utils.check_index_tensor_dtype(F.dtype(tensor_index), + const_utils.TENSOR_GETITEM) result = None if check_dtypes: result = F.gather(data, tensor_index, 0) @@ -246,14 +247,13 @@ def _tensor_getitem_by_tuple(data, tuple_index): Outputs: Tensor, element type is same as the element type of data. """ - index_types = multi_utils.hyper_map(F.typeof, tuple_index) - index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_GETITEM) - result = None - if index_elements_type == multi_utils.NO_TENSOR: - result = _tensor_slice(data, tuple_index) - if index_elements_type == multi_utils.ALL_TENSOR: - result = _tensor_getitem_by_tuple_of_tensor(data, tuple_index) - return result + indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) + index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) + if index_elements_type == const_utils.NO_TENSOR: + return _tensor_slice(data, tuple_index) + if index_elements_type == const_utils.ALL_TENSOR: + return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) + return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index) @getitem.register("Tensor", "Ellipsis") @@ -273,6 +273,17 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index): def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): """Tensor getitem by a tuple of tensor.""" - indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_GETITEM) + indices = compile_utils.generate_indices_from_tuple_of_tensor(data, + tuple_index, + const_utils.TENSOR_GETITEM) + result = F.gather_nd(data, indices) + return result + + +def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): + """Tensor getitem by a tuple of mixed tensor.""" + indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, + tuple_index, + const_utils.TENSOR_GETITEM) result = F.gather_nd(data, indices) return result diff --git a/mindspore/ops/composite/multitype_ops/in_impl.py b/mindspore/ops/composite/multitype_ops/in_impl.py new file mode 100644 index 00000000000..26f7bce4374 --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/in_impl.py @@ -0,0 +1,101 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""in_impl""" + +from . import _constexpr_utils as const_utils +from ... import functional as F +from ...composite import base + +in_ = base.MultitypeFuncGraph("in") +""" +in_ is a metafuncgraph object which will determine if a in b +using ".register" decorator +""" + + +@in_.register("Number", "Tuple") +def _number_in_tuple(x, y): + """ + Determine if a number in tuple. + + Args: + x (Number): x + y (tuple): y + + Returns: + bool, if x in y return true, x not in y return false. + """ + return const_utils.scalar_in_sequence(x, y) + + +@in_.register("Number", "List") +def _number_in_list(x, y): + """ + Determine if a number in list. + + Args: + x (Number): x + y (list): y + + Returns: + bool, if x in y return true, x not in y return false. + """ + return const_utils.scalar_in_sequence(x, y) + + +@in_.register("String", "Tuple") +def _string_in_tuple(x, y): + """ + Determine if a str in a tuple. + + Args: + x (str): x + y (tuple): y + + Returns: + bool, if x in y return true, x not in y return false. + """ + return const_utils.scalar_in_sequence(x, y) + + +@in_.register("String", "List") +def _string_in_list(x, y): + """ + Determine if a str in a list. + + Args: + x (str): x + y (list): y + + Returns: + bool, if x in y return true, x not in y return false. + """ + return const_utils.scalar_in_sequence(x, y) + + +@in_.register("String", "Dictionary") +def _str_in_dict(x, y): + """ + Determine if a str in dict. + + Args: + x: str + y: dict + + Returns: + bool, if x in y return true, x not in y return false. + """ + return F.in_dict(x, y) diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py index f51dc12c273..53659c62055 100644 --- a/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -15,10 +15,11 @@ """Implementation for setitem.""" +from . import _compile_utils as compile_utils +from . import _constexpr_utils as const_utils +from ... import functional as F from ...composite import base from ....common import dtype as mstype -from ... import functional as F -from . import _utils as multi_utils setitem = base.MultitypeFuncGraph('setitem') @@ -139,8 +140,8 @@ def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor): Tensor, element type and shape is same as data. """ index_dtype = F.dtype(index) - tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype) - if tensor_dtype == multi_utils.INT_: + tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype) + if tensor_dtype == const_utils.INT_: return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor) return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor) @@ -166,8 +167,8 @@ def _tensor_setitem_by_tensor_with_number(data, index, value): Tensor, element type and shape is same as data. """ index_dtype = F.dtype(index) - tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype) - if tensor_dtype == multi_utils.BOOL_: + tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype) + if tensor_dtype == const_utils.BOOL_: return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value) return _tensor_setitem_by_int_tensor_with_scalar(data, index, value) @@ -190,17 +191,24 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value): Outputs: Tensor, element type and shape is same as data. """ - index_types = multi_utils.hyper_map(F.typeof, tuple_index) - index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM) - result = None - if index_elements_type == multi_utils.NO_TENSOR: - result = _tensor_assgin_number(data, tuple_index, value) - if index_elements_type == multi_utils.ALL_TENSOR: - indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) - updates = multi_utils.generate_updates_from_scalar(data, indices, value, - multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) - result = F.scatter_nd_update(data, indices, updates) - return result + indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) + index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) + + if index_elements_type == const_utils.NO_TENSOR: + return _tensor_assgin_number(data, tuple_index, value) + if index_elements_type == const_utils.ALL_TENSOR: + indices = compile_utils.generate_indices_from_tuple_of_tensor(data, + tuple_index, + const_utils.TENSOR_SETITEM) + else: + indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, + tuple_index, + const_utils.TENSOR_SETITEM) + updates = compile_utils.generate_updates_from_scalar(data, + indices, + value, + const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) + return F.scatter_nd_update(data, indices, updates) @setitem.register("Tensor", "Tuple", "Tensor") @@ -221,17 +229,24 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): Outputs: Tensor, element type and shape is same as data. """ - index_types = multi_utils.hyper_map(F.typeof, tuple_index) - index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM) - result = None - if index_elements_type == multi_utils.NO_TENSOR: - result = _tensor_assgin_tensor(data, tuple_index, value) - if index_elements_type == multi_utils.ALL_TENSOR: - indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) - updates = multi_utils.generate_updates_from_tensor(data, indices, value, - multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) - result = F.scatter_nd_update(data, indices, updates) - return result + indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) + index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) + + if index_elements_type == const_utils.NO_TENSOR: + return _tensor_assgin_tensor(data, tuple_index, value) + if index_elements_type == const_utils.ALL_TENSOR: + indices = compile_utils.generate_indices_from_tuple_of_tensor(data, + tuple_index, + const_utils.TENSOR_SETITEM) + else: + indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, + tuple_index, + const_utils.TENSOR_SETITEM) + updates = compile_utils.generate_updates_from_tensor(data, + indices, + value, + const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) + return F.scatter_nd_update(data, indices, updates) @setitem.register("Tensor", "Tuple", "Tuple") @@ -253,15 +268,22 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): Outputs: Tensor, element type and shape is same as data. """ - index_types = multi_utils.hyper_map(F.typeof, tuple_index) - index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM) - result = None - if index_elements_type == multi_utils.ALL_TENSOR: - indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) - updates = multi_utils.generate_updates_from_tuple(data, indices, value, - multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) - result = F.scatter_nd_update(data, indices, updates) - return result + indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) + index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) + + if index_elements_type == const_utils.ALL_TENSOR: + indices = compile_utils.generate_indices_from_tuple_of_tensor(data, + tuple_index, + const_utils.TENSOR_SETITEM) + else: + indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, + tuple_index, + const_utils.TENSOR_SETITEM) + updates = compile_utils.generate_updates_from_tuple(data, + indices, + value, + const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) + return F.scatter_nd_update(data, indices, updates) @setitem.register("Tensor", "Tensor", "Tuple") @@ -278,7 +300,7 @@ def _tensor_setitem_by_tensor_v2(data, index, value): Tensor, element type and shape is same as data. """ index_dtype = F.dtype(index) - check_dtype = multi_utils.check_tensor_dtype_valid(index_dtype, (mstype.int32, mstype.int64)) + check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM) result = None if check_dtype: result = _tensor_setitem_by_tensor_with_tuple(data, index, value) @@ -331,14 +353,14 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value): def _tensor_assgin_number(data, input_slice, value): """Givens a scalar assign to tensor by slice""" - check_result = multi_utils.check_tensor_setitem_index(input_slice) + check_result = const_utils.check_tensor_setitem_index(input_slice) result = None if check_result: data_shape = F.shape(data) - indices = multi_utils.slice2indices(input_slice, data_shape) - is_tuple_int = multi_utils.tuple_element_is_int(input_slice) + indices = const_utils.slice2indices(input_slice, data_shape) + is_tuple_int = const_utils.tuple_element_is_int(input_slice) if is_tuple_int: - indices = multi_utils.integer_to_indices(input_slice, data_shape) + indices = const_utils.integer_to_indices(input_slice, data_shape) result = _tensor_indices_number(data, data_shape, input_slice, indices, value) return result @@ -347,7 +369,7 @@ def _tensor_assgin_number(data, input_slice, value): def _tensor_setitem_with_int_v1(data, index, value): """Syntax: A[1] = 3""" data_shape = F.shape(data) - indices = multi_utils.integer_to_indices(index, data_shape) + indices = const_utils.integer_to_indices(index, data_shape) return _tensor_indices_number(data, data_shape, index, indices, value) @@ -355,7 +377,7 @@ def _tensor_setitem_with_int_v1(data, index, value): def _tensor_setitem_with_int_v2(data, index, value): """Syntax: A[1] = Tensor""" data_shape = F.shape(data) - indices = multi_utils.integer_to_indices(index, data_shape) + indices = const_utils.integer_to_indices(index, data_shape) return _tensor_indices_tensor(data, data_shape, index, indices, value) @@ -376,7 +398,7 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value): data_size = F.size(data) value_shape = F.shape(value) value_size = F.size(value) - check_result = multi_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) + check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) if check_result: if data_size == value_size: result = F.reshape(value, data_shape) @@ -391,13 +413,13 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value): def _tensor_assgin_tensor(data, input_slice, value): """Assigns a tensor value to the tensor by slice.""" result = None - check_result = multi_utils.check_tensor_setitem_index(input_slice) + check_result = const_utils.check_tensor_setitem_index(input_slice) if check_result: data_shape = F.shape(data) - indices = multi_utils.slice2indices(input_slice, data_shape) - is_tuple_int = multi_utils.tuple_element_is_int(input_slice) + indices = const_utils.slice2indices(input_slice, data_shape) + is_tuple_int = const_utils.tuple_element_is_int(input_slice) if is_tuple_int: - indices = multi_utils.integer_to_indices(input_slice, data_shape) + indices = const_utils.integer_to_indices(input_slice, data_shape) result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) return result @@ -407,7 +429,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value): data_size = F.size(data) data_dtype = F.dtype(data) indices_size = F.size(indices) - indices_size = multi_utils.check_indices(indices_size, index) + indices_size = const_utils.check_indices(indices_size, index) update = F.fill(mstype.int32, (indices_size,), 1) condition_1d = F.scatter_nd(indices, update, (data_size,)) condition = F.reshape(condition_1d, data_shape) @@ -415,7 +437,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value): value_fill = None value_size = F.size(value) - value_size = multi_utils.check_indices_value_size(indices_size, value_size) + value_size = const_utils.check_indices_value_size(indices_size, value_size) if value_size == 1: value_fill = F.fill(data_dtype, (indices_size,), 1) value = F.cast(value, data_dtype) @@ -432,7 +454,7 @@ def _tensor_indices_number(data, data_shape, index, indices, value): data_size = F.size(data) data_dtype = F.dtype(data) indices_size = F.size(indices) - indices_size = multi_utils.check_indices(indices_size, index) + indices_size = const_utils.check_indices(indices_size, index) update = F.fill(mstype.int32, (indices_size,), 1) condition_1d = F.scatter_nd(indices, update, (data_size,)) condition = F.reshape(condition_1d, data_shape) @@ -445,16 +467,16 @@ def _tensor_indices_number(data, data_shape, index, indices, value): def _tensor_setitem_by_tensor_with_tuple(data, index, value): """Set a tensor item by a tensor with a tuple.""" - updates = multi_utils.generate_updates_from_tuple(data, index, value, - multi_utils.SET_ITEM_BY_ONE_TENSOR) + updates = compile_utils.generate_updates_from_tuple(data, index, value, + const_utils.SET_ITEM_BY_ONE_TENSOR) result = F.scatter_update(data, index, updates) return result def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): """Set a tensor item by a int tensor with a scalar.""" - updates = multi_utils.generate_updates_from_scalar(data, index, value, - multi_utils.SET_ITEM_BY_ONE_TENSOR) + updates = compile_utils.generate_updates_from_scalar(data, index, value, + const_utils.SET_ITEM_BY_ONE_TENSOR) return F.scatter_update(data, index, updates) @@ -462,7 +484,7 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): """Set a tensor item by a bool tensor with a scalar.""" index_shape = F.shape(index) shape = F.shape(data) - shape = multi_utils.check_equal( + shape = const_utils.check_equal( shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") dtype = F.dtype(data) u = F.fill(dtype, shape, value) @@ -471,8 +493,8 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): """Set a tensor item by a int tensor with a tensor.""" - updates = multi_utils.generate_updates_from_tensor(data, index, value, - multi_utils.SET_ITEM_BY_ONE_TENSOR) + updates = compile_utils.generate_updates_from_tensor(data, index, value, + const_utils.SET_ITEM_BY_ONE_TENSOR) return F.scatter_update(data, index, updates) @@ -480,10 +502,10 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value): """Set a tensor item by a bool tensor with a tensor.""" index_shape = F.shape(index) data_shape = F.shape(data) - data_shape = multi_utils.check_equal(data_shape, index_shape, + data_shape = const_utils.check_equal(data_shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") size = F.size(value) - size = multi_utils.check_equal(1, size, + size = const_utils.check_equal(1, size, "When assign value is a tensor, its size should be {}, but current size is {}.") dtype = F.dtype(data) u_cast = F.cast(value, dtype) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 757335e2931..6745adfbd9c 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1425,7 +1425,6 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name): validator.check_value_type("shape", x_shape, [tuple, list], prim_name) validator.check_integer("len of input_x", len(x_shape), 1, Rel.GT, prim_name) validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name) - validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT, prim_name) rank_base = len(x_shape[0]) N = len(x_shape) out_shape = x_shape[0] diff --git a/tests/mindspore_test_framework/components/executor/exec_forward.py b/tests/mindspore_test_framework/components/executor/exec_forward.py index 3dcbc36e97c..3404bf8f7e4 100644 --- a/tests/mindspore_test_framework/components/executor/exec_forward.py +++ b/tests/mindspore_test_framework/components/executor/exec_forward.py @@ -33,9 +33,4 @@ class IdentityEC(IExectorComponent): keyword.desc_inputs: self.inputs[keyword.desc_inputs], keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs]) } - print("buxue------------------------------------------------") - print("inputs") - print(ret[keyword.desc_inputs]) - print("outputs") - print(ret[keyword.result]) return ret diff --git a/tests/ut/python/ops/test_list.py b/tests/ut/python/dtype/test_list.py similarity index 90% rename from tests/ut/python/ops/test_list.py rename to tests/ut/python/dtype/test_list.py index f5f919b998b..d5e316bed1b 100644 --- a/tests/ut/python/ops/test_list.py +++ b/tests/ut/python/dtype/test_list.py @@ -19,9 +19,9 @@ import mindspore.nn as nn import mindspore.context as context from mindspore import Tensor from mindspore.ops import operations as P -from ..ut_filter import non_graph_engine -from ....mindspore_test_framework.mindspore_test import mindspore_test -from ....mindspore_test_framework.pipeline.forward.compile_forward \ +from tests.ut.python.ut_filter import non_graph_engine +from tests.mindspore_test_framework.mindspore_test import mindspore_test +from tests.mindspore_test_framework.pipeline.forward.compile_forward \ import pipeline_for_compile_forward_ge_graph_for_case_by_case_config @@ -133,7 +133,7 @@ def test_list_append_2(): class ListOperate(nn.Cell): - def __init__(self,): + def __init__(self, ): super(ListOperate, self).__init__() def construct(self, t, l): @@ -152,6 +152,20 @@ class ListOperate(nn.Cell): return x +class InListNet(nn.Cell): + def __init__(self, ): + super(InListNet, self).__init__() + self.list_ = [1, 2, 3, 4, 5, "ok"] + + def construct(self, x): + ret = x + if 2 in self.list_: + ret = x + x + if "ok" in self.list_: + ret = x - x + return ret + + class AxisListNet(nn.Cell): def __init__(self): super(AxisListNet, self).__init__() @@ -204,10 +218,15 @@ test_case_ops = [ ('AxisListDefault', { 'block': AxisListDefaultNet(), 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), + ('InList', { + 'block': InListNet(), + 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), ] test_case_lists = [test_case_ops] test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists) + + # use -k to select certain testcast # pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm diff --git a/tests/ut/python/ops/test_tuple.py b/tests/ut/python/dtype/test_tuple.py similarity index 79% rename from tests/ut/python/ops/test_tuple.py rename to tests/ut/python/dtype/test_tuple.py index eafaaede916..4e20bef25da 100644 --- a/tests/ut/python/ops/test_tuple.py +++ b/tests/ut/python/dtype/test_tuple.py @@ -19,9 +19,9 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore import dtype as mstype -from ..ut_filter import non_graph_engine -from ....mindspore_test_framework.mindspore_test import mindspore_test -from ....mindspore_test_framework.pipeline.forward.compile_forward \ +from tests.ut.python.ut_filter import non_graph_engine +from tests.mindspore_test_framework.mindspore_test import mindspore_test +from tests.mindspore_test_framework.pipeline.forward.compile_forward \ import pipeline_for_compile_forward_ge_graph_for_case_by_case_config context.set_context(mode=context.GRAPH_MODE, save_graphs=True) @@ -52,6 +52,20 @@ class NestTupleGraphNet(nn.Cell): return self.layers[0][1](x) +class InTupleNet(nn.Cell): + def __init__(self, ): + super(InTupleNet, self).__init__() + self.tuple_ = (1, 2, 3, 4, 5, "ok") + + def construct(self, x): + ret = x + if 2 in self.tuple_: + ret = x + x + if "ok" in self.tuple_: + ret = x - x + return ret + + test_case_ops = [ ('TupleGraph', { 'block': TupleGraphNet(), @@ -59,6 +73,9 @@ test_case_ops = [ ('NestTupleGraph', { 'block': NestTupleGraphNet(), 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), + ('InTuple', { + 'block': InTupleNet(), + 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}) ] test_case_lists = [test_case_ops] diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index 776c43b784f..b6a261d2927 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -176,12 +176,134 @@ class TensorGetItemByThreeTensors(Cell): return ret -class TensorGetItemByMixedTensors(Cell): +class TensorGetItemByMixedTensors_0(Cell): def __init__(self): - super(TensorGetItemByMixedTensors, self).__init__() + super(TensorGetItemByMixedTensors_0, self).__init__() + self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32)) + + def construct(self, tensor, index_0, index_1): + ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const + return ret + + +class TensorGetItemByMixedTensors_1(Cell): + def __init__(self): + super(TensorGetItemByMixedTensors_1, self).__init__() + self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32)) + + def construct(self, tensor, index_0, index_1): + ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const + return ret + + +class TensorGetItemByMixedTensors_2(Cell): + def __init__(self): + super(TensorGetItemByMixedTensors_2, self).__init__() + self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32)) + + def construct(self, tensor, index_0, index_1): + ret = tensor[0, index_0, index_1, ..., 3] + self.const + return ret + + +class TensorGetItemByMixedTensors_3(Cell): + def __init__(self): + super(TensorGetItemByMixedTensors_3, self).__init__() + self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32)) + + def construct(self, tensor, index_0, index_1): + ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const + return ret + + +class TensorGetItemByMixedTensors_4(Cell): + def __init__(self): + super(TensorGetItemByMixedTensors_4, self).__init__() + self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32)) + + def construct(self, tensor, index_0, index_1, index_2): + ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const + return ret + + +class TensorGetItemByMixedTensors_5(Cell): + def __init__(self): + super(TensorGetItemByMixedTensors_5, self).__init__() + self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32)) + + def construct(self, tensor, index_0, index_1, index_2): + ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const + return ret + + +class TensorGetItemByMixedTensors_6(Cell): + def __init__(self): + super(TensorGetItemByMixedTensors_6, self).__init__() + self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32)) + + def construct(self, tensor, index_0, index_1, index_2): + ret = tensor[..., index_0, index_1, index_2, 3] + self.const + return ret + + +class TensorSetItemByMixedTensors_0(Cell): + def __init__(self, value): + super(TensorSetItemByMixedTensors_0, self).__init__() + self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32)) + self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), + mstype.float32), + name="x") + self.value = value + + def construct(self, index_0, index_1, index_2): + self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value + ret = self.param + self.const + return ret + + +class TensorSetItemByMixedTensors_1(Cell): + def __init__(self, value): + super(TensorSetItemByMixedTensors_1, self).__init__() + self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32)) + self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), + name="x") + self.value = value + + def construct(self, index_0, index_1, index_2): + self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value + ret = self.param + self.const + return ret + + +class TensorSetItemByMixedTensors_2(Cell): + def __init__(self, value): + super(TensorSetItemByMixedTensors_2, self).__init__() + self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32)) + self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), + name="x") + self.value = value + + def construct(self, index_0, index_1, index_2): + self.param[..., index_0, index_1, index_2, 3] = self.value + ret = self.param + self.const + return ret + + +class TensorGetItemByMixedTensorsTypeError(Cell): + def __init__(self): + super(TensorGetItemByMixedTensorsTypeError, self).__init__() def construct(self, x, index_0, index_1): - ret = x[index_0, index_1, 0:6] + ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]] + return ret + + +class TensorGetItemByMixedTensorsNumberError(Cell): + def __init__(self): + super(TensorGetItemByMixedTensorsNumberError, self).__init__() + + def construct(self, x, index_0, index_1): + ret = x[index_0, index_1, 0:3, ..., index_1, index_0] return ret @@ -189,7 +311,7 @@ class TensorSetItemByOneTensorWithNumber(Cell): def __init__(self, value): super(TensorSetItemByOneTensorWithNumber, self).__init__() self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) - self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") self.value = value def construct(self, index): @@ -202,7 +324,7 @@ class TensorSetItemByOneTensorWithTensor(Cell): def __init__(self): super(TensorSetItemByOneTensorWithTensor, self).__init__() self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) - self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") def construct(self, index, value): self.param[index] = value @@ -214,7 +336,7 @@ class TensorSetItemByOneTensorWithTupleOfNumber(Cell): def __init__(self, value): super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__() self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) - self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") self.value = value def construct(self, index): @@ -227,7 +349,7 @@ class TensorSetItemByOneTensorWithTupleOfTensor(Cell): def __init__(self): super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__() self.const = Tensor(np.ones((6, 3, 8)), mstype.float32) - self.param = Parameter(Tensor(np.arange(6*3*8).reshape((6, 3, 8)), mstype.float32), name="x") + self.param = Parameter(Tensor(np.arange(6 * 3 * 8).reshape((6, 3, 8)), mstype.float32), name="x") def construct(self, index, value_0, value_1, value_2): self.param[index] = (value_0, value_1, value_2) @@ -239,7 +361,7 @@ class TensorSetItemByTensorsWithNumber(Cell): def __init__(self, value): super(TensorSetItemByTensorsWithNumber, self).__init__() self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) - self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") self.value = value def construct(self, index_0, index_1, index_2): @@ -252,7 +374,7 @@ class TensorSetItemByTensorsWithTensor(Cell): def __init__(self): super(TensorSetItemByTensorsWithTensor, self).__init__() self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) - self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") def construct(self, index_0, index_1, index_2, value): self.param[index_0, index_1, index_2] = value @@ -264,7 +386,7 @@ class TensorSetItemByTensorsWithTensorNumberError(Cell): def __init__(self): super(TensorSetItemByTensorsWithTensorNumberError, self).__init__() self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) - self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") def construct(self, index_0, index_1, index_2, index_3, value): self.param[index_0, index_1, index_2, index_3] = value @@ -276,7 +398,7 @@ class TensorSetItemByTensorsWithTupleOfNumber(Cell): def __init__(self, value): super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__() self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) - self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") self.value = value def construct(self, index_0, index_1, index_2): @@ -289,7 +411,7 @@ class TensorSetItemByTensorsWithTupleOfTensor(Cell): def __init__(self): super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__() self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) - self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") def construct(self, index_0, index_1, index_2, value_0, value_1, value_2): self.param[index_0, index_1, index_2] = (value_0, value_1, value_2) @@ -301,7 +423,7 @@ class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell): def __init__(self): super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__() self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) - self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") def construct(self, index_0, index_1, index_2, value_0, value_1): self.param[index_0, index_1, index_2] = (value_0, value_1) @@ -313,7 +435,7 @@ class TensorSetItemByMixedTensors(Cell): def __init__(self): super(TensorSetItemByMixedTensors, self).__init__() self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) - self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") self.value = 99.0 def construct(self, index_0, index_1): @@ -538,11 +660,11 @@ def test_tensor_assign_bool_index(): net1(Ta, Tb, Tc, u_tensor) with pytest.raises(ValueError): net1(Ta, Td, Tc, u_tensor) - with pytest.raises(TypeError): + with pytest.raises(IndexError): net1(Ta, u_tensor, Tc, u_tensor) with pytest.raises(ValueError): net1(Ta, Tb, Td, u_tensor) - with pytest.raises(TypeError): + with pytest.raises(IndexError): net1(Ta, Tb, Ta, u_tensor) with pytest.raises(ValueError): net1(Ta, Tb, Tc, u_tensor_error) @@ -620,22 +742,67 @@ test_cases = [ }), ('TensorGetItemByOneTensor', { 'block': TensorGetItemByOneTensor(), - 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)], }), ('TensorGetItemByTwoTensors', { 'block': TensorGetItemByTwoTensors(), - 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)], }), ('TensorGetItemByThreeTensors', { 'block': TensorGetItemByThreeTensors(), - 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], }), + ('TensorGetItemByMixedTensors_0', { + 'block': TensorGetItemByMixedTensors_0(), + 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), + Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensors_1', { + 'block': TensorGetItemByMixedTensors_1(), + 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), + Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensors_2', { + 'block': TensorGetItemByMixedTensors_2(), + 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), + Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensors_3', { + 'block': TensorGetItemByMixedTensors_3(), + 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), + Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensors_4', { + 'block': TensorGetItemByMixedTensors_4(), + 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.float32), + Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensors_5', { + 'block': TensorGetItemByMixedTensors_5(), + 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), + Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensors_6', { + 'block': TensorGetItemByMixedTensors_6(), + 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), + Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), ('TensorSetItemByOneTensorWithNumber', { 'block': TensorSetItemByOneTensorWithNumber(value=0.0), 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)], @@ -683,46 +850,143 @@ test_cases = [ Tensor(np.zeros((4, 5)), mstype.float32), Tensor(np.ones((4, 5)), mstype.float32), Tensor(np.ones((4, 5)) * 2, mstype.float32)], - }) + }), + ('TensorSetItemByMixedTensorsWithNumber_0', { + 'block': TensorSetItemByMixedTensors_0(value=88.0), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorSetItemByMixedTensorsWithTensor_0', { + 'block': TensorSetItemByMixedTensors_0(value=Tensor(np.ones((4, 5, 3, 9), np.float32))), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensorsWithTupleOfNumber_0', { + 'block': TensorSetItemByMixedTensors_0(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensorsWithTupleOfTensor_0', { + 'block': TensorSetItemByMixedTensors_0(value=(Tensor(np.ones((4, 5, 3, 9), np.float32)), + Tensor(np.zeros((4, 5, 3, 9), np.float32)), + Tensor(np.ones((4, 5, 3, 9), np.float32)))), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorSetItemByMixedTensorsWithNumber_1', { + 'block': TensorSetItemByMixedTensors_1(value=88.0), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorSetItemByMixedTensorsWithTensor_1', { + 'block': TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensorsWithTupleOfNumber_1', { + 'block': TensorSetItemByMixedTensors_1(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensorsWithTupleOfTensor_1', { + 'block': TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)), + Tensor(np.zeros((5, 2, 6), np.float32)), + Tensor(np.ones((5, 2, 6), np.float32)), + Tensor(np.ones((5, 2, 6), np.float32)))), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorSetItemByMixedTensorsWithNumber_2', { + 'block': TensorSetItemByMixedTensors_2(value=88.0), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorSetItemByMixedTensorsWithTensor_2', { + 'block': TensorSetItemByMixedTensors_2(value=Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensorsWithTupleOfNumber_2', { + 'block': TensorSetItemByMixedTensors_2(value=(1.0, 2.0, 3.0, 4.0, 5.0)), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensorsWithTupleOfTensor_2', { + 'block': TensorSetItemByMixedTensors_2(value=(Tensor(np.ones((4, 5), np.float32)), + Tensor(np.zeros((4, 5), np.float32)), + Tensor(np.ones((4, 5), np.float32)))), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), ] raise_error_set = [ ('TensorGetItemByOneTensorDtypeError', { - 'block': (TensorGetItemByOneTensor(), {'exception': TypeError}), - 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + 'block': (TensorGetItemByOneTensor(), {'exception': IndexError}), + 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)], }), ('TensorGetItemByTwoTensorsShapeError', { - 'block': (TensorGetItemByTwoTensors(), {'exception': ValueError}), - 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + 'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}), + 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)], }), ('TensorGetItemByTwoTensorsDtypeError', { - 'block': (TensorGetItemByTwoTensors(), {'exception': TypeError}), - 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + 'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}), + 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)], }), ('TensorGetItemByThreeTensorsShapeError', { - 'block': (TensorGetItemByThreeTensors(), {'exception': ValueError}), - 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + 'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}), + 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)], }), ('TensorGetItemByThreeTensorsDtypeError', { - 'block': (TensorGetItemByThreeTensors(), {'exception': TypeError}), - 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + 'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}), + 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), - Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64), + Tensor(np.random.randint(7, size=(4, 5)), mstype.int64), Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], }), - ('TensorGetItemByMixedTensors', { - 'block': (TensorGetItemByMixedTensors(), {'exception': IndexError}), - 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + ('TensorGetItemByMixedTensorsNumberError', { + 'block': (TensorGetItemByMixedTensorsNumberError(), {'exception': IndexError}), + 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32), Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), - Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64)], + Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensorsTypeError', { + 'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': TypeError}), + 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32), + Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensorsDtypeError', { + 'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}), + 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32), + Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.float32)], + }), + ('TensorGetItemByMixedTensorsShapeError', { + 'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}), + 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32), + Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(2, 4, 5)), mstype.int32)], }), ('TensorSetItemByOneTensorWithNumberTypeError', { 'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}), @@ -760,21 +1024,21 @@ raise_error_set = [ Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], }), ('TensorSetItemByTensorsWithTensorShapeError', { - 'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}), + 'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}), 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), Tensor(np.zeros((2, 5)), mstype.float32)], }), ('TensorSetItemByTensorsWithTensorTypeError', { - 'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}), + 'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}), 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), Tensor(np.zeros((4, 5)), mstype.int32)], }), ('TensorSetItemByTensorsWithTensorNumberError', { - 'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}), + 'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}), 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), @@ -782,19 +1046,19 @@ raise_error_set = [ Tensor(np.zeros((2, 5)), mstype.float32)], }), ('TensorSetItemByTensorsWithTupleOfNumberTypeError', { - 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0, 1, 2, 3, 4)), {'exception': TypeError}), + 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1, 2, 3, 4)), {'exception': TypeError}), 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], }), ('TensorSetItemByTensorsWithTupleOfNumberNumberError', { - 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}), + 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}), 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], }), ('TensorSetItemByTensorsWithTupleOfTensorNumberError', { - 'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}), + 'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}), 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), @@ -802,7 +1066,7 @@ raise_error_set = [ Tensor(np.ones((4, 5)), mstype.float32)], }), ('TensorSetItemByTensorsWithTupleOfTensorTypeError', { - 'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}), + 'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}), 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), @@ -810,10 +1074,65 @@ raise_error_set = [ Tensor(np.ones((4, 5)), mstype.int32), Tensor(np.ones((4, 5)) * 2, mstype.int32)], }), - ('TensorSetItemByMixedTensors', { - 'block': (TensorSetItemByMixedTensors(), {'exception': IndexError}), - 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), - Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)], + ('TensorSetItemByMixedTensorsWithNumberValueTypeError', { + 'block': (TensorSetItemByMixedTensors_1(value=88), {'exception': TypeError}), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorSetItemByMixedTensorsWithNumberIndexTypeError', { + 'block': (TensorSetItemByMixedTensors_1(value=88.0), {'exception': IndexError}), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.float32)], + }), + ('TensorSetItemByMixedTensorsWithTensorValueDtypeError', { + 'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.int32))), + {'exception': TypeError}), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorSetItemByMixedTensorsWithTensorValueShapeError', { + 'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((3, 2, 6), np.float32))), + {'exception': ValueError}), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorSetItemByMixedTensorsWithTensorIndexDtypeError', { + 'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))), + {'exception': IndexError}), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.float32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensorsWithTupleOfNumberValueTypeError', { + 'block': (TensorSetItemByMixedTensors_1(value=(1.0, 2, 3.0, 4.0, 5.0, 6.0)), + {'exception': TypeError}), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensorsWithTupleOfTensorValueDtypeError', { + 'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)), + Tensor(np.zeros((5, 2, 6), np.float32)), + Tensor(np.ones((5, 2, 6), np.float32)), + Tensor(np.ones((5, 2, 6), np.int32)))), + {'exception': TypeError}), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], + }), + ('TensorGetItemByMixedTensorsWithTupleOfTensorIndexDtypeError', { + 'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)), + Tensor(np.zeros((5, 2, 6), np.float32)), + Tensor(np.ones((5, 2, 6), np.float32)), + Tensor(np.ones((5, 2, 6), np.int32)))), + {'exception': IndexError}), + 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.float32), + Tensor(np.random.randint(4, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)], }) ]