!19142 tensor setitem improve performance with tensorcopyslices

Merge pull request !19142 from huangmengxi/setitem_dev
This commit is contained in:
i-robot 2021-07-01 09:56:55 +00:00 committed by Gitee
commit 89dc56045d
2 changed files with 31 additions and 0 deletions

View File

@ -18,12 +18,14 @@ from . import _constexpr_utils as const_utils
from ... import functional as F
from ... import operations as P
from ...composite import base
from ...operations._inner_ops import TensorCopySlices
from ....common.tensor import Tensor
from ....common import dtype as mstype
from ....common._register_for_tensor import tensor_operator_registry
hyper_map = base.HyperMap()
stack = P.Stack(axis=-1)
copy_slice = TensorCopySlices()
def _tensor_getitem(self, index):
@ -671,6 +673,14 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
result = None
check_result = const_utils.check_tensor_setitem_index(input_slice)
if check_result:
start, stop, step = const_utils.normalize_slice(input_slice, data.shape[0])
if step == 1:
dim0_size = stop - start
if dim0_size <= 0:
return data
value_shape = (dim0_size,) + const_utils.tuple_slice(data.shape, 1, None)
value = _broadcast(value_shape, value)
return copy_slice(data, value, (start,), (stop,), (step,))
data_shape = F.shape(data)
indices = const_utils.slice2indices(input_slice, data_shape)
if indices is False:
@ -691,6 +701,18 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
"""Assigns the tensor by tuple with tensor value."""
op_name = const_utils.TENSOR_SETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
if const_utils.use_copy_slice(tuple_index):
dim1_start, dim1_stop, _ = const_utils.normalize_slice(tuple_index[1], data.shape[1])
if dim1_stop - dim1_start <= 0:
return data
start = (tuple_index[0], dim1_start)
stop = (tuple_index[0] + 1, dim1_stop)
step = (1, 1)
value_shape = (dim1_stop - dim1_start,) + const_utils.tuple_slice(data.shape, 2, None)
value = _broadcast(value_shape, value)
return copy_slice(data, value, start, stop, step)
tuple_index, value, idx_advanced = remove_expanded_dims(tuple_index, F.shape(data), value)
if tuple_index is False:

View File

@ -825,3 +825,12 @@ def infer_out_shape(*shapes):
raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}')
shape_out.appendleft(max_size)
return tuple(shape_out)
@constexpr
def use_copy_slice(tuple_index):
if tuple_index is not None and len(tuple_index) >= 2:
return (isinstance(tuple_index[0], int) and
isinstance(tuple_index[1], slice) and tuple_index[1].step in (1, None) and
all(x == slice(None, None, None) for x in tuple_index[2:]))
return False