diff --git a/mindspore/_extends/parse/__init__.py b/mindspore/_extends/parse/__init__.py index 7b8dccee9f2..323932560ae 100644 --- a/mindspore/_extends/parse/__init__.py +++ b/mindspore/_extends/parse/__init__.py @@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope, get_dataclass_attributes, get_dataclass_methods, get_obj_id, get_module_namespace, get_obj_type, get_object_key, get_default_input, get_parse_method_of_class, get_scope_name, - is_class_member, parse_cb, resolve_symbol, create_ellipsis_obj) + is_class_member, parse_cb, resolve_symbol) from .serialize import * __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', @@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name', - 'create_slice_obj', 'create_ellipsis_obj'] + 'create_slice_obj'] diff --git a/mindspore/_extends/parse/parser.py b/mindspore/_extends/parse/parser.py index f556ce0d64e..2a1c9e0943a 100644 --- a/mindspore/_extends/parse/parser.py +++ b/mindspore/_extends/parse/parser.py @@ -29,7 +29,6 @@ from mindspore.common.dtype import pytype_to_dtype from mindspore.common.api import _MindSporeFunction from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT -from ..utils import Slice, Ellipsis_ # define return value RET_SUCCESS = 0 @@ -70,14 +69,9 @@ parse_expr_statement_white_list = ( "append", ) -def create_ellipsis_obj(): - """Create Slice object""" - return Ellipsis_() - - def create_slice_obj(start, end, step): - """Create Slice object""" - return Slice(start, end, step) + """Create slice object""" + return slice(start, end, step) def parse_cb(func, parse_method=None): diff --git a/mindspore/_extends/utils.py b/mindspore/_extends/utils.py index fecbf546f58..8469ddda8b8 100644 --- a/mindspore/_extends/utils.py +++ b/mindspore/_extends/utils.py @@ -19,7 +19,6 @@ import logging import os import inspect from functools import wraps -from dataclasses import dataclass def cal_sha256(file_path): @@ -100,20 +99,3 @@ def cell_attr_register(fn=None, attrs=None): if fn is not None: return wrap_cell(fn) return wrap_cell - - -@dataclass -class Slice: - """ - Slice class - """ - start: int - end: int - step: int - - -@dataclass -class Ellipsis_: - """ - Ellipsis class - """ diff --git a/mindspore/ccsrc/operator/composite/composite.cc b/mindspore/ccsrc/operator/composite/composite.cc index 88c326436d5..221d2b9aac8 100644 --- a/mindspore/ccsrc/operator/composite/composite.cc +++ b/mindspore/ccsrc/operator/composite/composite.cc @@ -932,206 +932,6 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_ return ret; } -int ConvertBinaryToDecimal(const std::vector &number_bin) { - unsigned int number_dec = 0; - for (size_t index = 0; index < number_bin.size(); index++) { - number_dec |= number_bin[index] << index; - } - return static_cast(number_dec); -} - -void ParseSlice(const AbstractSlicePtr &slice, std::vector *begin, std::vector *end, - std::vector *strides, int length) { - MS_EXCEPTION_IF_NULL(slice); - MS_EXCEPTION_IF_NULL(begin); - MS_EXCEPTION_IF_NULL(end); - MS_EXCEPTION_IF_NULL(strides); - if (length <= 0) { - MS_LOG(EXCEPTION) << "Could not slice a dim when it's length less than 1"; - } - - int start_default = 0; - int stop_default = length; - int step_default = 1; - int step_value = CheckSliceMember(slice->step(), step_default, "step"); - if (step_value < 0) { - start_default = -1; - stop_default = -(length + 1); - } - - begin->push_back(CheckSliceMember(slice->start(), start_default, "begin")); - end->push_back(CheckSliceMember(slice->stop(), stop_default, "stop")); - strides->push_back(step_value); -} - -int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, const std::vector &shape, - std::vector *begin, std::vector *end, std::vector *strides) { - MS_EXCEPTION_IF_NULL(slice_tuple); - MS_EXCEPTION_IF_NULL(begin); - MS_EXCEPTION_IF_NULL(end); - MS_EXCEPTION_IF_NULL(strides); - - size_t slice_tuple_size = slice_tuple->size(); - size_t shape_size = shape.size(); - if (slice_tuple_size > shape_size) { - MS_LOG(EXCEPTION) << "The number of slice data to slice tensor should be less than the rank of tensor," - "when the rank of tensor is " - << shape_size << ", the number of slice is " << slice_tuple_size; - } - - std::vector shrink; - auto slice_tuple_eles = slice_tuple->elements(); - size_t ellipsis_num = 0; - - for (size_t index = 0; index < slice_tuple_size; index++) { - if (slice_tuple_eles[index]->isa()) { - AbstractSlicePtr slice = dyn_cast(slice_tuple_eles[index]); - ParseSlice(slice, begin, end, strides, shape[index]); - shrink.push_back(0); - continue; - } - - if (slice_tuple_eles[index]->isa()) { - int ele_index = GetArgScalarValue(dyn_cast(slice_tuple_eles[index]), "slice_tuple"); - begin->push_back(ele_index); - end->push_back(ele_index + 1); - strides->push_back(1); - shrink.push_back(1); - continue; - } - - if (slice_tuple_eles[index]->isa()) { - ellipsis_num++; - if (ellipsis_num > 1) { - MS_LOG(EXCEPTION) << "Tensor slice supports at most one ellipsis"; - } - size_t ellipsis_len = shape_size - (slice_tuple_size - 1); - begin->insert(begin->end(), ellipsis_len, 0); - end->insert(end->end(), shape.begin() + index, shape.begin() + index + ellipsis_len); - strides->insert(strides->end(), ellipsis_len, 1); - shrink.insert(shrink.end(), ellipsis_len, 0); - continue; - } - - MS_LOG(EXCEPTION) << "Slice tuple only could contain slice, int number or ellipsis, but got " - << slice_tuple_eles[index]->ToString(); - } - - if (ellipsis_num == 0) { - for (size_t index = slice_tuple_size; index < shape_size; index++) { - begin->push_back(0); - end->push_back(shape[index]); - strides->push_back(1); - } - } - return ConvertBinaryToDecimal(shrink); -} - -int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr &slice, const std::vector &shape, - std::vector *begin, std::vector *end, std::vector *strides) { - MS_EXCEPTION_IF_NULL(begin); - MS_EXCEPTION_IF_NULL(end); - MS_EXCEPTION_IF_NULL(strides); - size_t shape_size = shape.size(); - if (shape_size == 0) { - MS_LOG(EXCEPTION) << "Could slice a scalar tensor"; - } - - ParseSlice(slice, begin, end, strides, shape[0]); - - for (size_t index = 1; index < shape_size; index++) { - begin->push_back(0); - end->push_back(shape[index]); - strides->push_back(1); - } - - return 0; -} - -int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, const std::vector &shape, - std::vector *begin, std::vector *end, - std::vector *strides) { - MS_EXCEPTION_IF_NULL(begin); - MS_EXCEPTION_IF_NULL(end); - MS_EXCEPTION_IF_NULL(strides); - int ele_index = GetArgScalarValue(scalar, "slice_tuple"); - - begin->push_back(ele_index); - end->push_back(ele_index + 1); - strides->push_back(1); - - for (size_t index = 1; index < shape.size(); index++) { - begin->push_back(0); - end->push_back(shape[index]); - strides->push_back(1); - } - - return 1; -} - -FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) { - auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional"); - ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph)); - return ret_graph; -} - -FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - // slice a tensor - // args: tensor, slice or slice tuple - const std::string op_name = std::string("TensorSlice"); - abstract::CheckArgsSize(op_name, args_spec_list, 2); - AbstractTensorPtr tensorPtr = abstract::CheckArg(op_name, args_spec_list, 0); - - FuncGraphPtr ret_graph = std::make_shared(); - ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); - AnfNodePtr tensor_node = ret_graph->add_parameter(); - (void)ret_graph->add_parameter(); - - auto shape = tensorPtr->shape()->shape(); - std::vector begin; - std::vector end; - std::vector strides; - int shrink_axis_mask; - - if (args_spec_list[1]->isa()) { - AbstractTuplePtr tuple_ptr = dyn_cast(args_spec_list[1]); - shrink_axis_mask = GenerateStridedSliceParametersFromTuple(tuple_ptr, shape, &begin, &end, &strides); - } else if (args_spec_list[1]->isa()) { - AbstractSlicePtr slice_ptr = dyn_cast(args_spec_list[1]); - shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides); - } else if (args_spec_list[1]->isa()) { - AbstractScalarPtr scalar_ptr = dyn_cast(args_spec_list[1]); - if (scalar_ptr->BuildValue()->isa()) { - if (scalar_ptr->BuildValue()->cast()->value()) { - return ExpandADim(ret_graph, tensor_node); - } - MS_LOG(EXCEPTION) << "TensorSlice not support the index is False."; - } - shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides); - } else if (args_spec_list[1]->isa()) { - ret_graph->set_output(tensor_node); - return ret_graph; - } else if (args_spec_list[1]->isa()) { - return ExpandADim(ret_graph, tensor_node); - } else { - std::ostringstream args_info; - for (const auto &arg : args_spec_list) { - MS_EXCEPTION_IF_NULL(arg); - args_info << arg->ToString() << "\n"; - } - MS_LOG(EXCEPTION) - << "TensorSlice requires the input should be one of [slice, ellipsis, int number, bool, none, tuple] , but got " - << args_info.str(); - } - - auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations"); - auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0), - NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)}); - ret_graph->set_output(ret_graph->NewCNode( - {PrimStridedSlice, tensor_node, NewValueNode(begin), NewValueNode(end), NewValueNode(strides)})); - return ret_graph; -} - FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // select indexed item // args: tuple of items, index @@ -1162,11 +962,6 @@ REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) { .def(py::init()); })); -REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) { - (void)py::class_>(*m, "TensorSlice_") - .def(py::init()); - })); - REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) { (void)py::class_>( *m, "TupleGetItemTensor_") diff --git a/mindspore/ccsrc/operator/composite/composite.h b/mindspore/ccsrc/operator/composite/composite.h index acfb7abc879..5944c81fb01 100644 --- a/mindspore/ccsrc/operator/composite/composite.h +++ b/mindspore/ccsrc/operator/composite/composite.h @@ -175,16 +175,6 @@ class TupleSlice : public MetaFuncGraph { }; using TupleSlicePtr = std::shared_ptr; -class TensorSlice : public MetaFuncGraph { - public: - explicit TensorSlice(const std::string &name) : MetaFuncGraph(name) {} - ~TensorSlice() override = default; - MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } -}; -using TensorSlicePtr = std::shared_ptr; - class TupleGetItemTensor : public MetaFuncGraph { public: explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {} diff --git a/mindspore/ccsrc/pipeline/parse/data_converter.cc b/mindspore/ccsrc/pipeline/parse/data_converter.cc index 2c0213d4aa7..c36ef2452d4 100644 --- a/mindspore/ccsrc/pipeline/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/parse/data_converter.cc @@ -209,6 +209,28 @@ bool ConvertTensor(const py::object &obj, ValuePtr *const data) { return true; } +bool ConvertSlice(const py::object &obj, ValuePtr *const data) { + MS_LOG(DEBUG) << "Converting slice object"; + + py::slice slice_obj = obj.cast(); + auto convert_func = [obj](std::string attr) -> ValuePtr { + auto py_attr = py::getattr(obj, attr.c_str()); + if (py::isinstance(py_attr)) { + return kNone; + } else if (py::isinstance(py_attr)) { + int value = py::cast(py_attr); + return MakeValue(value); + } else { + MS_LOG(EXCEPTION) << "Slice should contain only int or none"; + } + }; + ValuePtr start = convert_func("start"); + ValuePtr stop = convert_func("stop"); + ValuePtr step = convert_func("step"); + *data = std::make_shared(start, stop, step); + return true; +} + FuncGraphPtr ConvertToBpropCut(py::object obj) { std::vector results = data_converter::GetObjKey(obj); std::string obj_key = results[0]; @@ -321,6 +343,10 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature converted = std::make_shared(py::cast(obj)); } else if (py::isinstance(obj)) { ret = ConvertDict(obj, &converted, use_signature); + } else if (py::isinstance(obj)) { + ret = ConvertSlice(obj, &converted); + } else if (py::isinstance(obj)) { + converted = kEllipsis; } else if (py::isinstance(obj)) { ret = ConvertTuple(obj, &converted, use_signature); } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index f6c78f0cd21..6d93b6a5a5c 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -353,11 +353,9 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { auto value = abs_base->cast()->ref(); dic = ConvertAbstractToPython(value); } else if (abs_base->isa()) { - auto arg_slice = dyn_cast(abs_base); - std::vector shape; - dic["shape"] = shape; - dic["dtype"] = arg_slice->BuildType(); - dic["value"] = BuildValue(arg_slice->BuildValue()); + dic["shape"] = py::none(); + dic["dtype"] = py::ellipsis(); + dic["value"] = py::ellipsis(); } else if (abs_base->isa()) { auto arg_tuple = dyn_cast(abs_base); size_t len = arg_tuple->size(); diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index 36a62e6ef92..6e28e38ed13 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -106,7 +106,7 @@ py::object ValuePtrToPyData(const ValuePtr &value) { } ret = rets; } else if (value->isa()) { - ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_ELLIPSIS); + ret = py::ellipsis(); } else if (value->isa()) { auto slice = value->cast(); auto start = ValuePtrToPyData(slice->start()); diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index e760e23536d..3be5472402b 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -206,6 +206,9 @@ class Parameter: res.default_input = res.default_input / other return res + def __setitem__(self, index, value): + return self + def set_parameter_data(self, data): """Set `default_input` of current `Parameter`.""" if isinstance(data, bool): diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index c87c0a2e007..9c0f7e20947 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -144,6 +144,13 @@ class Tensor(Tensor_): out = tensor_operator_registry.get('__le__')(self, other) return out + def __getitem__(self, index): + out = tensor_operator_registry.get('__getitem__')(self, index) + return out + + def __setitem__(self, index, value): + return self + def __gt__(self, other): out = tensor_operator_registry.get('__gt__')(self, other) return out diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 612906ffc71..8b774b80bb0 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -19,7 +19,7 @@ from functools import partial from mindspore import context -from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, TensorSlice_, \ +from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \ TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_ from ...common import dtype as mstype from ...common.api import ms_function, _pynative_exec, _wrap_func @@ -27,7 +27,7 @@ from .. import functional as F from ...common.parameter import Parameter -__all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] +__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] def add_flags(fn, **flags): diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index 8954470b76c..f99fec64d06 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -18,7 +18,9 @@ from . import _constexpr_utils as const_utils from ... import functional as F from ... import operations as P from ...composite import base +from ....common.tensor import Tensor from ....common import dtype as mstype +from ....common._register_for_tensor import tensor_operator_registry hyper_map = base.HyperMap() pack = P.Pack(axis=-1) @@ -152,3 +154,101 @@ def generate_updates_from_tensor(data, index, value, op_type): if need_broadcast: return broadcast(updates_shape, value) return value + + + +def tensor_getitem(self, index): + """Handle tensor getitem""" + if isinstance(index, Tensor): + return tensor_index_by_tensor(self, index) + if isinstance(index, tuple): + return tensor_index_by_tuple(self, index) + if isinstance(index, int): + return tensor_index_by_number(self, index) + if isinstance(index, slice): + return tensor_index_by_slice(self, index) + if isinstance(index, bool): + return tensor_index_by_bool(self, index) + if index is ...: + return self + raise IndexError("Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32,\ + got {} with type{}".format(index, type(index))) + + + +tensor_operator_registry.register("__getitem__", tensor_getitem) + + +def tensor_getitem_by_tuple_of_tensor(data, tuple_index): + """Tensor getitem by a tuple of tensor.""" + indices = generate_indices_from_tuple_of_tensor(data, + tuple_index, + const_utils.TENSOR_GETITEM) + result = F.gather_nd(data, indices) + return result + + +def tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): + """Tensor getitem by a tuple of mixed tensor.""" + indices = generate_indices_from_tuple_of_mixed_tensors(data, + tuple_index, + const_utils.TENSOR_GETITEM) + result = F.gather_nd(data, indices) + return result + + +def tensor_index_by_slice(data, slice_index): + """Tensor getitem by a single slice""" + begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(F.shape(data), slice_index) + return F.strided_slice(data, begin_strides, end_strides, step_strides) + + +def tensor_index_by_integer(data, number): + """Tensor getitem by a single integer number""" + begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(F.shape(data), number) + shrink_axis_mask = 1 + return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) + + +def tensor_index_by_bool(data, bool_value): + """Tensor getitem by a single bool value""" + if bool_value: + return F.expand_dims(data, 0) + return const_utils.raise_index_error("bool value as indexing ,false is not supported") + + +def tensor_index_by_number(data, number): + """Tensor getitem by a Number which may be integer/float/bool value""" + number_type = const_utils.check_number_index_type(number) + if number_type == const_utils.BOOL_: + return tensor_index_by_bool(data, number) + if number_type == const_utils.INT_: + return tensor_index_by_integer(data, number) + return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool") + + +def tensor_index_by_tensor(data, tensor_index): + """Tensor getitem by a single tensor""" + dtype_valid = const_utils.check_index_tensor_dtype(F.dtype(tensor_index), + const_utils.TENSOR_GETITEM) + if dtype_valid: + return F.gather(data, tensor_index, 0) + return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool") + + +def tensor_index_by_tuple_slice(data, t): + """Tensor getitem by a tuple of slice""" + begin_strides, end_strides, step_strides, shrink_axis_mask = \ + const_utils.get_stride_info_from_tuple(F.shape(data), t) + return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) + + +def tensor_index_by_tuple(data, tuple_index): + """Tensor getitem by tuple of various types""" + indexes_types = hyper_map(F.typeof, tuple_index) + index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) + if index_elements_type == const_utils.NO_TENSOR: + return tensor_index_by_tuple_slice(data, tuple_index) + if index_elements_type == const_utils.ALL_TENSOR: + return tensor_getitem_by_tuple_of_tensor(data, tuple_index) + return tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index) diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index e4d42aed033..02756ffe561 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -20,7 +20,6 @@ import numpy as np from ...primitive import constexpr from .... import log as logger -from ...._extends.utils import Slice, Ellipsis_ from ....common import dtype as mstype from ....common.tensor import Tensor from ....ops import _utils as op_utils @@ -41,6 +40,11 @@ SET_ITEM_BY_ONE_TENSOR = 0 SET_ITEM_BY_TUPLE_OF_TENSOR = 1 +@constexpr +def raise_index_error(msg): + raise IndexError(msg) + + @constexpr def check_equal(param1, param2, msg="{},{}"): """Checks whether the two parameters are equal or not.""" @@ -54,7 +58,8 @@ def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): """Checks the shape and size of the sensor and value.""" if data_shape == value_shape or data_size == value_size or value_size == 1: return True - raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape)) + raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format( + value_shape, data_shape)) @constexpr @@ -63,16 +68,18 @@ def check_tensor_setitem_index(index, element_type=None): if index is None: raise IndexError("Tensor's index cannot be None.") # eg. Tensor[Slice] = u - if isinstance(index, Slice): + if isinstance(index, slice): return True # eg. Tensor[tuple] = u if isinstance(index, tuple): if not index: raise IndexError("Tensor's index cannot be empty.") - # eg. Tensor[tuple(Slice...)] = u - if isinstance(index[0], (Slice, Ellipsis_, int)): - return True - raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0]))) + # eg. Tensor[tuple(Slice,...)] = u + for item in index: + if not isinstance(item, (slice, type(...), int)): + raise IndexError( + "Index of type '{}' is not supported yet.".format(type(item))) + return True # eg. Tensor[Tensor[dtype=bool]] = u if isinstance(index, mstype.tensor_type): if element_type is None or element_type != mstype.bool_: @@ -81,7 +88,8 @@ def check_tensor_setitem_index(index, element_type=None): "{} type is not supported yet.".format(element_type)) return True - raise IndexError("Index of type '{}' is not supported yet.".format(type(index))) + raise IndexError( + "Index of type '{}' is not supported yet.".format(type(index))) @constexpr @@ -116,12 +124,12 @@ def slice_expand(input_slices, shape): index = 0 slices = None # Slice or tuple(Slice...) - if isinstance(input_slices, Slice): + if isinstance(input_slices, slice): slices = (input_slices,) - elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)): + elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (slice, type(...))): is_have_ellipsis = False for _, element in enumerate(input_slices): - if isinstance(element, Ellipsis_): + if isinstance(element, type(...)): is_have_ellipsis = True break if is_have_ellipsis: @@ -130,10 +138,9 @@ def slice_expand(input_slices, shape): slices = input_slices else: raise IndexError("Tensor's index type is not supported yet.") - for s in slices: start = 0 if (s.start is None) else s.start - stop = shape[index] if (s.end is None) else s.end + stop = shape[index] if (s.stop is None) else s.stop step = 1 if (s.step is None) else s.step begin.append(start) end.append(stop) @@ -151,11 +158,11 @@ def ellipsis2slice(input_, shape): """Converts ellipsis to slice.""" input_slice = input_ result = [] - if isinstance(input_, Ellipsis_): + if isinstance(input_, type(...)): input_slice = (input_,) ell_count = 0 for _, element in enumerate(input_slice): - if not isinstance(element, Ellipsis_): + if not isinstance(element, type(...)): result.append(element) continue ell_count += 1 @@ -163,7 +170,7 @@ def ellipsis2slice(input_, shape): raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, " "but it is currently {}".format(input_slice)) for _ in range(len(shape) - len(input_slice) + 1): - result.append(Slice(None, None, None)) + result.append(slice(None, None, None)) return tuple(result) @@ -196,7 +203,8 @@ def slice2indices(input_slices, shape): def check_indices(indices_size, index): """Checks indices whether is empty.""" if indices_size < 1: - raise IndexError("The tensor's index is unreasonable. index:{}".format(index)) + raise IndexError( + "The tensor's index is unreasonable. index:{}".format(index)) return indices_size @@ -230,7 +238,7 @@ def tuple_element_is_slice(indexs): raise IndexError("Tensor's index cannot be empty.") if isinstance(indexs, tuple): for _, ele in enumerate(indexs): - if not isinstance(ele, Slice): + if not isinstance(ele, slice): return False return True return False @@ -285,7 +293,8 @@ def check_value_elements(data_dtype, types): return ALL_TENSOR if scalars_number == len(types): return ALL_SCALAR - raise TypeError(f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.") + raise TypeError( + f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.") @constexpr @@ -295,7 +304,8 @@ def get_index_tensor_dtype(dtype): return INT_ if dtype == mstype.bool_: return BOOL_ - raise IndexError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") + raise IndexError( + f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") @constexpr @@ -313,7 +323,8 @@ def check_index_tensor_dtype(dtype, op_name): """Check a tensor data type.""" if dtype == mstype.int32: return True - raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.") + raise IndexError( + f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.") @constexpr @@ -332,7 +343,8 @@ def generate_broadcast_shape(shapes, op_name): for i, shape in enumerate(shapes): logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") try: - broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name) + broadcast_shape = op_utils.get_broadcast_shape( + broadcast_shape, shape, op_name) except ValueError as ex: raise IndexError(ex) return tuple(broadcast_shape) @@ -398,7 +410,8 @@ def convert_ellipsis_to_tensors(slice_number, if isinstance(ele, tuple): shape.extend([1] * len(ele)) if array is None: - raise ValueError(f"For '{op_name}', generate tensors from ellipsis failed.") + raise ValueError( + f"For '{op_name}', generate tensors from ellipsis failed.") array = np.reshape(array, shape) reps = compute_multiples(shape, final_shape) tensor = Tensor(np.tile(array, reps)) @@ -428,7 +441,8 @@ def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_n else: shape.append(1) if array is None: - raise ValueError(f"For '{op_name}', generate tensor from 'slice' failed.") + raise ValueError( + f"For '{op_name}', generate tensor from 'slice' failed.") array = np.reshape(array, shape) reps = compute_multiples(shape, final_shape) tensor = Tensor(np.tile(array, reps)) @@ -523,14 +537,15 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape, tensor_count += 1 elif isinstance(ele_type, mstype.slice_type): slice_obj = slice(slice_indexes[slice_count].start, - slice_indexes[slice_count].end, + slice_indexes[slice_count].stop, slice_indexes[slice_count].step) # Use list to represent slicing result. indexes_info[pos] = list(range(data_shape[pos]))[slice_obj] slice_count += 1 elif isinstance(ele_type, mstype.ellipsis_type): if ellipsis_num != 0: - raise IndexError(f"For '{op_name}', the index could only contain one ellipsis.") + raise IndexError( + f"For '{op_name}', the index could only contain one ellipsis.") ellipsis_occupied_dims = data_rank - indexes_size + 1 for j in range(pos, pos + ellipsis_occupied_dims): # Use list to represent slicing result. @@ -540,7 +555,8 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape, raise IndexError(f"For '{op_name}', the index elements only support " f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {ele_type}.") broadcast_shape, final_shape, indexes_shapes_info = \ - _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name) + _derive_result_shape_info_from_tuple_of_mixed_tensors( + indexes_info, index_tensors_info, op_name) return broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims @@ -556,10 +572,12 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te """Derive the resulting shape information from the a tuple index of mixed tensors.""" index_tensor_info_key = list(index_tensors_info.keys()) index_tensor_info_value = list(index_tensors_info.values()) - broadcast_shape = generate_broadcast_shape(index_tensor_info_value, op_name) + broadcast_shape = generate_broadcast_shape( + index_tensor_info_value, op_name) final_shape = [] indexes_shapes_info = [] - mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key) + mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous( + index_tensor_info_key) if mixed_tensors_continuous: tensor_shape_dealt = False for ele in indexes_info.values(): @@ -638,3 +656,98 @@ def get_np_eps(input_dtype): nptype = mstype.dtype_to_nptype(input_dtype) eps = np.finfo(nptype).eps return float(eps) + + +@constexpr +def check_number_index_type(number): + """Check if it is int or bool number""" + if isinstance(number, bool): + return BOOL_ + if isinstance(number, int): + return INT_ + raise IndexError("Only support integers, slices(`:`), ellipsis(`...`), None and bool, got {0} type is {1} " + .format(number, type(number))) + + +@constexpr +def get_stride_info_from_slice(data_shape, slice_index): + """Get stride info from a python slice""" + begin, end, step = get_slice_stride(data_shape[0], slice_index) + begin_strides = [begin] + end_strides = [end] + step_strides = [step] + for end in data_shape[1:]: + begin_strides.append(0) + end_strides.append(end) + step_strides.append(1) + return tuple(begin_strides), tuple(end_strides), tuple(step_strides) + + +@constexpr +def get_stride_info_from_integer(data_shape, number): + """Get stride info from a integer""" + begin_strides = [number] + end_strides = [number+1] + step_strides = [1] + for end in data_shape[1:]: + begin_strides.append(0) + end_strides.append(end) + step_strides.append(1) + return tuple(begin_strides), tuple(end_strides), tuple(step_strides) + + +def get_slice_stride(dim_size, index_slice): + """Get slice stride info""" + step = 1 if index_slice.step is None else index_slice.step + start_default = 0 + stop_default = dim_size + if step < 0: + start_default = -1 + stop_default = -(dim_size+1) + start = start_default if index_slice.start is None else index_slice.start + stop = stop_default if index_slice.stop is None else index_slice.stop + return start, stop, step + + +@constexpr +def get_stride_info_from_tuple(data_shape, index_tuple): + """Get stride info from a tuple""" + begin_strides = [] + end_strides = [] + step_strides = [] + index_size = len(index_tuple) + data_shape_size = len(data_shape) + shrink_axis = 0 + index_count = 0 + ellipsis_count = 0 + for idx, item in enumerate(index_tuple): + if isinstance(item, slice): + start, stop, step = get_slice_stride(data_shape[idx], item) + begin_strides.append(start) + end_strides.append(stop) + step_strides.append(step) + index_count = index_count + 1 + elif isinstance(item, int): + begin_strides.append(item) + end_strides.append(item + 1) + step_strides.append(1) + shrink_axis = shrink_axis + (1 << index_count) + index_count = index_count + 1 + elif item is ...: + ellipsis_count = ellipsis_count + 1 + if ellipsis_count > 1: + raise IndexError("An index can have only one ellipsis (...)") + ellipsis_range_size = data_shape_size - (index_size - 1) + begin_strides.extend([0] * (ellipsis_range_size)) + end_strides.extend( + [i for i in data_shape[index_count: index_count + (ellipsis_range_size)]]) + step_strides.extend([1] * (ellipsis_range_size)) + index_count = index_count + ellipsis_range_size + else: + raise IndexError("Not supported index data type, got ", + item, " type is ", type(item)) + for item in range(index_count, data_shape_size): + begin_strides.append(0) + end_strides.append(data_shape[item]) + step_strides.append(1) + return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis diff --git a/mindspore/ops/composite/multitype_ops/getitem_impl.py b/mindspore/ops/composite/multitype_ops/getitem_impl.py index 1295aba87e2..7bc9fd28053 100644 --- a/mindspore/ops/composite/multitype_ops/getitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/getitem_impl.py @@ -15,7 +15,6 @@ """Implementation for getitem.""" from . import _compile_utils as compile_utils -from . import _constexpr_utils as const_utils from .. import base from ... import functional as F @@ -50,29 +49,6 @@ _tuple_slice = _TupleSlice('tuple_slice') """_tuple_slice is an metafuncgraph object which will slice a tuple.""" -class _TensorSlice(base.TensorSlice_): - """ - Slices a tensor. - - Inputs: - data (Tensor): A tensor to be sliced. - s (slice): The index to slice tuple data. - - Outputs: - Tensor, consists of some elements of data. - """ - - def __init__(self, name): - base.TensorSlice_.__init__(self, name) - - def __call__(self, *args): - pass - - -_tensor_slice = _TensorSlice('tensor_slice') -"""_tensor_slice is an metafuncgraph object which will slice a tensor.""" - - class _TupleGetItemTensor(base.TupleGetItemTensor_): """ Getting item of tuple by tensor index. @@ -182,13 +158,13 @@ def _tensor_getitem_by_number(data, number_index): Outputs: Tensor, element type is as same as the element type of data. """ - return _tensor_slice(data, number_index) + return compile_utils.tensor_index_by_number(data, number_index) @getitem.register("Tensor", "None") def _tensor_getitem_by_none(data, index): """ - Getting item of tensor by None. + For none indexing , expand data with one dim Inputs: data (Tensor): A tensor. @@ -197,7 +173,7 @@ def _tensor_getitem_by_none(data, index): Outputs: Tensor, element type is as same as the element type of data. """ - return _tensor_slice(data, index) + return F.expand_dims(data, 0) @getitem.register("Tensor", "Slice") @@ -212,13 +188,13 @@ def _tensor_getitem_by_slice(data, slice_index): Outputs: Tensor, element type is same as the element type of data. """ - return _tensor_slice(data, slice_index) + return compile_utils.tensor_index_by_slice(data, slice_index) @getitem.register("Tensor", "Tensor") def _tensor_getitem_by_tensor(data, tensor_index): """ - Getting item of tensor by slice. + Getting item of tensor by tensor indice. Inputs: data (Tensor): A tensor. @@ -227,18 +203,13 @@ def _tensor_getitem_by_tensor(data, tensor_index): Outputs: Tensor, element type is same as the element type of data. """ - check_dtypes = const_utils.check_index_tensor_dtype(F.dtype(tensor_index), - const_utils.TENSOR_GETITEM) - result = None - if check_dtypes: - result = F.gather(data, tensor_index, 0) - return result + return compile_utils.tensor_index_by_tensor(data, tensor_index) @getitem.register("Tensor", "Tuple") def _tensor_getitem_by_tuple(data, tuple_index): """ - Getting item of tensor by slice tuple. + Getting item of tensor by tuple. Inputs: data (Tensor): A tensor. @@ -247,13 +218,7 @@ def _tensor_getitem_by_tuple(data, tuple_index): Outputs: Tensor, element type is same as the element type of data. """ - indexes_types = compile_utils.hyper_map(F.typeof, tuple_index) - index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) - if index_elements_type == const_utils.NO_TENSOR: - return _tensor_slice(data, tuple_index) - if index_elements_type == const_utils.ALL_TENSOR: - return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) - return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index) + return compile_utils.tensor_index_by_tuple(data, tuple_index) @getitem.register("Tensor", "Ellipsis") @@ -268,22 +233,4 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index): Outputs: Tensor, same as data. """ - return _tensor_slice(data, ellipsis_index) - - -def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): - """Tensor getitem by a tuple of tensor.""" - indices = compile_utils.generate_indices_from_tuple_of_tensor(data, - tuple_index, - const_utils.TENSOR_GETITEM) - result = F.gather_nd(data, indices) - return result - - -def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): - """Tensor getitem by a tuple of mixed tensor.""" - indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data, - tuple_index, - const_utils.TENSOR_GETITEM) - result = F.gather_nd(data, indices) - return result + return data diff --git a/tests/st/pynative/test_tensor_index.py b/tests/st/pynative/test_tensor_index.py new file mode 100644 index 00000000000..879b4f4c2e5 --- /dev/null +++ b/tests/st/pynative/test_tensor_index.py @@ -0,0 +1,741 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_tensor_slice """ +import numpy as np +import pytest + +from mindspore import Tensor, Parameter +from mindspore import context +from mindspore import dtype as mstype +from mindspore.nn import Cell + +def setup_module(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + +class NetWorkSlicePositive(Cell): + def __init__(self): + super(NetWorkSlicePositive, self).__init__() + self.tensor_ret0 = Tensor(np.ones([1, 2, 3], np.int32)) + self.tensor_ret1 = Tensor(np.ones([4, 8, 10], np.int32)) + self.tensor_ret2 = Tensor(np.ones([6, 8, 10], np.int32)) + self.tensor_ret3 = Tensor(np.ones([3, 8, 10], np.int32)) + + def construct(self, tensor): + ret0 = tensor[3:4:1, 1:5:2, 3:6:1] + self.tensor_ret0 + ret1 = tensor[-6:4:1, 0:8:1, ::1] + self.tensor_ret1 + ret2 = tensor[::, ::, ::] + self.tensor_ret2 + ret3 = tensor[::2] + self.tensor_ret3 + return ret0, ret1, ret2, ret3 + + +def test_slice_positive(): + net = NetWorkSlicePositive() + input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32) + input_0 = Tensor(input_np) + output0, output1, output2, output3 = net(input_0) + assert np.all(output0.asnumpy() == input_np[3:4:1, 1:5:2, 3:6:1] + np.ones([1, 2, 3])) + assert np.all(output1.asnumpy() == input_np[-6:4:1, 0:8:1, ::1] + np.ones([4, 8, 10])) + assert np.all(output2.asnumpy() == input_np[::, ::, ::] + np.ones([6, 8, 10])) + assert np.all(output3.asnumpy() == input_np[::2] + np.ones([3, 8, 10])) + + +class NetWorkSliceEllipsis(Cell): + def __init__(self): + super(NetWorkSliceEllipsis, self).__init__() + self.tensor_ret0 = Tensor(np.ones([2, 7, 8], np.int32)) + self.tensor_ret1 = Tensor(np.ones([6, 7, 8, 9], np.int32)) + self.tensor_ret2 = Tensor(np.ones([1, 6, 7, 8, 9], np.int32)) + + def construct(self, tensor): + ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0 + ret1 = tensor[...] + self.tensor_ret1 + ret2 = tensor[None] + self.tensor_ret2 + ret3 = tensor[True] + self.tensor_ret2 + return ret0, ret1, ret2, ret3 + + +def Xtest_slice_ellipsis(): + net = NetWorkSliceEllipsis() + input_np = np.arange(6*7*8*9).reshape(6, 7, 8, 9).astype(np.int32) + input_0 = Tensor(input_np) + output0, output1, output2, output3 = net(input_0) + assert np.all(output0.asnumpy() == input_np[0:4:2, ..., 1] + np.ones([1, 2, 3])) + assert np.all(output1.asnumpy() == input_np[...] + np.ones([6, 7, 8, 9])) + assert np.all(output2.asnumpy() == input_np[None] + np.ones([6, 7, 8, 9])) + assert np.all(output3.asnumpy() == input_np[True] + np.ones([1, 6, 7, 8, 9])) + + +class NetWorkReduceDimension(Cell): + def __init__(self): + super(NetWorkReduceDimension, self).__init__() + self.tensor_ret1 = Tensor(np.ones([3, 10], np.int32)) + self.tensor_ret2 = Tensor(np.ones([6, 8], np.int32)) + self.tensor_ret3 = Tensor(np.array(8, np.int32)) + self.tensor_ret4 = Tensor(np.ones([8, 10], np.int32)) + + def construct(self, tensor): + ret1 = tensor[::2, 1, ::1] + self.tensor_ret1 + ret2 = tensor[::, ::, 0] + self.tensor_ret2 + ret3 = tensor[3, 2, 5] + self.tensor_ret3 + ret4 = tensor[1] + self.tensor_ret4 + return ret1, ret2, ret3, ret4 + + +def Xtest_reduce_dimension(): + net = NetWorkReduceDimension() + input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32) + input_0 = Tensor(input_np) + output1, output2, output3, output4 = net(input_0) + assert np.all(output1.asnumpy() == input_np[::2, 1, ::1] + np.ones([3, 10])) + assert np.all(output2.asnumpy() == input_np[::, ::, 0] + np.ones([6, 8])) + assert np.all(output3.asnumpy() == input_np[3, 2, 5] + np.array(8, np.int32)) + assert np.all(output4.asnumpy() == input_np[1] + np.ones([8, 10])) + + +class NetWorkSliceStep(Cell): + def __init__(self): + super(NetWorkSliceStep, self).__init__() + self.tensor_ret1 = Tensor(np.ones([6, 5, 10], np.int32)) + self.tensor_ret2 = Tensor(np.ones([3, 5, 5], np.int32)) + + def construct(self, tensor): + ret1 = tensor[::1, -5::, ::-1] + self.tensor_ret1 + ret2 = tensor[::2, -5::, ::2] + self.tensor_ret2 + return ret1, ret2 + + +def Xtest_step_negative(): + net = NetWorkSliceEllipsis() + input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32) + input_0 = Tensor(input_np) + output1, output2 = net(input_0) + assert np.all(output1.asnumpy() == input_np[::1, -5::, ::-1] + np.ones([6, 8, 10])) + assert np.all(output2.asnumpy() == input_np[::2, -5::, ::2] + np.ones([3, 5, 5])) + + +class TensorGetItemByThreeTensors(Cell): + def __init__(self): + super(TensorGetItemByThreeTensors, self).__init__() + self.const0 = Tensor(np.ones((4, 5, 8, 10)), mstype.int32) + self.const1 = Tensor(np.ones((3, 4, 5, 10)), mstype.int32) + self.const2 = Tensor(np.ones((5, 3, 4, 5)), mstype.int32) + + def construct(self, x, index_0, index_1, index_2): + ret0 = x[index_0] + self.const0 + ret1 = x[index_0, index_1] + self.const1 + ret2 = x[index_0, index_1, index_2] + self.const2 + return ret0, ret1, ret2 + + +def Xtest_getitem_by_tensors(): + net = TensorGetItemByThreeTensors() + input_x = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32) + index_0 = np.random.randint(6, size=(3, 4, 5)).astype(np.int32) + index_1 = np.random.randint(6, size=(4, 5)).astype(np.int32) + index_2 = np.random.randint(6, size=(5, 3, 4, 5)).astype(np.int32) + input_x_ms = Tensor(input_x) + index_0_ms = Tensor(index_0) + index_1_ms = Tensor(index_1) + input_2_ms = Tensor(index_2) + output0, output1, output2 = net(input_x_ms, index_0_ms, index_1_ms, input_2_ms) + assert np.all(output0.asnumpy() == input_x[index_0] + np.ones([4, 5, 8, 10])) + assert np.all(output1.asnumpy() == input_x[index_0, index_1] + np.ones([3, 4, 5, 10])) + assert np.all(output2.asnumpy() == input_x[index_0, index_1, index_2] + np.ones([5, 3, 4, 5])) + + +class TensorGetItemByMixedTensors_0(Cell): + def __init__(self): + super(TensorGetItemByMixedTensors_0, self).__init__() + self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32)) + + def construct(self, tensor, index_0, index_1): + ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const + return ret + + +class TensorGetItemByMixedTensors_1(Cell): + def __init__(self): + super(TensorGetItemByMixedTensors_1, self).__init__() + self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32)) + + def construct(self, tensor, index_0, index_1): + ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const + return ret + + +class TensorGetItemByMixedTensors_2(Cell): + def __init__(self): + super(TensorGetItemByMixedTensors_2, self).__init__() + self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32)) + + def construct(self, tensor, index_0, index_1): + ret = tensor[0, index_0, index_1, ..., 3] + self.const + return ret + + +class TensorGetItemByMixedTensors_3(Cell): + def __init__(self): + super(TensorGetItemByMixedTensors_3, self).__init__() + self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32)) + + def construct(self, tensor, index_0, index_1): + ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const + return ret + + +class TensorGetItemByMixedTensors_4(Cell): + def __init__(self): + super(TensorGetItemByMixedTensors_4, self).__init__() + self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32)) + + def construct(self, tensor, index_0, index_1, index_2): + ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const + return ret + + +class TensorGetItemByMixedTensors_5(Cell): + def __init__(self): + super(TensorGetItemByMixedTensors_5, self).__init__() + self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32)) + + def construct(self, tensor, index_0, index_1, index_2): + ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const + return ret + + +class TensorGetItemByMixedTensors_6(Cell): + def __init__(self): + super(TensorGetItemByMixedTensors_6, self).__init__() + self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32)) + + def construct(self, tensor, index_0, index_1, index_2): + ret = tensor[..., index_0, index_1, index_2, 3] + self.const + return ret + + +class TensorSetItemByMixedTensors_0(Cell): + def __init__(self, value): + super(TensorSetItemByMixedTensors_0, self).__init__() + self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32)) + self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), + mstype.float32), + name="x") + self.value = value + + def construct(self, index_0, index_1, index_2): + self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value + ret = self.param + self.const + return ret + + +class TensorSetItemByMixedTensors_1(Cell): + def __init__(self, value): + super(TensorSetItemByMixedTensors_1, self).__init__() + self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32)) + self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32), + name="x") + self.value = value + + def construct(self, index_0, index_1, index_2): + self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value + ret = self.param + self.const + return ret + + +class TensorSetItemByMixedTensors_2(Cell): + def __init__(self, value): + super(TensorSetItemByMixedTensors_2, self).__init__() + self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float16)) + self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float16), + name="x") + self.value = value + + def construct(self, index_0, index_1, index_2): + self.param[..., index_0, index_1, index_2, 3] = self.value + ret = self.param + self.const + return ret + + +class TensorGetItemByMixedTensorsTypeError(Cell): + def __init__(self): + super(TensorGetItemByMixedTensorsTypeError, self).__init__() + + def construct(self, x, index_0, index_1): + ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]] + return ret + + +class TensorGetItemByMixedTensorsNumberError(Cell): + def __init__(self): + super(TensorGetItemByMixedTensorsNumberError, self).__init__() + + def construct(self, x, index_0, index_1): + ret = x[index_0, index_1, 0:3, ..., index_1, index_0] + return ret + + +class TensorSetItemByOneTensorWithNumber(Cell): + def __init__(self, value): + super(TensorSetItemByOneTensorWithNumber, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") + self.value = value + + def construct(self, index): + self.param[index] = self.value + ret = self.param + self.const + return ret + + +class TensorSetItemByOneTensorWithTensor(Cell): + def __init__(self): + super(TensorSetItemByOneTensorWithTensor, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") + + def construct(self, index, value): + self.param[index] = value + ret = self.param + self.const + return ret + + +class TensorSetItemByOneTensorWithTupleOfNumber(Cell): + def __init__(self, value): + super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") + self.value = value + + def construct(self, index): + self.param[index] = self.value + ret = self.param + self.const + return ret + + +class TensorSetItemByOneTensorWithTupleOfTensor(Cell): + def __init__(self): + super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__() + self.const = Tensor(np.ones((6, 3, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6 * 3 * 8).reshape((6, 3, 8)), mstype.float32), name="x") + + def construct(self, index, value_0, value_1, value_2): + self.param[index] = (value_0, value_1, value_2) + ret = self.param + self.const + return ret + + +class TensorSetItemByTensorsWithNumber(Cell): + def __init__(self, value): + super(TensorSetItemByTensorsWithNumber, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") + self.value = value + + def construct(self, index_0, index_1, index_2): + self.param[index_0, index_1, index_2] = self.value + ret = self.param + self.const + return ret + + +class TensorSetItemByTensorsWithTensor(Cell): + def __init__(self): + super(TensorSetItemByTensorsWithTensor, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") + + def construct(self, index_0, index_1, index_2, value): + self.param[index_0, index_1, index_2] = value + ret = self.param + self.const + return ret + + +class TensorSetItemByTensorsWithTensorNumberError(Cell): + def __init__(self): + super(TensorSetItemByTensorsWithTensorNumberError, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") + + def construct(self, index_0, index_1, index_2, index_3, value): + self.param[index_0, index_1, index_2, index_3] = value + ret = self.param + self.const + return ret + + +class TensorSetItemByTensorsWithTupleOfNumber(Cell): + def __init__(self, value): + super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") + self.value = value + + def construct(self, index_0, index_1, index_2): + self.param[index_0, index_1, index_2] = self.value + ret = self.param + self.const + return ret + + +class TensorSetItemByTensorsWithTupleOfTensor(Cell): + def __init__(self): + super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") + + def construct(self, index_0, index_1, index_2, value_0, value_1, value_2): + self.param[index_0, index_1, index_2] = (value_0, value_1, value_2) + ret = self.param + self.const + return ret + + +class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell): + def __init__(self): + super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") + + def construct(self, index_0, index_1, index_2, value_0, value_1): + self.param[index_0, index_1, index_2] = (value_0, value_1) + ret = self.param + self.const + return ret + + +class TensorSetItemByMixedTensors(Cell): + def __init__(self): + super(TensorSetItemByMixedTensors, self).__init__() + self.const = Tensor(np.ones((6, 7, 8)), mstype.float32) + self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x") + self.value = 99.0 + + def construct(self, index_0, index_1): + self.param[index_0, index_1, 0:6] = self.value + ret = self.param + self.const + return ret + + +class TensorAssignWithSliceError1(Cell): + def __init__(self): + super(TensorAssignWithSliceError1, self).__init__() + + def construct(self, a, b): + a[1:3:-1, ::] = b + return a + + +class TensorAssignWithSliceError2(Cell): + def __init__(self): + super(TensorAssignWithSliceError2, self).__init__() + + def construct(self, a, b): + a[1:3:-1] = b + return a + + +class TensorAssignWithSlice2(Cell): + def __init__(self): + super(TensorAssignWithSlice2, self).__init__() + + def construct(self, a, b, ck): + a[1:5] = b + a[3:4] = 5 + a[-1:1:-1] = b + a[-1:3:-1] = 5 + a[::] = b + a[::] = 9 + z = a + ck + return z + + +class TensorAssignWithSlice(Cell): + def __init__(self): + super(TensorAssignWithSlice, self).__init__() + self.c = 2 + + def construct(self, a, b, ck): + a[1:3, ::] = b + a[2:3:, 3:] = b + a[::] = b + a[::] = self.c + a[::, ::] = b + a[::, ::] = self.c + a[2:3:, 0:, 4:1:-1] = b + a[2:3:, 0:, 4:1:-1] = self.c + z = a + ck + return z + + +def test_tensor_assign(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + net = TensorAssignWithSlice() + net2 = TensorAssignWithSlice2() + net_e1 = TensorAssignWithSliceError1() + net_e2 = TensorAssignWithSliceError2() + a = np.arange(60).reshape(3, 4, 5) + ck = np.arange(60).reshape(3, 4, 5) + b = Tensor([1], dtype=mstype.float32) + Ta = Tensor(a, dtype=mstype.float32) + Tck = Tensor(ck, dtype=mstype.float32) + Ta4d = Tensor(a.reshape(1, 3, 4, 5), dtype=mstype.float32) + Ta4d_ck = Tensor(ck.reshape(1, 3, 4, 5), dtype=mstype.float32) + Tb = Tensor([1, 3], dtype=mstype.float32) + Tc = Tensor([], dtype=mstype.float32) + t = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32) + tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32) + net(Ta, b, Tck) + net2(t, b, tck) + # Error for A[Slice] = Number + # 1. A[Slice] = Number, Slice error + with pytest.raises(IndexError): + net_e2(t, 2) + + # Error for A[Slice] = U, U is a Tensor + # 1. A[Slice] = U, u.size is error + with pytest.raises(ValueError): + net2(t, Tb, tck) + # 2. A[Slice] = U, U is empty + with pytest.raises(ValueError): + net2(t, Tc, tck) + # 3. A[Slice] = U, U.size error + with pytest.raises(ValueError): + net2(t, Tb, tck) + + # Error for A[Tuple(Slice...)] = Tensor + # 1. A[Tuple(Slice...)] = U, U is empty + with pytest.raises(ValueError): + net(Ta, Tc, Tck) + # 2. A[Tuple(Slice...)] = U, U.size error + with pytest.raises(ValueError): + net(Ta, Tb, Tck) + # 3. A[Tuple(Slice...)] = U, Slice error + with pytest.raises(IndexError): + net_e1(Ta, b) + + # Error for A[Tuple(Slice...)] = Number + # 1. A[Tuple(Slice...)] = Number, Slice error + with pytest.raises(IndexError): + net_e1(Ta, 2) + + net = TensorAssignWithInteger() + # Error for A[Number] = scalar/Tensor + # 1. A[Number] = U, U is a Tensor, u.size not match + with pytest.raises(ValueError): + net(Ta, Tb, Tck) + with pytest.raises(ValueError): + net(Ta, Tc, Tck) + # 2. A[Number] = U, the number index error + with pytest.raises(IndexError): + net(Ta4d, b, Ta4d_ck) + + # Error for A[(n,m)] = scalar/Tensor + # 1. A[(n,m)] = U, U is a tensor. u.size not match + net = TensorAssignWithTupleInteger() + with pytest.raises(ValueError): + net(Ta, Tc, Tck) + with pytest.raises(ValueError): + net(Ta, Tb, Tck) + # 2. A[(n,m)] = U, the number index error + with pytest.raises(IndexError): + net(Ta4d, b, Ta4d_ck) + + # Error for A[...] = U or A[1:, ...] = u + # 1. A[...] = scalar/tensor + net = TensorAssignWithEllipsis() + net(Ta, Ta4d) + with pytest.raises(ValueError): + net(Ta, Tc) + with pytest.raises(ValueError): + net(Ta, Tb) + # 2. A[::, 1:, ...] = scalar/tensor + net = TensorAssignWithTupleEllipsis() + net(Ta, b) + Tc = Tensor(1, mstype.float32) + with pytest.raises(ValueError): + net(Ta, Tc) + with pytest.raises(ValueError): + net(Ta, Tb) + + +class TensorAssignWithTupleEllipsis2(Cell): + def __init__(self): + super(TensorAssignWithTupleEllipsis2, self).__init__() + + def construct(self, a, b): + a[1:, ..., ::] = b + return a + + +class TensorAssignWithTupleEllipsis(Cell): + def __init__(self): + super(TensorAssignWithTupleEllipsis, self).__init__() + + def construct(self, a, b): + a[:2, ...] = 1 + a[1:, ...] = b + return a + + +class TensorAssignWithEllipsis(Cell): + def __init__(self): + super(TensorAssignWithEllipsis, self).__init__() + + def construct(self, a, b): + a[...] = 1 + a[...] = b + return a + + +class TensorAssignWithInteger(Cell): + def __init__(self): + super(TensorAssignWithInteger, self).__init__() + + def construct(self, a, b, ck): + a[1] = 1 + a[0] = b + z = a + ck + return z + + +class TensorAssignWithTupleInteger(Cell): + def __init__(self): + super(TensorAssignWithTupleInteger, self).__init__() + + def construct(self, a, b, ck): + a[(1)] = 1 + a[(1)] = b + a[(1, 1)] = b + a[(1, 1)] = 1 + z = a + ck + return z + + +class TensorAssignWithBoolTensorIndex(Cell): + def __init__(self): + super(TensorAssignWithBoolTensorIndex, self).__init__() + self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32) + self.u_scalar = 5 + + def construct(self, a, b, c, u_tensor): + a[c] = self.u_scalar + a[b] = u_tensor + z = a + self.t + return z + + +class TensorAssignWithBoolTensorIndexError(Cell): + def __init__(self): + super(TensorAssignWithBoolTensorIndexError, self).__init__() + + def construct(self, a, b, c, u_tensor): + a[b][c] = u_tensor + return a + + +class TensorAssignWithBoolTensorIndex2(Cell): + def __init__(self): + super(TensorAssignWithBoolTensorIndex2, self).__init__() + self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float32) + self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32) + self.u_scalar = 5 + + def construct(self, a, u_tensor): + a[a > 8] = u_tensor + a[a >= 6] = self.u_scalar + a[a < 3] = self.u_scalar + a[a <= 5] = u_tensor + a[a == 5] = self.u_scalar + z = a + self.t + return z + + +class TensorAssignWithBoolTensorIndex2Error(Cell): + def __init__(self): + super(TensorAssignWithBoolTensorIndex2Error, self).__init__() + + def construct(self, a, u_tensor): + a[a > 8][a > 5] = u_tensor + return a + + +def Xtest_tensor_assign_bool_index(): + a = np.arange(60).reshape(3, 4, 5) + b = a > 5 + c = a < 3 + Ta = Tensor(a, dtype=mstype.float32) + Tb = Tensor(b) + Tc = Tensor(c) + Td = Tensor([True, True]) + u_tensor = Tensor([1], dtype=mstype.float32) + u_tensor_error = Tensor([1, 2], dtype=mstype.float32) + u_scalar = 5 + net1 = TensorAssignWithBoolTensorIndex() + net2 = TensorAssignWithBoolTensorIndex2() + net1(Ta, Tb, Tc, u_tensor) + net1(Ta, Tb, Tc, u_tensor) + with pytest.raises(ValueError): + net1(Ta, Td, Tc, u_tensor) + with pytest.raises(IndexError): + net1(Ta, u_tensor, Tc, u_tensor) + with pytest.raises(ValueError): + net1(Ta, Tb, Td, u_tensor) + with pytest.raises(IndexError): + net1(Ta, Tb, Ta, u_tensor) + with pytest.raises(ValueError): + net1(Ta, Tb, Tc, u_tensor_error) + # net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar) + with pytest.raises(ValueError): + net2(Ta, u_tensor_error) + net3 = TensorAssignWithBoolTensorIndexError() + with pytest.raises(AttributeError): + net3(Ta, Tb, Tc, u_tensor) + with pytest.raises(AttributeError): + net3(Ta, Tb, Tc, u_scalar) + net4 = TensorAssignWithBoolTensorIndex2Error() + with pytest.raises(AttributeError): + net4(Ta, u_tensor) + with pytest.raises(AttributeError): + net4(Ta, u_scalar) + + +def Xtest_tensor_slice_reduce_out_of_bounds_neg(): + class NetWork(Cell): + def __init__(self): + super(NetWork, self).__init__() + self.tensor_ret = Tensor(np.array(9, np.int32)) + + def construct(self, tensor): + ret = tensor[-7, 3, 4] + return ret + + input_tensor = Tensor(np.ones([6, 8, 10], np.int32)) + net = NetWork() + with pytest.raises(ValueError) as ex: + net(input_tensor) + assert "For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`" in str( + ex.value) + + +def Xtest_tensor_slice_reduce_out_of_bounds_positive(): + class NetWork(Cell): + def __init__(self): + super(NetWork, self).__init__() + self.tensor_ret = Tensor(np.array(9, np.int32)) + + def construct(self, tensor): + ret = tensor[6, 3, 4] + return ret + + input_tensor = Tensor(np.ones([6, 8, 10], np.int32)) + net = NetWork() + with pytest.raises(ValueError) as ex: + net(input_tensor) + assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value) diff --git a/tests/ut/cpp/operator/composite_test.cc b/tests/ut/cpp/operator/composite_test.cc index 84e9fda9d2e..8ca318300a8 100644 --- a/tests/ut/cpp/operator/composite_test.cc +++ b/tests/ut/cpp/operator/composite_test.cc @@ -240,156 +240,6 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) { ASSERT_EQ(real, expect); } -TEST_F(TestComposite, test_TensorSliceBySlice) { - MetaFuncGraphPtr tensorSlicePtr = std::make_shared("tensor_slice"); - FuncGraphPtr tensorSlicePtrGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2); - - AbstractBasePtrList eles; - AbstractScalarPtr start_index = std::make_shared(1); - AbstractScalarPtr stop_index = std::make_shared(6); - AbstractScalarPtr step = std::make_shared(2); - - AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8}); - AbstractSlicePtr slice = std::make_shared(start_index, stop_index, step); - AbstractBasePtrList args_spec_list = {tensor, slice}; - - AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred->abstract()); - if (ret == nullptr) { - FAIL() << "Cast ret to abstract array failed."; - } - AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({3, 7, 8}); - ASSERT_EQ(*ret, *expect); -} - -TEST_F(TestComposite, test_TensorSliceBySliceTuple) { - MetaFuncGraphPtr tensorSlicePtr = std::make_shared("tensor_slice"); - FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2); - - AbstractBasePtrList eles; - AbstractScalarPtr start_index = std::make_shared(0); - AbstractScalarPtr stop_index = std::make_shared(6); - AbstractScalarPtr step = std::make_shared(2); - AbstractSlicePtr slice = std::make_shared(start_index, stop_index, step); - eles.push_back(slice); - - start_index = std::make_shared(1); - stop_index = std::make_shared(5); - step = std::make_shared(1); - slice = std::make_shared(start_index, stop_index, step); - eles.push_back(slice); - - start_index = std::make_shared(2); - stop_index = std::make_shared(8); - step = std::make_shared(3); - slice = std::make_shared(start_index, stop_index, step); - eles.push_back(slice); - - AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8}); - AbstractTuplePtr slice_tuple = std::make_shared(eles); - AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; - - AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); - if (ret == nullptr) { - FAIL() << "Cast ret to abstract array failed."; - } - AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({3, 4, 2}); - ASSERT_EQ(*ret, *expect); -} - -TEST_F(TestComposite, test_TensorSliceBySliceTupleToReduceDimension) { - MetaFuncGraphPtr tensorSlicePtr = std::make_shared("tensor_slice"); - FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2); - - AbstractBasePtrList eles; - AbstractScalarPtr start_index = std::make_shared(1); - AbstractScalarPtr stop_index = std::make_shared(5); - AbstractScalarPtr step = std::make_shared(2); - AbstractSlicePtr slice = std::make_shared(start_index, stop_index, step); - eles.push_back(slice); - - AbstractScalarPtr elem_index = std::make_shared(1); - eles.push_back(elem_index); - - start_index = std::make_shared(2); - stop_index = std::make_shared(6); - step = std::make_shared(1); - slice = std::make_shared(start_index, stop_index, step); - eles.push_back(slice); - - AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8}); - AbstractTuplePtr slice_tuple = std::make_shared(eles); - AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; - - AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); - if (ret == nullptr) { - FAIL() << "Cast ret to abstract array failed."; - } - AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({2, 4}); - ASSERT_EQ(*ret, *expect); -} - -TEST_F(TestComposite, test_TensorSliceByScalar) { - MetaFuncGraphPtr tensorSlicePtr = std::make_shared("tensor_slice"); - FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2); - - AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8}); - AbstractScalarPtr start_index = std::make_shared(2); - AbstractBasePtrList args_spec_list = {tensor, start_index}; - - AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); - if (ret == nullptr) { - FAIL() << "Cast ret to abstract array failed."; - } - AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({7, 8}); - ASSERT_EQ(*ret, *expect); -} - -TEST_F(TestComposite, test_TensorSliceByScalarTuple) { - MetaFuncGraphPtr tensorSlicePtr = std::make_shared("tensor_slice"); - FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2); - - AbstractBasePtrList eles; - AbstractScalarPtr elem_index = std::make_shared(1); - eles.push_back(elem_index); - elem_index = std::make_shared(3); - eles.push_back(elem_index); - - AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8}); - AbstractTuplePtr slice_tuple = std::make_shared(eles); - AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; - - AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); - if (ret == nullptr) { - FAIL() << "Cast ret to abstract array failed."; - } - AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({8}); - ASSERT_EQ(*ret, *expect); -} - -TEST_F(TestComposite, test_TensorSliceByScalarTupleToScalar) { - MetaFuncGraphPtr tensorSlicePtr = std::make_shared("tensor_slice"); - FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2); - - AbstractBasePtrList eles; - AbstractScalarPtr elem_index = std::make_shared(3); - eles.push_back(elem_index); - elem_index = std::make_shared(0); - eles.push_back(elem_index); - elem_index = std::make_shared(6); - eles.push_back(elem_index); - - AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8}); - AbstractTuplePtr slice_tuple = std::make_shared(eles); - AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; - - AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); - if (ret == nullptr) { - FAIL() << "Cast ret to abstract array failed."; - } - AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({}); - ASSERT_EQ(*ret, *expect); -} - TEST_F(TestComposite, test_UnpackCall_3args) { MetaFuncGraphPtr unPackCallPtr = std::make_shared("UnPackCall"); FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3); diff --git a/tests/ut/python/nn/optim/test_optimizer.py b/tests/ut/python/nn/optim/test_optimizer.py index 548094840e5..0a12f113651 100644 --- a/tests/ut/python/nn/optim/test_optimizer.py +++ b/tests/ut/python/nn/optim/test_optimizer.py @@ -107,5 +107,5 @@ class TestUnsupportParam(): def test_Sgd_init(self): with pytest.raises(TypeError): - paramsTensor = Tensor(np.zeros([1, 2, 3])) + paramsTensor = Parameter(Tensor(np.zeros([1, 2, 3])), "x") SGD(paramsTensor) diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index 7985b9c2324..55159f19ae0 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -25,7 +25,6 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \ pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception - class NetWorkSlicePositive(Cell): def __init__(self): super(NetWorkSlicePositive, self).__init__() @@ -528,6 +527,7 @@ def test_tensor_assign(): # 2. A[::, 1:, ...] = scalar/tensor net = TensorAssignWithTupleEllipsis() net(Ta, b) + Tc = Tensor(1, mstype.float32) with pytest.raises(ValueError): net(Ta, Tc) with pytest.raises(ValueError): diff --git a/tests/ut/python/pynative_mode/ops/test_grad.py b/tests/ut/python/pynative_mode/ops/test_grad.py index c2e14129ea6..8d880a86d9b 100644 --- a/tests/ut/python/pynative_mode/ops/test_grad.py +++ b/tests/ut/python/pynative_mode/ops/test_grad.py @@ -168,7 +168,7 @@ def test_select_grad(): sens = Tensor(np.ones_like(out.asnumpy()).astype(np.float32)) args = [cond, x, y, sens] gout = gfn(*args) - expect_cond = np.zeros_like(cond) + expect_cond = np.zeros_like(cond.asnumpy()) expect_x = np.array([[1, 0, 0], [0, 1, 1]]) expect_y = np.array([[0, 1, 1], [1, 0, 0]]) assert np.all(gout[0].asnumpy() == expect_cond)