!19142 tensor setitem improve performance with tensorcopyslices
Merge pull request !19142 from huangmengxi/setitem_dev
This commit is contained in:
commit
89dc56045d
|
@ -18,12 +18,14 @@ from . import _constexpr_utils as const_utils
|
|||
from ... import functional as F
|
||||
from ... import operations as P
|
||||
from ...composite import base
|
||||
from ...operations._inner_ops import TensorCopySlices
|
||||
from ....common.tensor import Tensor
|
||||
from ....common import dtype as mstype
|
||||
from ....common._register_for_tensor import tensor_operator_registry
|
||||
|
||||
hyper_map = base.HyperMap()
|
||||
stack = P.Stack(axis=-1)
|
||||
copy_slice = TensorCopySlices()
|
||||
|
||||
|
||||
def _tensor_getitem(self, index):
|
||||
|
@ -671,6 +673,14 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
|
|||
result = None
|
||||
check_result = const_utils.check_tensor_setitem_index(input_slice)
|
||||
if check_result:
|
||||
start, stop, step = const_utils.normalize_slice(input_slice, data.shape[0])
|
||||
if step == 1:
|
||||
dim0_size = stop - start
|
||||
if dim0_size <= 0:
|
||||
return data
|
||||
value_shape = (dim0_size,) + const_utils.tuple_slice(data.shape, 1, None)
|
||||
value = _broadcast(value_shape, value)
|
||||
return copy_slice(data, value, (start,), (stop,), (step,))
|
||||
data_shape = F.shape(data)
|
||||
indices = const_utils.slice2indices(input_slice, data_shape)
|
||||
if indices is False:
|
||||
|
@ -691,6 +701,18 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
|
|||
"""Assigns the tensor by tuple with tensor value."""
|
||||
op_name = const_utils.TENSOR_SETITEM
|
||||
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
||||
|
||||
if const_utils.use_copy_slice(tuple_index):
|
||||
dim1_start, dim1_stop, _ = const_utils.normalize_slice(tuple_index[1], data.shape[1])
|
||||
if dim1_stop - dim1_start <= 0:
|
||||
return data
|
||||
start = (tuple_index[0], dim1_start)
|
||||
stop = (tuple_index[0] + 1, dim1_stop)
|
||||
step = (1, 1)
|
||||
value_shape = (dim1_stop - dim1_start,) + const_utils.tuple_slice(data.shape, 2, None)
|
||||
value = _broadcast(value_shape, value)
|
||||
return copy_slice(data, value, start, stop, step)
|
||||
|
||||
tuple_index, value, idx_advanced = remove_expanded_dims(tuple_index, F.shape(data), value)
|
||||
|
||||
if tuple_index is False:
|
||||
|
|
|
@ -825,3 +825,12 @@ def infer_out_shape(*shapes):
|
|||
raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}')
|
||||
shape_out.appendleft(max_size)
|
||||
return tuple(shape_out)
|
||||
|
||||
|
||||
@constexpr
|
||||
def use_copy_slice(tuple_index):
|
||||
if tuple_index is not None and len(tuple_index) >= 2:
|
||||
return (isinstance(tuple_index[0], int) and
|
||||
isinstance(tuple_index[1], slice) and tuple_index[1].step in (1, None) and
|
||||
all(x == slice(None, None, None) for x in tuple_index[2:]))
|
||||
return False
|
||||
|
|
Loading…
Reference in New Issue