From ea74bb6407a45040fb372c01e11d8416514d1aa7 Mon Sep 17 00:00:00 2001 From: Payne Date: Tue, 12 Jan 2021 20:08:29 +0800 Subject: [PATCH] fix the getitem bug --- .../multitype_ops/_constexpr_utils.py | 104 ++++++++++-------- 1 file changed, 57 insertions(+), 47 deletions(-) diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index cc918e2aff3..f6b6317c6ea 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -147,13 +147,15 @@ def judge_index_type(index_type, target_type): @constexpr def check_type_valid(dtype, target_type, op_name): if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): - raise TypeError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.") + raise TypeError( + f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.") @constexpr def check_index_type_valid(dtype, target_type, op_name): if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): - raise IndexError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.") + raise IndexError( + f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.") @constexpr @@ -189,7 +191,8 @@ def get_pos_of_indexes_types(indexes_types, op_name): raise IndexError(f"For '{op_name}', the index elements only support " f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.") if len(ellipsis_positions) > 1: - raise IndexError(f"For '{op_name}, an index can only have a single ellipsis('...')") + raise IndexError( + f"For '{op_name}, an index can only have a single ellipsis('...')") return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \ tensor_positions, sequence_positions @@ -260,7 +263,7 @@ def ellipsis2slice(input_, shape): return tuple(result) -@ constexpr +@constexpr def slice2indices(input_slices, shape): """ Converts slice to indices. @@ -285,7 +288,7 @@ def slice2indices(input_slices, shape): return ravel -@ constexpr +@constexpr def check_indices(indices_size, index): """Checks indices whether is empty.""" if indices_size < 1: @@ -294,7 +297,7 @@ def check_indices(indices_size, index): return indices_size -@ constexpr +@constexpr def check_indices_value_size(indices_size, value_size): """Checks if the sizes are already matched.""" if value_size < 1: @@ -307,7 +310,7 @@ def check_indices_value_size(indices_size, value_size): return value_size -@ constexpr +@constexpr def integer_to_indices(index, shape): """Converts int or tuple[int] to indices.""" size = reduce(lambda x, y: x * y, shape) @@ -317,7 +320,7 @@ def integer_to_indices(index, shape): return Tensor(value, dtype=mstype.int32) -@ constexpr +@constexpr def tuple_element_is_int(indexs): """Judges tuple element type.""" if not indexs: @@ -330,18 +333,19 @@ def tuple_element_is_int(indexs): return False -@ constexpr +@constexpr def tuple_index_int_cnt(types, op_name): """count the int type of types which contains the tuple elements' type.""" int_cnt = sum(isinstance(ele, mstype.Int) for ele in types) return ALL_INT if int_cnt == len(types) else NO_INT if int_cnt == 0 else CONTAIN_INT -@ constexpr +@constexpr def tuple_index_type_cnt(types, op_name): """count the tensor type of types which contains the tuple elements' type.""" tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types) - basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types) + basic_cnt = sum(isinstance( + ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types) if tensor_cnt == len(types): return ALL_TENSOR if basic_cnt == len(types): @@ -349,7 +353,7 @@ def tuple_index_type_cnt(types, op_name): return MIXED -@ constexpr +@constexpr def check_value_elements(data_dtype, types): """Judges the type of all elements of the tuple.""" tensors_number = 0 @@ -377,10 +381,10 @@ def check_value_elements(data_dtype, types): # TODO to del -@ constexpr +@constexpr def get_index_tensor_dtype(dtype): """Check a tuple of tensor data type.""" - if dtype == mstype.int32: + if dtype in mstype.int_type: return INT_ if dtype == mstype.bool_: return BOOL_ @@ -389,7 +393,7 @@ def get_index_tensor_dtype(dtype): # TODO to del -@ constexpr +@constexpr def check_index_tensors_dtype(indexes_types, op_name): """Check a tuple of tensor data type.""" for index_type in indexes_types: @@ -400,7 +404,7 @@ def check_index_tensors_dtype(indexes_types, op_name): # TODO to del -@ constexpr +@constexpr def check_index_tensor_dtype(index_type, op_name): """Check a tensor data type.""" if index_type in (mstype.int32, mstype.int64): @@ -410,7 +414,7 @@ def check_index_tensor_dtype(index_type, op_name): # TODO to del -@ constexpr +@constexpr def check_tensors_dtype_same(data_dtype, value_dtype, op_name): """Check tensors data type same.""" if value_dtype == data_dtype: @@ -419,7 +423,7 @@ def check_tensors_dtype_same(data_dtype, value_dtype, op_name): f"is not consistent with assigned tensor data type {data_dtype}.") -@ constexpr +@constexpr def generate_broadcast_shape(shapes, op_name): """Generate broadcast shape for a tuple of shape.""" if not shapes: @@ -428,13 +432,14 @@ def generate_broadcast_shape(shapes, op_name): for i, shape in enumerate(shapes): logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") try: - broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name) + broadcast_shape = op_utils.get_broadcast_shape( + broadcast_shape, shape, op_name) except ValueError as ex: raise IndexError(ex) return tuple(broadcast_shape) -@ constexpr +@constexpr def check_two_shapes_need_broadcast(shape_x, shape_y): """Check two shapes need broadcast.""" error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape " @@ -451,14 +456,14 @@ def check_two_shapes_need_broadcast(shape_x, shape_y): return True -@ constexpr +@constexpr def compute_multiples(origin_shape, broadcast_shape): """Compute multiples between origin shape with broadcast shape.""" len_gap = len(broadcast_shape) - len(origin_shape) return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape)) -@ constexpr +@constexpr def compute_new_shape(origin_shape, indexes_shapes_info): """Compute new shape between origin shape with final shape.""" new_shape = [] @@ -470,21 +475,22 @@ def compute_new_shape(origin_shape, indexes_shapes_info): return tuple(new_shape) -@ constexpr +@constexpr def check_sequence_index_type(sequence_index, op_name): """check if the item's type of list_index is bool or int""" - if not all([isinstance(index, (int, bool)) for index in sequence_index]): - raise IndexError(f"In the {op_name} operation, only support 'integer' or 'boolean' array(list/tuple), " - f"but got {type(index)} in array") + for index in sequence_index: + if not isinstance(index, int): + raise IndexError(f"In the {op_name} operation, only support 'inter' or 'boolean' array(list/tuple), " + f"but got {type(index)} in array.") -@ constexpr +@constexpr def convert_int_to_slice(tuple_index): tuple_index_new = tuple(slice(i, i+1, 1) for i in tuple_index) return tuple_index_new -@ constexpr +@constexpr def check_and_transform_int_index(index, shape, op_name): if index < -shape or index >= shape: raise IndexError(f"In the \"{op_name}\", the index should in the range [-{shape}, {shape-1}] to fit " @@ -494,16 +500,20 @@ def check_and_transform_int_index(index, shape, op_name): return index -@ constexpr +@constexpr def transform_sequence_index(sequence_index, shape, op_name): """transform list or tuple with integer and boolean to tuple with integer index""" - bool_count = len(list(filter(lambda index: isinstance(index, bool), sequence_index))) - int_count = len(list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count + bool_count = len( + list(filter(lambda index: isinstance(index, bool), sequence_index))) + int_count = len( + list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count if int_count == 0: if bool_count == shape: - list_index = list(filter(lambda i: sequence_index[i], range(bool_count))) + list_index = list( + filter(lambda i: sequence_index[i], range(bool_count))) else: - raise IndexError("The boolean array should have the same length with the corresponding dimensiton") + raise IndexError( + "The boolean array should have the same length with the corresponding dimensiton") else: list_index = [int(index) for index in sequence_index] for i, index in enumerate(list_index): @@ -512,7 +522,7 @@ def transform_sequence_index(sequence_index, shape, op_name): return sub_tuple_index -@ constexpr +@constexpr def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name): """Convert a slice to a tensor.""" shape = [] @@ -540,7 +550,7 @@ def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_n return tensor -@ constexpr +@constexpr def check_shapes_same(value_shapes, op_name): """Check if the shapes in the tuple are consistent.""" for i, shape in enumerate(value_shapes): @@ -550,7 +560,7 @@ def check_shapes_same(value_shapes, op_name): return True -@ constexpr +@constexpr def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type): """Convert a scalar to a tensor.""" if op_type == SET_ITEM_BY_ONE_TENSOR: @@ -563,7 +573,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty f" is not consistent with the assigned tensor data type {data_dtype}.") -@ constexpr +@constexpr def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type): """Convert a tuple of scalar to a tensor.""" updates_shape = generate_updates_shape(data_shape, index_shape, op_type) @@ -575,7 +585,7 @@ def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value return Tensor(np.tile(array, reps)) -@ constexpr +@constexpr def generate_updates_shape(data_shape, index_shape, op_type): """Generate updates shape for 'tensor setitem'.""" if op_type == SET_ITEM_BY_ONE_TENSOR: @@ -585,7 +595,7 @@ def generate_updates_shape(data_shape, index_shape, op_type): return updates_shape -@ constexpr +@constexpr def check_tuple_index_len(data_rank, tuple_index_len, op_name): """Check if the number of index tensor exceeds the dimension of the operated tensor.""" if tuple_index_len <= data_rank: @@ -594,7 +604,7 @@ def check_tuple_index_len(data_rank, tuple_index_len, op_name): f"is greater than the dimension {data_rank} of the operated tensor.") -@ constexpr +@constexpr def generate_index_info_from_tuple_of_mixed_tensors(data_shape, indexes_types, tensor_indexes_shapes, tensor_indexes_dtypes, slice_indexes, op_name): """ @@ -694,14 +704,14 @@ def scalar_in_sequence(x, y): return False -@ constexpr +@constexpr def get_np_eps(input_dtype): nptype = mstype.dtype_to_nptype(input_dtype) eps = np.finfo(nptype).eps return float(eps) -@ constexpr +@constexpr def check_number_index_type(number): """Check if it is int or bool number""" if isinstance(number, bool): @@ -712,7 +722,7 @@ def check_number_index_type(number): .format(number, type(number))) -@ constexpr +@constexpr def get_stride_info_from_slice(data_shape, slice_index): """Get stride info from a python slice""" begin, end, step = get_slice_stride(data_shape[0], slice_index) @@ -726,7 +736,7 @@ def get_stride_info_from_slice(data_shape, slice_index): return tuple(begin_strides), tuple(end_strides), tuple(step_strides) -@ constexpr +@constexpr def get_stride_info_from_integer(data_shape, number): """Get stride info from a integer""" begin_strides = [number] @@ -752,7 +762,7 @@ def get_slice_stride(dim_size, index_slice): return start, stop, step -@ constexpr +@constexpr def get_stride_info_from_tuple(data_shape, tuple_index): """Get stride info from a tuple""" begin_strides, end_strides, step_strides = [], [], [] @@ -792,14 +802,14 @@ def get_stride_info_from_tuple(data_shape, tuple_index): return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis -@ constexpr +@constexpr def mstype_eq(x, y): if x == y: return True return False -@ constexpr +@constexpr def scalar_to_tensor(x): """Convert a scalar to a tensor""" return Tensor(x)