forked from mindspore-Ecosystem/mindspore
!1920 SupportPynativeIndexing
Merge pull request !1920 from amongo/SupportPynativeIndexing
This commit is contained in:
commit
f859dfecc8
|
@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope,
|
|||
get_dataclass_attributes, get_dataclass_methods, get_obj_id,
|
||||
get_module_namespace, get_obj_type, get_object_key,
|
||||
get_default_input, get_parse_method_of_class, get_scope_name,
|
||||
is_class_member, parse_cb, resolve_symbol, create_ellipsis_obj)
|
||||
is_class_member, parse_cb, resolve_symbol)
|
||||
from .serialize import *
|
||||
|
||||
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
|
||||
|
@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
|
|||
'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
|
||||
'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
|
||||
'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name',
|
||||
'create_slice_obj', 'create_ellipsis_obj']
|
||||
'create_slice_obj']
|
||||
|
|
|
@ -29,7 +29,6 @@ from mindspore.common.dtype import pytype_to_dtype
|
|||
from mindspore.common.api import _MindSporeFunction
|
||||
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
|
||||
from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
|
||||
from ..utils import Slice, Ellipsis_
|
||||
|
||||
# define return value
|
||||
RET_SUCCESS = 0
|
||||
|
@ -70,14 +69,9 @@ parse_expr_statement_white_list = (
|
|||
"append",
|
||||
)
|
||||
|
||||
def create_ellipsis_obj():
|
||||
"""Create Slice object"""
|
||||
return Ellipsis_()
|
||||
|
||||
|
||||
def create_slice_obj(start, end, step):
|
||||
"""Create Slice object"""
|
||||
return Slice(start, end, step)
|
||||
"""Create slice object"""
|
||||
return slice(start, end, step)
|
||||
|
||||
|
||||
def parse_cb(func, parse_method=None):
|
||||
|
|
|
@ -19,7 +19,6 @@ import logging
|
|||
import os
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def cal_sha256(file_path):
|
||||
|
@ -100,20 +99,3 @@ def cell_attr_register(fn=None, attrs=None):
|
|||
if fn is not None:
|
||||
return wrap_cell(fn)
|
||||
return wrap_cell
|
||||
|
||||
|
||||
@dataclass
|
||||
class Slice:
|
||||
"""
|
||||
Slice class
|
||||
"""
|
||||
start: int
|
||||
end: int
|
||||
step: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ellipsis_:
|
||||
"""
|
||||
Ellipsis class
|
||||
"""
|
||||
|
|
|
@ -932,206 +932,6 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
|
|||
return ret;
|
||||
}
|
||||
|
||||
int ConvertBinaryToDecimal(const std::vector<unsigned int> &number_bin) {
|
||||
unsigned int number_dec = 0;
|
||||
for (size_t index = 0; index < number_bin.size(); index++) {
|
||||
number_dec |= number_bin[index] << index;
|
||||
}
|
||||
return static_cast<int>(number_dec);
|
||||
}
|
||||
|
||||
void ParseSlice(const AbstractSlicePtr &slice, std::vector<int> *begin, std::vector<int> *end,
|
||||
std::vector<int> *strides, int length) {
|
||||
MS_EXCEPTION_IF_NULL(slice);
|
||||
MS_EXCEPTION_IF_NULL(begin);
|
||||
MS_EXCEPTION_IF_NULL(end);
|
||||
MS_EXCEPTION_IF_NULL(strides);
|
||||
if (length <= 0) {
|
||||
MS_LOG(EXCEPTION) << "Could not slice a dim when it's length less than 1";
|
||||
}
|
||||
|
||||
int start_default = 0;
|
||||
int stop_default = length;
|
||||
int step_default = 1;
|
||||
int step_value = CheckSliceMember(slice->step(), step_default, "step");
|
||||
if (step_value < 0) {
|
||||
start_default = -1;
|
||||
stop_default = -(length + 1);
|
||||
}
|
||||
|
||||
begin->push_back(CheckSliceMember(slice->start(), start_default, "begin"));
|
||||
end->push_back(CheckSliceMember(slice->stop(), stop_default, "stop"));
|
||||
strides->push_back(step_value);
|
||||
}
|
||||
|
||||
int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, const std::vector<int> &shape,
|
||||
std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) {
|
||||
MS_EXCEPTION_IF_NULL(slice_tuple);
|
||||
MS_EXCEPTION_IF_NULL(begin);
|
||||
MS_EXCEPTION_IF_NULL(end);
|
||||
MS_EXCEPTION_IF_NULL(strides);
|
||||
|
||||
size_t slice_tuple_size = slice_tuple->size();
|
||||
size_t shape_size = shape.size();
|
||||
if (slice_tuple_size > shape_size) {
|
||||
MS_LOG(EXCEPTION) << "The number of slice data to slice tensor should be less than the rank of tensor,"
|
||||
"when the rank of tensor is "
|
||||
<< shape_size << ", the number of slice is " << slice_tuple_size;
|
||||
}
|
||||
|
||||
std::vector<unsigned int> shrink;
|
||||
auto slice_tuple_eles = slice_tuple->elements();
|
||||
size_t ellipsis_num = 0;
|
||||
|
||||
for (size_t index = 0; index < slice_tuple_size; index++) {
|
||||
if (slice_tuple_eles[index]->isa<AbstractSlice>()) {
|
||||
AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]);
|
||||
ParseSlice(slice, begin, end, strides, shape[index]);
|
||||
shrink.push_back(0);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slice_tuple_eles[index]->isa<AbstractScalar>()) {
|
||||
int ele_index = GetArgScalarValue(dyn_cast<AbstractScalar>(slice_tuple_eles[index]), "slice_tuple");
|
||||
begin->push_back(ele_index);
|
||||
end->push_back(ele_index + 1);
|
||||
strides->push_back(1);
|
||||
shrink.push_back(1);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slice_tuple_eles[index]->isa<AbstractEllipsis>()) {
|
||||
ellipsis_num++;
|
||||
if (ellipsis_num > 1) {
|
||||
MS_LOG(EXCEPTION) << "Tensor slice supports at most one ellipsis";
|
||||
}
|
||||
size_t ellipsis_len = shape_size - (slice_tuple_size - 1);
|
||||
begin->insert(begin->end(), ellipsis_len, 0);
|
||||
end->insert(end->end(), shape.begin() + index, shape.begin() + index + ellipsis_len);
|
||||
strides->insert(strides->end(), ellipsis_len, 1);
|
||||
shrink.insert(shrink.end(), ellipsis_len, 0);
|
||||
continue;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Slice tuple only could contain slice, int number or ellipsis, but got "
|
||||
<< slice_tuple_eles[index]->ToString();
|
||||
}
|
||||
|
||||
if (ellipsis_num == 0) {
|
||||
for (size_t index = slice_tuple_size; index < shape_size; index++) {
|
||||
begin->push_back(0);
|
||||
end->push_back(shape[index]);
|
||||
strides->push_back(1);
|
||||
}
|
||||
}
|
||||
return ConvertBinaryToDecimal(shrink);
|
||||
}
|
||||
|
||||
int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr &slice, const std::vector<int> &shape,
|
||||
std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) {
|
||||
MS_EXCEPTION_IF_NULL(begin);
|
||||
MS_EXCEPTION_IF_NULL(end);
|
||||
MS_EXCEPTION_IF_NULL(strides);
|
||||
size_t shape_size = shape.size();
|
||||
if (shape_size == 0) {
|
||||
MS_LOG(EXCEPTION) << "Could slice a scalar tensor";
|
||||
}
|
||||
|
||||
ParseSlice(slice, begin, end, strides, shape[0]);
|
||||
|
||||
for (size_t index = 1; index < shape_size; index++) {
|
||||
begin->push_back(0);
|
||||
end->push_back(shape[index]);
|
||||
strides->push_back(1);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, const std::vector<int> &shape,
|
||||
std::vector<int> *begin, std::vector<int> *end,
|
||||
std::vector<int> *strides) {
|
||||
MS_EXCEPTION_IF_NULL(begin);
|
||||
MS_EXCEPTION_IF_NULL(end);
|
||||
MS_EXCEPTION_IF_NULL(strides);
|
||||
int ele_index = GetArgScalarValue(scalar, "slice_tuple");
|
||||
|
||||
begin->push_back(ele_index);
|
||||
end->push_back(ele_index + 1);
|
||||
strides->push_back(1);
|
||||
|
||||
for (size_t index = 1; index < shape.size(); index++) {
|
||||
begin->push_back(0);
|
||||
end->push_back(shape[index]);
|
||||
strides->push_back(1);
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) {
|
||||
auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
|
||||
ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
|
||||
return ret_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
// slice a tensor
|
||||
// args: tensor, slice or slice tuple
|
||||
const std::string op_name = std::string("TensorSlice");
|
||||
abstract::CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
|
||||
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
||||
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
AnfNodePtr tensor_node = ret_graph->add_parameter();
|
||||
(void)ret_graph->add_parameter();
|
||||
|
||||
auto shape = tensorPtr->shape()->shape();
|
||||
std::vector<int> begin;
|
||||
std::vector<int> end;
|
||||
std::vector<int> strides;
|
||||
int shrink_axis_mask;
|
||||
|
||||
if (args_spec_list[1]->isa<AbstractTuple>()) {
|
||||
AbstractTuplePtr tuple_ptr = dyn_cast<AbstractTuple>(args_spec_list[1]);
|
||||
shrink_axis_mask = GenerateStridedSliceParametersFromTuple(tuple_ptr, shape, &begin, &end, &strides);
|
||||
} else if (args_spec_list[1]->isa<AbstractSlice>()) {
|
||||
AbstractSlicePtr slice_ptr = dyn_cast<AbstractSlice>(args_spec_list[1]);
|
||||
shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides);
|
||||
} else if (args_spec_list[1]->isa<AbstractScalar>()) {
|
||||
AbstractScalarPtr scalar_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]);
|
||||
if (scalar_ptr->BuildValue()->isa<BoolImm>()) {
|
||||
if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) {
|
||||
return ExpandADim(ret_graph, tensor_node);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "TensorSlice not support the index is False.";
|
||||
}
|
||||
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
|
||||
} else if (args_spec_list[1]->isa<AbstractEllipsis>()) {
|
||||
ret_graph->set_output(tensor_node);
|
||||
return ret_graph;
|
||||
} else if (args_spec_list[1]->isa<AbstractNone>()) {
|
||||
return ExpandADim(ret_graph, tensor_node);
|
||||
} else {
|
||||
std::ostringstream args_info;
|
||||
for (const auto &arg : args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
args_info << arg->ToString() << "\n";
|
||||
}
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "TensorSlice requires the input should be one of [slice, ellipsis, int number, bool, none, tuple] , but got "
|
||||
<< args_info.str();
|
||||
}
|
||||
|
||||
auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations");
|
||||
auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0),
|
||||
NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)});
|
||||
ret_graph->set_output(ret_graph->NewCNode(
|
||||
{PrimStridedSlice, tensor_node, NewValueNode(begin), NewValueNode(end), NewValueNode(strides)}));
|
||||
return ret_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
// select indexed item
|
||||
// args: tuple of items, index
|
||||
|
@ -1162,11 +962,6 @@ REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) {
|
|||
.def(py::init<std::string &>());
|
||||
}));
|
||||
|
||||
REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) {
|
||||
(void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_")
|
||||
.def(py::init<std::string &>());
|
||||
}));
|
||||
|
||||
REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
|
||||
(void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>(
|
||||
*m, "TupleGetItemTensor_")
|
||||
|
|
|
@ -175,16 +175,6 @@ class TupleSlice : public MetaFuncGraph {
|
|||
};
|
||||
using TupleSlicePtr = std::shared_ptr<TupleSlice>;
|
||||
|
||||
class TensorSlice : public MetaFuncGraph {
|
||||
public:
|
||||
explicit TensorSlice(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~TensorSlice() override = default;
|
||||
MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||
friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; }
|
||||
};
|
||||
using TensorSlicePtr = std::shared_ptr<TensorSlice>;
|
||||
|
||||
class TupleGetItemTensor : public MetaFuncGraph {
|
||||
public:
|
||||
explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {}
|
||||
|
|
|
@ -209,6 +209,28 @@ bool ConvertTensor(const py::object &obj, ValuePtr *const data) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
|
||||
MS_LOG(DEBUG) << "Converting slice object";
|
||||
|
||||
py::slice slice_obj = obj.cast<py::slice>();
|
||||
auto convert_func = [obj](std::string attr) -> ValuePtr {
|
||||
auto py_attr = py::getattr(obj, attr.c_str());
|
||||
if (py::isinstance<py::none>(py_attr)) {
|
||||
return kNone;
|
||||
} else if (py::isinstance<py::int_>(py_attr)) {
|
||||
int value = py::cast<int>(py_attr);
|
||||
return MakeValue(value);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Slice should contain only int or none";
|
||||
}
|
||||
};
|
||||
ValuePtr start = convert_func("start");
|
||||
ValuePtr stop = convert_func("stop");
|
||||
ValuePtr step = convert_func("step");
|
||||
*data = std::make_shared<ValueSlice>(start, stop, step);
|
||||
return true;
|
||||
}
|
||||
|
||||
FuncGraphPtr ConvertToBpropCut(py::object obj) {
|
||||
std::vector<std::string> results = data_converter::GetObjKey(obj);
|
||||
std::string obj_key = results[0];
|
||||
|
@ -321,6 +343,10 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
|
|||
converted = std::make_shared<StringImm>(py::cast<std::string>(obj));
|
||||
} else if (py::isinstance<py::dict>(obj)) {
|
||||
ret = ConvertDict(obj, &converted, use_signature);
|
||||
} else if (py::isinstance<py::slice>(obj)) {
|
||||
ret = ConvertSlice(obj, &converted);
|
||||
} else if (py::isinstance<py::ellipsis>(obj)) {
|
||||
converted = kEllipsis;
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
ret = ConvertTuple(obj, &converted, use_signature);
|
||||
} else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {
|
||||
|
|
|
@ -353,11 +353,9 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
|||
auto value = abs_base->cast<AbstractRefPtr>()->ref();
|
||||
dic = ConvertAbstractToPython(value);
|
||||
} else if (abs_base->isa<AbstractEllipsis>()) {
|
||||
auto arg_slice = dyn_cast<AbstractEllipsis>(abs_base);
|
||||
std::vector<int> shape;
|
||||
dic["shape"] = shape;
|
||||
dic["dtype"] = arg_slice->BuildType();
|
||||
dic["value"] = BuildValue(arg_slice->BuildValue());
|
||||
dic["shape"] = py::none();
|
||||
dic["dtype"] = py::ellipsis();
|
||||
dic["value"] = py::ellipsis();
|
||||
} else if (abs_base->isa<AbstractTuple>()) {
|
||||
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
|
||||
size_t len = arg_tuple->size();
|
||||
|
|
|
@ -106,7 +106,7 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
|
|||
}
|
||||
ret = rets;
|
||||
} else if (value->isa<EllipsisObj>()) {
|
||||
ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_ELLIPSIS);
|
||||
ret = py::ellipsis();
|
||||
} else if (value->isa<ValueSlice>()) {
|
||||
auto slice = value->cast<ValueSlicePtr>();
|
||||
auto start = ValuePtrToPyData(slice->start());
|
||||
|
|
|
@ -206,6 +206,9 @@ class Parameter:
|
|||
res.default_input = res.default_input / other
|
||||
return res
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
return self
|
||||
|
||||
def set_parameter_data(self, data):
|
||||
"""Set `default_input` of current `Parameter`."""
|
||||
if isinstance(data, bool):
|
||||
|
|
|
@ -144,6 +144,13 @@ class Tensor(Tensor_):
|
|||
out = tensor_operator_registry.get('__le__')(self, other)
|
||||
return out
|
||||
|
||||
def __getitem__(self, index):
|
||||
out = tensor_operator_registry.get('__getitem__')(self, index)
|
||||
return out
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
return self
|
||||
|
||||
def __gt__(self, other):
|
||||
out = tensor_operator_registry.get('__gt__')(self, other)
|
||||
return out
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
from functools import partial
|
||||
|
||||
from mindspore import context
|
||||
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, TensorSlice_, \
|
||||
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \
|
||||
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
|
||||
from ...common import dtype as mstype
|
||||
from ...common.api import ms_function, _pynative_exec, _wrap_func
|
||||
|
@ -27,7 +27,7 @@ from .. import functional as F
|
|||
from ...common.parameter import Parameter
|
||||
|
||||
|
||||
__all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
||||
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
||||
|
||||
|
||||
def add_flags(fn, **flags):
|
||||
|
|
|
@ -18,7 +18,9 @@ from . import _constexpr_utils as const_utils
|
|||
from ... import functional as F
|
||||
from ... import operations as P
|
||||
from ...composite import base
|
||||
from ....common.tensor import Tensor
|
||||
from ....common import dtype as mstype
|
||||
from ....common._register_for_tensor import tensor_operator_registry
|
||||
|
||||
hyper_map = base.HyperMap()
|
||||
pack = P.Pack(axis=-1)
|
||||
|
@ -152,3 +154,101 @@ def generate_updates_from_tensor(data, index, value, op_type):
|
|||
if need_broadcast:
|
||||
return broadcast(updates_shape, value)
|
||||
return value
|
||||
|
||||
|
||||
|
||||
def tensor_getitem(self, index):
|
||||
"""Handle tensor getitem"""
|
||||
if isinstance(index, Tensor):
|
||||
return tensor_index_by_tensor(self, index)
|
||||
if isinstance(index, tuple):
|
||||
return tensor_index_by_tuple(self, index)
|
||||
if isinstance(index, int):
|
||||
return tensor_index_by_number(self, index)
|
||||
if isinstance(index, slice):
|
||||
return tensor_index_by_slice(self, index)
|
||||
if isinstance(index, bool):
|
||||
return tensor_index_by_bool(self, index)
|
||||
if index is ...:
|
||||
return self
|
||||
raise IndexError("Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32,\
|
||||
got {} with type{}".format(index, type(index)))
|
||||
|
||||
|
||||
|
||||
tensor_operator_registry.register("__getitem__", tensor_getitem)
|
||||
|
||||
|
||||
def tensor_getitem_by_tuple_of_tensor(data, tuple_index):
|
||||
"""Tensor getitem by a tuple of tensor."""
|
||||
indices = generate_indices_from_tuple_of_tensor(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_GETITEM)
|
||||
result = F.gather_nd(data, indices)
|
||||
return result
|
||||
|
||||
|
||||
def tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
|
||||
"""Tensor getitem by a tuple of mixed tensor."""
|
||||
indices = generate_indices_from_tuple_of_mixed_tensors(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_GETITEM)
|
||||
result = F.gather_nd(data, indices)
|
||||
return result
|
||||
|
||||
|
||||
def tensor_index_by_slice(data, slice_index):
|
||||
"""Tensor getitem by a single slice"""
|
||||
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(F.shape(data), slice_index)
|
||||
return F.strided_slice(data, begin_strides, end_strides, step_strides)
|
||||
|
||||
|
||||
def tensor_index_by_integer(data, number):
|
||||
"""Tensor getitem by a single integer number"""
|
||||
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(F.shape(data), number)
|
||||
shrink_axis_mask = 1
|
||||
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
|
||||
|
||||
|
||||
def tensor_index_by_bool(data, bool_value):
|
||||
"""Tensor getitem by a single bool value"""
|
||||
if bool_value:
|
||||
return F.expand_dims(data, 0)
|
||||
return const_utils.raise_index_error("bool value as indexing ,false is not supported")
|
||||
|
||||
|
||||
def tensor_index_by_number(data, number):
|
||||
"""Tensor getitem by a Number which may be integer/float/bool value"""
|
||||
number_type = const_utils.check_number_index_type(number)
|
||||
if number_type == const_utils.BOOL_:
|
||||
return tensor_index_by_bool(data, number)
|
||||
if number_type == const_utils.INT_:
|
||||
return tensor_index_by_integer(data, number)
|
||||
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool")
|
||||
|
||||
|
||||
def tensor_index_by_tensor(data, tensor_index):
|
||||
"""Tensor getitem by a single tensor"""
|
||||
dtype_valid = const_utils.check_index_tensor_dtype(F.dtype(tensor_index),
|
||||
const_utils.TENSOR_GETITEM)
|
||||
if dtype_valid:
|
||||
return F.gather(data, tensor_index, 0)
|
||||
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool")
|
||||
|
||||
|
||||
def tensor_index_by_tuple_slice(data, t):
|
||||
"""Tensor getitem by a tuple of slice"""
|
||||
begin_strides, end_strides, step_strides, shrink_axis_mask = \
|
||||
const_utils.get_stride_info_from_tuple(F.shape(data), t)
|
||||
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
|
||||
|
||||
|
||||
def tensor_index_by_tuple(data, tuple_index):
|
||||
"""Tensor getitem by tuple of various types"""
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
|
||||
if index_elements_type == const_utils.NO_TENSOR:
|
||||
return tensor_index_by_tuple_slice(data, tuple_index)
|
||||
if index_elements_type == const_utils.ALL_TENSOR:
|
||||
return tensor_getitem_by_tuple_of_tensor(data, tuple_index)
|
||||
return tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index)
|
||||
|
|
|
@ -20,7 +20,6 @@ import numpy as np
|
|||
|
||||
from ...primitive import constexpr
|
||||
from .... import log as logger
|
||||
from ...._extends.utils import Slice, Ellipsis_
|
||||
from ....common import dtype as mstype
|
||||
from ....common.tensor import Tensor
|
||||
from ....ops import _utils as op_utils
|
||||
|
@ -41,6 +40,11 @@ SET_ITEM_BY_ONE_TENSOR = 0
|
|||
SET_ITEM_BY_TUPLE_OF_TENSOR = 1
|
||||
|
||||
|
||||
@constexpr
|
||||
def raise_index_error(msg):
|
||||
raise IndexError(msg)
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_equal(param1, param2, msg="{},{}"):
|
||||
"""Checks whether the two parameters are equal or not."""
|
||||
|
@ -54,7 +58,8 @@ def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
|
|||
"""Checks the shape and size of the sensor and value."""
|
||||
if data_shape == value_shape or data_size == value_size or value_size == 1:
|
||||
return True
|
||||
raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape))
|
||||
raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(
|
||||
value_shape, data_shape))
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -63,16 +68,18 @@ def check_tensor_setitem_index(index, element_type=None):
|
|||
if index is None:
|
||||
raise IndexError("Tensor's index cannot be None.")
|
||||
# eg. Tensor[Slice] = u
|
||||
if isinstance(index, Slice):
|
||||
if isinstance(index, slice):
|
||||
return True
|
||||
# eg. Tensor[tuple] = u
|
||||
if isinstance(index, tuple):
|
||||
if not index:
|
||||
raise IndexError("Tensor's index cannot be empty.")
|
||||
# eg. Tensor[tuple(Slice...)] = u
|
||||
if isinstance(index[0], (Slice, Ellipsis_, int)):
|
||||
return True
|
||||
raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0])))
|
||||
# eg. Tensor[tuple(Slice,...)] = u
|
||||
for item in index:
|
||||
if not isinstance(item, (slice, type(...), int)):
|
||||
raise IndexError(
|
||||
"Index of type '{}' is not supported yet.".format(type(item)))
|
||||
return True
|
||||
# eg. Tensor[Tensor[dtype=bool]] = u
|
||||
if isinstance(index, mstype.tensor_type):
|
||||
if element_type is None or element_type != mstype.bool_:
|
||||
|
@ -81,7 +88,8 @@ def check_tensor_setitem_index(index, element_type=None):
|
|||
"{} type is not supported yet.".format(element_type))
|
||||
return True
|
||||
|
||||
raise IndexError("Index of type '{}' is not supported yet.".format(type(index)))
|
||||
raise IndexError(
|
||||
"Index of type '{}' is not supported yet.".format(type(index)))
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -116,12 +124,12 @@ def slice_expand(input_slices, shape):
|
|||
index = 0
|
||||
slices = None
|
||||
# Slice or tuple(Slice...)
|
||||
if isinstance(input_slices, Slice):
|
||||
if isinstance(input_slices, slice):
|
||||
slices = (input_slices,)
|
||||
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)):
|
||||
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (slice, type(...))):
|
||||
is_have_ellipsis = False
|
||||
for _, element in enumerate(input_slices):
|
||||
if isinstance(element, Ellipsis_):
|
||||
if isinstance(element, type(...)):
|
||||
is_have_ellipsis = True
|
||||
break
|
||||
if is_have_ellipsis:
|
||||
|
@ -130,10 +138,9 @@ def slice_expand(input_slices, shape):
|
|||
slices = input_slices
|
||||
else:
|
||||
raise IndexError("Tensor's index type is not supported yet.")
|
||||
|
||||
for s in slices:
|
||||
start = 0 if (s.start is None) else s.start
|
||||
stop = shape[index] if (s.end is None) else s.end
|
||||
stop = shape[index] if (s.stop is None) else s.stop
|
||||
step = 1 if (s.step is None) else s.step
|
||||
begin.append(start)
|
||||
end.append(stop)
|
||||
|
@ -151,11 +158,11 @@ def ellipsis2slice(input_, shape):
|
|||
"""Converts ellipsis to slice."""
|
||||
input_slice = input_
|
||||
result = []
|
||||
if isinstance(input_, Ellipsis_):
|
||||
if isinstance(input_, type(...)):
|
||||
input_slice = (input_,)
|
||||
ell_count = 0
|
||||
for _, element in enumerate(input_slice):
|
||||
if not isinstance(element, Ellipsis_):
|
||||
if not isinstance(element, type(...)):
|
||||
result.append(element)
|
||||
continue
|
||||
ell_count += 1
|
||||
|
@ -163,7 +170,7 @@ def ellipsis2slice(input_, shape):
|
|||
raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, "
|
||||
"but it is currently {}".format(input_slice))
|
||||
for _ in range(len(shape) - len(input_slice) + 1):
|
||||
result.append(Slice(None, None, None))
|
||||
result.append(slice(None, None, None))
|
||||
return tuple(result)
|
||||
|
||||
|
||||
|
@ -196,7 +203,8 @@ def slice2indices(input_slices, shape):
|
|||
def check_indices(indices_size, index):
|
||||
"""Checks indices whether is empty."""
|
||||
if indices_size < 1:
|
||||
raise IndexError("The tensor's index is unreasonable. index:{}".format(index))
|
||||
raise IndexError(
|
||||
"The tensor's index is unreasonable. index:{}".format(index))
|
||||
return indices_size
|
||||
|
||||
|
||||
|
@ -230,7 +238,7 @@ def tuple_element_is_slice(indexs):
|
|||
raise IndexError("Tensor's index cannot be empty.")
|
||||
if isinstance(indexs, tuple):
|
||||
for _, ele in enumerate(indexs):
|
||||
if not isinstance(ele, Slice):
|
||||
if not isinstance(ele, slice):
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
@ -285,7 +293,8 @@ def check_value_elements(data_dtype, types):
|
|||
return ALL_TENSOR
|
||||
if scalars_number == len(types):
|
||||
return ALL_SCALAR
|
||||
raise TypeError(f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.")
|
||||
raise TypeError(
|
||||
f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -295,7 +304,8 @@ def get_index_tensor_dtype(dtype):
|
|||
return INT_
|
||||
if dtype == mstype.bool_:
|
||||
return BOOL_
|
||||
raise IndexError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
|
||||
raise IndexError(
|
||||
f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -313,7 +323,8 @@ def check_index_tensor_dtype(dtype, op_name):
|
|||
"""Check a tensor data type."""
|
||||
if dtype == mstype.int32:
|
||||
return True
|
||||
raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.")
|
||||
raise IndexError(
|
||||
f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -332,7 +343,8 @@ def generate_broadcast_shape(shapes, op_name):
|
|||
for i, shape in enumerate(shapes):
|
||||
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
|
||||
try:
|
||||
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
|
||||
broadcast_shape = op_utils.get_broadcast_shape(
|
||||
broadcast_shape, shape, op_name)
|
||||
except ValueError as ex:
|
||||
raise IndexError(ex)
|
||||
return tuple(broadcast_shape)
|
||||
|
@ -398,7 +410,8 @@ def convert_ellipsis_to_tensors(slice_number,
|
|||
if isinstance(ele, tuple):
|
||||
shape.extend([1] * len(ele))
|
||||
if array is None:
|
||||
raise ValueError(f"For '{op_name}', generate tensors from ellipsis failed.")
|
||||
raise ValueError(
|
||||
f"For '{op_name}', generate tensors from ellipsis failed.")
|
||||
array = np.reshape(array, shape)
|
||||
reps = compute_multiples(shape, final_shape)
|
||||
tensor = Tensor(np.tile(array, reps))
|
||||
|
@ -428,7 +441,8 @@ def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_n
|
|||
else:
|
||||
shape.append(1)
|
||||
if array is None:
|
||||
raise ValueError(f"For '{op_name}', generate tensor from 'slice' failed.")
|
||||
raise ValueError(
|
||||
f"For '{op_name}', generate tensor from 'slice' failed.")
|
||||
array = np.reshape(array, shape)
|
||||
reps = compute_multiples(shape, final_shape)
|
||||
tensor = Tensor(np.tile(array, reps))
|
||||
|
@ -523,14 +537,15 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape,
|
|||
tensor_count += 1
|
||||
elif isinstance(ele_type, mstype.slice_type):
|
||||
slice_obj = slice(slice_indexes[slice_count].start,
|
||||
slice_indexes[slice_count].end,
|
||||
slice_indexes[slice_count].stop,
|
||||
slice_indexes[slice_count].step)
|
||||
# Use list to represent slicing result.
|
||||
indexes_info[pos] = list(range(data_shape[pos]))[slice_obj]
|
||||
slice_count += 1
|
||||
elif isinstance(ele_type, mstype.ellipsis_type):
|
||||
if ellipsis_num != 0:
|
||||
raise IndexError(f"For '{op_name}', the index could only contain one ellipsis.")
|
||||
raise IndexError(
|
||||
f"For '{op_name}', the index could only contain one ellipsis.")
|
||||
ellipsis_occupied_dims = data_rank - indexes_size + 1
|
||||
for j in range(pos, pos + ellipsis_occupied_dims):
|
||||
# Use list to represent slicing result.
|
||||
|
@ -540,7 +555,8 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape,
|
|||
raise IndexError(f"For '{op_name}', the index elements only support "
|
||||
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {ele_type}.")
|
||||
broadcast_shape, final_shape, indexes_shapes_info = \
|
||||
_derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name)
|
||||
_derive_result_shape_info_from_tuple_of_mixed_tensors(
|
||||
indexes_info, index_tensors_info, op_name)
|
||||
return broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims
|
||||
|
||||
|
||||
|
@ -556,10 +572,12 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te
|
|||
"""Derive the resulting shape information from the a tuple index of mixed tensors."""
|
||||
index_tensor_info_key = list(index_tensors_info.keys())
|
||||
index_tensor_info_value = list(index_tensors_info.values())
|
||||
broadcast_shape = generate_broadcast_shape(index_tensor_info_value, op_name)
|
||||
broadcast_shape = generate_broadcast_shape(
|
||||
index_tensor_info_value, op_name)
|
||||
final_shape = []
|
||||
indexes_shapes_info = []
|
||||
mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key)
|
||||
mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous(
|
||||
index_tensor_info_key)
|
||||
if mixed_tensors_continuous:
|
||||
tensor_shape_dealt = False
|
||||
for ele in indexes_info.values():
|
||||
|
@ -638,3 +656,98 @@ def get_np_eps(input_dtype):
|
|||
nptype = mstype.dtype_to_nptype(input_dtype)
|
||||
eps = np.finfo(nptype).eps
|
||||
return float(eps)
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_number_index_type(number):
|
||||
"""Check if it is int or bool number"""
|
||||
if isinstance(number, bool):
|
||||
return BOOL_
|
||||
if isinstance(number, int):
|
||||
return INT_
|
||||
raise IndexError("Only support integers, slices(`:`), ellipsis(`...`), None and bool, got {0} type is {1} "
|
||||
.format(number, type(number)))
|
||||
|
||||
|
||||
@constexpr
|
||||
def get_stride_info_from_slice(data_shape, slice_index):
|
||||
"""Get stride info from a python slice"""
|
||||
begin, end, step = get_slice_stride(data_shape[0], slice_index)
|
||||
begin_strides = [begin]
|
||||
end_strides = [end]
|
||||
step_strides = [step]
|
||||
for end in data_shape[1:]:
|
||||
begin_strides.append(0)
|
||||
end_strides.append(end)
|
||||
step_strides.append(1)
|
||||
return tuple(begin_strides), tuple(end_strides), tuple(step_strides)
|
||||
|
||||
|
||||
@constexpr
|
||||
def get_stride_info_from_integer(data_shape, number):
|
||||
"""Get stride info from a integer"""
|
||||
begin_strides = [number]
|
||||
end_strides = [number+1]
|
||||
step_strides = [1]
|
||||
for end in data_shape[1:]:
|
||||
begin_strides.append(0)
|
||||
end_strides.append(end)
|
||||
step_strides.append(1)
|
||||
return tuple(begin_strides), tuple(end_strides), tuple(step_strides)
|
||||
|
||||
|
||||
def get_slice_stride(dim_size, index_slice):
|
||||
"""Get slice stride info"""
|
||||
step = 1 if index_slice.step is None else index_slice.step
|
||||
start_default = 0
|
||||
stop_default = dim_size
|
||||
if step < 0:
|
||||
start_default = -1
|
||||
stop_default = -(dim_size+1)
|
||||
start = start_default if index_slice.start is None else index_slice.start
|
||||
stop = stop_default if index_slice.stop is None else index_slice.stop
|
||||
return start, stop, step
|
||||
|
||||
|
||||
@constexpr
|
||||
def get_stride_info_from_tuple(data_shape, index_tuple):
|
||||
"""Get stride info from a tuple"""
|
||||
begin_strides = []
|
||||
end_strides = []
|
||||
step_strides = []
|
||||
index_size = len(index_tuple)
|
||||
data_shape_size = len(data_shape)
|
||||
shrink_axis = 0
|
||||
index_count = 0
|
||||
ellipsis_count = 0
|
||||
for idx, item in enumerate(index_tuple):
|
||||
if isinstance(item, slice):
|
||||
start, stop, step = get_slice_stride(data_shape[idx], item)
|
||||
begin_strides.append(start)
|
||||
end_strides.append(stop)
|
||||
step_strides.append(step)
|
||||
index_count = index_count + 1
|
||||
elif isinstance(item, int):
|
||||
begin_strides.append(item)
|
||||
end_strides.append(item + 1)
|
||||
step_strides.append(1)
|
||||
shrink_axis = shrink_axis + (1 << index_count)
|
||||
index_count = index_count + 1
|
||||
elif item is ...:
|
||||
ellipsis_count = ellipsis_count + 1
|
||||
if ellipsis_count > 1:
|
||||
raise IndexError("An index can have only one ellipsis (...)")
|
||||
ellipsis_range_size = data_shape_size - (index_size - 1)
|
||||
begin_strides.extend([0] * (ellipsis_range_size))
|
||||
end_strides.extend(
|
||||
[i for i in data_shape[index_count: index_count + (ellipsis_range_size)]])
|
||||
step_strides.extend([1] * (ellipsis_range_size))
|
||||
index_count = index_count + ellipsis_range_size
|
||||
else:
|
||||
raise IndexError("Not supported index data type, got ",
|
||||
item, " type is ", type(item))
|
||||
for item in range(index_count, data_shape_size):
|
||||
begin_strides.append(0)
|
||||
end_strides.append(data_shape[item])
|
||||
step_strides.append(1)
|
||||
return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
|
||||
"""Implementation for getitem."""
|
||||
from . import _compile_utils as compile_utils
|
||||
from . import _constexpr_utils as const_utils
|
||||
from .. import base
|
||||
from ... import functional as F
|
||||
|
||||
|
@ -50,29 +49,6 @@ _tuple_slice = _TupleSlice('tuple_slice')
|
|||
"""_tuple_slice is an metafuncgraph object which will slice a tuple."""
|
||||
|
||||
|
||||
class _TensorSlice(base.TensorSlice_):
|
||||
"""
|
||||
Slices a tensor.
|
||||
|
||||
Inputs:
|
||||
data (Tensor): A tensor to be sliced.
|
||||
s (slice): The index to slice tuple data.
|
||||
|
||||
Outputs:
|
||||
Tensor, consists of some elements of data.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
base.TensorSlice_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
_tensor_slice = _TensorSlice('tensor_slice')
|
||||
"""_tensor_slice is an metafuncgraph object which will slice a tensor."""
|
||||
|
||||
|
||||
class _TupleGetItemTensor(base.TupleGetItemTensor_):
|
||||
"""
|
||||
Getting item of tuple by tensor index.
|
||||
|
@ -182,13 +158,13 @@ def _tensor_getitem_by_number(data, number_index):
|
|||
Outputs:
|
||||
Tensor, element type is as same as the element type of data.
|
||||
"""
|
||||
return _tensor_slice(data, number_index)
|
||||
return compile_utils.tensor_index_by_number(data, number_index)
|
||||
|
||||
|
||||
@getitem.register("Tensor", "None")
|
||||
def _tensor_getitem_by_none(data, index):
|
||||
"""
|
||||
Getting item of tensor by None.
|
||||
For none indexing , expand data with one dim
|
||||
|
||||
Inputs:
|
||||
data (Tensor): A tensor.
|
||||
|
@ -197,7 +173,7 @@ def _tensor_getitem_by_none(data, index):
|
|||
Outputs:
|
||||
Tensor, element type is as same as the element type of data.
|
||||
"""
|
||||
return _tensor_slice(data, index)
|
||||
return F.expand_dims(data, 0)
|
||||
|
||||
|
||||
@getitem.register("Tensor", "Slice")
|
||||
|
@ -212,13 +188,13 @@ def _tensor_getitem_by_slice(data, slice_index):
|
|||
Outputs:
|
||||
Tensor, element type is same as the element type of data.
|
||||
"""
|
||||
return _tensor_slice(data, slice_index)
|
||||
return compile_utils.tensor_index_by_slice(data, slice_index)
|
||||
|
||||
|
||||
@getitem.register("Tensor", "Tensor")
|
||||
def _tensor_getitem_by_tensor(data, tensor_index):
|
||||
"""
|
||||
Getting item of tensor by slice.
|
||||
Getting item of tensor by tensor indice.
|
||||
|
||||
Inputs:
|
||||
data (Tensor): A tensor.
|
||||
|
@ -227,18 +203,13 @@ def _tensor_getitem_by_tensor(data, tensor_index):
|
|||
Outputs:
|
||||
Tensor, element type is same as the element type of data.
|
||||
"""
|
||||
check_dtypes = const_utils.check_index_tensor_dtype(F.dtype(tensor_index),
|
||||
const_utils.TENSOR_GETITEM)
|
||||
result = None
|
||||
if check_dtypes:
|
||||
result = F.gather(data, tensor_index, 0)
|
||||
return result
|
||||
return compile_utils.tensor_index_by_tensor(data, tensor_index)
|
||||
|
||||
|
||||
@getitem.register("Tensor", "Tuple")
|
||||
def _tensor_getitem_by_tuple(data, tuple_index):
|
||||
"""
|
||||
Getting item of tensor by slice tuple.
|
||||
Getting item of tensor by tuple.
|
||||
|
||||
Inputs:
|
||||
data (Tensor): A tensor.
|
||||
|
@ -247,13 +218,7 @@ def _tensor_getitem_by_tuple(data, tuple_index):
|
|||
Outputs:
|
||||
Tensor, element type is same as the element type of data.
|
||||
"""
|
||||
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
|
||||
if index_elements_type == const_utils.NO_TENSOR:
|
||||
return _tensor_slice(data, tuple_index)
|
||||
if index_elements_type == const_utils.ALL_TENSOR:
|
||||
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
|
||||
return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index)
|
||||
return compile_utils.tensor_index_by_tuple(data, tuple_index)
|
||||
|
||||
|
||||
@getitem.register("Tensor", "Ellipsis")
|
||||
|
@ -268,22 +233,4 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
|
|||
Outputs:
|
||||
Tensor, same as data.
|
||||
"""
|
||||
return _tensor_slice(data, ellipsis_index)
|
||||
|
||||
|
||||
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
|
||||
"""Tensor getitem by a tuple of tensor."""
|
||||
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_GETITEM)
|
||||
result = F.gather_nd(data, indices)
|
||||
return result
|
||||
|
||||
|
||||
def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
|
||||
"""Tensor getitem by a tuple of mixed tensor."""
|
||||
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
|
||||
tuple_index,
|
||||
const_utils.TENSOR_GETITEM)
|
||||
result = F.gather_nd(data, indices)
|
||||
return result
|
||||
return data
|
||||
|
|
|
@ -0,0 +1,741 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_tensor_slice """
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import context
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.nn import Cell
|
||||
|
||||
def setup_module():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
|
||||
class NetWorkSlicePositive(Cell):
|
||||
def __init__(self):
|
||||
super(NetWorkSlicePositive, self).__init__()
|
||||
self.tensor_ret0 = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
self.tensor_ret1 = Tensor(np.ones([4, 8, 10], np.int32))
|
||||
self.tensor_ret2 = Tensor(np.ones([6, 8, 10], np.int32))
|
||||
self.tensor_ret3 = Tensor(np.ones([3, 8, 10], np.int32))
|
||||
|
||||
def construct(self, tensor):
|
||||
ret0 = tensor[3:4:1, 1:5:2, 3:6:1] + self.tensor_ret0
|
||||
ret1 = tensor[-6:4:1, 0:8:1, ::1] + self.tensor_ret1
|
||||
ret2 = tensor[::, ::, ::] + self.tensor_ret2
|
||||
ret3 = tensor[::2] + self.tensor_ret3
|
||||
return ret0, ret1, ret2, ret3
|
||||
|
||||
|
||||
def test_slice_positive():
|
||||
net = NetWorkSlicePositive()
|
||||
input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
|
||||
input_0 = Tensor(input_np)
|
||||
output0, output1, output2, output3 = net(input_0)
|
||||
assert np.all(output0.asnumpy() == input_np[3:4:1, 1:5:2, 3:6:1] + np.ones([1, 2, 3]))
|
||||
assert np.all(output1.asnumpy() == input_np[-6:4:1, 0:8:1, ::1] + np.ones([4, 8, 10]))
|
||||
assert np.all(output2.asnumpy() == input_np[::, ::, ::] + np.ones([6, 8, 10]))
|
||||
assert np.all(output3.asnumpy() == input_np[::2] + np.ones([3, 8, 10]))
|
||||
|
||||
|
||||
class NetWorkSliceEllipsis(Cell):
|
||||
def __init__(self):
|
||||
super(NetWorkSliceEllipsis, self).__init__()
|
||||
self.tensor_ret0 = Tensor(np.ones([2, 7, 8], np.int32))
|
||||
self.tensor_ret1 = Tensor(np.ones([6, 7, 8, 9], np.int32))
|
||||
self.tensor_ret2 = Tensor(np.ones([1, 6, 7, 8, 9], np.int32))
|
||||
|
||||
def construct(self, tensor):
|
||||
ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0
|
||||
ret1 = tensor[...] + self.tensor_ret1
|
||||
ret2 = tensor[None] + self.tensor_ret2
|
||||
ret3 = tensor[True] + self.tensor_ret2
|
||||
return ret0, ret1, ret2, ret3
|
||||
|
||||
|
||||
def Xtest_slice_ellipsis():
|
||||
net = NetWorkSliceEllipsis()
|
||||
input_np = np.arange(6*7*8*9).reshape(6, 7, 8, 9).astype(np.int32)
|
||||
input_0 = Tensor(input_np)
|
||||
output0, output1, output2, output3 = net(input_0)
|
||||
assert np.all(output0.asnumpy() == input_np[0:4:2, ..., 1] + np.ones([1, 2, 3]))
|
||||
assert np.all(output1.asnumpy() == input_np[...] + np.ones([6, 7, 8, 9]))
|
||||
assert np.all(output2.asnumpy() == input_np[None] + np.ones([6, 7, 8, 9]))
|
||||
assert np.all(output3.asnumpy() == input_np[True] + np.ones([1, 6, 7, 8, 9]))
|
||||
|
||||
|
||||
class NetWorkReduceDimension(Cell):
|
||||
def __init__(self):
|
||||
super(NetWorkReduceDimension, self).__init__()
|
||||
self.tensor_ret1 = Tensor(np.ones([3, 10], np.int32))
|
||||
self.tensor_ret2 = Tensor(np.ones([6, 8], np.int32))
|
||||
self.tensor_ret3 = Tensor(np.array(8, np.int32))
|
||||
self.tensor_ret4 = Tensor(np.ones([8, 10], np.int32))
|
||||
|
||||
def construct(self, tensor):
|
||||
ret1 = tensor[::2, 1, ::1] + self.tensor_ret1
|
||||
ret2 = tensor[::, ::, 0] + self.tensor_ret2
|
||||
ret3 = tensor[3, 2, 5] + self.tensor_ret3
|
||||
ret4 = tensor[1] + self.tensor_ret4
|
||||
return ret1, ret2, ret3, ret4
|
||||
|
||||
|
||||
def Xtest_reduce_dimension():
|
||||
net = NetWorkReduceDimension()
|
||||
input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
|
||||
input_0 = Tensor(input_np)
|
||||
output1, output2, output3, output4 = net(input_0)
|
||||
assert np.all(output1.asnumpy() == input_np[::2, 1, ::1] + np.ones([3, 10]))
|
||||
assert np.all(output2.asnumpy() == input_np[::, ::, 0] + np.ones([6, 8]))
|
||||
assert np.all(output3.asnumpy() == input_np[3, 2, 5] + np.array(8, np.int32))
|
||||
assert np.all(output4.asnumpy() == input_np[1] + np.ones([8, 10]))
|
||||
|
||||
|
||||
class NetWorkSliceStep(Cell):
|
||||
def __init__(self):
|
||||
super(NetWorkSliceStep, self).__init__()
|
||||
self.tensor_ret1 = Tensor(np.ones([6, 5, 10], np.int32))
|
||||
self.tensor_ret2 = Tensor(np.ones([3, 5, 5], np.int32))
|
||||
|
||||
def construct(self, tensor):
|
||||
ret1 = tensor[::1, -5::, ::-1] + self.tensor_ret1
|
||||
ret2 = tensor[::2, -5::, ::2] + self.tensor_ret2
|
||||
return ret1, ret2
|
||||
|
||||
|
||||
def Xtest_step_negative():
|
||||
net = NetWorkSliceEllipsis()
|
||||
input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
|
||||
input_0 = Tensor(input_np)
|
||||
output1, output2 = net(input_0)
|
||||
assert np.all(output1.asnumpy() == input_np[::1, -5::, ::-1] + np.ones([6, 8, 10]))
|
||||
assert np.all(output2.asnumpy() == input_np[::2, -5::, ::2] + np.ones([3, 5, 5]))
|
||||
|
||||
|
||||
class TensorGetItemByThreeTensors(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByThreeTensors, self).__init__()
|
||||
self.const0 = Tensor(np.ones((4, 5, 8, 10)), mstype.int32)
|
||||
self.const1 = Tensor(np.ones((3, 4, 5, 10)), mstype.int32)
|
||||
self.const2 = Tensor(np.ones((5, 3, 4, 5)), mstype.int32)
|
||||
|
||||
def construct(self, x, index_0, index_1, index_2):
|
||||
ret0 = x[index_0] + self.const0
|
||||
ret1 = x[index_0, index_1] + self.const1
|
||||
ret2 = x[index_0, index_1, index_2] + self.const2
|
||||
return ret0, ret1, ret2
|
||||
|
||||
|
||||
def Xtest_getitem_by_tensors():
|
||||
net = TensorGetItemByThreeTensors()
|
||||
input_x = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
|
||||
index_0 = np.random.randint(6, size=(3, 4, 5)).astype(np.int32)
|
||||
index_1 = np.random.randint(6, size=(4, 5)).astype(np.int32)
|
||||
index_2 = np.random.randint(6, size=(5, 3, 4, 5)).astype(np.int32)
|
||||
input_x_ms = Tensor(input_x)
|
||||
index_0_ms = Tensor(index_0)
|
||||
index_1_ms = Tensor(index_1)
|
||||
input_2_ms = Tensor(index_2)
|
||||
output0, output1, output2 = net(input_x_ms, index_0_ms, index_1_ms, input_2_ms)
|
||||
assert np.all(output0.asnumpy() == input_x[index_0] + np.ones([4, 5, 8, 10]))
|
||||
assert np.all(output1.asnumpy() == input_x[index_0, index_1] + np.ones([3, 4, 5, 10]))
|
||||
assert np.all(output2.asnumpy() == input_x[index_0, index_1, index_2] + np.ones([5, 3, 4, 5]))
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_0(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_0, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1):
|
||||
ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_1(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_1, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1):
|
||||
ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_2(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_2, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1):
|
||||
ret = tensor[0, index_0, index_1, ..., 3] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_3(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_3, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1):
|
||||
ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_4(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_4, self).__init__()
|
||||
self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1, index_2):
|
||||
ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_5(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_5, self).__init__()
|
||||
self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1, index_2):
|
||||
ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensors_6(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensors_6, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))
|
||||
|
||||
def construct(self, tensor, index_0, index_1, index_2):
|
||||
ret = tensor[..., index_0, index_1, index_2, 3] + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByMixedTensors_0(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByMixedTensors_0, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32))
|
||||
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)),
|
||||
mstype.float32),
|
||||
name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index_0, index_1, index_2):
|
||||
self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByMixedTensors_1(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByMixedTensors_1, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32))
|
||||
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
|
||||
name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index_0, index_1, index_2):
|
||||
self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByMixedTensors_2(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByMixedTensors_2, self).__init__()
|
||||
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float16))
|
||||
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float16),
|
||||
name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index_0, index_1, index_2):
|
||||
self.param[..., index_0, index_1, index_2, 3] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensorsTypeError(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensorsTypeError, self).__init__()
|
||||
|
||||
def construct(self, x, index_0, index_1):
|
||||
ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]]
|
||||
return ret
|
||||
|
||||
|
||||
class TensorGetItemByMixedTensorsNumberError(Cell):
|
||||
def __init__(self):
|
||||
super(TensorGetItemByMixedTensorsNumberError, self).__init__()
|
||||
|
||||
def construct(self, x, index_0, index_1):
|
||||
ret = x[index_0, index_1, 0:3, ..., index_1, index_0]
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByOneTensorWithNumber(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByOneTensorWithNumber, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index):
|
||||
self.param[index] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByOneTensorWithTensor(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByOneTensorWithTensor, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
|
||||
def construct(self, index, value):
|
||||
self.param[index] = value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index):
|
||||
self.param[index] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 3, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6 * 3 * 8).reshape((6, 3, 8)), mstype.float32), name="x")
|
||||
|
||||
def construct(self, index, value_0, value_1, value_2):
|
||||
self.param[index] = (value_0, value_1, value_2)
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithNumber(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByTensorsWithNumber, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index_0, index_1, index_2):
|
||||
self.param[index_0, index_1, index_2] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithTensor(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByTensorsWithTensor, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
|
||||
def construct(self, index_0, index_1, index_2, value):
|
||||
self.param[index_0, index_1, index_2] = value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithTensorNumberError(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByTensorsWithTensorNumberError, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
|
||||
def construct(self, index_0, index_1, index_2, index_3, value):
|
||||
self.param[index_0, index_1, index_2, index_3] = value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithTupleOfNumber(Cell):
|
||||
def __init__(self, value):
|
||||
super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
self.value = value
|
||||
|
||||
def construct(self, index_0, index_1, index_2):
|
||||
self.param[index_0, index_1, index_2] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithTupleOfTensor(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
|
||||
def construct(self, index_0, index_1, index_2, value_0, value_1, value_2):
|
||||
self.param[index_0, index_1, index_2] = (value_0, value_1, value_2)
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
|
||||
def construct(self, index_0, index_1, index_2, value_0, value_1):
|
||||
self.param[index_0, index_1, index_2] = (value_0, value_1)
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorSetItemByMixedTensors(Cell):
|
||||
def __init__(self):
|
||||
super(TensorSetItemByMixedTensors, self).__init__()
|
||||
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
|
||||
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
|
||||
self.value = 99.0
|
||||
|
||||
def construct(self, index_0, index_1):
|
||||
self.param[index_0, index_1, 0:6] = self.value
|
||||
ret = self.param + self.const
|
||||
return ret
|
||||
|
||||
|
||||
class TensorAssignWithSliceError1(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithSliceError1, self).__init__()
|
||||
|
||||
def construct(self, a, b):
|
||||
a[1:3:-1, ::] = b
|
||||
return a
|
||||
|
||||
|
||||
class TensorAssignWithSliceError2(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithSliceError2, self).__init__()
|
||||
|
||||
def construct(self, a, b):
|
||||
a[1:3:-1] = b
|
||||
return a
|
||||
|
||||
|
||||
class TensorAssignWithSlice2(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithSlice2, self).__init__()
|
||||
|
||||
def construct(self, a, b, ck):
|
||||
a[1:5] = b
|
||||
a[3:4] = 5
|
||||
a[-1:1:-1] = b
|
||||
a[-1:3:-1] = 5
|
||||
a[::] = b
|
||||
a[::] = 9
|
||||
z = a + ck
|
||||
return z
|
||||
|
||||
|
||||
class TensorAssignWithSlice(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithSlice, self).__init__()
|
||||
self.c = 2
|
||||
|
||||
def construct(self, a, b, ck):
|
||||
a[1:3, ::] = b
|
||||
a[2:3:, 3:] = b
|
||||
a[::] = b
|
||||
a[::] = self.c
|
||||
a[::, ::] = b
|
||||
a[::, ::] = self.c
|
||||
a[2:3:, 0:, 4:1:-1] = b
|
||||
a[2:3:, 0:, 4:1:-1] = self.c
|
||||
z = a + ck
|
||||
return z
|
||||
|
||||
|
||||
def test_tensor_assign():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
net = TensorAssignWithSlice()
|
||||
net2 = TensorAssignWithSlice2()
|
||||
net_e1 = TensorAssignWithSliceError1()
|
||||
net_e2 = TensorAssignWithSliceError2()
|
||||
a = np.arange(60).reshape(3, 4, 5)
|
||||
ck = np.arange(60).reshape(3, 4, 5)
|
||||
b = Tensor([1], dtype=mstype.float32)
|
||||
Ta = Tensor(a, dtype=mstype.float32)
|
||||
Tck = Tensor(ck, dtype=mstype.float32)
|
||||
Ta4d = Tensor(a.reshape(1, 3, 4, 5), dtype=mstype.float32)
|
||||
Ta4d_ck = Tensor(ck.reshape(1, 3, 4, 5), dtype=mstype.float32)
|
||||
Tb = Tensor([1, 3], dtype=mstype.float32)
|
||||
Tc = Tensor([], dtype=mstype.float32)
|
||||
t = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
|
||||
tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
|
||||
net(Ta, b, Tck)
|
||||
net2(t, b, tck)
|
||||
# Error for A[Slice] = Number
|
||||
# 1. A[Slice] = Number, Slice error
|
||||
with pytest.raises(IndexError):
|
||||
net_e2(t, 2)
|
||||
|
||||
# Error for A[Slice] = U, U is a Tensor
|
||||
# 1. A[Slice] = U, u.size is error
|
||||
with pytest.raises(ValueError):
|
||||
net2(t, Tb, tck)
|
||||
# 2. A[Slice] = U, U is empty
|
||||
with pytest.raises(ValueError):
|
||||
net2(t, Tc, tck)
|
||||
# 3. A[Slice] = U, U.size error
|
||||
with pytest.raises(ValueError):
|
||||
net2(t, Tb, tck)
|
||||
|
||||
# Error for A[Tuple(Slice...)] = Tensor
|
||||
# 1. A[Tuple(Slice...)] = U, U is empty
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tc, Tck)
|
||||
# 2. A[Tuple(Slice...)] = U, U.size error
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tb, Tck)
|
||||
# 3. A[Tuple(Slice...)] = U, Slice error
|
||||
with pytest.raises(IndexError):
|
||||
net_e1(Ta, b)
|
||||
|
||||
# Error for A[Tuple(Slice...)] = Number
|
||||
# 1. A[Tuple(Slice...)] = Number, Slice error
|
||||
with pytest.raises(IndexError):
|
||||
net_e1(Ta, 2)
|
||||
|
||||
net = TensorAssignWithInteger()
|
||||
# Error for A[Number] = scalar/Tensor
|
||||
# 1. A[Number] = U, U is a Tensor, u.size not match
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tb, Tck)
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tc, Tck)
|
||||
# 2. A[Number] = U, the number index error
|
||||
with pytest.raises(IndexError):
|
||||
net(Ta4d, b, Ta4d_ck)
|
||||
|
||||
# Error for A[(n,m)] = scalar/Tensor
|
||||
# 1. A[(n,m)] = U, U is a tensor. u.size not match
|
||||
net = TensorAssignWithTupleInteger()
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tc, Tck)
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tb, Tck)
|
||||
# 2. A[(n,m)] = U, the number index error
|
||||
with pytest.raises(IndexError):
|
||||
net(Ta4d, b, Ta4d_ck)
|
||||
|
||||
# Error for A[...] = U or A[1:, ...] = u
|
||||
# 1. A[...] = scalar/tensor
|
||||
net = TensorAssignWithEllipsis()
|
||||
net(Ta, Ta4d)
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tc)
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tb)
|
||||
# 2. A[::, 1:, ...] = scalar/tensor
|
||||
net = TensorAssignWithTupleEllipsis()
|
||||
net(Ta, b)
|
||||
Tc = Tensor(1, mstype.float32)
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tc)
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tb)
|
||||
|
||||
|
||||
class TensorAssignWithTupleEllipsis2(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithTupleEllipsis2, self).__init__()
|
||||
|
||||
def construct(self, a, b):
|
||||
a[1:, ..., ::] = b
|
||||
return a
|
||||
|
||||
|
||||
class TensorAssignWithTupleEllipsis(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithTupleEllipsis, self).__init__()
|
||||
|
||||
def construct(self, a, b):
|
||||
a[:2, ...] = 1
|
||||
a[1:, ...] = b
|
||||
return a
|
||||
|
||||
|
||||
class TensorAssignWithEllipsis(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithEllipsis, self).__init__()
|
||||
|
||||
def construct(self, a, b):
|
||||
a[...] = 1
|
||||
a[...] = b
|
||||
return a
|
||||
|
||||
|
||||
class TensorAssignWithInteger(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithInteger, self).__init__()
|
||||
|
||||
def construct(self, a, b, ck):
|
||||
a[1] = 1
|
||||
a[0] = b
|
||||
z = a + ck
|
||||
return z
|
||||
|
||||
|
||||
class TensorAssignWithTupleInteger(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithTupleInteger, self).__init__()
|
||||
|
||||
def construct(self, a, b, ck):
|
||||
a[(1)] = 1
|
||||
a[(1)] = b
|
||||
a[(1, 1)] = b
|
||||
a[(1, 1)] = 1
|
||||
z = a + ck
|
||||
return z
|
||||
|
||||
|
||||
class TensorAssignWithBoolTensorIndex(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithBoolTensorIndex, self).__init__()
|
||||
self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32)
|
||||
self.u_scalar = 5
|
||||
|
||||
def construct(self, a, b, c, u_tensor):
|
||||
a[c] = self.u_scalar
|
||||
a[b] = u_tensor
|
||||
z = a + self.t
|
||||
return z
|
||||
|
||||
|
||||
class TensorAssignWithBoolTensorIndexError(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithBoolTensorIndexError, self).__init__()
|
||||
|
||||
def construct(self, a, b, c, u_tensor):
|
||||
a[b][c] = u_tensor
|
||||
return a
|
||||
|
||||
|
||||
class TensorAssignWithBoolTensorIndex2(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithBoolTensorIndex2, self).__init__()
|
||||
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float32)
|
||||
self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32)
|
||||
self.u_scalar = 5
|
||||
|
||||
def construct(self, a, u_tensor):
|
||||
a[a > 8] = u_tensor
|
||||
a[a >= 6] = self.u_scalar
|
||||
a[a < 3] = self.u_scalar
|
||||
a[a <= 5] = u_tensor
|
||||
a[a == 5] = self.u_scalar
|
||||
z = a + self.t
|
||||
return z
|
||||
|
||||
|
||||
class TensorAssignWithBoolTensorIndex2Error(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithBoolTensorIndex2Error, self).__init__()
|
||||
|
||||
def construct(self, a, u_tensor):
|
||||
a[a > 8][a > 5] = u_tensor
|
||||
return a
|
||||
|
||||
|
||||
def Xtest_tensor_assign_bool_index():
|
||||
a = np.arange(60).reshape(3, 4, 5)
|
||||
b = a > 5
|
||||
c = a < 3
|
||||
Ta = Tensor(a, dtype=mstype.float32)
|
||||
Tb = Tensor(b)
|
||||
Tc = Tensor(c)
|
||||
Td = Tensor([True, True])
|
||||
u_tensor = Tensor([1], dtype=mstype.float32)
|
||||
u_tensor_error = Tensor([1, 2], dtype=mstype.float32)
|
||||
u_scalar = 5
|
||||
net1 = TensorAssignWithBoolTensorIndex()
|
||||
net2 = TensorAssignWithBoolTensorIndex2()
|
||||
net1(Ta, Tb, Tc, u_tensor)
|
||||
net1(Ta, Tb, Tc, u_tensor)
|
||||
with pytest.raises(ValueError):
|
||||
net1(Ta, Td, Tc, u_tensor)
|
||||
with pytest.raises(IndexError):
|
||||
net1(Ta, u_tensor, Tc, u_tensor)
|
||||
with pytest.raises(ValueError):
|
||||
net1(Ta, Tb, Td, u_tensor)
|
||||
with pytest.raises(IndexError):
|
||||
net1(Ta, Tb, Ta, u_tensor)
|
||||
with pytest.raises(ValueError):
|
||||
net1(Ta, Tb, Tc, u_tensor_error)
|
||||
# net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
|
||||
with pytest.raises(ValueError):
|
||||
net2(Ta, u_tensor_error)
|
||||
net3 = TensorAssignWithBoolTensorIndexError()
|
||||
with pytest.raises(AttributeError):
|
||||
net3(Ta, Tb, Tc, u_tensor)
|
||||
with pytest.raises(AttributeError):
|
||||
net3(Ta, Tb, Tc, u_scalar)
|
||||
net4 = TensorAssignWithBoolTensorIndex2Error()
|
||||
with pytest.raises(AttributeError):
|
||||
net4(Ta, u_tensor)
|
||||
with pytest.raises(AttributeError):
|
||||
net4(Ta, u_scalar)
|
||||
|
||||
|
||||
def Xtest_tensor_slice_reduce_out_of_bounds_neg():
|
||||
class NetWork(Cell):
|
||||
def __init__(self):
|
||||
super(NetWork, self).__init__()
|
||||
self.tensor_ret = Tensor(np.array(9, np.int32))
|
||||
|
||||
def construct(self, tensor):
|
||||
ret = tensor[-7, 3, 4]
|
||||
return ret
|
||||
|
||||
input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
|
||||
net = NetWork()
|
||||
with pytest.raises(ValueError) as ex:
|
||||
net(input_tensor)
|
||||
assert "For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`" in str(
|
||||
ex.value)
|
||||
|
||||
|
||||
def Xtest_tensor_slice_reduce_out_of_bounds_positive():
|
||||
class NetWork(Cell):
|
||||
def __init__(self):
|
||||
super(NetWork, self).__init__()
|
||||
self.tensor_ret = Tensor(np.array(9, np.int32))
|
||||
|
||||
def construct(self, tensor):
|
||||
ret = tensor[6, 3, 4]
|
||||
return ret
|
||||
|
||||
input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
|
||||
net = NetWork()
|
||||
with pytest.raises(ValueError) as ex:
|
||||
net(input_tensor)
|
||||
assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value)
|
|
@ -240,156 +240,6 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
|
|||
ASSERT_EQ(real, expect);
|
||||
}
|
||||
|
||||
TEST_F(TestComposite, test_TensorSliceBySlice) {
|
||||
MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
|
||||
FuncGraphPtr tensorSlicePtrGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(1);
|
||||
AbstractScalarPtr stop_index = std::make_shared<AbstractScalar>(6);
|
||||
AbstractScalarPtr step = std::make_shared<AbstractScalar>(2);
|
||||
|
||||
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
|
||||
AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||
AbstractBasePtrList args_spec_list = {tensor, slice};
|
||||
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract array failed.";
|
||||
}
|
||||
AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({3, 7, 8});
|
||||
ASSERT_EQ(*ret, *expect);
|
||||
}
|
||||
|
||||
TEST_F(TestComposite, test_TensorSliceBySliceTuple) {
|
||||
MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
|
||||
FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(0);
|
||||
AbstractScalarPtr stop_index = std::make_shared<AbstractScalar>(6);
|
||||
AbstractScalarPtr step = std::make_shared<AbstractScalar>(2);
|
||||
AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||
eles.push_back(slice);
|
||||
|
||||
start_index = std::make_shared<AbstractScalar>(1);
|
||||
stop_index = std::make_shared<AbstractScalar>(5);
|
||||
step = std::make_shared<AbstractScalar>(1);
|
||||
slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||
eles.push_back(slice);
|
||||
|
||||
start_index = std::make_shared<AbstractScalar>(2);
|
||||
stop_index = std::make_shared<AbstractScalar>(8);
|
||||
step = std::make_shared<AbstractScalar>(3);
|
||||
slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||
eles.push_back(slice);
|
||||
|
||||
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
|
||||
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
|
||||
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
|
||||
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract array failed.";
|
||||
}
|
||||
AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({3, 4, 2});
|
||||
ASSERT_EQ(*ret, *expect);
|
||||
}
|
||||
|
||||
TEST_F(TestComposite, test_TensorSliceBySliceTupleToReduceDimension) {
|
||||
MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
|
||||
FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(1);
|
||||
AbstractScalarPtr stop_index = std::make_shared<AbstractScalar>(5);
|
||||
AbstractScalarPtr step = std::make_shared<AbstractScalar>(2);
|
||||
AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||
eles.push_back(slice);
|
||||
|
||||
AbstractScalarPtr elem_index = std::make_shared<AbstractScalar>(1);
|
||||
eles.push_back(elem_index);
|
||||
|
||||
start_index = std::make_shared<AbstractScalar>(2);
|
||||
stop_index = std::make_shared<AbstractScalar>(6);
|
||||
step = std::make_shared<AbstractScalar>(1);
|
||||
slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||
eles.push_back(slice);
|
||||
|
||||
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
|
||||
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
|
||||
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
|
||||
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract array failed.";
|
||||
}
|
||||
AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({2, 4});
|
||||
ASSERT_EQ(*ret, *expect);
|
||||
}
|
||||
|
||||
TEST_F(TestComposite, test_TensorSliceByScalar) {
|
||||
MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
|
||||
FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
|
||||
|
||||
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
|
||||
AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(2);
|
||||
AbstractBasePtrList args_spec_list = {tensor, start_index};
|
||||
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract array failed.";
|
||||
}
|
||||
AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({7, 8});
|
||||
ASSERT_EQ(*ret, *expect);
|
||||
}
|
||||
|
||||
TEST_F(TestComposite, test_TensorSliceByScalarTuple) {
|
||||
MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
|
||||
FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
AbstractScalarPtr elem_index = std::make_shared<AbstractScalar>(1);
|
||||
eles.push_back(elem_index);
|
||||
elem_index = std::make_shared<AbstractScalar>(3);
|
||||
eles.push_back(elem_index);
|
||||
|
||||
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
|
||||
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
|
||||
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
|
||||
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract array failed.";
|
||||
}
|
||||
AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({8});
|
||||
ASSERT_EQ(*ret, *expect);
|
||||
}
|
||||
|
||||
TEST_F(TestComposite, test_TensorSliceByScalarTupleToScalar) {
|
||||
MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
|
||||
FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
AbstractScalarPtr elem_index = std::make_shared<AbstractScalar>(3);
|
||||
eles.push_back(elem_index);
|
||||
elem_index = std::make_shared<AbstractScalar>(0);
|
||||
eles.push_back(elem_index);
|
||||
elem_index = std::make_shared<AbstractScalar>(6);
|
||||
eles.push_back(elem_index);
|
||||
|
||||
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
|
||||
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
|
||||
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
|
||||
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract array failed.";
|
||||
}
|
||||
AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({});
|
||||
ASSERT_EQ(*ret, *expect);
|
||||
}
|
||||
|
||||
TEST_F(TestComposite, test_UnpackCall_3args) {
|
||||
MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
|
||||
FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3);
|
||||
|
|
|
@ -107,5 +107,5 @@ class TestUnsupportParam():
|
|||
|
||||
def test_Sgd_init(self):
|
||||
with pytest.raises(TypeError):
|
||||
paramsTensor = Tensor(np.zeros([1, 2, 3]))
|
||||
paramsTensor = Parameter(Tensor(np.zeros([1, 2, 3])), "x")
|
||||
SGD(paramsTensor)
|
||||
|
|
|
@ -25,7 +25,6 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
|||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
|
||||
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
|
||||
|
||||
|
||||
class NetWorkSlicePositive(Cell):
|
||||
def __init__(self):
|
||||
super(NetWorkSlicePositive, self).__init__()
|
||||
|
@ -528,6 +527,7 @@ def test_tensor_assign():
|
|||
# 2. A[::, 1:, ...] = scalar/tensor
|
||||
net = TensorAssignWithTupleEllipsis()
|
||||
net(Ta, b)
|
||||
Tc = Tensor(1, mstype.float32)
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tc)
|
||||
with pytest.raises(ValueError):
|
||||
|
|
|
@ -168,7 +168,7 @@ def test_select_grad():
|
|||
sens = Tensor(np.ones_like(out.asnumpy()).astype(np.float32))
|
||||
args = [cond, x, y, sens]
|
||||
gout = gfn(*args)
|
||||
expect_cond = np.zeros_like(cond)
|
||||
expect_cond = np.zeros_like(cond.asnumpy())
|
||||
expect_x = np.array([[1, 0, 0], [0, 1, 1]])
|
||||
expect_y = np.array([[0, 1, 1], [1, 0, 0]])
|
||||
assert np.all(gout[0].asnumpy() == expect_cond)
|
||||
|
|
Loading…
Reference in New Issue