forked from mindspore-Ecosystem/mindspore
!1407 support mixed tensor index for tensor get item and set item and support in operator.
Merge pull request !1407 from zhangbuxue/support_mixed_tensor_for_tensor_get_item_and_tensor_set_item
This commit is contained in:
commit
ad279e90fd
|
@ -105,7 +105,7 @@ convert_object_map = {
|
|||
T.ge: multitype_ops.greater_equal,
|
||||
T.is_: F.is_,
|
||||
T.is_not: F.is_not,
|
||||
T.contains: F.in_dict,
|
||||
T.contains: multitype_ops.in_,
|
||||
T.not_contains: F.not_in_dict,
|
||||
|
||||
# system function
|
||||
|
|
|
@ -474,6 +474,8 @@ REGISTER_PYBIND_DEFINE(
|
|||
(void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init());
|
||||
(void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
|
||||
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
|
||||
(void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init());
|
||||
(void)py::class_<Ellipsis, Type, std::shared_ptr<Ellipsis>>(m_sub, "Ellipsis").def(py::init());
|
||||
}));
|
||||
|
||||
const TypePtr kTypeExternal = std::make_shared<External>();
|
||||
|
|
|
@ -95,6 +95,8 @@ string = typing.String()
|
|||
type_refkey = typing.RefKeyType()
|
||||
tensor_type = typing.TensorType
|
||||
anything_type = typing.TypeAnything
|
||||
slice_type = typing.Slice
|
||||
ellipsis_type = typing.Ellipsis
|
||||
|
||||
number_type = (int8,
|
||||
int16,
|
||||
|
|
|
@ -37,6 +37,7 @@ from .logical_and_impl import logical_and
|
|||
from .logical_or_impl import logical_or
|
||||
from .logic_not_impl import logical_not
|
||||
from .uadd_impl import uadd
|
||||
from .in_impl import in_
|
||||
__all__ = [
|
||||
'add',
|
||||
'sub',
|
||||
|
@ -59,5 +60,6 @@ __all__ = [
|
|||
'setitem',
|
||||
'logical_and',
|
||||
'logical_or',
|
||||
'logical_not'
|
||||
'logical_not',
|
||||
'in_'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""constexpr util"""
|
||||
from . import _constexpr_utils as const_utils
|
||||
from ... import functional as F
|
||||
from ... import operations as P
|
||||
from ...composite import base
|
||||
from ....common import dtype as mstype
|
||||
|
||||
hyper_map = base.HyperMap()
|
||||
pack = P.Pack(axis=-1)
|
||||
|
||||
|
||||
def broadcast(broadcast_shape, x):
|
||||
"""Broadcast tensor to the required shape."""
|
||||
if F.shape(x) == broadcast_shape:
|
||||
return x
|
||||
multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape)
|
||||
if multiples:
|
||||
return F.tile(x, multiples)
|
||||
return x
|
||||
|
||||
|
||||
def transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x):
|
||||
"""Transform indexing tensor to the required."""
|
||||
x = broadcast(broadcast_shape, x)
|
||||
return broadcast(final_shape, F.reshape(x, new_shape))
|
||||
|
||||
|
||||
def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
|
||||
"""Generate an indices tensor from a tuple of tensor."""
|
||||
indices = None
|
||||
check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
|
||||
if check_index_tensor_number:
|
||||
dtype_tuple = hyper_map(F.dtype, tuple_index)
|
||||
check_dtypes = const_utils.check_index_tensors_dtype(dtype_tuple, op_name)
|
||||
if check_dtypes:
|
||||
shape_tuple = hyper_map(F.shape, tuple_index)
|
||||
broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name)
|
||||
broadcast_tensors = hyper_map(F.partial(broadcast, broadcast_shape), tuple_index)
|
||||
indices = pack(broadcast_tensors)
|
||||
return indices
|
||||
|
||||
|
||||
def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
|
||||
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
int_positions = const_utils.get_pos_of_int_index(indexes_types)
|
||||
for i in int_positions:
|
||||
tuple_index = F.tuple_setitem(tuple_index, i, F.scalar_to_tensor(tuple_index[i], mstype.int32))
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
tensor_positions, slice_positions, ellipsis_position = \
|
||||
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
|
||||
tensor_indexes = []
|
||||
slice_indexes = []
|
||||
for i in tensor_positions:
|
||||
tensor_indexes.append(tuple_index[i])
|
||||
for j in slice_positions:
|
||||
slice_indexes.append(tuple_index[j])
|
||||
data_shape = F.shape(data)
|
||||
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
|
||||
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
|
||||
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \
|
||||
const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape,
|
||||
indexes_types,
|
||||
tensor_indexes_shapes,
|
||||
tensor_indexes_dtypes,
|
||||
slice_indexes,
|
||||
op_name)
|
||||
|
||||
slice_number = 0
|
||||
final_index_tensors = []
|
||||
tuple_index_size = len(tuple_index)
|
||||
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
|
||||
for i in range(tuple_index_size):
|
||||
if i in tensor_positions:
|
||||
transform_tensor = transform_indexing_tensor(broadcast_shape,
|
||||
final_shape,
|
||||
index_tensor_new_shape,
|
||||
tuple_index[i])
|
||||
final_index_tensors.append(transform_tensor)
|
||||
if i in slice_positions:
|
||||
slice_tensor = const_utils.convert_slice_to_tensor(slice_number,
|
||||
final_shape,
|
||||
indexes_shapes_info,
|
||||
op_name)
|
||||
final_index_tensors.append(slice_tensor)
|
||||
slice_number += 1
|
||||
if i == ellipsis_position:
|
||||
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(slice_number,
|
||||
ellipsis_occupied_dims,
|
||||
final_shape,
|
||||
indexes_shapes_info,
|
||||
op_name)
|
||||
for ele in ellipsis_tensors:
|
||||
final_index_tensors.append(ele)
|
||||
slice_number += ellipsis_occupied_dims
|
||||
indices = pack(final_index_tensors)
|
||||
return indices
|
||||
|
||||
|
||||
def generate_updates_from_scalar(data, indices, value, op_type):
|
||||
"""Generate an updates tensor from a scalar."""
|
||||
data_shape = F.shape(data)
|
||||
indices_shape = F.shape(indices)
|
||||
data_dtype = F.dtype(data)
|
||||
return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type)
|
||||
|
||||
|
||||
def generate_updates_from_tuple(data, index, value, op_type):
|
||||
"""Generate an updates tensor from a tuple."""
|
||||
value_types = hyper_map(F.typeof, value)
|
||||
data_dtype = F.dtype(data)
|
||||
value_elements_type = const_utils.check_value_elements(data_dtype, value_types)
|
||||
if value_elements_type == const_utils.ALL_TENSOR:
|
||||
value_shapes = hyper_map(F.shape, value)
|
||||
shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM)
|
||||
if shapes_same:
|
||||
value = F.pack(value)
|
||||
return generate_updates_from_tensor(data, index, value, op_type)
|
||||
|
||||
data_shape = F.shape(data)
|
||||
index_shape = F.shape(index)
|
||||
return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type)
|
||||
|
||||
|
||||
def generate_updates_from_tensor(data, index, value, op_type):
|
||||
"""Generate an updates tensor from a tensor."""
|
||||
data_shape = F.shape(data)
|
||||
index_shape = F.shape(index)
|
||||
value_shape = F.shape(value)
|
||||
data_dtype = F.dtype(data)
|
||||
value_dtype = F.dtype(value)
|
||||
updates_shape = value_shape
|
||||
check_dtype_same = const_utils.check_tensors_dtype_same(data_dtype, value_dtype, const_utils.TENSOR_SETITEM)
|
||||
if check_dtype_same:
|
||||
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type)
|
||||
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape)
|
||||
if need_broadcast:
|
||||
return broadcast(updates_shape, value)
|
||||
return value
|
|
@ -15,19 +15,15 @@
|
|||
|
||||
"""constexpr util"""
|
||||
from functools import reduce
|
||||
import numpy as np
|
||||
from ...primitive import constexpr
|
||||
from ....common.tensor import Tensor
|
||||
from ....common import dtype as mstype
|
||||
from ...._extends.utils import Slice, Ellipsis_
|
||||
from ....ops import _utils as op_utils
|
||||
from ...composite import base
|
||||
from .... import log as logger
|
||||
from ... import functional as F
|
||||
from ... import operations as P
|
||||
|
||||
hyper_map = base.HyperMap()
|
||||
pack = P.Pack(axis=-1)
|
||||
import numpy as np
|
||||
|
||||
from ...primitive import constexpr
|
||||
from .... import log as logger
|
||||
from ...._extends.utils import Slice, Ellipsis_
|
||||
from ....common import dtype as mstype
|
||||
from ....common.tensor import Tensor
|
||||
from ....ops import _utils as op_utils
|
||||
|
||||
ALL_TENSOR = 0
|
||||
NO_TENSOR = 1
|
||||
|
@ -264,7 +260,7 @@ def tuple_index_elements_type(types, op_name):
|
|||
return ALL_TENSOR
|
||||
if tensors_number == 0:
|
||||
return NO_TENSOR
|
||||
raise IndexError(f"For '{op_name}', the index does not support mixed tensor.")
|
||||
return CONTAIN_TENSOR
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -279,12 +275,12 @@ def check_value_elements(data_dtype, types):
|
|||
tensors_number += 1
|
||||
else:
|
||||
raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' "
|
||||
f"in value tuple is not consistent with origin tensor data type '{data_dtype}'.")
|
||||
f"in value tuple is not consistent with assigned tensor data type '{data_dtype}'.")
|
||||
elif mstype.issubclass_(ele, data_dtype):
|
||||
scalars_number += 1
|
||||
else:
|
||||
raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in "
|
||||
f"value tuple is not consistent with origin tensor data type '{data_dtype}'.")
|
||||
f"value tuple is not consistent with assigned tensor data type '{data_dtype}'.")
|
||||
if tensors_number == len(types):
|
||||
return ALL_TENSOR
|
||||
if scalars_number == len(types):
|
||||
|
@ -299,51 +295,46 @@ def get_index_tensor_dtype(dtype):
|
|||
return INT_
|
||||
if dtype == mstype.bool_:
|
||||
return BOOL_
|
||||
raise TypeError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
|
||||
raise IndexError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_index_tensors_dtype(dtypes, op_name):
|
||||
"""Check a tuple of tensor data type."""
|
||||
if op_name == TENSOR_GETITEM:
|
||||
valid_dtypes = (mstype.int32, mstype.int64)
|
||||
elif op_name == TENSOR_SETITEM:
|
||||
valid_dtypes = (mstype.int32,)
|
||||
else:
|
||||
raise ValueError("Unsupported operation.")
|
||||
for ele in dtypes:
|
||||
if ele in valid_dtypes and ele == dtypes[0]:
|
||||
continue
|
||||
raise TypeError(f"For '{op_name}', the index tensors data type must be same, "
|
||||
f"and should be one of the following: {valid_dtypes}, but got {dtypes}.")
|
||||
if not ele == mstype.int32:
|
||||
raise IndexError(f"For '{op_name}', the all index tensor "
|
||||
f"data types should be mstype.int32, but got {dtypes}.")
|
||||
return True
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_tensor_dtype_valid(dtype, valid_dtypes):
|
||||
def check_index_tensor_dtype(dtype, op_name):
|
||||
"""Check a tensor data type."""
|
||||
if dtype in valid_dtypes:
|
||||
if dtype == mstype.int32:
|
||||
return True
|
||||
raise TypeError(f"The index tensor data type must be one of "
|
||||
f"the following: {valid_dtypes}, but got {dtype}.")
|
||||
raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_tensors_dtype_same(x_dtype, y_dtype, op_name):
|
||||
def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
|
||||
"""Check tensors data type same."""
|
||||
if x_dtype == y_dtype:
|
||||
if value_dtype == data_dtype:
|
||||
return True
|
||||
raise TypeError(f"For '{op_name}', the value data type '{y_dtype}' "
|
||||
f"is not consistent with origin tensor data type {x_dtype}.")
|
||||
raise TypeError(f"For '{op_name}', the value data type '{value_dtype}' "
|
||||
f"is not consistent with assigned tensor data type {data_dtype}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def broadcast_shapes(shapes, op_name):
|
||||
"""Broadcasts a tuple of tensor."""
|
||||
def generate_broadcast_shape(shapes, op_name):
|
||||
"""Generate broadcast shape for a tuple of shape."""
|
||||
broadcast_shape = shapes[0]
|
||||
for i, shape in enumerate(shapes):
|
||||
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
|
||||
try:
|
||||
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
|
||||
except ValueError as ex:
|
||||
raise IndexError(ex)
|
||||
return tuple(broadcast_shape)
|
||||
|
||||
|
||||
|
@ -366,14 +357,82 @@ def check_two_shapes_need_broadcast(shape_x, shape_y):
|
|||
|
||||
@constexpr
|
||||
def compute_multiples(origin_shape, broadcast_shape):
|
||||
"""Compute multiples between broadcast_shape with origin_shape."""
|
||||
"""Compute multiples between origin shape with broadcast shape."""
|
||||
len_gap = len(broadcast_shape) - len(origin_shape)
|
||||
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape))
|
||||
|
||||
|
||||
def tile(broadcast_shape, x):
|
||||
multiples = compute_multiples(F.shape(x), broadcast_shape)
|
||||
return F.tile(x, multiples)
|
||||
@constexpr
|
||||
def compute_new_shape(origin_shape, indexes_shapes_info):
|
||||
"""Compute new shape between origin shape with final shape."""
|
||||
new_shape = []
|
||||
for i in indexes_shapes_info:
|
||||
if i == origin_shape:
|
||||
new_shape.extend(origin_shape)
|
||||
else:
|
||||
new_shape.append(1)
|
||||
return tuple(new_shape)
|
||||
|
||||
|
||||
@constexpr
|
||||
def convert_ellipsis_to_tensors(slice_number,
|
||||
ellipsis_occupied_dims,
|
||||
final_shape,
|
||||
indexes_shapes_info,
|
||||
op_name):
|
||||
"""Convert an ellipsis to a list of tensor."""
|
||||
tensor_list = []
|
||||
dims_dealt_count = 0
|
||||
while dims_dealt_count < ellipsis_occupied_dims:
|
||||
shape = []
|
||||
slice_count = 0
|
||||
array = None
|
||||
for ele in indexes_shapes_info:
|
||||
if isinstance(ele, list):
|
||||
if slice_count == slice_number:
|
||||
array = np.array(ele, np.int32)
|
||||
shape.append(len(ele))
|
||||
else:
|
||||
shape.append(1)
|
||||
slice_count += 1
|
||||
if isinstance(ele, tuple):
|
||||
shape.extend([1] * len(ele))
|
||||
if array is None:
|
||||
raise ValueError(f"For '{op_name}', generate tensors from ellipsis failed.")
|
||||
array = np.reshape(array, shape)
|
||||
reps = compute_multiples(shape, final_shape)
|
||||
tensor = Tensor(np.tile(array, reps))
|
||||
tensor_list.append(tensor)
|
||||
slice_number += 1
|
||||
dims_dealt_count += 1
|
||||
return tensor_list
|
||||
|
||||
|
||||
@constexpr
|
||||
def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name):
|
||||
"""Convert a slice to a tensor."""
|
||||
shape = []
|
||||
count = 0
|
||||
array = None
|
||||
for ele in indexes_shapes_info:
|
||||
if isinstance(ele, list):
|
||||
if count == slice_number:
|
||||
array = np.array(ele, np.int32)
|
||||
shape.append(len(ele))
|
||||
else:
|
||||
# When the slice is not the slice looking for, the shape is filled with 1.
|
||||
shape.append(1)
|
||||
count += 1
|
||||
elif isinstance(ele, tuple):
|
||||
shape.extend([1] * len(ele))
|
||||
else:
|
||||
shape.append(1)
|
||||
if array is None:
|
||||
raise ValueError(f"For '{op_name}', generate tensor from 'slice' failed.")
|
||||
array = np.reshape(array, shape)
|
||||
reps = compute_multiples(shape, final_shape)
|
||||
tensor = Tensor(np.tile(array, reps))
|
||||
return tensor
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -381,8 +440,8 @@ def check_shapes_same(value_shapes, op_name):
|
|||
"""Check if the shapes in the tuple are consistent."""
|
||||
for i, shape in enumerate(value_shapes):
|
||||
if shape != value_shapes[0]:
|
||||
raise ValueError(f"For '{op_name}', the {i}th tensor shape in value tuple "
|
||||
f"is not same as the first tensor shape.")
|
||||
raise ValueError(f"For '{op_name}', the {i}th tensor shape in "
|
||||
f"value tuple is not same as the first tensor shape.")
|
||||
return True
|
||||
|
||||
|
||||
|
@ -396,7 +455,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty
|
|||
if isinstance(value, mstype.dtype_to_pytype(data_dtype)):
|
||||
return Tensor(np.full(updates_shape, value), dtype=data_dtype)
|
||||
raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'"
|
||||
f" is not consistent with tensor data type {data_dtype}.")
|
||||
f" is not consistent with the assigned tensor data type {data_dtype}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -404,8 +463,8 @@ def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value
|
|||
"""Convert a tuple of scalar to a tensor."""
|
||||
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
|
||||
if len(value) != updates_shape[-1]:
|
||||
raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} in the updates tuple "
|
||||
f"does not meet the requirements: {updates_shape[-1]}.")
|
||||
raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} "
|
||||
f"in the updates tuple does not meet the requirements: {updates_shape[-1]}.")
|
||||
array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype))
|
||||
reps = compute_multiples(updates_shape[-1:], updates_shape)
|
||||
return Tensor(np.tile(array, reps))
|
||||
|
@ -430,58 +489,145 @@ def check_number_of_index_tensor(data_shape, tuple_len, op_name):
|
|||
f"is greater than the dimension {len(data_shape)} of the operated tensor.")
|
||||
|
||||
|
||||
def generate_indeices_from_tuple_of_tensor(data, tuple_index, op_name):
|
||||
"""Generate an indices tensor from a tuple of tensor."""
|
||||
indices = None
|
||||
check_index_tensor_number = check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
|
||||
if check_index_tensor_number:
|
||||
dtype_tuple = hyper_map(F.dtype, tuple_index)
|
||||
check_dtypes = check_index_tensors_dtype(dtype_tuple, op_name)
|
||||
if check_dtypes:
|
||||
shape_tuple = hyper_map(F.shape, tuple_index)
|
||||
broadcast_shape = broadcast_shapes(shape_tuple, op_name)
|
||||
broadcast_tensors = hyper_map(F.partial(tile, broadcast_shape), tuple_index)
|
||||
indices = pack(broadcast_tensors)
|
||||
return indices
|
||||
@constexpr
|
||||
def generate_index_info_from_tuple_of_mixed_tensors(data_shape,
|
||||
indexes_types,
|
||||
tensor_indexes_shapes,
|
||||
tensor_indexes_dtypes,
|
||||
slice_indexes,
|
||||
op_name):
|
||||
"""
|
||||
Generate index info which contain broadcast shape, final shape,
|
||||
indexes shapes info, ellipsis size from a tuple of mixed tensors.
|
||||
"""
|
||||
check_index_tensors_dtype(tensor_indexes_dtypes, op_name)
|
||||
data_rank = len(data_shape)
|
||||
indexes_size = len(indexes_types)
|
||||
if indexes_size > data_rank:
|
||||
raise IndexError(f"For '{op_name}', the number {indexes_size} of index elements "
|
||||
f"is greater than the dimension {len(data_shape)} of the operated tensor.")
|
||||
indexes_info = {}
|
||||
index_tensors_info = {}
|
||||
ellipsis_num = 0
|
||||
ellipsis_occupied_dims = 0
|
||||
tensor_count = 0
|
||||
slice_count = 0
|
||||
for i, ele_type in enumerate(indexes_types):
|
||||
if ellipsis_num == 0:
|
||||
pos = i
|
||||
else:
|
||||
pos = i + ellipsis_occupied_dims - 1
|
||||
if isinstance(ele_type, mstype.tensor_type):
|
||||
indexes_info[pos] = tensor_indexes_shapes[tensor_count]
|
||||
index_tensors_info[pos] = tensor_indexes_shapes[tensor_count]
|
||||
tensor_count += 1
|
||||
elif isinstance(ele_type, mstype.slice_type):
|
||||
slice_obj = slice(slice_indexes[slice_count].start,
|
||||
slice_indexes[slice_count].end,
|
||||
slice_indexes[slice_count].step)
|
||||
# Use list to represent slicing result.
|
||||
indexes_info[pos] = list(range(data_shape[pos]))[slice_obj]
|
||||
slice_count += 1
|
||||
elif isinstance(ele_type, mstype.ellipsis_type):
|
||||
if ellipsis_num != 0:
|
||||
raise IndexError(f"For '{op_name}', the index could only contain one ellipsis.")
|
||||
ellipsis_occupied_dims = data_rank - indexes_size + 1
|
||||
for j in range(pos, pos + ellipsis_occupied_dims):
|
||||
# Use list to represent slicing result.
|
||||
indexes_info[j] = list(range(data_shape[j]))
|
||||
ellipsis_num += 1
|
||||
else:
|
||||
raise IndexError(f"For '{op_name}', the index elements only support "
|
||||
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {ele_type}.")
|
||||
broadcast_shape, final_shape, indexes_shapes_info = \
|
||||
_derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name)
|
||||
return broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims
|
||||
|
||||
|
||||
def generate_updates_from_scalar(data, indices, value, op_type):
|
||||
"""Generate an updates tensor from a scalar."""
|
||||
data_shape = F.shape(data)
|
||||
indices_shape = F.shape(indices)
|
||||
data_dtype = F.dtype(data)
|
||||
return convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type)
|
||||
def _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key: list):
|
||||
"""Determine whether the tensor in the index appears continuously."""
|
||||
for i in range(len(index_tensor_info_key) - 1):
|
||||
if index_tensor_info_key[i + 1] != index_tensor_info_key[i] + 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def generate_updates_from_tuple(data, index, value, op_type):
|
||||
"""Generate an updates tensor from a tuple."""
|
||||
value_types = hyper_map(F.typeof, value)
|
||||
data_dtype = F.dtype(data)
|
||||
value_elements_type = check_value_elements(data_dtype, value_types)
|
||||
if value_elements_type == ALL_TENSOR:
|
||||
value_shapes = hyper_map(F.shape, value)
|
||||
shapes_same = check_shapes_same(value_shapes, TENSOR_SETITEM)
|
||||
if shapes_same:
|
||||
value = F.pack(value)
|
||||
return generate_updates_from_tensor(data, index, value, op_type)
|
||||
|
||||
data_shape = F.shape(data)
|
||||
index_shape = F.shape(index)
|
||||
return convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type)
|
||||
def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name):
|
||||
"""Derive the resulting shape information from the a tuple index of mixed tensors."""
|
||||
index_tensor_info_key = list(index_tensors_info.keys())
|
||||
index_tensor_info_value = list(index_tensors_info.values())
|
||||
broadcast_shape = generate_broadcast_shape(index_tensor_info_value, op_name)
|
||||
final_shape = []
|
||||
indexes_shapes_info = []
|
||||
mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key)
|
||||
if mixed_tensors_continuous:
|
||||
tensor_shape_dealt = False
|
||||
for ele in indexes_info.values():
|
||||
if isinstance(ele, list):
|
||||
final_shape.append(len(ele))
|
||||
indexes_shapes_info.append(ele)
|
||||
elif isinstance(ele, tuple):
|
||||
if not tensor_shape_dealt:
|
||||
final_shape.extend(broadcast_shape)
|
||||
indexes_shapes_info.append(broadcast_shape)
|
||||
tensor_shape_dealt = True
|
||||
else:
|
||||
raise IndexError(f"For '{op_name}', the index elements only support "
|
||||
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.")
|
||||
else:
|
||||
final_shape.extend(broadcast_shape)
|
||||
indexes_shapes_info.append(broadcast_shape)
|
||||
for ele in indexes_info.values():
|
||||
if isinstance(ele, list):
|
||||
final_shape.append(len(ele))
|
||||
indexes_shapes_info.append(ele)
|
||||
elif isinstance(ele, tuple):
|
||||
continue
|
||||
else:
|
||||
raise IndexError(f"For '{op_name}', the index elements only support "
|
||||
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.")
|
||||
return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info)
|
||||
|
||||
|
||||
def generate_updates_from_tensor(data, index, value, op_type):
|
||||
"""Generate an updates tensor from a tensor."""
|
||||
data_shape = F.shape(data)
|
||||
index_shape = F.shape(index)
|
||||
value_shape = F.shape(value)
|
||||
data_dtype = F.dtype(data)
|
||||
value_dtype = F.dtype(value)
|
||||
updates_shape = value_shape
|
||||
check_dtype_same = check_tensors_dtype_same(data_dtype, value_dtype, TENSOR_SETITEM)
|
||||
if check_dtype_same:
|
||||
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
|
||||
need_broadcast = check_two_shapes_need_broadcast(updates_shape, value_shape)
|
||||
if need_broadcast:
|
||||
return tile(updates_shape, value)
|
||||
return value
|
||||
@constexpr
|
||||
def get_pos_of_int_index(indexes_types):
|
||||
"""Get int index positions from the mixed tensors index which contains int, tensor, slice, and ellipsis."""
|
||||
int_positions = []
|
||||
for i, ele_type in enumerate(indexes_types):
|
||||
if ele_type == mstype.int32:
|
||||
int_positions.append(i)
|
||||
return int_positions
|
||||
|
||||
|
||||
@constexpr
|
||||
def separate_mixed_tensors_index(indexes_types, op_name):
|
||||
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
|
||||
tensor_positions = []
|
||||
slice_positions = []
|
||||
ellipsis_position = None
|
||||
for i, ele_type in enumerate(indexes_types):
|
||||
if isinstance(ele_type, mstype.tensor_type):
|
||||
tensor_positions.append(i)
|
||||
elif isinstance(ele_type, mstype.slice_type):
|
||||
slice_positions.append(i)
|
||||
elif isinstance(ele_type, mstype.ellipsis_type):
|
||||
ellipsis_position = i
|
||||
else:
|
||||
raise IndexError(f"For '{op_name}', the index elements only support "
|
||||
f"'Tensor', 'int32', 'Slice', 'Ellipsis', but got {ele_type}.")
|
||||
|
||||
return tensor_positions, slice_positions, ellipsis_position
|
||||
|
||||
|
||||
@constexpr
|
||||
def scalar_in_sequence(x, y):
|
||||
"""Determine whether the scalar in the sequence."""
|
||||
if x is None:
|
||||
raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, "
|
||||
"but the scalar is not.")
|
||||
if y is None:
|
||||
raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, "
|
||||
"but the sequence is not.")
|
||||
if x in y:
|
||||
return True
|
||||
return False
|
|
@ -14,11 +14,11 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Implementation for getitem."""
|
||||
|
||||
from . import _utils as multi_utils
|
||||
from . import _compile_utils as compile_utils
|
||||
from . import _constexpr_utils as const_utils
|
||||
from .. import base
|
||||
from ... import functional as F
|
||||
from ....common import dtype as mstype
|
||||
|
||||
|
||||
getitem = base.MultitypeFuncGraph('getitem')
|
||||
"""
|
||||
|
@ -227,7 +227,8 @@ def _tensor_getitem_by_tensor(data, tensor_index):
|
|||
Outputs:
|
||||
Tensor, element type is same as the element type of data.
|
||||
"""
|
||||
check_dtypes = multi_utils.check_tensor_dtype_valid(F.dtype(tensor_index), (mstype.int32, mstype.int64))
|
||||
check_dtypes = const_utils.check_index_tensor_dtype(F.dtype(tensor_index),
|
||||
const_utils.TENSOR_GETITEM)
|
||||
result = None
|
||||
if check_dtypes:
|
||||
result = F.gather(data, tensor_index, 0)
|
||||
|
@ -246,14 +247,13 @@ def _tensor_getitem_by_tuple(data, tuple_index):
|
|||
Outputs:
|
||||
Tensor, element type is same as the element type of data.
|
||||
"""
|
||||
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_GETITEM)
|
||||
result = None
|
||||
if index_elements_type == multi_utils.NO_TENSOR:
|
||||
result = _tensor_slice(data, tuple_index)
|
||||
if index_elements_type == multi_utils.ALL_TENSOR:
|
||||
result = _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
|
||||
return result
|
||||
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
|
||||
if index_elements_type == const_utils.NO_TENSOR:
|
||||
return _tensor_slice(data, tuple_index)
|
||||
if index_elements_type == const_utils.ALL_TENSOR:
|
||||
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
|
||||
return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index)
|
||||
|
||||
|
||||
@getitem.register("Tensor", "Ellipsis")
|
||||
|
@ -273,6 +273,17 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
|
|||
|
||||
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
|
||||
"""Tensor getitem by a tuple of tensor."""
|
||||
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_GETITEM)
|
||||
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_GETITEM)
|
||||
result = F.gather_nd(data, indices)
|
||||
return result
|
||||
|
||||
|
||||
def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
|
||||
"""Tensor getitem by a tuple of mixed tensor."""
|
||||
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_GETITEM)
|
||||
result = F.gather_nd(data, indices)
|
||||
return result
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""in_impl"""
|
||||
|
||||
from . import _constexpr_utils as const_utils
|
||||
from ... import functional as F
|
||||
from ...composite import base
|
||||
|
||||
in_ = base.MultitypeFuncGraph("in")
|
||||
"""
|
||||
in_ is a metafuncgraph object which will determine if a in b
|
||||
using ".register" decorator
|
||||
"""
|
||||
|
||||
|
||||
@in_.register("Number", "Tuple")
|
||||
def _number_in_tuple(x, y):
|
||||
"""
|
||||
Determine if a number in tuple.
|
||||
|
||||
Args:
|
||||
x (Number): x
|
||||
y (tuple): y
|
||||
|
||||
Returns:
|
||||
bool, if x in y return true, x not in y return false.
|
||||
"""
|
||||
return const_utils.scalar_in_sequence(x, y)
|
||||
|
||||
|
||||
@in_.register("Number", "List")
|
||||
def _number_in_list(x, y):
|
||||
"""
|
||||
Determine if a number in list.
|
||||
|
||||
Args:
|
||||
x (Number): x
|
||||
y (list): y
|
||||
|
||||
Returns:
|
||||
bool, if x in y return true, x not in y return false.
|
||||
"""
|
||||
return const_utils.scalar_in_sequence(x, y)
|
||||
|
||||
|
||||
@in_.register("String", "Tuple")
|
||||
def _string_in_tuple(x, y):
|
||||
"""
|
||||
Determine if a str in a tuple.
|
||||
|
||||
Args:
|
||||
x (str): x
|
||||
y (tuple): y
|
||||
|
||||
Returns:
|
||||
bool, if x in y return true, x not in y return false.
|
||||
"""
|
||||
return const_utils.scalar_in_sequence(x, y)
|
||||
|
||||
|
||||
@in_.register("String", "List")
|
||||
def _string_in_list(x, y):
|
||||
"""
|
||||
Determine if a str in a list.
|
||||
|
||||
Args:
|
||||
x (str): x
|
||||
y (list): y
|
||||
|
||||
Returns:
|
||||
bool, if x in y return true, x not in y return false.
|
||||
"""
|
||||
return const_utils.scalar_in_sequence(x, y)
|
||||
|
||||
|
||||
@in_.register("String", "Dictionary")
|
||||
def _str_in_dict(x, y):
|
||||
"""
|
||||
Determine if a str in dict.
|
||||
|
||||
Args:
|
||||
x: str
|
||||
y: dict
|
||||
|
||||
Returns:
|
||||
bool, if x in y return true, x not in y return false.
|
||||
"""
|
||||
return F.in_dict(x, y)
|
|
@ -15,10 +15,11 @@
|
|||
|
||||
"""Implementation for setitem."""
|
||||
|
||||
from . import _compile_utils as compile_utils
|
||||
from . import _constexpr_utils as const_utils
|
||||
from ... import functional as F
|
||||
from ...composite import base
|
||||
from ....common import dtype as mstype
|
||||
from ... import functional as F
|
||||
from . import _utils as multi_utils
|
||||
|
||||
setitem = base.MultitypeFuncGraph('setitem')
|
||||
|
||||
|
@ -139,8 +140,8 @@ def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
|
|||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
index_dtype = F.dtype(index)
|
||||
tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype)
|
||||
if tensor_dtype == multi_utils.INT_:
|
||||
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
|
||||
if tensor_dtype == const_utils.INT_:
|
||||
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
|
||||
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
|
||||
|
||||
|
@ -166,8 +167,8 @@ def _tensor_setitem_by_tensor_with_number(data, index, value):
|
|||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
index_dtype = F.dtype(index)
|
||||
tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype)
|
||||
if tensor_dtype == multi_utils.BOOL_:
|
||||
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
|
||||
if tensor_dtype == const_utils.BOOL_:
|
||||
return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value)
|
||||
return _tensor_setitem_by_int_tensor_with_scalar(data, index, value)
|
||||
|
||||
|
@ -190,17 +191,24 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM)
|
||||
result = None
|
||||
if index_elements_type == multi_utils.NO_TENSOR:
|
||||
result = _tensor_assgin_number(data, tuple_index, value)
|
||||
if index_elements_type == multi_utils.ALL_TENSOR:
|
||||
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
|
||||
updates = multi_utils.generate_updates_from_scalar(data, indices, value,
|
||||
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
result = F.scatter_nd_update(data, indices, updates)
|
||||
return result
|
||||
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
|
||||
if index_elements_type == const_utils.NO_TENSOR:
|
||||
return _tensor_assgin_number(data, tuple_index, value)
|
||||
if index_elements_type == const_utils.ALL_TENSOR:
|
||||
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
else:
|
||||
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
updates = compile_utils.generate_updates_from_scalar(data,
|
||||
indices,
|
||||
value,
|
||||
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
return F.scatter_nd_update(data, indices, updates)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tuple", "Tensor")
|
||||
|
@ -221,17 +229,24 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM)
|
||||
result = None
|
||||
if index_elements_type == multi_utils.NO_TENSOR:
|
||||
result = _tensor_assgin_tensor(data, tuple_index, value)
|
||||
if index_elements_type == multi_utils.ALL_TENSOR:
|
||||
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
|
||||
updates = multi_utils.generate_updates_from_tensor(data, indices, value,
|
||||
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
result = F.scatter_nd_update(data, indices, updates)
|
||||
return result
|
||||
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
|
||||
if index_elements_type == const_utils.NO_TENSOR:
|
||||
return _tensor_assgin_tensor(data, tuple_index, value)
|
||||
if index_elements_type == const_utils.ALL_TENSOR:
|
||||
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
else:
|
||||
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
updates = compile_utils.generate_updates_from_tensor(data,
|
||||
indices,
|
||||
value,
|
||||
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
return F.scatter_nd_update(data, indices, updates)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tuple", "Tuple")
|
||||
|
@ -253,15 +268,22 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM)
|
||||
result = None
|
||||
if index_elements_type == multi_utils.ALL_TENSOR:
|
||||
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
|
||||
updates = multi_utils.generate_updates_from_tuple(data, indices, value,
|
||||
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
result = F.scatter_nd_update(data, indices, updates)
|
||||
return result
|
||||
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
|
||||
|
||||
if index_elements_type == const_utils.ALL_TENSOR:
|
||||
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
else:
|
||||
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_SETITEM)
|
||||
updates = compile_utils.generate_updates_from_tuple(data,
|
||||
indices,
|
||||
value,
|
||||
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
|
||||
return F.scatter_nd_update(data, indices, updates)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tensor", "Tuple")
|
||||
|
@ -278,7 +300,7 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
|
|||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
index_dtype = F.dtype(index)
|
||||
check_dtype = multi_utils.check_tensor_dtype_valid(index_dtype, (mstype.int32, mstype.int64))
|
||||
check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM)
|
||||
result = None
|
||||
if check_dtype:
|
||||
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
|
||||
|
@ -331,14 +353,14 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value):
|
|||
|
||||
def _tensor_assgin_number(data, input_slice, value):
|
||||
"""Givens a scalar assign to tensor by slice"""
|
||||
check_result = multi_utils.check_tensor_setitem_index(input_slice)
|
||||
check_result = const_utils.check_tensor_setitem_index(input_slice)
|
||||
result = None
|
||||
if check_result:
|
||||
data_shape = F.shape(data)
|
||||
indices = multi_utils.slice2indices(input_slice, data_shape)
|
||||
is_tuple_int = multi_utils.tuple_element_is_int(input_slice)
|
||||
indices = const_utils.slice2indices(input_slice, data_shape)
|
||||
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
|
||||
if is_tuple_int:
|
||||
indices = multi_utils.integer_to_indices(input_slice, data_shape)
|
||||
indices = const_utils.integer_to_indices(input_slice, data_shape)
|
||||
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
|
||||
return result
|
||||
|
||||
|
@ -347,7 +369,7 @@ def _tensor_assgin_number(data, input_slice, value):
|
|||
def _tensor_setitem_with_int_v1(data, index, value):
|
||||
"""Syntax: A[1] = 3"""
|
||||
data_shape = F.shape(data)
|
||||
indices = multi_utils.integer_to_indices(index, data_shape)
|
||||
indices = const_utils.integer_to_indices(index, data_shape)
|
||||
return _tensor_indices_number(data, data_shape, index, indices, value)
|
||||
|
||||
|
||||
|
@ -355,7 +377,7 @@ def _tensor_setitem_with_int_v1(data, index, value):
|
|||
def _tensor_setitem_with_int_v2(data, index, value):
|
||||
"""Syntax: A[1] = Tensor"""
|
||||
data_shape = F.shape(data)
|
||||
indices = multi_utils.integer_to_indices(index, data_shape)
|
||||
indices = const_utils.integer_to_indices(index, data_shape)
|
||||
return _tensor_indices_tensor(data, data_shape, index, indices, value)
|
||||
|
||||
|
||||
|
@ -376,7 +398,7 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
|
|||
data_size = F.size(data)
|
||||
value_shape = F.shape(value)
|
||||
value_size = F.size(value)
|
||||
check_result = multi_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
|
||||
check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
|
||||
if check_result:
|
||||
if data_size == value_size:
|
||||
result = F.reshape(value, data_shape)
|
||||
|
@ -391,13 +413,13 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
|
|||
def _tensor_assgin_tensor(data, input_slice, value):
|
||||
"""Assigns a tensor value to the tensor by slice."""
|
||||
result = None
|
||||
check_result = multi_utils.check_tensor_setitem_index(input_slice)
|
||||
check_result = const_utils.check_tensor_setitem_index(input_slice)
|
||||
if check_result:
|
||||
data_shape = F.shape(data)
|
||||
indices = multi_utils.slice2indices(input_slice, data_shape)
|
||||
is_tuple_int = multi_utils.tuple_element_is_int(input_slice)
|
||||
indices = const_utils.slice2indices(input_slice, data_shape)
|
||||
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
|
||||
if is_tuple_int:
|
||||
indices = multi_utils.integer_to_indices(input_slice, data_shape)
|
||||
indices = const_utils.integer_to_indices(input_slice, data_shape)
|
||||
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
|
||||
return result
|
||||
|
||||
|
@ -407,7 +429,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value):
|
|||
data_size = F.size(data)
|
||||
data_dtype = F.dtype(data)
|
||||
indices_size = F.size(indices)
|
||||
indices_size = multi_utils.check_indices(indices_size, index)
|
||||
indices_size = const_utils.check_indices(indices_size, index)
|
||||
update = F.fill(mstype.int32, (indices_size,), 1)
|
||||
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
||||
condition = F.reshape(condition_1d, data_shape)
|
||||
|
@ -415,7 +437,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value):
|
|||
value_fill = None
|
||||
value_size = F.size(value)
|
||||
|
||||
value_size = multi_utils.check_indices_value_size(indices_size, value_size)
|
||||
value_size = const_utils.check_indices_value_size(indices_size, value_size)
|
||||
if value_size == 1:
|
||||
value_fill = F.fill(data_dtype, (indices_size,), 1)
|
||||
value = F.cast(value, data_dtype)
|
||||
|
@ -432,7 +454,7 @@ def _tensor_indices_number(data, data_shape, index, indices, value):
|
|||
data_size = F.size(data)
|
||||
data_dtype = F.dtype(data)
|
||||
indices_size = F.size(indices)
|
||||
indices_size = multi_utils.check_indices(indices_size, index)
|
||||
indices_size = const_utils.check_indices(indices_size, index)
|
||||
update = F.fill(mstype.int32, (indices_size,), 1)
|
||||
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
||||
condition = F.reshape(condition_1d, data_shape)
|
||||
|
@ -445,16 +467,16 @@ def _tensor_indices_number(data, data_shape, index, indices, value):
|
|||
|
||||
def _tensor_setitem_by_tensor_with_tuple(data, index, value):
|
||||
"""Set a tensor item by a tensor with a tuple."""
|
||||
updates = multi_utils.generate_updates_from_tuple(data, index, value,
|
||||
multi_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
updates = compile_utils.generate_updates_from_tuple(data, index, value,
|
||||
const_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
result = F.scatter_update(data, index, updates)
|
||||
return result
|
||||
|
||||
|
||||
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
|
||||
"""Set a tensor item by a int tensor with a scalar."""
|
||||
updates = multi_utils.generate_updates_from_scalar(data, index, value,
|
||||
multi_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
updates = compile_utils.generate_updates_from_scalar(data, index, value,
|
||||
const_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
return F.scatter_update(data, index, updates)
|
||||
|
||||
|
||||
|
@ -462,7 +484,7 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
|
|||
"""Set a tensor item by a bool tensor with a scalar."""
|
||||
index_shape = F.shape(index)
|
||||
shape = F.shape(data)
|
||||
shape = multi_utils.check_equal(
|
||||
shape = const_utils.check_equal(
|
||||
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||
dtype = F.dtype(data)
|
||||
u = F.fill(dtype, shape, value)
|
||||
|
@ -471,8 +493,8 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
|
|||
|
||||
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
|
||||
"""Set a tensor item by a int tensor with a tensor."""
|
||||
updates = multi_utils.generate_updates_from_tensor(data, index, value,
|
||||
multi_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
updates = compile_utils.generate_updates_from_tensor(data, index, value,
|
||||
const_utils.SET_ITEM_BY_ONE_TENSOR)
|
||||
return F.scatter_update(data, index, updates)
|
||||
|
||||
|
||||
|
@ -480,10 +502,10 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
|||
"""Set a tensor item by a bool tensor with a tensor."""
|
||||
index_shape = F.shape(index)
|
||||
data_shape = F.shape(data)
|
||||
data_shape = multi_utils.check_equal(data_shape, index_shape,
|
||||
data_shape = const_utils.check_equal(data_shape, index_shape,
|
||||
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||
size = F.size(value)
|
||||
size = multi_utils.check_equal(1, size,
|
||||
size = const_utils.check_equal(1, size,
|
||||
"When assign value is a tensor, its size should be {}, but current size is {}.")
|
||||
dtype = F.dtype(data)
|
||||
u_cast = F.cast(value, dtype)
|
||||
|
|
|
@ -1419,7 +1419,6 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name):
|
|||
validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
|
||||
validator.check_integer("len of input_x", len(x_shape), 1, Rel.GT, prim_name)
|
||||
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name)
|
||||
validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT, prim_name)
|
||||
rank_base = len(x_shape[0])
|
||||
N = len(x_shape)
|
||||
out_shape = x_shape[0]
|
||||
|
|
|
@ -33,9 +33,4 @@ class IdentityEC(IExectorComponent):
|
|||
keyword.desc_inputs: self.inputs[keyword.desc_inputs],
|
||||
keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs])
|
||||
}
|
||||
print("buxue------------------------------------------------")
|
||||
print("inputs")
|
||||
print(ret[keyword.desc_inputs])
|
||||
print("outputs")
|
||||
print(ret[keyword.result])
|
||||
return ret
|
||||
|
|
|
@ -19,9 +19,9 @@ import mindspore.nn as nn
|
|||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from ..ut_filter import non_graph_engine
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
from tests.ut.python.ut_filter import non_graph_engine
|
||||
from tests.mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from tests.mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
|
||||
|
||||
|
||||
|
@ -152,6 +152,20 @@ class ListOperate(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
class InListNet(nn.Cell):
|
||||
def __init__(self, ):
|
||||
super(InListNet, self).__init__()
|
||||
self.list_ = [1, 2, 3, 4, 5, "ok"]
|
||||
|
||||
def construct(self, x):
|
||||
ret = x
|
||||
if 2 in self.list_:
|
||||
ret = x + x
|
||||
if "ok" in self.list_:
|
||||
ret = x - x
|
||||
return ret
|
||||
|
||||
|
||||
class AxisListNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(AxisListNet, self).__init__()
|
||||
|
@ -204,10 +218,15 @@ test_case_ops = [
|
|||
('AxisListDefault', {
|
||||
'block': AxisListDefaultNet(),
|
||||
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
|
||||
('InList', {
|
||||
'block': InListNet(),
|
||||
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
|
||||
]
|
||||
|
||||
test_case_lists = [test_case_ops]
|
||||
test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
|
||||
|
||||
|
||||
# use -k to select certain testcast
|
||||
# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm
|
||||
|
|
@ -19,9 +19,9 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype as mstype
|
||||
from ..ut_filter import non_graph_engine
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
from tests.ut.python.ut_filter import non_graph_engine
|
||||
from tests.mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from tests.mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
|
@ -52,6 +52,20 @@ class NestTupleGraphNet(nn.Cell):
|
|||
return self.layers[0][1](x)
|
||||
|
||||
|
||||
class InTupleNet(nn.Cell):
|
||||
def __init__(self, ):
|
||||
super(InTupleNet, self).__init__()
|
||||
self.tuple_ = (1, 2, 3, 4, 5, "ok")
|
||||
|
||||
def construct(self, x):
|
||||
ret = x
|
||||
if 2 in self.tuple_:
|
||||
ret = x + x
|
||||
if "ok" in self.tuple_:
|
||||
ret = x - x
|
||||
return ret
|
||||
|
||||
|
||||
test_case_ops = [
|
||||
('TupleGraph', {
|
||||
'block': TupleGraphNet(),
|
||||
|
@ -59,6 +73,9 @@ test_case_ops = [
|
|||
('NestTupleGraph', {
|
||||
'block': NestTupleGraphNet(),
|
||||
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
|
||||
('InTuple', {
|
||||
'block': InTupleNet(),
|
||||
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]})
|
||||
]
|
||||
|
||||
test_case_lists = [test_case_ops]
|
|
@ -176,12 +176,134 @@ class TensorGetItemByThreeTensors(Cell):
|
|||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors(Cell):
|
||||
class TensorGetItemByMixedTensors_0(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors, self).__init__()
|
||||
super(TensorGetItemByMixedTensors_0, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1):
|
||||
ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_1(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_1, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1):
|
||||
ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_2(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_2, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1):
|
||||
ret = tensor[0, index_0, index_1, ..., 3] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_3(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_3, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1):
|
||||
ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_4(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_4, self).__init__()
|
||||
self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1, index_2):
|
||||
ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_5(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_5, self).__init__()
|
||||
self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1, index_2):
|
||||
ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_6(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_6, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1, index_2):
|
||||
ret = tensor[..., index_0, index_1, index_2, 3] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByMixedTensors_0(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByMixedTensors_0, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32))
|
||||
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)),
|
||||
mstype.float32),
|
||||
name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index_0, index_1, index_2):
|
||||
self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByMixedTensors_1(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByMixedTensors_1, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32))
|
||||
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
|
||||
name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index_0, index_1, index_2):
|
||||
self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByMixedTensors_2(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByMixedTensors_2, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32))
|
||||
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
|
||||
name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index_0, index_1, index_2):
|
||||
self.param[..., index_0, index_1, index_2, 3] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensorsTypeError(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensorsTypeError, self).__init__()
|
||||
|
||||
def construct(self, x, index_0, index_1):
|
||||
ret = x[index_0, index_1, 0:6]
|
||||
ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]]
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensorsNumberError(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensorsNumberError, self).__init__()
|
||||
|
||||
def construct(self, x, index_0, index_1):
|
||||
ret = x[index_0, index_1, 0:3, ..., index_1, index_0]
|
||||
return ret
|
||||
|
||||
|
||||
|
@ -538,11 +660,11 @@ def test_tensor_assign_bool_index():
|
|||
net1(Ta, Tb, Tc, u_tensor)
|
||||
with pytest.raises(ValueError):
|
||||
net1(Ta, Td, Tc, u_tensor)
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(IndexError):
|
||||
net1(Ta, u_tensor, Tc, u_tensor)
|
||||
with pytest.raises(ValueError):
|
||||
net1(Ta, Tb, Td, u_tensor)
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(IndexError):
|
||||
net1(Ta, Tb, Ta, u_tensor)
|
||||
with pytest.raises(ValueError):
|
||||
net1(Ta, Tb, Tc, u_tensor_error)
|
||||
|
@ -636,6 +758,51 @@ test_cases = [
|
|||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensors_0', {
|
||||
'block': TensorGetItemByMixedTensors_0(),
|
||||
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
|
||||
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensors_1', {
|
||||
'block': TensorGetItemByMixedTensors_1(),
|
||||
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
|
||||
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensors_2', {
|
||||
'block': TensorGetItemByMixedTensors_2(),
|
||||
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
|
||||
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensors_3', {
|
||||
'block': TensorGetItemByMixedTensors_3(),
|
||||
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
|
||||
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensors_4', {
|
||||
'block': TensorGetItemByMixedTensors_4(),
|
||||
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.float32),
|
||||
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensors_5', {
|
||||
'block': TensorGetItemByMixedTensors_5(),
|
||||
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
|
||||
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensors_6', {
|
||||
'block': TensorGetItemByMixedTensors_6(),
|
||||
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
|
||||
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByOneTensorWithNumber', {
|
||||
'block': TensorSetItemByOneTensorWithNumber(value=0.0),
|
||||
'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
|
||||
|
@ -683,46 +850,143 @@ test_cases = [
|
|||
Tensor(np.zeros((4, 5)), mstype.float32),
|
||||
Tensor(np.ones((4, 5)), mstype.float32),
|
||||
Tensor(np.ones((4, 5)) * 2, mstype.float32)],
|
||||
})
|
||||
}),
|
||||
('TensorSetItemByMixedTensorsWithNumber_0', {
|
||||
'block': TensorSetItemByMixedTensors_0(value=88.0),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByMixedTensorsWithTensor_0', {
|
||||
'block': TensorSetItemByMixedTensors_0(value=Tensor(np.ones((4, 5, 3, 9), np.float32))),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensorsWithTupleOfNumber_0', {
|
||||
'block': TensorSetItemByMixedTensors_0(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensorsWithTupleOfTensor_0', {
|
||||
'block': TensorSetItemByMixedTensors_0(value=(Tensor(np.ones((4, 5, 3, 9), np.float32)),
|
||||
Tensor(np.zeros((4, 5, 3, 9), np.float32)),
|
||||
Tensor(np.ones((4, 5, 3, 9), np.float32)))),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByMixedTensorsWithNumber_1', {
|
||||
'block': TensorSetItemByMixedTensors_1(value=88.0),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByMixedTensorsWithTensor_1', {
|
||||
'block': TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensorsWithTupleOfNumber_1', {
|
||||
'block': TensorSetItemByMixedTensors_1(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensorsWithTupleOfTensor_1', {
|
||||
'block': TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
|
||||
Tensor(np.zeros((5, 2, 6), np.float32)),
|
||||
Tensor(np.ones((5, 2, 6), np.float32)),
|
||||
Tensor(np.ones((5, 2, 6), np.float32)))),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByMixedTensorsWithNumber_2', {
|
||||
'block': TensorSetItemByMixedTensors_2(value=88.0),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByMixedTensorsWithTensor_2', {
|
||||
'block': TensorSetItemByMixedTensors_2(value=Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensorsWithTupleOfNumber_2', {
|
||||
'block': TensorSetItemByMixedTensors_2(value=(1.0, 2.0, 3.0, 4.0, 5.0)),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensorsWithTupleOfTensor_2', {
|
||||
'block': TensorSetItemByMixedTensors_2(value=(Tensor(np.ones((4, 5), np.float32)),
|
||||
Tensor(np.zeros((4, 5), np.float32)),
|
||||
Tensor(np.ones((4, 5), np.float32)))),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
]
|
||||
|
||||
raise_error_set = [
|
||||
('TensorGetItemByOneTensorDtypeError', {
|
||||
'block': (TensorGetItemByOneTensor(), {'exception': TypeError}),
|
||||
'block': (TensorGetItemByOneTensor(), {'exception': IndexError}),
|
||||
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
|
||||
Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)],
|
||||
}),
|
||||
('TensorGetItemByTwoTensorsShapeError', {
|
||||
'block': (TensorGetItemByTwoTensors(), {'exception': ValueError}),
|
||||
'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}),
|
||||
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
|
||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByTwoTensorsDtypeError', {
|
||||
'block': (TensorGetItemByTwoTensors(), {'exception': TypeError}),
|
||||
'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}),
|
||||
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
|
||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)],
|
||||
}),
|
||||
('TensorGetItemByThreeTensorsShapeError', {
|
||||
'block': (TensorGetItemByThreeTensors(), {'exception': ValueError}),
|
||||
'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}),
|
||||
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
|
||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByThreeTensorsDtypeError', {
|
||||
'block': (TensorGetItemByThreeTensors(), {'exception': TypeError}),
|
||||
'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}),
|
||||
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
|
||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int64),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensors', {
|
||||
'block': (TensorGetItemByMixedTensors(), {'exception': IndexError}),
|
||||
('TensorGetItemByMixedTensorsNumberError', {
|
||||
'block': (TensorGetItemByMixedTensorsNumberError(), {'exception': IndexError}),
|
||||
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
|
||||
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64)],
|
||||
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensorsTypeError', {
|
||||
'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensorsDtypeError', {
|
||||
'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}),
|
||||
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.float32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensorsShapeError', {
|
||||
'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}),
|
||||
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(2, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByOneTensorWithNumberTypeError', {
|
||||
'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}),
|
||||
|
@ -782,7 +1046,7 @@ raise_error_set = [
|
|||
Tensor(np.zeros((2, 5)), mstype.float32)],
|
||||
}),
|
||||
('TensorSetItemByTensorsWithTupleOfNumberTypeError', {
|
||||
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0, 1, 2, 3, 4)), {'exception': TypeError}),
|
||||
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1, 2, 3, 4)), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
|
||||
|
@ -810,10 +1074,65 @@ raise_error_set = [
|
|||
Tensor(np.ones((4, 5)), mstype.int32),
|
||||
Tensor(np.ones((4, 5)) * 2, mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByMixedTensors', {
|
||||
'block': (TensorSetItemByMixedTensors(), {'exception': IndexError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)],
|
||||
('TensorSetItemByMixedTensorsWithNumberValueTypeError', {
|
||||
'block': (TensorSetItemByMixedTensors_1(value=88), {'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByMixedTensorsWithNumberIndexTypeError', {
|
||||
'block': (TensorSetItemByMixedTensors_1(value=88.0), {'exception': IndexError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.float32)],
|
||||
}),
|
||||
('TensorSetItemByMixedTensorsWithTensorValueDtypeError', {
|
||||
'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.int32))),
|
||||
{'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByMixedTensorsWithTensorValueShapeError', {
|
||||
'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((3, 2, 6), np.float32))),
|
||||
{'exception': ValueError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorSetItemByMixedTensorsWithTensorIndexDtypeError', {
|
||||
'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))),
|
||||
{'exception': IndexError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.float32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensorsWithTupleOfNumberValueTypeError', {
|
||||
'block': (TensorSetItemByMixedTensors_1(value=(1.0, 2, 3.0, 4.0, 5.0, 6.0)),
|
||||
{'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensorsWithTupleOfTensorValueDtypeError', {
|
||||
'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
|
||||
Tensor(np.zeros((5, 2, 6), np.float32)),
|
||||
Tensor(np.ones((5, 2, 6), np.float32)),
|
||||
Tensor(np.ones((5, 2, 6), np.int32)))),
|
||||
{'exception': TypeError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
}),
|
||||
('TensorGetItemByMixedTensorsWithTupleOfTensorIndexDtypeError', {
|
||||
'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
|
||||
Tensor(np.zeros((5, 2, 6), np.float32)),
|
||||
Tensor(np.ones((5, 2, 6), np.float32)),
|
||||
Tensor(np.ones((5, 2, 6), np.int32)))),
|
||||
{'exception': IndexError}),
|
||||
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.float32),
|
||||
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
|
||||
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
|
||||
})
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue