diff --git a/mindspore/ccsrc/operator/composite/composite.cc b/mindspore/ccsrc/operator/composite/composite.cc index b0edaa345c8..88d264974dd 100644 --- a/mindspore/ccsrc/operator/composite/composite.cc +++ b/mindspore/ccsrc/operator/composite/composite.cc @@ -1172,6 +1172,12 @@ int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, co return 1; } +FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) { + auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional"); + ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph)); + return ret_graph; +} + FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // slice a tensor // args: tensor, slice or slice tuple @@ -1229,12 +1235,6 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec return ret_graph; } -FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const { - auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional"); - ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph)); - return ret_graph; -} - FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // select indexed item // args: tuple of items, index diff --git a/mindspore/ccsrc/operator/composite/composite.h b/mindspore/ccsrc/operator/composite/composite.h index 7061eb7441f..c5166f5cc15 100644 --- a/mindspore/ccsrc/operator/composite/composite.h +++ b/mindspore/ccsrc/operator/composite/composite.h @@ -206,8 +206,6 @@ class TensorSlice : public MetaFuncGraph { MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } - - FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const; }; using TensorSlicePtr = std::shared_ptr; diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index cbc88552401..0b4aca2b99f 100644 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -101,6 +101,7 @@ const char kNameReLU6[] = "ReLU6"; const char kNameReLU6Grad[] = "ReLU6Grad"; const char kNameElu[] = "Elu"; const char kNameEluGrad[] = "EluGrad"; +const char kNameScatterUpdate[] = "ScatterUpdate"; const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; const char kNameScatterMax[] = "ScatterMax"; const char kNameNMSWithMask[] = "NMSWithMask"; @@ -256,6 +257,7 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)}, {string(kNameZerosLike), ADPT_DESC(ZerosLike)}, {string(kNameOnesLike), ADPT_DESC(OnesLike)}, + {string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)}, {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, {string(kNameScatterMax), ADPT_DESC(ScatterMax)}, {string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)}, diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 7e7b3391b7e..92743b9e888 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -515,6 +515,11 @@ INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}}; ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits())}, {"num", ATTR_DESC(num, AnyTraits())}}; DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}}; +// ScatterUpdate +INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; +ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ScatterUpdate) = {{0, OUTPUT_DESC(var)}}; + // ScatterNdUpdate INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index 1ab8d4efd7d..2d7d0b159a9 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -132,6 +132,8 @@ DECLARE_OP_ADAPTER(ZerosLike) DECLARE_OP_USE_OUTPUT(ZerosLike) DECLARE_OP_ADAPTER(OnesLike) DECLARE_OP_USE_OUTPUT(OnesLike) +DECLARE_OP_ADAPTER(ScatterUpdate) +DECLARE_OP_USE_OUTPUT(ScatterUpdate) DECLARE_OP_ADAPTER(ScatterNdUpdate) DECLARE_OP_USE_OUTPUT(ScatterNdUpdate) DECLARE_OP_ADAPTER(ScatterMax) diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index fae47af1cb4..5756f7f2c3c 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -178,13 +178,14 @@ from .bounding_box_encode import _bounding_box_encode_tbe from .check_valid import _check_valid_tbe from .iou import _iou_tbe from .arg_max import _arg_max_tbe -from .nms_with_mask import nms_with_mask_op_info -from .random_choice_with_mask import random_choice_with_mask_op_info -from .sgd import sgd_op_info -from .lars_update import lars_update_op_info +from .nms_with_mask import _nms_with_mask_tbe +from .random_choice_with_mask import _random_choice_with_mask_tbe +from .sgd import _sgd_tbe +from .lars_update import _lars_update_tbe from .bn_training_update_v2 import _bn_training_update_v2_tbe -from .square_sum_all import square_sum_all_op_info +from .square_sum_all import _square_sum_all_tbe from .pack import _pack_tbe from .unpack import _unpack_tbe +from .scatter_update import _scatter_update_tbe from .prelu import _prelu_tbe from .prelu_grad import _prelu_grad_tbe diff --git a/mindspore/ops/_op_impl/tbe/scatter_update.py b/mindspore/ops/_op_impl/tbe/scatter_update.py new file mode 100644 index 00000000000..3c330fe4353 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/scatter_update.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ScatterUpdate op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +scatter_update_op_info = TBERegOp("ScatterUpdate") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("scatter_update.so") \ + .compute_cost(10) \ + .kernel_name("scatter_update") \ + .partial_flag(True) \ + .attr("use_locking", "optional", "bool", "all") \ + .input(0, "var", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .input(1, "updates", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \ + .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ + .get_op_info() + + +@op_info_register(scatter_update_op_info) +def _scatter_update_tbe(): + """ScatterUpdate TBE register""" + return diff --git a/mindspore/ops/_utils/__init__.py b/mindspore/ops/_utils/__init__.py index 8fe11029684..35804095412 100644 --- a/mindspore/ops/_utils/__init__.py +++ b/mindspore/ops/_utils/__init__.py @@ -14,6 +14,6 @@ # ============================================================================ """ops utils.""" -from .utils import _get_broadcast_shape, _get_concat_offset +from .utils import get_broadcast_shape, get_concat_offset -__all__ = ['_get_broadcast_shape', '_get_concat_offset'] +__all__ = ['get_broadcast_shape', 'get_concat_offset'] diff --git a/mindspore/ops/_utils/utils.py b/mindspore/ops/_utils/utils.py index 90496afc9bd..0e6850dcb12 100644 --- a/mindspore/ops/_utils/utils.py +++ b/mindspore/ops/_utils/utils.py @@ -19,7 +19,8 @@ from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype -def _get_broadcast_shape(x_shape, y_shape, prim_name): + +def get_broadcast_shape(x_shape, y_shape, prim_name): """ Doing broadcast between tensor x and tensor y. @@ -37,7 +38,7 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name): Examples: >>> x_shape = [1, 2, 3] >>> y_shape = [1, 2] - >>> broadcast_shape = _get_broadcast_shape(x_shape, y_shape) + >>> broadcast_shape = get_broadcast_shape(x_shape, y_shape) """ if x_shape == y_shape: return x_shape @@ -54,15 +55,14 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name): elif x_shape[i] == y_shape[i]: broadcast_shape_back.append(x_shape[i]) else: - raise ValueError("For '{}' the x_shape {} and y_shape {} can not broadcast.".format( - prim_name, x_shape, y_shape)) + raise ValueError(f"For '{prim_name}', the x_shape {x_shape} and y_shape {y_shape} can not broadcast.") broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] - broadcast_shape = broadcast_shape_front + broadcast_shape_back + broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back return broadcast_shape -def _get_concat_offset(x_shp, x_type, axis, prim_name): +def get_concat_offset(x_shp, x_type, axis, prim_name): """for concat and concatoffset check args and compute offset""" validator.check_value_type("shape", x_shp, [tuple], prim_name) validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name) @@ -73,7 +73,7 @@ def _get_concat_offset(x_shp, x_type, axis, prim_name): if axis < 0: axis = axis + rank_base all_shp = x_shp[0][axis] - offset = [0,] + offset = [0] for i in range(1, len(x_shp)): v = x_shp[i] validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name) diff --git a/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py b/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py deleted file mode 100644 index d008f96648a..00000000000 --- a/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py +++ /dev/null @@ -1,226 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""constexpr util""" - -from functools import reduce -import numpy as np -from ...primitive import constexpr -from ....common.tensor import Tensor -from ....common import dtype as mstype -from ...._extends.utils import Slice, Ellipsis_ - -@constexpr -def check_equal(param1, param2, msg="{},{}"): - """Checks whether the two parameters are equal or not.""" - if param1 != param2: - raise ValueError(msg.format(param1, param2)) - return param1 - - -@constexpr -def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): - """Checks the shape and size of the sensor and value.""" - if data_shape == value_shape or data_size == value_size or value_size == 1: - return True - raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape)) - - -@constexpr -def check_tensor_setitem_index(index, element_type=None): - """Checks tuple index type of tensor assignment.""" - if index is None: - raise IndexError("Tensor's index cannot be None.") - # eg. Tensor[Slice] = u - if isinstance(index, Slice): - return True - # eg. Tensor[tuple] = u - if isinstance(index, tuple): - if not index: - raise IndexError("Tensor's index cannot be empty.") - # eg. Tensor[tuple(Slice...)] = u - if isinstance(index[0], (Slice, Ellipsis_, int)): - return True - raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0]))) - # eg. Tensor[Tensor[dtype=bool]] = u - if index == mstype.tensor: - if element_type is None or element_type != mstype.bool_: - raise TypeError( - "The index of tensor should be a bool type tensor. " - "{} type is not supported yet.".format(element_type)) - return True - - raise IndexError("Index of type '{}' is not supported yet.".format(type(index))) - - -@constexpr -def is_same_type(inst, type_): - """ - Checks whether an object is an instance of a target type. - - Inputs: - inst (mindspore.dtype): Inspected type. - type_ (mindspore.dtype): Target type. - - Outputs: - bool, the check result. - """ - return inst == type_ - - -def slice_expand(input_slices, shape): - """ - Converts slice to indices. - - Inputs: - slices (Union[Slice, tuple[Slice]]): Slice tuple or slice. - shape (tuple): The shape of a sensor is an integer element tuple. - - Outputs: - tuple[list], This is expressed as (begins, ends, strides). - """ - begin = [] - end = [] - strides = [] - index = 0 - slices = None - # Slice or tuple(Slice...) - if isinstance(input_slices, Slice): - slices = (input_slices,) - elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)): - is_have_ellipsis = False - for _, element in enumerate(input_slices): - if isinstance(element, Ellipsis_): - is_have_ellipsis = True - break - if is_have_ellipsis: - slices = ellipsis2slice(input_slices, shape) - else: - slices = input_slices - else: - raise IndexError("Tensor's index type is not supported yet.") - - for s in slices: - start = 0 if (s.start is None) else s.start - stop = shape[index] if (s.end is None) else s.end - step = 1 if (s.step is None) else s.step - begin.append(start) - end.append(stop) - strides.append(step) - index += 1 - while index < len(shape): - begin.append(0) - end.append(shape[index]) - strides.append(1) - index += 1 - return begin, end, strides - - -def ellipsis2slice(input_, shape): - """Converts ellipsis to slice.""" - input_slice = input_ - result = [] - if isinstance(input_, Ellipsis_): - input_slice = (input_,) - ell_count = 0 - for _, element in enumerate(input_slice): - if not isinstance(element, Ellipsis_): - result.append(element) - continue - ell_count += 1 - if ell_count > 1: - raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, " - "but it is currently {}".format(input_slice)) - for _ in range(len(shape) - len(input_slice) + 1): - result.append(Slice(None, None, None)) - return tuple(result) - - -@constexpr -def slice2indices(input_slices, shape): - """ - Converts slice to indices. - - Inputs: - slices (Union[Slice, tuple[Slice]]): Slice tuple or slice. - shape (tuple): The shape of a tensor is an integer element tuple. - - Outputs: - Tensor, the shape is (n, 1). - """ - begin, end, strides = slice_expand(input_slices, shape) - np_r = [] - for i, element in enumerate(shape): - s = begin[i] if (begin[i] >= 0) else (element + begin[i]) - e = end[i] if (end[i] >= 0) else (element + end[i]) - np_r.append(np.r_[s:e:strides[i]]) - # Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape) - np_ix = np.ix_(*np_r) - ravel = np.ravel_multi_index(np_ix, shape) - ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32) - return ravel - -@constexpr -def check_indices(indices_size, index): - """Checks indices whether is empty.""" - if indices_size < 1: - raise IndexError("The tensor's index is unreasonable. index:{}".format(index)) - return indices_size - - -@constexpr -def check_indices_value_size(indices_size, value_size): - """Checks if the sizes are already matched.""" - if value_size < 1: - raise ValueError("The value assigned to tensor cannot be empty.") - if value_size > 1: - if value_size != indices_size: - raise ValueError( - "The value given to tensor does not match the index size," - " value size:{}, indics size:{}".format(value_size, indices_size)) - return value_size - -@constexpr -def integer_to_indices(index, shape): - """Converts int or tuple[int] to indices.""" - size = reduce(lambda x, y: x * y, shape) - range_ = np.arange(size).reshape(shape) - value = range_[index] - value = value.reshape(-1, 1) - return Tensor(value, dtype=mstype.int32) - -@constexpr -def tuple_element_is_slice(indexs): - """Judges tuple element type.""" - if not indexs: - raise IndexError("Tensor's index cannot be empty.") - if isinstance(indexs, tuple): - for _, ele in enumerate(indexs): - if not isinstance(ele, Slice): - return False - return True - return False - -@constexpr -def tuple_element_is_int(indexs): - """Judges tuple element type.""" - if not indexs: - raise IndexError("Tensor's index cannot be empty.") - if isinstance(indexs, tuple): - for _, ele in enumerate(indexs): - if not isinstance(ele, int): - return False - return True - return False diff --git a/mindspore/ops/composite/multitype_ops/_utils.py b/mindspore/ops/composite/multitype_ops/_utils.py new file mode 100644 index 00000000000..cff88dfdbba --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/_utils.py @@ -0,0 +1,487 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""constexpr util""" +from functools import reduce +import numpy as np +from ...primitive import constexpr +from ....common.tensor import Tensor +from ....common import dtype as mstype +from ...._extends.utils import Slice, Ellipsis_ +from ....ops import _utils as op_utils +from ...composite import base +from .... import log as logger +from ... import functional as F +from ... import operations as P + +hyper_map = base.HyperMap() +pack = P.Pack(axis=-1) + +ALL_TENSOR = 0 +NO_TENSOR = 1 +CONTAIN_TENSOR = 2 +ALL_SCALAR = 3 + +INT_ = 0 +BOOL_ = 1 +UNSUPPORTED_DTYPE = 2 + +TENSOR_SETITEM = "tensor setitem" +TENSOR_GETITEM = "tensor getitem" + +SET_ITEM_BY_ONE_TENSOR = 0 +SET_ITEM_BY_TUPLE_OF_TENSOR = 1 + + +@constexpr +def check_equal(param1, param2, msg="{},{}"): + """Checks whether the two parameters are equal or not.""" + if param1 != param2: + raise ValueError(msg.format(param1, param2)) + return param1 + + +@constexpr +def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): + """Checks the shape and size of the sensor and value.""" + if data_shape == value_shape or data_size == value_size or value_size == 1: + return True + raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape)) + + +@constexpr +def check_tensor_setitem_index(index, element_type=None): + """Checks tuple index type of tensor assignment.""" + if index is None: + raise IndexError("Tensor's index cannot be None.") + # eg. Tensor[Slice] = u + if isinstance(index, Slice): + return True + # eg. Tensor[tuple] = u + if isinstance(index, tuple): + if not index: + raise IndexError("Tensor's index cannot be empty.") + # eg. Tensor[tuple(Slice...)] = u + if isinstance(index[0], (Slice, Ellipsis_, int)): + return True + raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0]))) + # eg. Tensor[Tensor[dtype=bool]] = u + if isinstance(index, mstype.tensor_type): + if element_type is None or element_type != mstype.bool_: + raise TypeError( + "The index of tensor should be a bool type tensor. " + "{} type is not supported yet.".format(element_type)) + return True + + raise IndexError("Index of type '{}' is not supported yet.".format(type(index))) + + +@constexpr +def is_same_type(inst, type_): + """ + Checks whether an object is an instance of a target type. + + Inputs: + inst (mindspore.dtype): Inspected type. + type_ (mindspore.dtype): Target type. + + Outputs: + bool, the check result. + """ + return inst == type_ + + +def slice_expand(input_slices, shape): + """ + Converts slice to indices. + + Inputs: + slices (Union[Slice, tuple[Slice]]): Slice tuple or slice. + shape (tuple): The shape of a sensor is an integer element tuple. + + Outputs: + tuple[list], This is expressed as (begins, ends, strides). + """ + begin = [] + end = [] + strides = [] + index = 0 + slices = None + # Slice or tuple(Slice...) + if isinstance(input_slices, Slice): + slices = (input_slices,) + elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)): + is_have_ellipsis = False + for _, element in enumerate(input_slices): + if isinstance(element, Ellipsis_): + is_have_ellipsis = True + break + if is_have_ellipsis: + slices = ellipsis2slice(input_slices, shape) + else: + slices = input_slices + else: + raise IndexError("Tensor's index type is not supported yet.") + + for s in slices: + start = 0 if (s.start is None) else s.start + stop = shape[index] if (s.end is None) else s.end + step = 1 if (s.step is None) else s.step + begin.append(start) + end.append(stop) + strides.append(step) + index += 1 + while index < len(shape): + begin.append(0) + end.append(shape[index]) + strides.append(1) + index += 1 + return begin, end, strides + + +def ellipsis2slice(input_, shape): + """Converts ellipsis to slice.""" + input_slice = input_ + result = [] + if isinstance(input_, Ellipsis_): + input_slice = (input_,) + ell_count = 0 + for _, element in enumerate(input_slice): + if not isinstance(element, Ellipsis_): + result.append(element) + continue + ell_count += 1 + if ell_count > 1: + raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, " + "but it is currently {}".format(input_slice)) + for _ in range(len(shape) - len(input_slice) + 1): + result.append(Slice(None, None, None)) + return tuple(result) + + +@constexpr +def slice2indices(input_slices, shape): + """ + Converts slice to indices. + + Inputs: + slices (Union[Slice, tuple[Slice]]): Slice tuple or slice. + shape (tuple): The shape of a tensor is an integer element tuple. + + Outputs: + Tensor, the shape is (n, 1). + """ + begin, end, strides = slice_expand(input_slices, shape) + np_r = [] + for i, element in enumerate(shape): + s = begin[i] if (begin[i] >= 0) else (element + begin[i]) + e = end[i] if (end[i] >= 0) else (element + end[i]) + np_r.append(np.r_[s:e:strides[i]]) + # Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape) + np_ix = np.ix_(*np_r) + ravel = np.ravel_multi_index(np_ix, shape) + ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32) + return ravel + + +@constexpr +def check_indices(indices_size, index): + """Checks indices whether is empty.""" + if indices_size < 1: + raise IndexError("The tensor's index is unreasonable. index:{}".format(index)) + return indices_size + + +@constexpr +def check_indices_value_size(indices_size, value_size): + """Checks if the sizes are already matched.""" + if value_size < 1: + raise ValueError("The value assigned to tensor cannot be empty.") + if value_size > 1: + if value_size != indices_size: + raise ValueError( + "The value given to tensor does not match the index size," + " value size:{}, indics size:{}".format(value_size, indices_size)) + return value_size + + +@constexpr +def integer_to_indices(index, shape): + """Converts int or tuple[int] to indices.""" + size = reduce(lambda x, y: x * y, shape) + range_ = np.arange(size).reshape(shape) + value = range_[index] + value = value.reshape(-1, 1) + return Tensor(value, dtype=mstype.int32) + + +@constexpr +def tuple_element_is_slice(indexs): + """Judges tuple element type.""" + if not indexs: + raise IndexError("Tensor's index cannot be empty.") + if isinstance(indexs, tuple): + for _, ele in enumerate(indexs): + if not isinstance(ele, Slice): + return False + return True + return False + + +@constexpr +def tuple_element_is_int(indexs): + """Judges tuple element type.""" + if not indexs: + raise IndexError("Tensor's index cannot be empty.") + if isinstance(indexs, tuple): + for _, ele in enumerate(indexs): + if not isinstance(ele, int): + return False + return True + return False + + +@constexpr +def tuple_elements_type(types): + """Judges the type of all elements of the tuple.""" + tensors_number = 0 + for ele in types: + if isinstance(ele, mstype.tensor_type): + tensors_number += 1 + if tensors_number == len(types): + return ALL_TENSOR + if tensors_number == 0: + return NO_TENSOR + return CONTAIN_TENSOR + + +@constexpr +def check_value_elements(data_dtype, types): + """Judges the type of all elements of the tuple.""" + tensors_number = 0 + scalars_number = 0 + for i, ele in enumerate(types): + if isinstance(ele, mstype.tensor_type): + ele_dtype = ele.element_type() + if data_dtype == ele_dtype: + tensors_number += 1 + else: + raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' " + f"in value tuple is not consistent with origin tensor data type '{data_dtype}'.") + elif mstype.issubclass_(ele, data_dtype): + scalars_number += 1 + else: + raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in " + f"value tuple is not consistent with origin tensor data type '{data_dtype}'.") + if tensors_number == len(types): + return ALL_TENSOR + if scalars_number == len(types): + return ALL_SCALAR + raise TypeError(f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.") + + +@constexpr +def get_index_tensor_dtype(dtype): + """Check a tuple of tensor data type.""" + if dtype == mstype.int32: + return INT_ + if dtype == mstype.bool_: + return BOOL_ + raise TypeError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") + + +@constexpr +def check_index_tensors_dtype(dtypes, op_name): + """Check a tuple of tensor data type.""" + if op_name == TENSOR_GETITEM: + valid_dtypes = (mstype.int32, mstype.int64) + elif op_name == TENSOR_SETITEM: + valid_dtypes = (mstype.int32,) + else: + raise ValueError("Unsupported operation.") + for ele in dtypes: + if ele in valid_dtypes and ele == dtypes[0]: + continue + raise TypeError(f"For '{op_name}', the index tensors data type must be same, " + f"and should be one of the following: {valid_dtypes}, but got {dtypes}.") + return True + + +@constexpr +def check_tensor_dtype_valid(dtype, valid_dtypes): + """Check a tensor data type.""" + if dtype in valid_dtypes: + return True + raise TypeError(f"The index tensor data type must be one of " + f"the following: {valid_dtypes}, but got {dtype}.") + + +@constexpr +def check_tensors_dtype_same(x_dtype, y_dtype, op_name): + """Check tensors data type same.""" + if x_dtype == y_dtype: + return True + raise TypeError(f"For '{op_name}', the value data type '{y_dtype}' " + f"is not consistent with origin tensor data type {x_dtype}.") + + +@constexpr +def broadcast_shapes(shapes, op_name): + """Broadcasts a tuple of tensor.""" + broadcast_shape = shapes[0] + for i, shape in enumerate(shapes): + logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") + broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name) + return tuple(broadcast_shape) + + +@constexpr +def check_two_shapes_need_broadcast(shape_x, shape_y): + """Check two shapes need broadcast.""" + error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape " + f"{shape_y} could not broadcast the required updates shape {shape_x}.") + if len(shape_y) > len(shape_x): + raise error + for i in range(-len(shape_y), 0): + if shape_y[i] > shape_x[i]: + raise error + if shape_y[i] < shape_x[i] and shape_y[i] != 1: + raise error + if shape_y == shape_x: + return False + return True + + +@constexpr +def compute_multiples(origin_shape, broadcast_shape): + """Compute multiples between broadcast_shape with origin_shape.""" + len_gap = len(broadcast_shape) - len(origin_shape) + return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape)) + + +def tile(broadcast_shape, x): + multiples = compute_multiples(F.shape(x), broadcast_shape) + return F.tile(x, multiples) + + +@constexpr +def check_shapes_same(value_shapes, op_name): + """Check if the shapes in the tuple are consistent.""" + for i, shape in enumerate(value_shapes): + if shape != value_shapes[0]: + raise ValueError(f"For '{op_name}', the {i}th tensor shape in value tuple " + f"is not same as the first tensor shape.") + return True + + +@constexpr +def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type): + """Convert a scalar to a tensor.""" + if op_type == SET_ITEM_BY_ONE_TENSOR: + updates_shape = indices_shape + data_shape[1:] + else: + updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:] + if isinstance(value, mstype.dtype_to_pytype(data_dtype)): + return Tensor(np.full(updates_shape, value), dtype=data_dtype) + raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'" + f" is not consistent with tensor data type {data_dtype}.") + + +@constexpr +def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type): + """Convert a tuple of scalar to a tensor.""" + updates_shape = generate_updates_shape(data_shape, index_shape, op_type) + if len(value) != updates_shape[-1]: + raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} in the updates tuple " + f"does not meet the requirements: {updates_shape[-1]}.") + array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype)) + reps = compute_multiples(updates_shape[-1:], updates_shape) + return Tensor(np.tile(array, reps)) + + +@constexpr +def generate_updates_shape(data_shape, index_shape, op_type): + """Generate updates shape for 'tensor setitem'.""" + if op_type == SET_ITEM_BY_ONE_TENSOR: + updates_shape = index_shape + data_shape[1:] + else: + updates_shape = index_shape[:-1] + data_shape[index_shape[-1]:] + return updates_shape + + +@constexpr +def check_number_of_index_tensor(data_shape, tuple_len, op_name): + """Check if the number of index tensor exceeds the dimension of the operated tensor.""" + if tuple_len <= len(data_shape): + return True + raise IndexError(f"For '{op_name}', the number {tuple_len} of index tensor " + f"is greater than the dimension {len(data_shape)} of the operated tensor.") + + +def generate_indeices_from_tuple_of_tensor(data, tuple_index, op_name): + """Generate an indices tensor from a tuple of tensor.""" + indices = None + check_index_tensor_number = check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name) + if check_index_tensor_number: + dtype_tuple = hyper_map(F.dtype, tuple_index) + check_dtypes = check_index_tensors_dtype(dtype_tuple, op_name) + if check_dtypes: + shape_tuple = hyper_map(F.shape, tuple_index) + broadcast_shape = broadcast_shapes(shape_tuple, op_name) + broadcast_tensors = hyper_map(F.partial(tile, broadcast_shape), tuple_index) + indices = pack(broadcast_tensors) + return indices + + +def generate_updates_from_scalar(data, indices, value, op_type): + """Generate an updates tensor from a scalar.""" + data_shape = F.shape(data) + indices_shape = F.shape(indices) + data_dtype = F.dtype(data) + return convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) + + +def generate_updates_from_tuple(data, index, value, op_type): + """Generate an updates tensor from a tuple.""" + value_types = hyper_map(F.typeof, value) + data_dtype = F.dtype(data) + value_elements_type = check_value_elements(data_dtype, value_types) + if value_elements_type == ALL_TENSOR: + value_shapes = hyper_map(F.shape, value) + shapes_same = check_shapes_same(value_shapes, TENSOR_SETITEM) + if shapes_same: + value = F.pack(value) + return generate_updates_from_tensor(data, index, value, op_type) + + data_shape = F.shape(data) + index_shape = F.shape(index) + return convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type) + + +def generate_updates_from_tensor(data, index, value, op_type): + """Generate an updates tensor from a tensor.""" + data_shape = F.shape(data) + index_shape = F.shape(index) + value_shape = F.shape(value) + data_dtype = F.dtype(data) + value_dtype = F.dtype(value) + updates_shape = value_shape + check_dtype_same = check_tensors_dtype_same(data_dtype, value_dtype, TENSOR_SETITEM) + if check_dtype_same: + updates_shape = generate_updates_shape(data_shape, index_shape, op_type) + need_broadcast = check_two_shapes_need_broadcast(updates_shape, value_shape) + if need_broadcast: + return tile(updates_shape, value) + return value diff --git a/mindspore/ops/composite/multitype_ops/getitem_impl.py b/mindspore/ops/composite/multitype_ops/getitem_impl.py index 3df117837b0..5e217ba1b90 100644 --- a/mindspore/ops/composite/multitype_ops/getitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/getitem_impl.py @@ -15,9 +15,10 @@ """Implementation for getitem.""" -from ...composite import base +from . import _utils as multi_utils +from ..import base from ... import functional as F - +from ....common import dtype as mstype getitem = base.MultitypeFuncGraph('getitem') """ @@ -214,19 +215,45 @@ def _tensor_getitem_by_slice(data, slice_index): return _tensor_slice(data, slice_index) +@getitem.register("Tensor", "Tensor") +def _tensor_getitem_by_tensor(data, tensor_index): + """ + Getting item of tensor by slice. + + Inputs: + data (Tensor): A tensor. + tensor_index (Tensor): An index expressed by tensor. + + Outputs: + Tensor, element type is same as the element type of data. + """ + check_dtypes = multi_utils.check_tensor_dtype_valid(F.dtype(tensor_index), (mstype.int32, mstype.int64)) + result = None + if check_dtypes: + result = F.gather(data, tensor_index, 0) + return result + + @getitem.register("Tensor", "Tuple") -def _tensor_getitem_by_slice_tuple(data, slice_tuple_index): +def _tensor_getitem_by_tuple(data, tuple_index): """ Getting item of tensor by slice tuple. Inputs: data (Tensor): A tensor. - slice_tuple_index (tuple): Index in tuple. + tuple_index (tuple): Index in tuple. Outputs: Tensor, element type is same as the element type of data. """ - return _tensor_slice(data, slice_tuple_index) + index_types = multi_utils.hyper_map(F.typeof, tuple_index) + index_elements_type = multi_utils.tuple_elements_type(index_types) + result = None + if index_elements_type == multi_utils.NO_TENSOR: + result = _tensor_slice(data, tuple_index) + if index_elements_type == multi_utils.ALL_TENSOR: + result = _tensor_getitem_by_tuple_of_tensor(data, tuple_index) + return result @getitem.register("Tensor", "Ellipsis") @@ -242,3 +269,10 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index): Tensor, same as data. """ return _tensor_slice(data, ellipsis_index) + + +def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): + """Tensor getitem by a tuple of tensor.""" + indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_GETITEM) + result = F.gather_nd(data, indices) + return result diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py index 2f44bdc5bad..49a1666defe 100644 --- a/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -18,10 +18,11 @@ from ...composite import base from ....common import dtype as mstype from ... import functional as F -from . import _multitype_ops_util as mult_util +from . import _utils as multi_utils setitem = base.MultitypeFuncGraph('setitem') + @setitem.register("List", "Number", "String") def _list_setitem_with_string(data, number_index, value): """ @@ -118,7 +119,7 @@ def _dict_setitem_with_number(data, key, value): @setitem.register("Tensor", "Tensor", "Tensor") -def _tensor_setitem_by_tensor_v1(data, index, value_tensor): +def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor): """ Tensor assignment. @@ -137,27 +138,15 @@ def _tensor_setitem_by_tensor_v1(data, index, value_tensor): Outputs: Tensor, element type and shape is same as data. """ - result = None index_dtype = F.dtype(index) - index_shape = F.shape(index) - check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype) - if check_result: - data_shape = F.shape(data) - data_shape = mult_util.check_equal(data_shape, index_shape, - "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") - size = F.size(value_tensor) - size = mult_util.check_equal(1, size, - "When assign value is a tensor, its size should be {}, but current size is {}.") - dtype = F.dtype(data) - u_cast = F.cast(value_tensor, dtype) - one_data = F.ones_like(data) - u = F.tensor_mul(one_data, u_cast) - result = F.select(index, u, data) - return result + tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype) + if tensor_dtype == multi_utils.INT_: + return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor) + return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor) @setitem.register("Tensor", "Tensor", "Number") -def _tensor_setitem_by_tensor_v2(data, index, value): +def _tensor_setitem_by_tensor_with_number(data, index, value): """ Tensor assignment. @@ -171,22 +160,128 @@ def _tensor_setitem_by_tensor_v2(data, index, value): Inputs: data (Tensor): Assigned tensor. index (Tensor): Tensor of bool type. - value_tensor (Number): Assignment value. + value (Number): Assignment value. Outputs: Tensor, element type and shape is same as data. """ - result = None index_dtype = F.dtype(index) - index_shape = F.shape(index) - check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype) - if check_result: - shape = F.shape(data) - shape = mult_util.check_equal( - shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") - dtype = F.dtype(data) - u = F.fill(dtype, shape, value) - result = F.select(index, u, data) + tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype) + if tensor_dtype == multi_utils.BOOL_: + return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value) + return _tensor_setitem_by_int_tensor_with_scalar(data, index, value) + + +@setitem.register("Tensor", "Tuple", "Number") +def _tensor_setitem_by_tuple_with_number(data, tuple_index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[B, C, D] = u. + Restraint condition: 1) A is a Tensor, and B, C, D are index. + 2) u is a scalar. + + Inputs: + data (Tensor): Assigned tensor. + index (Tuple): An index tuple. + value (Number): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + index_types = multi_utils.hyper_map(F.typeof, tuple_index) + index_elements_type = multi_utils.tuple_elements_type(index_types) + result = None + if index_elements_type == multi_utils.NO_TENSOR: + result = _tensor_assgin_number(data, tuple_index, value) + if index_elements_type == multi_utils.ALL_TENSOR: + indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) + updates = multi_utils.generate_updates_from_scalar(data, indices, value, + multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) + result = F.scatter_nd_update(data, indices, updates) + return result + + +@setitem.register("Tensor", "Tuple", "Tensor") +def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[B, C, D] = U. + Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors. + 2) U is a Tensor. + + Inputs: + data (Tensor): Assigned tensor. + index (Tuple): An index tuple. + value (Tensor): Assignment tensor, should has the same data type as 'data'. + + Outputs: + Tensor, element type and shape is same as data. + """ + index_types = multi_utils.hyper_map(F.typeof, tuple_index) + index_elements_type = multi_utils.tuple_elements_type(index_types) + result = None + if index_elements_type == multi_utils.NO_TENSOR: + result = _tensor_assgin_tensor(data, tuple_index, value) + if index_elements_type == multi_utils.ALL_TENSOR: + indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) + updates = multi_utils.generate_updates_from_tensor(data, indices, value, + multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) + result = F.scatter_nd_update(data, indices, updates) + return result + + +@setitem.register("Tensor", "Tuple", "Tuple") +def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[B, C, D] = U. + Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors. + 2) A B and C could be broadcast. + 3) U is a Tensor. + + Inputs: + data (Tensor): Assigned tensor. + index (Tuple): A tuple of tensor, these tensor could be broadcast. + value (Tensor): Assignment tensor, should has the same data type as 'data'. + + Outputs: + Tensor, element type and shape is same as data. + """ + index_types = multi_utils.hyper_map(F.typeof, tuple_index) + index_elements_type = multi_utils.tuple_elements_type(index_types) + result = None + if index_elements_type == multi_utils.ALL_TENSOR: + indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM) + updates = multi_utils.generate_updates_from_tuple(data, indices, value, + multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) + result = F.scatter_nd_update(data, indices, updates) + return result + + +@setitem.register("Tensor", "Tensor", "Tuple") +def _tensor_setitem_by_tensor_v2(data, index, value): + """ + Tensor assignment. + + Inputs: + data (Tensor): Assigned tensor. + index (Tensor): Tensor of bool type. + value (Tuple): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + index_dtype = F.dtype(index) + check_dtype = multi_utils.check_tensor_dtype_valid(index_dtype, (mstype.int32, mstype.int64)) + result = None + if check_dtype: + result = _tensor_setitem_by_tensor_with_tuple(data, index, value) return result @@ -212,66 +307,6 @@ def _tensor_setitem_with_slice_v3(data, input_slice, value): return _tensor_assgin_tensor(data, input_slice, value) -@setitem.register("Tensor", "Tuple", "Tensor") -def _tensor_setitem_with_slice_v4(data, input_slice, value): - """ - Tensor assignment. - - Note: - Syntax support: A[tuple(Slice)] = U, and A[tuple(Number)] = U - Restraint condition: A is a Tensor - Slice like "1:3, ::, :4:-1" - U is a Tensor(size=1) or Tensor(size>1) - - Inputs: - data (Tensor): Assigned tensor. - input_slice (Union[tuple[Slice], tuple[Number]]): Slice expression. - value (Number): Assignment value. - - Outputs: - Tensor, element type and shape is same as data. - """ - return _tensor_assgin_tensor(data, input_slice, value) - - -def _tensor_assgin_tensor(data, input_slice, value): - """Assigns a tensor value to the tensor by slice.""" - result = None - check_result = mult_util.check_tensor_setitem_index(input_slice) - if check_result: - data_shape = F.shape(data) - indices = mult_util.slice2indices(input_slice, data_shape) - is_tuple_int = mult_util.tuple_element_is_int(input_slice) - if is_tuple_int: - indices = mult_util.integer_to_indices(input_slice, data_shape) - result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) - return result - - -def _tensor_indices_tensor(data, data_shape, index, indices, value): - """Assigns a tensor value to the tensor.""" - data_size = F.size(data) - data_dtype = F.dtype(data) - indices_size = F.size(indices) - indices_size = mult_util.check_indices(indices_size, index) - update = F.fill(mstype.int32, (indices_size,), 1) - condition_1d = F.scatter_nd(indices, update, (data_size,)) - condition = F.reshape(condition_1d, data_shape) - condition = F.cast(condition, mstype.bool_) - value_fill = None - value_size = F.size(value) - - value_size = mult_util.check_indices_value_size(indices_size, value_size) - if value_size == 1: - value_fill = F.fill(data_dtype, (indices_size,), 1) - value = F.cast(value, data_dtype) - value_fill = F.tensor_mul(value_fill, value) - elif value_size > 1: - value_fill = F.reshape(value, (indices_size,)) - value_1d = F.scatter_nd(indices, value_fill, (data_size,)) - u = F.reshape(value_1d, data_shape) - return F.select(condition, u, data) - @setitem.register("Tensor", "Slice", "Number") def _tensor_setitem_with_slice_v1(data, input_slice, value): """ @@ -294,63 +329,25 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value): return _tensor_assgin_number(data, input_slice, value) -@setitem.register("Tensor", "Tuple", "Number") -def _tensor_setitem_with_slice_v2(data, input_slice, value): - """ - Tensor assignment. - - Note: - Syntax support: A[tuple(Slice)] = u, and A[tuple(Number)] = u - Restraint condition: A is a Tensor. - Slice like "1:3, ::, :4:-1" - u is a scalar - - Inputs: - data (Tensor): Assigned tensor. - input_slice (Union[tuple[Slice], tuple[Number]]): slice expression. - value (Number): Assignment value. - - Outputs: - Tensor, element type and shape is same as data. - """ - return _tensor_assgin_number(data, input_slice, value) - - def _tensor_assgin_number(data, input_slice, value): """Givens a scalar assign to tensor by slice""" - check_result = mult_util.check_tensor_setitem_index(input_slice) + check_result = multi_utils.check_tensor_setitem_index(input_slice) result = None if check_result: data_shape = F.shape(data) - indices = mult_util.slice2indices(input_slice, data_shape) - is_tuple_int = mult_util.tuple_element_is_int(input_slice) + indices = multi_utils.slice2indices(input_slice, data_shape) + is_tuple_int = multi_utils.tuple_element_is_int(input_slice) if is_tuple_int: - indices = mult_util.integer_to_indices(input_slice, data_shape) + indices = multi_utils.integer_to_indices(input_slice, data_shape) result = _tensor_indices_number(data, data_shape, input_slice, indices, value) return result -def _tensor_indices_number(data, data_shape, index, indices, value): - """Assigns a scalar value to the tensor.""" - data_size = F.size(data) - data_dtype = F.dtype(data) - indices_size = F.size(indices) - indices_size = mult_util.check_indices(indices_size, index) - update = F.fill(mstype.int32, (indices_size,), 1) - condition_1d = F.scatter_nd(indices, update, (data_size,)) - condition = F.reshape(condition_1d, data_shape) - condition = F.cast(condition, mstype.bool_) - value_fill = F.fill(data_dtype, (indices_size,), value) - value_1d = F.scatter_nd(indices, value_fill, (data_size,)) - u = F.reshape(value_1d, data_shape) - return F.select(condition, u, data) - - @setitem.register("Tensor", "Number", "Number") def _tensor_setitem_with_int_v1(data, index, value): """Syntax: A[1] = 3""" data_shape = F.shape(data) - indices = mult_util.integer_to_indices(index, data_shape) + indices = multi_utils.integer_to_indices(index, data_shape) return _tensor_indices_number(data, data_shape, index, indices, value) @@ -358,7 +355,7 @@ def _tensor_setitem_with_int_v1(data, index, value): def _tensor_setitem_with_int_v2(data, index, value): """Syntax: A[1] = Tensor""" data_shape = F.shape(data) - indices = mult_util.integer_to_indices(index, data_shape) + indices = multi_utils.integer_to_indices(index, data_shape) return _tensor_indices_tensor(data, data_shape, index, indices, value) @@ -379,7 +376,7 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value): data_size = F.size(data) value_shape = F.shape(value) value_size = F.size(value) - check_result = mult_util.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) + check_result = multi_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) if check_result: if data_size == value_size: result = F.reshape(value, data_shape) @@ -389,3 +386,108 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value): param2 = F.cast(value, data_dtype) result = F.tensor_mul(param1, param2) return result + + +def _tensor_assgin_tensor(data, input_slice, value): + """Assigns a tensor value to the tensor by slice.""" + result = None + check_result = multi_utils.check_tensor_setitem_index(input_slice) + if check_result: + data_shape = F.shape(data) + indices = multi_utils.slice2indices(input_slice, data_shape) + is_tuple_int = multi_utils.tuple_element_is_int(input_slice) + if is_tuple_int: + indices = multi_utils.integer_to_indices(input_slice, data_shape) + result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value) + return result + + +def _tensor_indices_tensor(data, data_shape, index, indices, value): + """Assigns a tensor value to the tensor.""" + data_size = F.size(data) + data_dtype = F.dtype(data) + indices_size = F.size(indices) + indices_size = multi_utils.check_indices(indices_size, index) + update = F.fill(mstype.int32, (indices_size,), 1) + condition_1d = F.scatter_nd(indices, update, (data_size,)) + condition = F.reshape(condition_1d, data_shape) + condition = F.cast(condition, mstype.bool_) + value_fill = None + value_size = F.size(value) + + value_size = multi_utils.check_indices_value_size(indices_size, value_size) + if value_size == 1: + value_fill = F.fill(data_dtype, (indices_size,), 1) + value = F.cast(value, data_dtype) + value_fill = F.tensor_mul(value_fill, value) + elif value_size > 1: + value_fill = F.reshape(value, (indices_size,)) + value_1d = F.scatter_nd(indices, value_fill, (data_size,)) + u = F.reshape(value_1d, data_shape) + return F.select(condition, u, data) + + +def _tensor_indices_number(data, data_shape, index, indices, value): + """Assigns a scalar value to the tensor.""" + data_size = F.size(data) + data_dtype = F.dtype(data) + indices_size = F.size(indices) + indices_size = multi_utils.check_indices(indices_size, index) + update = F.fill(mstype.int32, (indices_size,), 1) + condition_1d = F.scatter_nd(indices, update, (data_size,)) + condition = F.reshape(condition_1d, data_shape) + condition = F.cast(condition, mstype.bool_) + value_fill = F.fill(data_dtype, (indices_size,), value) + value_1d = F.scatter_nd(indices, value_fill, (data_size,)) + u = F.reshape(value_1d, data_shape) + return F.select(condition, u, data) + + +def _tensor_setitem_by_tensor_with_tuple(data, index, value): + """Set a tensor item by a tensor with a tuple.""" + updates = multi_utils.generate_updates_from_tuple(data, index, value, + multi_utils.SET_ITEM_BY_ONE_TENSOR) + result = F.scatter_update(data, index, updates) + return result + + +def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): + """Set a tensor item by a int tensor with a scalar.""" + updates = multi_utils.generate_updates_from_scalar(data, index, value, + multi_utils.SET_ITEM_BY_ONE_TENSOR) + return F.scatter_update(data, index, updates) + + +def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): + """Set a tensor item by a bool tensor with a scalar.""" + index_shape = F.shape(index) + shape = F.shape(data) + shape = multi_utils.check_equal( + shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") + dtype = F.dtype(data) + u = F.fill(dtype, shape, value) + return F.select(index, u, data) + + +def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): + """Set a tensor item by a int tensor with a tensor.""" + updates = multi_utils.generate_updates_from_tensor(data, index, value, + multi_utils.SET_ITEM_BY_ONE_TENSOR) + return F.scatter_update(data, index, updates) + + +def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value): + """Set a tensor item by a bool tensor with a tensor.""" + index_shape = F.shape(index) + data_shape = F.shape(data) + data_shape = multi_utils.check_equal(data_shape, index_shape, + "The tensor(shape={}) and tensor index(shape={}) should be the same shape.") + size = F.size(value) + size = multi_utils.check_equal(1, size, + "When assign value is a tensor, its size should be {}, but current size is {}.") + dtype = F.dtype(data) + u_cast = F.cast(value, dtype) + one_data = F.ones_like(data) + u = F.tensor_mul(one_data, u_cast) + result = F.select(index, u, data) + return result diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index f7e014d8a5b..6559d9b2ab4 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -31,6 +31,7 @@ dtype = P.DType() issubclass_ = P.IsSubClass() isinstance_ = P.IsInstance() fill = P.Fill() +tile = P.Tile() select = P.Select() size = P.Size() ones_like = P.OnesLike() @@ -70,6 +71,12 @@ scalar_cast = P.ScalarCast() print_ = P.Print() expand_dims = P.ExpandDims() scatter_nd = P.ScatterNd() +gather = P.GatherV2() +gather_nd = P.GatherNd() +scatter_update = P.ScatterUpdate() +scatter_nd_update = P.ScatterNdUpdate() +pack = P.Pack() + tuple_setitem = Primitive('tuple_setitem') tuple_getitem = Primitive('tuple_getitem') diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index e3e59cdfcbc..4adf38d3c02 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Fill, GatherNd, GatherV2, InvertPermutation, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, - SameTypeShape, ScatterMax, + SameTypeShape, ScatterMax, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, Shape, Size, Slice, Split, Squeeze, StridedSlice, Tile, @@ -193,6 +193,7 @@ __all__ = [ 'Pad', 'MirrorPad', 'GatherNd', + 'ScatterUpdate', 'ScatterNdUpdate', 'Floor', 'NMSWithMask', diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 6f16c152bc6..51bbac63215 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -19,7 +19,7 @@ from ..._c_expression import signature_rw as sig_rw from ..._c_expression import signature_kind as sig_kind from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..._checkparam import Validator as validator, Rel -from .._utils import _get_concat_offset +from .._utils import get_concat_offset from ...common import dtype as mstype @@ -136,7 +136,7 @@ class ConcatOffset(PrimitiveWithInfer): axis = self.axis x_shp = input_x['shape'] x_type = input_x['dtype'] - offset, _, axis = _get_concat_offset(x_shp, x_type, axis, self.name) + offset, _, axis = get_concat_offset(x_shp, x_type, axis, self.name) self.add_prim_attr('T', x_type[0].element_type()) offset_values = [] for i in range(len(x_shp)): diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index e52b2ac3da4..45e04b83f2a 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -24,16 +24,15 @@ import itertools import numbers import numpy as np -from ..._c_expression import signature_rw as sig_rw -from ..._c_expression import signature_kind as sig_kind from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype from ...common.tensor import Tensor from ..operations.math_ops import _infer_shape_reduce -from .._utils import _get_concat_offset +from .._utils import get_concat_offset from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register + def _check_infer_attr_reduce(axis, keep_dims, prim_name): validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) validator.check_value_type('axis', axis, [int, tuple], prim_name) @@ -931,7 +930,7 @@ class InvertPermutation(PrimitiveWithInfer): z = [x_value[i] for i in range(len(x_value))] z.sort() - y = [None]*len(x_value) + y = [None] * len(x_value) for i, value in enumerate(x_value): validator.check_value_type("input[%d]" % i, value, [int], self.name) validator.check(f'value', z[i], f'index', i, Rel.EQ, self.name) @@ -1111,6 +1110,7 @@ class ArgMinWithValue(PrimitiveWithInfer): >>> input_x = Tensor(np.random.rand(5)) >>> index, output = P.ArgMinWithValue()(input_x) """ + @prim_attr_register def __init__(self, axis=0, keep_dims=False): """init ArgMinWithValue""" @@ -1352,7 +1352,7 @@ class Concat(PrimitiveWithInfer): axis = self.axis x_shp = input_x['shape'] x_type = input_x['dtype'] - _, all_shp, _ = _get_concat_offset(x_shp, x_type, axis, self.name) + _, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name) self.add_prim_attr('T', x_type[0].element_type()) self.add_prim_attr('inputNums', len(x_shp)) ret_shp = x_shp[0].copy() @@ -1376,15 +1376,13 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name): if axis < 0: axis = axis + rank_base + 1 for i in range(1, N): - v = x_shape[i] - validator.check('len of x_shape[%d]' % i, len(v), 'len of rank_base', rank_base, Rel.EQ, prim_name) validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, prim_name, TypeError) - for j in range(rank_base): - if v[j] != x_shape[0][j]: - raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element") + if x_shape[i] != x_shape[0]: + raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element") out_shape.insert(axis, N) return out_shape + class Pack(PrimitiveWithInfer): r""" Packs a list of tensors in specified axis. @@ -1831,7 +1829,7 @@ class DiagPart(PrimitiveWithInfer): return x_type def infer_shape(self, x_shape): - if len(x_shape)%2 != 0 or \ + if len(x_shape) % 2 != 0 or \ not x_shape: raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, " f"with shapes {x_shape}") @@ -2004,6 +2002,49 @@ class GatherNd(PrimitiveWithInfer): return x_dtype +class ScatterUpdate(PrimitiveWithInfer): + """ + Update tensor value by using input indices and value. + + Using given values to update tensor value, along with the input indices. + + Args: + use_locking (bool): Whether protect the assignment by a lock. Default: True. + + Inputs: + - **input_x** (Parameter) - The target tensor, with data type of Parameter. + - **indices** (Tensor) - The index of input tensor. + - **update** (Tensor) - The tensor to update the input tensor, has the same type as input, + and update.shape = indices.shape + input_x.shape[1:]. + + Outputs: + Tensor, has the same shape and type as `input_x`. + + Examples: + >>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)) + >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) + >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) + >>> op = P.ScatterNdUpdate() + >>> output = op(input_x, indices, update) + """ + + @prim_attr_register + def __init__(self, use_locking=True): + """Init ScatterNdUpdate""" + self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) + + def infer_shape(self, x_shape, indices_shape, value_shape): + if indices_shape + x_shape[1:] != value_shape: + raise ValueError('Input value are not match with input indices.') + return x_shape + + def infer_dtype(self, x_dtype, indices_dtype, value_dtype): + validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) + args = {"x": x_dtype, "value": value_dtype} + validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) + return x_dtype + + class ScatterNdUpdate(PrimitiveWithInfer): """ Update tensor value by using input indices and value. @@ -2028,11 +2069,6 @@ class ScatterNdUpdate(PrimitiveWithInfer): >>> op = P.ScatterNdUpdate() >>> output = op(input_x, indices, update) """ - __mindspore_signature__ = ( - ('input_x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD), - ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD), - ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD) - ) @prim_attr_register def __init__(self, use_locking=True): @@ -2142,10 +2178,10 @@ class SpaceToDepth(PrimitiveWithInfer): validator.check('x dimension', len(x_shape), '', 4, Rel.EQ) out_shape = copy.deepcopy(x_shape) for i in range(2): - if out_shape[i+2] % self.block_size != 0: - raise ValueError(f'For \'{self.name}\' input shape[{i+2}] {out_shape[i+2]} should be ' + if out_shape[i + 2] % self.block_size != 0: + raise ValueError(f'For \'{self.name}\' input shape[{i + 2}] {out_shape[i + 2]} should be ' f'fully divided by block_size {self.block_size}') - out_shape[i+2] //= self.block_size + out_shape[i + 2] //= self.block_size out_shape[1] *= self.block_size * self.block_size return out_shape @@ -2199,9 +2235,10 @@ class DepthToSpace(PrimitiveWithInfer): validator.check('x dimension', len(x_shape), '', 4, Rel.EQ) out_shape = copy.deepcopy(x_shape) for i in range(2): - out_shape[i+2] *= self.block_size + out_shape[i + 2] *= self.block_size - validator.check_integer('x_shape[1] % (block_size*block_size)', x_shape[1] % (self.block_size*self.block_size), + validator.check_integer('x_shape[1] % (block_size*block_size)', + x_shape[1] % (self.block_size * self.block_size), 0, Rel.EQ, self.name) out_shape[1] //= self.block_size * self.block_size return out_shape @@ -2251,6 +2288,7 @@ class SpaceToBatch(PrimitiveWithInfer): [[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]] """ + @prim_attr_register def __init__(self, block_size, paddings): """Init SpaceToBatch""" @@ -2271,12 +2309,12 @@ class SpaceToBatch(PrimitiveWithInfer): validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name) out_shape = copy.deepcopy(x_shape) for i in range(2): - padded = out_shape[i+2] + self.paddings[i][0] + \ + padded = out_shape[i + 2] + self.paddings[i][0] + \ self.paddings[i][1] if padded % self.block_size != 0: raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by ' f'block_size {self.block_size}') - out_shape[i+2] = padded // self.block_size + out_shape[i + 2] = padded // self.block_size out_shape[0] *= self.block_size * self.block_size return out_shape @@ -2319,6 +2357,7 @@ class BatchToSpace(PrimitiveWithInfer): [[[[1., 2.], [3., 4.]]]] """ + @prim_attr_register def __init__(self, block_size, crops): """Init BatchToSpace""" @@ -2339,10 +2378,10 @@ class BatchToSpace(PrimitiveWithInfer): validator.check('rank of input_x', len(x_shape), '', 4) out_shape = copy.deepcopy(x_shape) for i in range(2): - x_block_prod = out_shape[i+2] * self.block_size + x_block_prod = out_shape[i + 2] * self.block_size crops_sum = self.crops[i][0] + self.crops[i][1] validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name) - out_shape[i+2] = x_block_prod - crops_sum + out_shape[i + 2] = x_block_prod - crops_sum block_size_prod = self.block_size * self.block_size if out_shape[0] % block_size_prod != 0: raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by ' diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 127bb83219f..5a116f64e05 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -24,7 +24,7 @@ from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype from ...common.tensor import Tensor -from .._utils import _get_broadcast_shape +from .._utils import get_broadcast_shape from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op @@ -75,7 +75,7 @@ class _BinaryOp(PrimitiveWithInfer): self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) def infer_shape(self, x_shape, y_shape): - return _get_broadcast_shape(x_shape, y_shape, self.name) + return get_broadcast_shape(x_shape, y_shape, self.name) class _MathBinaryOp(_BinaryOp): diff --git a/tests/mindspore_test_framework/components/executor/exec_forward.py b/tests/mindspore_test_framework/components/executor/exec_forward.py index c4ea4626c98..3dcbc36e97c 100644 --- a/tests/mindspore_test_framework/components/executor/exec_forward.py +++ b/tests/mindspore_test_framework/components/executor/exec_forward.py @@ -27,9 +27,15 @@ class IdentityEC(IExectorComponent): def __call__(self): result_id = self.function[keyword.id] + '-' + self.inputs[keyword.id] group = self.function[keyword.group] + '-' + self.inputs[keyword.group] - return { + ret = { keyword.id: result_id, keyword.group: group, keyword.desc_inputs: self.inputs[keyword.desc_inputs], keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs]) } + print("buxue------------------------------------------------") + print("inputs") + print(ret[keyword.desc_inputs]) + print("outputs") + print(ret[keyword.result]) + return ret diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 3509fb1d02d..846b6d79fa7 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1297,7 +1297,7 @@ raise_set = [ ('ScatterNdUpdate', { 'block': (P.ScatterNdUpdate(), {'exception': TypeError}), 'desc_inputs': (Tensor(np.ones((2, 3), np.float32)), - Tensor(np.ones((2, 2), np.int32)), + Tensor(np.ones((2, 2), np.float32)), Tensor(np.ones((2,), np.float32))), 'desc_bprop': [[2, 3]]}), ('Pack', { diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index d07f1a35a94..6b4c84f14a3 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -16,13 +16,14 @@ import numpy as np import pytest -from mindspore import Tensor +from mindspore import Tensor, Parameter from mindspore import context from mindspore import dtype as mstype from mindspore.nn import Cell from ....mindspore_test_framework.mindspore_test import mindspore_test from ....mindspore_test_framework.pipeline.forward.compile_forward \ - import pipeline_for_compile_forward_ge_graph_for_case_by_case_config + import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \ + pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception class NetWorkSlicePositive(Cell): @@ -145,6 +146,160 @@ class TensorAssignWithSlice(Cell): return z +class TensorIndexByOneTensor(Cell): + def __init__(self): + super(TensorIndexByOneTensor, self).__init__() + self.const = Tensor(np.ones((5, 4, 7, 8)), mstype.int32) + + def construct(self, x, index): + ret = x[index] + self.const + return ret + + +class TensorIndexByTwoTensors(Cell): + def __init__(self): + super(TensorIndexByTwoTensors, self).__init__() + self.const = Tensor(np.ones((3, 4, 5, 8)), mstype.int32) + + def construct(self, x, index_0, index_1): + ret = x[index_0, index_1] + self.const + return ret + + +class TensorIndexByThreeTensors(Cell): + def __init__(self): + super(TensorIndexByThreeTensors, self).__init__() + self.const = Tensor(np.ones((5, 3, 4, 5)), mstype.int32) + + def construct(self, x, index_0, index_1, index_2): + ret = x[index_0, index_1, index_2] + self.const + return ret + + +class TensorSetItemByOneTensorWithNumber(Cell): + def __init__(self, value): + super(TensorSetItemByOneTensorWithNumber, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + self.value = value + + def construct(self, index): + self.param[index] = self.value + ret = self.param + self.const + return ret + + +class TensorSetItemByOneTensorWithTensor(Cell): + def __init__(self): + super(TensorSetItemByOneTensorWithTensor, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + + def construct(self, index, value): + self.param[index] = value + ret = self.param + self.const + return ret + + +class TensorSetItemByOneTensorWithTupleOfNumber(Cell): + def __init__(self, value): + super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + self.value = value + + def construct(self, index): + self.param[index] = self.value + ret = self.param + self.const + return ret + + +class TensorSetItemByOneTensorWithTupleOfTensor(Cell): + def __init__(self): + super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__() + self.const = Tensor(np.ones((6, 3, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6*3*8).reshape((6, 3, 8)), mstype.float32), name="x") + + def construct(self, index, value_0, value_1, value_2): + self.param[index] = (value_0, value_1, value_2) + ret = self.param + self.const + return ret + + +class TensorSetItemByTensorsWithNumber(Cell): + def __init__(self, value): + super(TensorSetItemByTensorsWithNumber, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + self.value = value + + def construct(self, index_0, index_1, index_2): + self.param[index_0, index_1, index_2] = self.value + ret = self.param + self.const + return ret + + +class TensorSetItemByTensorsWithTensor(Cell): + def __init__(self): + super(TensorSetItemByTensorsWithTensor, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + + def construct(self, index_0, index_1, index_2, value): + self.param[index_0, index_1, index_2] = value + ret = self.param + self.const + return ret + + +class TensorSetItemByTensorsWithTensorNumberError(Cell): + def __init__(self): + super(TensorSetItemByTensorsWithTensorNumberError, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + + def construct(self, index_0, index_1, index_2, index_3, value): + self.param[index_0, index_1, index_2, index_3] = value + ret = self.param + self.const + return ret + + +class TensorSetItemByTensorsWithTupleOfNumber(Cell): + def __init__(self, value): + super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + self.value = value + + def construct(self, index_0, index_1, index_2): + self.param[index_0, index_1, index_2] = self.value + ret = self.param + self.const + return ret + + +class TensorSetItemByTensorsWithTupleOfTensor(Cell): + def __init__(self): + super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + + def construct(self, index_0, index_1, index_2, value_0, value_1, value_2): + self.param[index_0, index_1, index_2] = (value_0, value_1, value_2) + ret = self.param + self.const + return ret + + +class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell): + def __init__(self): + super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x") + + def construct(self, index_0, index_1, index_2, value_0, value_1): + self.param[index_0, index_1, index_2] = (value_0, value_1) + ret = self.param + self.const + return ret + + def test_tensor_assign(): context.set_context(mode=context.GRAPH_MODE, save_graphs=True) net = TensorAssignWithSlice() @@ -441,15 +596,206 @@ test_cases = [ 'block': NetWorkSliceEllipsis(), 'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))], }), + ('TensorIndexByOneTensor', { + 'block': TensorIndexByOneTensor(), + 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)], + }), + ('TensorIndexByTwoTensors', { + 'block': TensorIndexByTwoTensors(), + 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)], + }), + ('TensorIndexByThreeTensors', { + 'block': TensorIndexByThreeTensors(), + 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], + }), + ('TensorSetItemByOneTensorWithNumber', { + 'block': TensorSetItemByOneTensorWithNumber(value=0.0), + 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)], + }), + ('TensorSetItemByOneTensorWithTensor', { + 'block': TensorSetItemByOneTensorWithTensor(), + 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32), + Tensor(np.zeros((4, 7, 8)), mstype.float32)], + }), + ('TensorSetItemByOneTensorWithTupleOfNumber', { + 'block': TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7)), + 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)], + }), + ('TensorSetItemByOneTensorWithTupleOfTensor', { + 'block': TensorSetItemByOneTensorWithTupleOfTensor(), + 'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32), + Tensor(np.zeros((8,), np.float32)), + Tensor(np.ones((8,), np.float32)), + Tensor(np.ones((8,), np.float32) * 2)], + }), + ('TensorSetItemByTensorsWithNumber', { + 'block': TensorSetItemByTensorsWithNumber(value=0.0), + 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], + }), + ('TensorSetItemByTensorsWithTensor', { + 'block': TensorSetItemByTensorsWithTensor(), + 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), + Tensor(np.zeros((4, 5)), mstype.float32)], + }), + ('TensorSetItemByTensorsWithTupleOfNumber', { + 'block': TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4)), + 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], + }), + ('TensorSetItemByTensorsWithTupleOfTensor', { + 'block': TensorSetItemByTensorsWithTupleOfTensor(), + 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), + Tensor(np.zeros((4, 5)), mstype.float32), + Tensor(np.ones((4, 5)), mstype.float32), + Tensor(np.ones((4, 5)) * 2, mstype.float32)], + }) +] + +raise_error_set = [ + ('TensorIndexByOneTensorDtypeError', { + 'block': (TensorIndexByOneTensor(), {'exception': TypeError}), + 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)], + }), + ('TensorIndexByTwoTensorsShapeError', { + 'block': (TensorIndexByTwoTensors(), {'exception': ValueError}), + 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)], + }), + ('TensorIndexByTwoTensorsDtypeError', { + 'block': (TensorIndexByTwoTensors(), {'exception': TypeError}), + 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)], + }), + ('TensorIndexByThreeTensorsShapeError', { + 'block': (TensorIndexByThreeTensors(), {'exception': ValueError}), + 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)], + }), + ('TensorIndexByThreeTensorsDtypeError', { + 'block': (TensorIndexByThreeTensors(), {'exception': TypeError}), + 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32), + Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64), + Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], + }), + ('TensorSetItemByOneTensorWithNumberTypeError', { + 'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}), + 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)], + }), + ('TensorSetItemByOneTensorWithTensorShapeError', { + 'block': (TensorSetItemByOneTensorWithTensor(), {'exception': ValueError}), + 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32), + Tensor(np.zeros((6, 7, 8)), mstype.float32)], + }), + ('TensorSetItemByOneTensorWithTensorDtypeError', { + 'block': (TensorSetItemByOneTensorWithTensor(), {'exception': TypeError}), + 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32), + Tensor(np.zeros((6, 7, 8)), mstype.int32)], + }), + ('TensorSetItemByOneTensorWithTupleOfNumberTypeError', { + 'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0, 1, 2, 3, 4, 5, 6, 7)), {'exception': TypeError}), + 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)], + }), + ('TensorSetItemByOneTensorWithTupleOfNumberNumberError', { + 'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2)), {'exception': ValueError}), + 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)], + }), + ('TensorSetItemByOneTensorWithTupleOfTensorDtyeError', { + 'block': (TensorSetItemByOneTensorWithTupleOfTensor(), {'exception': TypeError}), + 'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32), + Tensor(np.zeros((8,), np.int32)), + Tensor(np.ones((8,), np.int32)), + Tensor(np.ones((8,), np.float32) * 2)], + }), + ('TensorSetItemByTensorsWithNumberTypeError', { + 'block': (TensorSetItemByTensorsWithNumber(value=0), {'exception': TypeError}), + 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], + }), + ('TensorSetItemByTensorsWithTensorShapeError', { + 'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}), + 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), + Tensor(np.zeros((2, 5)), mstype.float32)], + }), + ('TensorSetItemByTensorsWithTensorTypeError', { + 'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}), + 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), + Tensor(np.zeros((4, 5)), mstype.int32)], + }), + ('TensorSetItemByTensorsWithTensorNumberError', { + 'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}), + 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), + Tensor(np.random.randint(8, size=(1, 3, 4, 5)), mstype.int32), + Tensor(np.zeros((2, 5)), mstype.float32)], + }), + ('TensorSetItemByTensorsWithTupleOfNumberTypeError', { + 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0, 1, 2, 3, 4)), {'exception': TypeError}), + 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], + }), + ('TensorSetItemByTensorsWithTupleOfNumberNumberError', { + 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}), + 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)], + }), + ('TensorSetItemByTensorsWithTupleOfTensorNumberError', { + 'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}), + 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), + Tensor(np.zeros((4, 5)), mstype.float32), + Tensor(np.ones((4, 5)), mstype.float32)], + }), + ('TensorSetItemByTensorsWithTupleOfTensorTypeError', { + 'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}), + 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32), + Tensor(np.random.randint(7, size=(4, 5)), mstype.int32), + Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32), + Tensor(np.zeros((4, 5)), mstype.float32), + Tensor(np.ones((4, 5)), mstype.int32), + Tensor(np.ones((4, 5)) * 2, mstype.int32)], + }) ] @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) -def test_compile(): - context.set_context(mode=context.GRAPH_MODE) +def test_exec(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) return test_cases +@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) +def test_check_exception(): + return raise_error_set + + def test_tensor_slice_reduce_out_of_bounds_neg(): class NetWork(Cell): def __init__(self): diff --git a/tests/ut/python/optimizer/test_debug_location.py b/tests/ut/python/optimizer/test_debug_location.py index 78486c7a6c3..2848b4771be 100644 --- a/tests/ut/python/optimizer/test_debug_location.py +++ b/tests/ut/python/optimizer/test_debug_location.py @@ -26,7 +26,7 @@ from mindspore.ops import functional as F from mindspore.ops import operations as P from mindspore.ops._grad.grad_base import bprop_getters from mindspore.ops._grad.grad_math_ops import binop_grad_common -from mindspore.ops._utils import _get_broadcast_shape +from mindspore.ops._utils import get_broadcast_shape from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register from mindspore.train.loss_scale_manager import DynamicLossScaleManager @@ -54,7 +54,7 @@ class MockSub(PrimitiveWithInfer): self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) def infer_shape(self, x_shape, y_shape): - return _get_broadcast_shape(x_shape, y_shape) + return get_broadcast_shape(x_shape, y_shape) def infer_dtype(self, x_dtype, y_dtype): return x_dtype