From fde373a3b7eadcc647534b83436dce8c02366a71 Mon Sep 17 00:00:00 2001 From: yepei6 Date: Fri, 15 Jan 2021 16:05:38 +0800 Subject: [PATCH] debug and add new supported scene --- .../composite/multitype_ops/_compile_utils.py | 69 +++++++------------ .../multitype_ops/_constexpr_utils.py | 35 +++++----- 2 files changed, 41 insertions(+), 63 deletions(-) diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index c2262af9a1a..9a82408a633 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -129,43 +129,24 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name): return tuple_index_new -def _expand_data_dims_with_none(data, tuple_index, op_name): - """expand the data's dim with 'None' in tuple_index""" +def _expand_data_dims(data, tuple_index, op_name): + """expand the data's dim with 'None' and 'Boolean' in tuple_index""" indexes_types = hyper_map(F.typeof, tuple_index) - none_positions, tuple_index_without_none = (), () + expand_positions, tuple_index_new = (), () for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)): - none_type_tag = const_utils.judge_index_type(index_type, mstype.type_none) - tuple_index_without_none += (const_utils.make_empty_slice(),) if none_type_tag else(index,) - none_positions += (i,) if none_type_tag else () - for dim in none_positions: - data = F.expand_dims(data, dim) - return data, tuple_index_without_none - - -def _expand_data_dims_with_bool(data, tuple_index, op_name): - """expand the data's dim with 'True/False' in tuple_index""" - indexes_types = hyper_map(F.typeof, tuple_index) - bool_positions, tuple_index_without_bool = (), () - - for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)): - bool_type_tag = const_utils.judge_index_type(index_type, mstype.bool_) - if bool_type_tag: - if index: - tuple_index_without_bool += (const_utils.make_tensor([0], mstype.int64),) - else: - # todo wait to complete the operations' support for zero dim-size, then could make 0 length tensor. - # to replace the 'False' - - return const_utils.raise_index_error("When tensor is indexed by a tuple which contains bool object, " - "the value only support 'True'.") + if const_utils.judge_index_type(index_type, mstype.type_none): + tuple_index_new += (const_utils.make_empty_slice(),) + expand_positions += (i,) + elif const_utils.judge_index_type(index_type, mstype.bool_): + tuple_index_new += (const_utils.make_tensor([0] if index else[], mstype.int64),) + expand_positions += (i,) else: - tuple_index_without_bool += (index,) - bool_positions += (i,) if bool_type_tag else () + tuple_index_new += (index,) - for dim in bool_positions: + for dim in expand_positions: data = F.expand_dims(data, dim) - return data, tuple_index_without_bool + return data, tuple_index_new def tensor_index_by_slice(data, slice_index): @@ -219,22 +200,22 @@ def tensor_index_by_tensor(data, tensor_index): def tensor_index_by_list(data, list_index): """Tensor getitem by list of int and bool""" data_shape = F.shape(data) - const_utils.check_sequence_index_type(list_index, const_utils.TENSOR_GETITEM) - sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM) - tensor_index = F.tuple_to_array(sub_tuple_index) - tensor_index = F.cast(tensor_index, mstype.int64) - return F.gather(data, tensor_index, 0) + indexes_types = hyper_map(F.typeof, list_index) + if const_utils.judge_indexes_types(indexes_types, mstype.int_type + (mstype.bool_,)): + sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM) + tensor_index = const_utils.make_tensor(sub_tuple_index, mstype.int64) + return F.gather(data, tensor_index, 0) + tuple_index_new = () + for index in list_index: + tuple_index_new += (index,) + return tensor_index_by_tuple(data, tuple_index_new) def tensor_index_by_tuple(data, tuple_index): """Tensor getitem by tuple of various types with None""" op_name = const_utils.TENSOR_GETITEM - if len(tuple_index) == 1: - return data[tuple_index[0]] - tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) - data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name) - data, tuple_index = _expand_data_dims_with_bool(data, tuple_index, op_name) + data, tuple_index = _expand_data_dims(data, tuple_index, op_name) indexes_types = hyper_map(F.typeof, tuple_index) contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name) @@ -502,7 +483,7 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value): return data op_name = const_utils.TENSOR_GETITEM tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) - data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name) + data, tuple_index = _expand_data_dims(data, tuple_index, op_name) indexes_types = hyper_map(F.typeof, tuple_index) contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) @@ -564,7 +545,7 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): return data op_name = const_utils.TENSOR_GETITEM tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) - data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name) + data, tuple_index = _expand_data_dims(data, tuple_index, op_name) indexes_types = hyper_map(F.typeof, tuple_index) contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) @@ -592,7 +573,7 @@ def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): return data op_name = const_utils.TENSOR_GETITEM tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) - data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name) + data, tuple_index = _expand_data_dims(data, tuple_index, op_name) indexes_types = hyper_map(F.typeof, tuple_index) contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index f6b6317c6ea..7b0e328ded2 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -74,7 +74,7 @@ def make_empty_slice(): @constexpr -def make_tensor(data, data_type, data_shape=None): +def make_tensor(data, data_type=mstype.int64, data_shape=None): if data_shape: return Tensor(np.zeros(data_shape), data_type) return Tensor(data, data_type) @@ -158,6 +158,15 @@ def check_index_type_valid(dtype, target_type, op_name): f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.") +@constexpr +def judge_indexes_types(dtypes, target_type): + """Check a tuple of tensor data type.""" + for dtype in dtypes: + if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): + return False + return True + + @constexpr def check_indexes_types_valid(dtypes, target_type, op_name): """Check a tuple of tensor data type.""" @@ -475,15 +484,6 @@ def compute_new_shape(origin_shape, indexes_shapes_info): return tuple(new_shape) -@constexpr -def check_sequence_index_type(sequence_index, op_name): - """check if the item's type of list_index is bool or int""" - for index in sequence_index: - if not isinstance(index, int): - raise IndexError(f"In the {op_name} operation, only support 'inter' or 'boolean' array(list/tuple), " - f"but got {type(index)} in array.") - - @constexpr def convert_int_to_slice(tuple_index): tuple_index_new = tuple(slice(i, i+1, 1) for i in tuple_index) @@ -503,19 +503,16 @@ def check_and_transform_int_index(index, shape, op_name): @constexpr def transform_sequence_index(sequence_index, shape, op_name): """transform list or tuple with integer and boolean to tuple with integer index""" - bool_count = len( - list(filter(lambda index: isinstance(index, bool), sequence_index))) - int_count = len( - list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count - if int_count == 0: + bool_count = len(list(filter(lambda index: isinstance(index, bool), sequence_index))) + int_count = len(list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count + if int_count == 0 and bool_count != 0: if bool_count == shape: - list_index = list( - filter(lambda i: sequence_index[i], range(bool_count))) + list_index = list(filter(lambda i: sequence_index[i], range(bool_count))) else: - raise IndexError( - "The boolean array should have the same length with the corresponding dimensiton") + raise IndexError("The boolean array should have the same length with the corresponding dimensiton") else: list_index = [int(index) for index in sequence_index] + for i, index in enumerate(list_index): list_index[i] = check_and_transform_int_index(index, shape, op_name) sub_tuple_index = tuple(list_index)