forked from mindspore-Ecosystem/mindspore
tensor assign with slice index
Signed-off-by: candanzg <zhangshucheng@huawei.com>
This commit is contained in:
parent
9edc69affe
commit
663d597330
|
@ -18,7 +18,7 @@ Interfaces for parser module in c++.
|
|||
|
||||
from .parser import (Parser, create_obj_instance, generate_scope,
|
||||
get_bprop_method_of_class, get_class_instance_type,
|
||||
get_class_member_namespace_symbol,
|
||||
get_class_member_namespace_symbol, create_slice_obj,
|
||||
get_dataclass_attributes, get_dataclass_methods,
|
||||
get_module_namespace, get_obj_type, get_object_key,
|
||||
get_parse_method_of_class, get_scope_name,
|
||||
|
@ -29,4 +29,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
|
|||
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type',
|
||||
'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol',
|
||||
'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj',
|
||||
'get_dataclass_methods', 'get_scope_name']
|
||||
'get_dataclass_methods', 'get_scope_name', 'create_slice_obj']
|
||||
|
|
|
@ -29,6 +29,7 @@ from mindspore.common.dtype import pytype_to_dtype
|
|||
from mindspore.common.api import _MindSporeFunction
|
||||
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
|
||||
from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
|
||||
from ..utils import Slice
|
||||
|
||||
# define return value
|
||||
RET_SUCCESS = 0
|
||||
|
@ -69,6 +70,10 @@ parse_expr_statement_white_list = (
|
|||
"append",
|
||||
)
|
||||
|
||||
def create_slice_obj(start, end, step):
|
||||
"""Create Slice object"""
|
||||
return Slice(start, end, step)
|
||||
|
||||
|
||||
def parse_cb(func, parse_method=None):
|
||||
"""Implements the function of parse."""
|
||||
|
|
|
@ -19,6 +19,7 @@ import logging
|
|||
import os
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def cal_sha256(file_path):
|
||||
|
@ -99,3 +100,13 @@ def cell_attr_register(fn=None, attrs=None):
|
|||
if fn is not None:
|
||||
return wrap_cell(fn)
|
||||
return wrap_cell
|
||||
|
||||
|
||||
@dataclass
|
||||
class Slice:
|
||||
"""
|
||||
Slice class
|
||||
"""
|
||||
start: int
|
||||
end: int
|
||||
step: int
|
||||
|
|
|
@ -123,6 +123,9 @@ class ValueSlice : public Value {
|
|||
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
std::string DumpText() const override { return ToString(); }
|
||||
ValuePtr start() const { return start_; }
|
||||
ValuePtr stop() const { return stop_; }
|
||||
ValuePtr step() const { return step_; }
|
||||
|
||||
private:
|
||||
ValuePtr start_;
|
||||
|
|
|
@ -79,6 +79,8 @@ const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement";
|
|||
const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope";
|
||||
const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name";
|
||||
|
||||
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
|
||||
|
||||
// define the common name
|
||||
const char NAMED_PRIMITIVE_ITER[] = "iter";
|
||||
const char NAMED_PRIMITIVE_NEXT[] = "next";
|
||||
|
|
|
@ -289,6 +289,13 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
|||
dic["shape"] = shape;
|
||||
dic["dtype"] = abs_base->BuildType();
|
||||
dic["value"] = BuildValue(abs_base->BuildValue());
|
||||
} else if (abs_base->isa<AbstractSlice>()) {
|
||||
auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
|
||||
std::vector<int> shape;
|
||||
dic["shape"] = shape;
|
||||
dic["dtype"] = arg_slice->BuildType();
|
||||
dic["value"] = BuildValue(arg_slice->BuildValue());
|
||||
|
||||
} else if (abs_base->isa<AbstractTuple>()) {
|
||||
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
|
||||
size_t len = arg_tuple->size();
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
|
||||
#include "ir/meta_tensor.h"
|
||||
#include "pipeline/parse/parse.h"
|
||||
#include "pipeline/parse/parse_base.h"
|
||||
#include "ir/value.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -97,6 +98,13 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
|
|||
i++;
|
||||
}
|
||||
ret = rets;
|
||||
} else if (value->isa<ValueSlice>()) {
|
||||
auto slice = value->cast<ValueSlicePtr>();
|
||||
auto start = ValuePtrToPyData(slice->start());
|
||||
auto end = ValuePtrToPyData(slice->stop());
|
||||
auto step = ValuePtrToPyData(slice->step());
|
||||
ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_SLICE, start, end,
|
||||
step);
|
||||
} else if (value->isa<Type>()) {
|
||||
py::tuple v(1);
|
||||
v[0] = value->cast<TypePtr>();
|
||||
|
|
|
@ -15,7 +15,43 @@
|
|||
|
||||
"""constexpr util"""
|
||||
|
||||
import numpy as np
|
||||
from ...primitive import constexpr
|
||||
from ....common.tensor import Tensor
|
||||
from ....common import dtype as mstype
|
||||
from ...._extends.utils import Slice
|
||||
|
||||
@constexpr
|
||||
def check_equal(param1, param2, msg="{},{}"):
|
||||
if param1 != param2:
|
||||
raise ValueError(msg.format(param1, param2))
|
||||
return param1
|
||||
|
||||
@constexpr
|
||||
def check_tensor_setitem_index(index, element_type=None):
|
||||
"""Check tuple index type of tensor assignment."""
|
||||
if index is None:
|
||||
raise ValueError("Tensor's index cannot be None.")
|
||||
# eg. Tensor[Slice] = u
|
||||
if isinstance(index, Slice):
|
||||
return True
|
||||
# eg. Tensor[Tuple] = u
|
||||
if isinstance(index, tuple):
|
||||
if not index:
|
||||
raise ValueError("Tensor's index cannot be empty.")
|
||||
# eg. Tensor[Tuple(Slice...)] = u
|
||||
if not isinstance(index[0], Slice):
|
||||
raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0])))
|
||||
return True
|
||||
# eg. Tensor[Tensor[dtype=bool]] = u
|
||||
if index == mstype.tensor:
|
||||
if element_type is None or element_type != mstype.bool_:
|
||||
raise ValueError(
|
||||
"The index of tensor should be a bool type tensor. \
|
||||
{} type is not supported yet.".format(element_type))
|
||||
return True
|
||||
|
||||
raise ValueError("Index of type '{}' is not supported yet.".format(type(index)))
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -43,3 +79,84 @@ def error_msg(msg="", format_values=""):
|
|||
"""
|
||||
|
||||
raise ValueError(msg.format(*format_values))
|
||||
|
||||
def slice_expand(input_slices, shape):
|
||||
"""
|
||||
Convert slice to indices.
|
||||
|
||||
Inputs:
|
||||
slices (List or Tuple(List, ...)): Slice tuple or slice.
|
||||
shape (Tuple): The shape of a sensor is an integer element tuple.
|
||||
|
||||
Outputs:
|
||||
(List, List, List), This is expressed as (begins, ends, strides).
|
||||
"""
|
||||
begin = []
|
||||
end = []
|
||||
strides = []
|
||||
index = 0
|
||||
slices = None
|
||||
# Slice or Tuple(Slice...)
|
||||
if isinstance(input_slices, Slice):
|
||||
slices = (input_slices,)
|
||||
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice):
|
||||
slices = input_slices
|
||||
else:
|
||||
raise ValueError("Tensor's index type is not supported yet.")
|
||||
|
||||
for s in slices:
|
||||
start = 0 if (s.start is None) else s.start
|
||||
stop = shape[index] if (s.end is None) else s.end
|
||||
step = 1 if (s.step is None) else s.step
|
||||
begin.append(start)
|
||||
end.append(stop)
|
||||
strides.append(step)
|
||||
index += 1
|
||||
while index < len(shape):
|
||||
begin.append(0)
|
||||
end.append(shape[index])
|
||||
strides.append(1)
|
||||
index += 1
|
||||
return begin, end, strides
|
||||
|
||||
@constexpr
|
||||
def slice2indices(input_slices, shape):
|
||||
"""
|
||||
Convert slice to indices.
|
||||
|
||||
Inputs:
|
||||
slices (List or Tuple(List, ...)): Slice tuple or slice.
|
||||
shape (Tuple): The shape of a sensor is an integer element tuple.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is (n, 1).
|
||||
"""
|
||||
begin, end, strides = slice_expand(input_slices, shape)
|
||||
np_r = []
|
||||
for i, element in enumerate(shape):
|
||||
s = begin[i] if (begin[i] >= 0) else (element + begin[i])
|
||||
e = end[i] if (end[i] >= 0) else (element + end[i])
|
||||
np_r.append(np.r_[s:e:strides[i]])
|
||||
# Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape)
|
||||
np_ix = np.ix_(*np_r)
|
||||
ravel = np.ravel_multi_index(np_ix, shape)
|
||||
ravel = Tensor(ravel.reshape(-1, 1), dtype=mstype.int32)
|
||||
return ravel
|
||||
|
||||
@constexpr
|
||||
def check_indices(indices_size, index):
|
||||
if indices_size < 1:
|
||||
raise ValueError("The tensor's index is unreasonable. index:{}".format(index))
|
||||
return indices_size
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_indices_value_size(indices_size, value_size):
|
||||
if value_size < 1:
|
||||
raise ValueError("The value assigned to tensor cannot be empty.")
|
||||
if value_size > 1:
|
||||
if value_size != indices_size:
|
||||
raise ValueError(
|
||||
"The value given to tensor does not match the index size. \
|
||||
value size:{}, indics size:{}".format(value_size, indices_size))
|
||||
return value_size
|
||||
|
|
|
@ -138,25 +138,23 @@ def _tensor_setitem_by_tensor_v1(data, index, value_tensor):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
result = None
|
||||
index_dtype = F.dtype(index)
|
||||
index_shape = F.shape(index)
|
||||
is_bool = mult_util.is_same_type(index_dtype, mstype.bool_)
|
||||
if not is_bool:
|
||||
return mult_util.error_msg(
|
||||
"The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,))
|
||||
data_shape = F.shape(data)
|
||||
if index_shape != data_shape:
|
||||
return mult_util.error_msg(
|
||||
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (data_shape, index_shape))
|
||||
size = F.size(value_tensor)
|
||||
if size != 1:
|
||||
return mult_util.error_msg(
|
||||
"When assign value is a tensor, its size should be 1, but current size is {}.", (size,))
|
||||
dtype = F.dtype(data)
|
||||
u_cast = F.cast(value_tensor, dtype)
|
||||
one_data = F.ones_like(data)
|
||||
u = F.tensor_mul(one_data, u_cast)
|
||||
return F.select(index, u, data)
|
||||
check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype)
|
||||
if check_result:
|
||||
data_shape = F.shape(data)
|
||||
data_shape = mult_util.check_equal(data_shape, index_shape,
|
||||
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||
size = F.size(value_tensor)
|
||||
size = mult_util.check_equal(1, size,
|
||||
"When assign value is a tensor, its size should be {}, but current size is {}.")
|
||||
dtype = F.dtype(data)
|
||||
u_cast = F.cast(value_tensor, dtype)
|
||||
one_data = F.ones_like(data)
|
||||
u = F.tensor_mul(one_data, u_cast)
|
||||
result = F.select(index, u, data)
|
||||
return result
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tensor", "Number")
|
||||
|
@ -179,16 +177,162 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
|
|||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
result = None
|
||||
index_dtype = F.dtype(index)
|
||||
index_shape = F.shape(index)
|
||||
is_bool = mult_util.is_same_type(index_dtype, mstype.bool_)
|
||||
if not is_bool:
|
||||
return mult_util.error_msg(
|
||||
"The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,))
|
||||
shape = F.shape(data)
|
||||
if index_shape != shape:
|
||||
return mult_util.error_msg(
|
||||
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (shape, index_shape))
|
||||
dtype = F.dtype(data)
|
||||
u = F.fill(dtype, shape, value)
|
||||
return F.select(index, u, data)
|
||||
check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype)
|
||||
if check_result:
|
||||
shape = F.shape(data)
|
||||
shape = mult_util.check_equal(
|
||||
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
||||
dtype = F.dtype(data)
|
||||
u = F.fill(dtype, shape, value)
|
||||
result = F.select(index, u, data)
|
||||
return result
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Slice", "Tensor")
|
||||
def _tensor_setitem_with_slice_v3(data, input_slice, value):
|
||||
"""
|
||||
Tensor assignment.
|
||||
|
||||
Note:
|
||||
Syntax support: A[Slice] = U
|
||||
Restraint condition: A is a Tensor
|
||||
Slice like "1:3"
|
||||
U is a Tensor(size=1) or Tensor(size>1)
|
||||
|
||||
Inputs:
|
||||
data (Tensor): Assigned tensor.
|
||||
input_slice (Slice): Slice expression.
|
||||
value (Number): Assignment value.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
return _tensor_assgin_tensor(data, input_slice, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tuple", "Tensor")
|
||||
def _tensor_setitem_with_slice_v4(data, input_slice, value):
|
||||
"""
|
||||
Tensor assignment.
|
||||
|
||||
Note:
|
||||
Syntax support: A[Slice] = U
|
||||
Restraint condition: A is a Tensor
|
||||
Slice like "1:3, ::, :4:-1"
|
||||
U is a Tensor(size=1) or Tensor(size>1)
|
||||
|
||||
Inputs:
|
||||
data (Tensor): Assigned tensor.
|
||||
input_slice (Tuple(Slice)): Slice expression.
|
||||
value (Number): Assignment value.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
return _tensor_assgin_tensor(data, input_slice, value)
|
||||
|
||||
|
||||
def _tensor_assgin_tensor(data, input_slice, value):
|
||||
"""Given a tensor value assign to tensor by slice"""
|
||||
# 1. condition
|
||||
result = None
|
||||
check_result = mult_util.check_tensor_setitem_index(input_slice)
|
||||
if check_result:
|
||||
data_shape = F.shape(data)
|
||||
data_size = F.size(data)
|
||||
data_dtype = F.dtype(data)
|
||||
indices = mult_util.slice2indices(input_slice, data_shape)
|
||||
indices_size = F.size(indices)
|
||||
indices_size = mult_util.check_indices(indices_size, input_slice)
|
||||
update = F.fill(data_dtype, (indices_size,), 1)
|
||||
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
||||
condition_1d = F.cast(condition_1d, mstype.bool_)
|
||||
condition = F.reshape(condition_1d, data_shape)
|
||||
# 2. u
|
||||
value_fill = None
|
||||
value_size = F.size(value)
|
||||
|
||||
value_size = mult_util.check_indices_value_size(indices_size, value_size)
|
||||
if value_size == 1:
|
||||
value_fill = F.fill(data_dtype, (indices_size,), 1)
|
||||
value = F.cast(value, data_dtype)
|
||||
value_fill = F.tensor_mul(value_fill, value)
|
||||
elif value_size > 1:
|
||||
value_fill = F.reshape(value, (indices_size,))
|
||||
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
||||
u = F.reshape(value_1d, data_shape)
|
||||
# A[slice]= u -> A[B]=U -> select(B, U, A)
|
||||
result = F.select(condition, u, data)
|
||||
return result
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Slice", "Number")
|
||||
def _tensor_setitem_with_slice_v1(data, input_slice, value):
|
||||
"""
|
||||
Tensor assignment.
|
||||
|
||||
Note:
|
||||
Syntax support: A[Slice] = u
|
||||
Restraint condition: A is a Tensor.
|
||||
Slice like "1:3"
|
||||
u is a scalar
|
||||
|
||||
Inputs:
|
||||
data (Tensor): Assigned tensor.
|
||||
input_slice (Slice): slice expression.
|
||||
value (Number): Assignment value.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
return _tensor_assgin_number(data, input_slice, value)
|
||||
|
||||
|
||||
@setitem.register("Tensor", "Tuple", "Number")
|
||||
def _tensor_setitem_with_slice_v2(data, input_slice, value):
|
||||
"""
|
||||
Tensor assignment.
|
||||
|
||||
Note:
|
||||
Syntax support: A[Slice] = u
|
||||
Restraint condition: A is a Tensor.
|
||||
Slice like "1:3, ::, :4:-1"
|
||||
u is a scalar
|
||||
|
||||
Inputs:
|
||||
data (Tensor): Assigned tensor.
|
||||
input_slice (Tuple(Slice)): slice expression.
|
||||
value (Number): Assignment value.
|
||||
|
||||
Outputs:
|
||||
Tensor, element type and shape is same as data.
|
||||
"""
|
||||
return _tensor_assgin_number(data, input_slice, value)
|
||||
|
||||
|
||||
def _tensor_assgin_number(data, input_slice, value):
|
||||
"""Given a scalar assign to tensor by slice"""
|
||||
# 1. condition
|
||||
check_result = mult_util.check_tensor_setitem_index(input_slice)
|
||||
result = None
|
||||
if check_result:
|
||||
data_shape = F.shape(data)
|
||||
data_size = F.size(data)
|
||||
data_dtype = F.dtype(data)
|
||||
indices = mult_util.slice2indices(input_slice, data_shape)
|
||||
indices_size = F.size(indices)
|
||||
indices_size = mult_util.check_indices(indices_size, input_slice)
|
||||
update = F.fill(data_dtype, (indices_size,), 1)
|
||||
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
||||
condition_1d = F.cast(condition_1d, mstype.bool_)
|
||||
condition = F.reshape(condition_1d, data_shape)
|
||||
# 2. u
|
||||
value_fill = F.fill(data_dtype, (indices_size,), value)
|
||||
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
||||
u = F.reshape(value_1d, data_shape)
|
||||
# A[slice]= u -> A[B]=U -> select(B, U, A)
|
||||
result = F.select(condition, u, data)
|
||||
return result
|
||||
|
|
|
@ -68,6 +68,7 @@ tuple_to_array = P.TupleToArray()
|
|||
scalar_cast = P.ScalarCast()
|
||||
print_ = P.Print()
|
||||
expand_dims = P.ExpandDims()
|
||||
scatter_nd = P.ScatterNd()
|
||||
|
||||
tuple_setitem = Primitive('tuple_setitem')
|
||||
tuple_getitem = Primitive('tuple_getitem')
|
||||
|
|
|
@ -94,10 +94,101 @@ class NetWorkReduceToScalar(Cell):
|
|||
return ret
|
||||
|
||||
|
||||
class TensorAssignWithSliceError1(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithSliceError1, self).__init__()
|
||||
|
||||
def construct(self, a, b):
|
||||
a[1:3:-1,::] = b
|
||||
return a
|
||||
|
||||
class TensorAssignWithSliceError2(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithSliceError2, self).__init__()
|
||||
|
||||
def construct(self, a, b):
|
||||
a[1:3:-1] = b
|
||||
return a
|
||||
class TensorAssignWithSlice2(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithSlice2, self).__init__()
|
||||
|
||||
def construct(self, a, b):
|
||||
a[1:5] = b
|
||||
a[3:4] = 5
|
||||
a[-1:1:-1] = b
|
||||
a[-1:3:-1] = 5
|
||||
a[::] = b
|
||||
a[::] = 9
|
||||
return a
|
||||
class TensorAssignWithSlice(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithSlice, self).__init__()
|
||||
self.c = 2
|
||||
|
||||
def construct(self, a, b):
|
||||
a[1:3,::] = b
|
||||
a[2:3:,3:] = b
|
||||
a[::] = b
|
||||
a[::] = self.c
|
||||
a[::,::] = b
|
||||
a[::,::] = self.c
|
||||
a[2:3:,0:, 4:1:-1] = b
|
||||
a[2:3:,0:, 4:1:-1] = self.c
|
||||
z = a
|
||||
return z
|
||||
|
||||
def test_tensor_assign_with_slice():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
net = TensorAssignWithSlice()
|
||||
net2= TensorAssignWithSlice2()
|
||||
net_e1 = TensorAssignWithSliceError1()
|
||||
net_e2 = TensorAssignWithSliceError2()
|
||||
a = np.arange(60).reshape(3,4,5)
|
||||
b = Tensor([1])
|
||||
Ta = Tensor(a)
|
||||
Tb= Tensor([1,3])
|
||||
Tc= Tensor([])
|
||||
t = Tensor([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
net(Ta, b)
|
||||
net2(t, b)
|
||||
# Error for A[Slice] = Number
|
||||
# 1. A[Slice] = Number, Slice error
|
||||
with pytest.raises(ValueError):
|
||||
net_e2(t, 2)
|
||||
|
||||
# Error for A[Slice] = U, U is a Tensor
|
||||
# 1. A[Slice] = U, u.size is error
|
||||
with pytest.raises(ValueError):
|
||||
net2(t, Tb)
|
||||
# 2. A[Slice] = U, U is empty
|
||||
with pytest.raises(ValueError):
|
||||
net2(t, Tc)
|
||||
# 3. A[Slice] = U, U.size error
|
||||
with pytest.raises(ValueError):
|
||||
net2(t, Tb)
|
||||
|
||||
# Error for A[Tuple(Slice...)] = Tensor
|
||||
# 1. A[Tuple(Slice...)] = U, U is empty
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tc)
|
||||
# 2. A[Tuple(Slice...)] = U, U.size error
|
||||
with pytest.raises(ValueError):
|
||||
net(Ta, Tb)
|
||||
# 3. A[Tuple(Slice...)] = U, Slice error
|
||||
with pytest.raises(ValueError):
|
||||
net_e1(Ta, b)
|
||||
|
||||
# Error for A[Tuple(Slice...)] = Number
|
||||
# 1. A[Tuple(Slice...)] = Number, Slice error
|
||||
with pytest.raises(ValueError):
|
||||
net_e1(Ta, 2)
|
||||
|
||||
|
||||
class TensorAssignWithBoolTensorIndex(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithBoolTensorIndex, self).__init__()
|
||||
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64)
|
||||
self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float64)
|
||||
|
||||
def construct(self, a, b, c, u_tensor, _scalar):
|
||||
a[c] = u_scalar
|
||||
|
@ -119,6 +210,7 @@ class TensorAssignWithBoolTensorIndex2(Cell):
|
|||
def __init__(self):
|
||||
super(TensorAssignWithBoolTensorIndex2, self).__init__()
|
||||
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64)
|
||||
self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float64)
|
||||
|
||||
def construct(self, a, u_tensor, _scalar):
|
||||
a[a > 8] = u_tensor
|
||||
|
@ -139,7 +231,7 @@ class TensorAssignWithBoolTensorIndex2Error(Cell):
|
|||
return a
|
||||
|
||||
|
||||
a = np.random.uniform(1, 10, [2, 3])
|
||||
a = np.random.uniform(1,10,[3,4,5])
|
||||
b = a > 5
|
||||
c = a < 3
|
||||
Ta = Tensor(a)
|
||||
|
@ -148,13 +240,13 @@ Tc = Tensor(c)
|
|||
Td = Tensor([True, True])
|
||||
u_tensor = Tensor([1])
|
||||
u_tensor_error = Tensor([1, 2])
|
||||
t_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8])
|
||||
u_scalar = 5
|
||||
|
||||
|
||||
def test_tensor_assign_bool_index():
|
||||
net1 = TensorAssignWithBoolTensorIndex()
|
||||
net2 = TensorAssignWithBoolTensorIndex2()
|
||||
|
||||
net1(Ta, Tb, Tc, u_tensor, u_scalar)
|
||||
net1(Ta, Tb, Tc, u_tensor, u_scalar)
|
||||
with pytest.raises(ValueError):
|
||||
net1(Ta, Td, Tc, u_tensor, u_scalar)
|
||||
|
@ -180,8 +272,15 @@ def test_tensor_assign_bool_index():
|
|||
with pytest.raises(AttributeError):
|
||||
net4(Ta, u_scalar)
|
||||
|
||||
|
||||
test_cases = [
|
||||
('TensorAssignWithSlice', {
|
||||
'block': TensorAssignWithSlice(),
|
||||
'desc_inputs': [Ta, u_tensor],
|
||||
}),
|
||||
('TensorAssignWithSlice2', {
|
||||
'block': TensorAssignWithSlice2(),
|
||||
'desc_inputs': [t_1d, u_tensor],
|
||||
}),
|
||||
('TensorAssignWithBoolTensorIndex', {
|
||||
'block': TensorAssignWithBoolTensorIndex(),
|
||||
'desc_inputs': [Ta, Tb, Tc, u_tensor, u_scalar],
|
||||
|
|
Loading…
Reference in New Issue