forked from mindspore-Ecosystem/mindspore
Tensor assign syntax:
1) A[B]=U 2) A[A>n]=U A.shape == B.shape U is a scalar or Tensor(size==1) B is Tensor(dtype=bool) n is a Number Signed-off-by: candanzg <zhangshucheng@huawei.com>
This commit is contained in:
parent
2d31ae97e8
commit
3f087dba1a
|
@ -83,6 +83,7 @@ convert_object_map = {
|
||||||
T.mul: multitype_ops.mul,
|
T.mul: multitype_ops.mul,
|
||||||
T.truediv: multitype_ops.div,
|
T.truediv: multitype_ops.div,
|
||||||
T.getitem: multitype_ops.getitem,
|
T.getitem: multitype_ops.getitem,
|
||||||
|
T.setitem: multitype_ops.setitem,
|
||||||
T.floordiv: multitype_ops.floordiv,
|
T.floordiv: multitype_ops.floordiv,
|
||||||
T.mod: multitype_ops.mod,
|
T.mod: multitype_ops.mod,
|
||||||
T.pow: multitype_ops.pow_,
|
T.pow: multitype_ops.pow_,
|
||||||
|
@ -118,7 +119,6 @@ convert_object_map = {
|
||||||
T.iter: M.ms_iter,
|
T.iter: M.ms_iter,
|
||||||
T.next: M.ms_next,
|
T.next: M.ms_next,
|
||||||
T.hasnext: M.hasnext,
|
T.hasnext: M.hasnext,
|
||||||
T.setitem: M.setitem,
|
|
||||||
|
|
||||||
T.make_tuple: F.make_tuple,
|
T.make_tuple: F.make_tuple,
|
||||||
T.make_dict: F.make_dict,
|
T.make_dict: F.make_dict,
|
||||||
|
|
|
@ -23,6 +23,7 @@ from .pow_impl import pow_
|
||||||
from .floordiv_impl import floordiv
|
from .floordiv_impl import floordiv
|
||||||
from .mod_impl import mod
|
from .mod_impl import mod
|
||||||
from .getitem_impl import getitem
|
from .getitem_impl import getitem
|
||||||
|
from .setitem_impl import setitem
|
||||||
from .zeros_like_impl import zeros_like
|
from .zeros_like_impl import zeros_like
|
||||||
from .ones_like_impl import ones_like
|
from .ones_like_impl import ones_like
|
||||||
from .equal_impl import equal
|
from .equal_impl import equal
|
||||||
|
@ -55,6 +56,7 @@ __all__ = [
|
||||||
'greater_equal',
|
'greater_equal',
|
||||||
'negative',
|
'negative',
|
||||||
'getitem',
|
'getitem',
|
||||||
|
'setitem',
|
||||||
'logical_and',
|
'logical_and',
|
||||||
'logical_or',
|
'logical_or',
|
||||||
'logical_not'
|
'logical_not'
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""constexpr util"""
|
||||||
|
|
||||||
|
from ...primitive import constexpr
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def is_same_type(inst, type_):
|
||||||
|
"""
|
||||||
|
Check whether an object is an instance of a target type.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
inst (mindspore.dtype): Inspected type.
|
||||||
|
type_ (mindspore.dtype): Target type.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
bool, the check result.
|
||||||
|
"""
|
||||||
|
return inst == type_
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def error_msg(msg="", format_values=""):
|
||||||
|
"""
|
||||||
|
Used to throw exception information.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
msg (str): information content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise ValueError(msg.format(*format_values))
|
|
@ -0,0 +1,194 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Implementation for setitem."""
|
||||||
|
|
||||||
|
from ...composite import base
|
||||||
|
from ....common import dtype as mstype
|
||||||
|
from ... import functional as F
|
||||||
|
from . import _multitype_ops_util as mult_util
|
||||||
|
|
||||||
|
setitem = base.MultitypeFuncGraph('setitem')
|
||||||
|
|
||||||
|
@setitem.register("List", "Number", "String")
|
||||||
|
def _list_setitem_with_string(data, number_index, value):
|
||||||
|
"""
|
||||||
|
Assign value to list.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (list): Data of type lis.
|
||||||
|
number_index (Number): Index of data.
|
||||||
|
value (String): Value given.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
List, type is same as the element type of data.
|
||||||
|
"""
|
||||||
|
return F.list_setitem(data, number_index, value)
|
||||||
|
|
||||||
|
|
||||||
|
@setitem.register("List", "Number", "Number")
|
||||||
|
def _list_setitem_with_number(data, number_index, value):
|
||||||
|
"""
|
||||||
|
Assign value to list.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (list): Data of type lis.
|
||||||
|
number_index (Number): Index of data.
|
||||||
|
value (Number): Value given.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
List, type is same as the element type of data.
|
||||||
|
"""
|
||||||
|
return F.list_setitem(data, number_index, value)
|
||||||
|
|
||||||
|
|
||||||
|
@setitem.register("List", "Number", "Tensor")
|
||||||
|
def _list_setitem_with_Tensor(data, number_index, value):
|
||||||
|
"""
|
||||||
|
Assign value to list.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (list): Data of type lis.
|
||||||
|
number_index (Number): Index of data.
|
||||||
|
value (Tensor): Value given.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
List, type is same as the element type of data.
|
||||||
|
"""
|
||||||
|
return F.list_setitem(data, number_index, value)
|
||||||
|
|
||||||
|
|
||||||
|
@setitem.register("List", "Number", "List")
|
||||||
|
def _list_setitem_with_List(data, number_index, value):
|
||||||
|
"""
|
||||||
|
Assign value to list.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (list): Data of type lis.
|
||||||
|
number_index (Number): Index of data.
|
||||||
|
value (List): Value given.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
List, type is same as the element type of data.
|
||||||
|
"""
|
||||||
|
return F.list_setitem(data, number_index, value)
|
||||||
|
|
||||||
|
|
||||||
|
@setitem.register("Dictionary", "String", "Tensor")
|
||||||
|
def _dict_setitem_with_tensor(data, key, value):
|
||||||
|
"""
|
||||||
|
Assign value to dictionary.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (Dictionary): Data of type dict.
|
||||||
|
key (str): Key of the data.
|
||||||
|
value (Tensor): Value given.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Dict, type is as same as the element type of data.
|
||||||
|
"""
|
||||||
|
return F.dict_setitem(data, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
@setitem.register("Dictionary", "String", "Number")
|
||||||
|
def _dict_setitem_with_number(data, key, value):
|
||||||
|
"""
|
||||||
|
Assign value to dictionary.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (Dictionary): Data of type dict.
|
||||||
|
key (str): Key of the data.
|
||||||
|
value (Number): Value given.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Dict, type is as same as the element type of data.
|
||||||
|
"""
|
||||||
|
return F.dict_setitem(data, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
@setitem.register("Tensor", "Tensor", "Tensor")
|
||||||
|
def _tensor_setitem_by_tensor_v1(data, index, value_tensor):
|
||||||
|
"""
|
||||||
|
Tensor assignment.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Syntax support: A[B] = U and A[A>n] = U.
|
||||||
|
Restraint condition: 1) A, U is a Tensor, and B is a bool Tensor.
|
||||||
|
2) A.shape == B.shape
|
||||||
|
3) U.size == 1
|
||||||
|
4) n is a number
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (Tensor): Assigned tensor.
|
||||||
|
index (Tensor): Tensor of bool type.
|
||||||
|
value_tensor (Tensor): Tensor with size 1.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, element type and shape is same as data.
|
||||||
|
"""
|
||||||
|
index_dtype = F.dtype(index)
|
||||||
|
index_shape = F.shape(index)
|
||||||
|
is_bool = mult_util.is_same_type(index_dtype, mstype.bool_)
|
||||||
|
if not is_bool:
|
||||||
|
return mult_util.error_msg(
|
||||||
|
"The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,))
|
||||||
|
data_shape = F.shape(data)
|
||||||
|
if index_shape != data_shape:
|
||||||
|
return mult_util.error_msg(
|
||||||
|
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (data_shape, index_shape))
|
||||||
|
size = F.size(value_tensor)
|
||||||
|
if size != 1:
|
||||||
|
return mult_util.error_msg(
|
||||||
|
"When assign value is a tensor, its size should be 1, but current size is {}.", (size,))
|
||||||
|
dtype = F.dtype(data)
|
||||||
|
u_cast = F.cast(value_tensor, dtype)
|
||||||
|
one_data = F.ones_like(data)
|
||||||
|
u = F.tensor_mul(one_data, u_cast)
|
||||||
|
return F.select(index, u, data)
|
||||||
|
|
||||||
|
|
||||||
|
@setitem.register("Tensor", "Tensor", "Number")
|
||||||
|
def _tensor_setitem_by_tensor_v2(data, index, value):
|
||||||
|
"""
|
||||||
|
Tensor assignment.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Syntax support: A[B] = u and A[A>n] = u.
|
||||||
|
Restraint condition: 1) A is a Tensor, and B is a bool Tensor.
|
||||||
|
2) A.shape == B.shape
|
||||||
|
3) u is a scalar
|
||||||
|
4) n is a number
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (Tensor): Assigned tensor.
|
||||||
|
index (Tensor): Tensor of bool type.
|
||||||
|
value_tensor (Number): Assignment value.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, element type and shape is same as data.
|
||||||
|
"""
|
||||||
|
index_dtype = F.dtype(index)
|
||||||
|
index_shape = F.shape(index)
|
||||||
|
is_bool = mult_util.is_same_type(index_dtype, mstype.bool_)
|
||||||
|
if not is_bool:
|
||||||
|
return mult_util.error_msg(
|
||||||
|
"The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,))
|
||||||
|
shape = F.shape(data)
|
||||||
|
if index_shape != shape:
|
||||||
|
return mult_util.error_msg(
|
||||||
|
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (shape, index_shape))
|
||||||
|
dtype = F.dtype(data)
|
||||||
|
u = F.fill(dtype, shape, value)
|
||||||
|
return F.select(index, u, data)
|
|
@ -31,6 +31,9 @@ dtype = P.DType()
|
||||||
issubclass_ = P.IsSubClass()
|
issubclass_ = P.IsSubClass()
|
||||||
isinstance_ = P.IsInstance()
|
isinstance_ = P.IsInstance()
|
||||||
fill = P.Fill()
|
fill = P.Fill()
|
||||||
|
select = P.Select()
|
||||||
|
size = P.Size()
|
||||||
|
ones_like = P.OnesLike()
|
||||||
shape = P.Shape()
|
shape = P.Shape()
|
||||||
rank = P.Rank()
|
rank = P.Rank()
|
||||||
reshape = P.Reshape()
|
reshape = P.Reshape()
|
||||||
|
@ -68,7 +71,9 @@ scalar_cast = P.ScalarCast()
|
||||||
tuple_setitem = Primitive('tuple_setitem')
|
tuple_setitem = Primitive('tuple_setitem')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive('tuple_getitem')
|
||||||
list_getitem = Primitive('list_getitem')
|
list_getitem = Primitive('list_getitem')
|
||||||
|
list_setitem = Primitive('list_setitem')
|
||||||
dict_getitem = Primitive('dict_getitem')
|
dict_getitem = Primitive('dict_getitem')
|
||||||
|
dict_setitem = Primitive('dict_setitem')
|
||||||
tuple_div = Primitive("tuple_div")
|
tuple_div = Primitive("tuple_div")
|
||||||
tuple_len = Primitive("tuple_len")
|
tuple_len = Primitive("tuple_len")
|
||||||
tuple_reversed = Primitive("tuple_reversed")
|
tuple_reversed = Primitive("tuple_reversed")
|
||||||
|
|
|
@ -18,6 +18,7 @@ import pytest
|
||||||
|
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
from mindspore import dtype as mstype
|
||||||
from mindspore.nn import Cell
|
from mindspore.nn import Cell
|
||||||
|
|
||||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||||
|
@ -79,7 +80,102 @@ class NetWorkReduceToScalar(Cell):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TensorAssignWithBoolTensorIndex(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TensorAssignWithBoolTensorIndex, self).__init__()
|
||||||
|
self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64)
|
||||||
|
|
||||||
|
def construct(self, a, b, c, u_tensor, _scalar):
|
||||||
|
a[c] = u_scalar
|
||||||
|
a[b] = u_tensor
|
||||||
|
z = a + self.t
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
class TensorAssignWithBoolTensorIndexError(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TensorAssignWithBoolTensorIndexError, self).__init__()
|
||||||
|
|
||||||
|
def construct(self, a, b, c, u_tensor):
|
||||||
|
a[b][c] = u_tensor
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
class TensorAssignWithBoolTensorIndex2(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TensorAssignWithBoolTensorIndex2, self).__init__()
|
||||||
|
self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64)
|
||||||
|
|
||||||
|
def construct(self, a, u_tensor, _scalar):
|
||||||
|
a[a>8] = u_tensor
|
||||||
|
a[a>=6] = u_scalar
|
||||||
|
a[a<3] = u_scalar
|
||||||
|
a[a<=5] = u_tensor
|
||||||
|
a[a==5] = u_scalar
|
||||||
|
z = a + self.t
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
class TensorAssignWithBoolTensorIndex2Error(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TensorAssignWithBoolTensorIndex2Error, self).__init__()
|
||||||
|
|
||||||
|
def construct(self, a, u_tensor):
|
||||||
|
a[a>8][a>5] = u_tensor
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
a = np.random.uniform(1,10,[2,3])
|
||||||
|
b = a > 5
|
||||||
|
c = a < 3
|
||||||
|
Ta = Tensor(a)
|
||||||
|
Tb = Tensor(b)
|
||||||
|
Tc = Tensor(c)
|
||||||
|
Td = Tensor([True, True])
|
||||||
|
u_tensor = Tensor([1])
|
||||||
|
u_tensor_error = Tensor([1, 2])
|
||||||
|
u_scalar = 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_assign_bool_index():
|
||||||
|
net1 = TensorAssignWithBoolTensorIndex()
|
||||||
|
net2 = TensorAssignWithBoolTensorIndex2()
|
||||||
|
|
||||||
|
net1(Ta, Tb, Tc, u_tensor, u_scalar)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
net1(Ta, Td, Tc, u_tensor, u_scalar)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
net1(Ta, u_tensor, Tc, u_tensor, u_scalar)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
net1(Ta, Tb, Td, u_tensor, u_scalar)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
net1(Ta, Tb, Ta, u_tensor, u_scalar)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
net1(Ta, Tb, Tc, u_tensor_error, u_scalar)
|
||||||
|
#net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
net2(Ta, u_tensor_error, u_scalar)
|
||||||
|
net3 = TensorAssignWithBoolTensorIndexError()
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
net3(Ta, Tb, Tc, u_tensor)
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
net3(Ta, Tb, Tc, u_scalar)
|
||||||
|
net4 = TensorAssignWithBoolTensorIndex2Error()
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
net4(Ta, u_tensor)
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
net4(Ta, u_scalar)
|
||||||
|
|
||||||
|
|
||||||
test_cases = [
|
test_cases = [
|
||||||
|
('TensorAssignWithBoolTensorIndex', {
|
||||||
|
'block': TensorAssignWithBoolTensorIndex(),
|
||||||
|
'desc_inputs': [Ta, Tb, Tc, u_tensor, u_scalar],
|
||||||
|
}),
|
||||||
|
('TensorAssignWithBoolTensorIndex2', {
|
||||||
|
'block': TensorAssignWithBoolTensorIndex2(),
|
||||||
|
'desc_inputs': [Ta, u_tensor, u_scalar],
|
||||||
|
}),
|
||||||
('SlicePositive', {
|
('SlicePositive', {
|
||||||
'block': NetWorkSlicePositive(),
|
'block': NetWorkSlicePositive(),
|
||||||
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
|
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
|
||||||
|
|
Loading…
Reference in New Issue