diff --git a/mindspore/_extends/parse/resources.py b/mindspore/_extends/parse/resources.py index 9fb357597ed..c2c27166973 100644 --- a/mindspore/_extends/parse/resources.py +++ b/mindspore/_extends/parse/resources.py @@ -83,6 +83,7 @@ convert_object_map = { T.mul: multitype_ops.mul, T.truediv: multitype_ops.div, T.getitem: multitype_ops.getitem, + T.setitem: multitype_ops.setitem, T.floordiv: multitype_ops.floordiv, T.mod: multitype_ops.mod, T.pow: multitype_ops.pow_, @@ -118,7 +119,6 @@ convert_object_map = { T.iter: M.ms_iter, T.next: M.ms_next, T.hasnext: M.hasnext, - T.setitem: M.setitem, T.make_tuple: F.make_tuple, T.make_dict: F.make_dict, diff --git a/mindspore/ops/composite/multitype_ops/__init__.py b/mindspore/ops/composite/multitype_ops/__init__.py index 40bf71d49a6..b7f4f671b8d 100644 --- a/mindspore/ops/composite/multitype_ops/__init__.py +++ b/mindspore/ops/composite/multitype_ops/__init__.py @@ -23,6 +23,7 @@ from .pow_impl import pow_ from .floordiv_impl import floordiv from .mod_impl import mod from .getitem_impl import getitem +from .setitem_impl import setitem from .zeros_like_impl import zeros_like from .ones_like_impl import ones_like from .equal_impl import equal @@ -55,6 +56,7 @@ __all__ = [ 'greater_equal', 'negative', 'getitem', + 'setitem', 'logical_and', 'logical_or', 'logical_not' diff --git a/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py b/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py new file mode 100644 index 00000000000..b3687c553c9 --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""constexpr util""" + +from ...primitive import constexpr + + +@constexpr +def is_same_type(inst, type_): + """ + Check whether an object is an instance of a target type. + + Inputs: + inst (mindspore.dtype): Inspected type. + type_ (mindspore.dtype): Target type. + + Outputs: + bool, the check result. + """ + return inst == type_ + + +@constexpr +def error_msg(msg="", format_values=""): + """ + Used to throw exception information. + + Inputs: + msg (str): information content. + """ + + raise ValueError(msg.format(*format_values)) diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py new file mode 100644 index 00000000000..31c96932c5e --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -0,0 +1,194 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Implementation for setitem.""" + +from ...composite import base +from ....common import dtype as mstype +from ... import functional as F +from . import _multitype_ops_util as mult_util + +setitem = base.MultitypeFuncGraph('setitem') + +@setitem.register("List", "Number", "String") +def _list_setitem_with_string(data, number_index, value): + """ + Assign value to list. + + Inputs: + data (list): Data of type lis. + number_index (Number): Index of data. + value (String): Value given. + + Outputs: + List, type is same as the element type of data. + """ + return F.list_setitem(data, number_index, value) + + +@setitem.register("List", "Number", "Number") +def _list_setitem_with_number(data, number_index, value): + """ + Assign value to list. + + Inputs: + data (list): Data of type lis. + number_index (Number): Index of data. + value (Number): Value given. + + Outputs: + List, type is same as the element type of data. + """ + return F.list_setitem(data, number_index, value) + + +@setitem.register("List", "Number", "Tensor") +def _list_setitem_with_Tensor(data, number_index, value): + """ + Assign value to list. + + Inputs: + data (list): Data of type lis. + number_index (Number): Index of data. + value (Tensor): Value given. + + Outputs: + List, type is same as the element type of data. + """ + return F.list_setitem(data, number_index, value) + + +@setitem.register("List", "Number", "List") +def _list_setitem_with_List(data, number_index, value): + """ + Assign value to list. + + Inputs: + data (list): Data of type lis. + number_index (Number): Index of data. + value (List): Value given. + + Outputs: + List, type is same as the element type of data. + """ + return F.list_setitem(data, number_index, value) + + +@setitem.register("Dictionary", "String", "Tensor") +def _dict_setitem_with_tensor(data, key, value): + """ + Assign value to dictionary. + + Inputs: + data (Dictionary): Data of type dict. + key (str): Key of the data. + value (Tensor): Value given. + + Outputs: + Dict, type is as same as the element type of data. + """ + return F.dict_setitem(data, key, value) + + +@setitem.register("Dictionary", "String", "Number") +def _dict_setitem_with_number(data, key, value): + """ + Assign value to dictionary. + + Inputs: + data (Dictionary): Data of type dict. + key (str): Key of the data. + value (Number): Value given. + + Outputs: + Dict, type is as same as the element type of data. + """ + return F.dict_setitem(data, key, value) + + +@setitem.register("Tensor", "Tensor", "Tensor") +def _tensor_setitem_by_tensor_v1(data, index, value_tensor): + """ + Tensor assignment. + + Note: + Syntax support: A[B] = U and A[A>n] = U. + Restraint condition: 1) A, U is a Tensor, and B is a bool Tensor. + 2) A.shape == B.shape + 3) U.size == 1 + 4) n is a number + + Inputs: + data (Tensor): Assigned tensor. + index (Tensor): Tensor of bool type. + value_tensor (Tensor): Tensor with size 1. + + Outputs: + Tensor, element type and shape is same as data. + """ + index_dtype = F.dtype(index) + index_shape = F.shape(index) + is_bool = mult_util.is_same_type(index_dtype, mstype.bool_) + if not is_bool: + return mult_util.error_msg( + "The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,)) + data_shape = F.shape(data) + if index_shape != data_shape: + return mult_util.error_msg( + "The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (data_shape, index_shape)) + size = F.size(value_tensor) + if size != 1: + return mult_util.error_msg( + "When assign value is a tensor, its size should be 1, but current size is {}.", (size,)) + dtype = F.dtype(data) + u_cast = F.cast(value_tensor, dtype) + one_data = F.ones_like(data) + u = F.tensor_mul(one_data, u_cast) + return F.select(index, u, data) + + +@setitem.register("Tensor", "Tensor", "Number") +def _tensor_setitem_by_tensor_v2(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[B] = u and A[A>n] = u. + Restraint condition: 1) A is a Tensor, and B is a bool Tensor. + 2) A.shape == B.shape + 3) u is a scalar + 4) n is a number + + Inputs: + data (Tensor): Assigned tensor. + index (Tensor): Tensor of bool type. + value_tensor (Number): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + index_dtype = F.dtype(index) + index_shape = F.shape(index) + is_bool = mult_util.is_same_type(index_dtype, mstype.bool_) + if not is_bool: + return mult_util.error_msg( + "The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,)) + shape = F.shape(data) + if index_shape != shape: + return mult_util.error_msg( + "The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (shape, index_shape)) + dtype = F.dtype(data) + u = F.fill(dtype, shape, value) + return F.select(index, u, data) diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 611c569553a..0ed750beb13 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -31,6 +31,9 @@ dtype = P.DType() issubclass_ = P.IsSubClass() isinstance_ = P.IsInstance() fill = P.Fill() +select = P.Select() +size = P.Size() +ones_like = P.OnesLike() shape = P.Shape() rank = P.Rank() reshape = P.Reshape() @@ -68,7 +71,9 @@ scalar_cast = P.ScalarCast() tuple_setitem = Primitive('tuple_setitem') tuple_getitem = Primitive('tuple_getitem') list_getitem = Primitive('list_getitem') +list_setitem = Primitive('list_setitem') dict_getitem = Primitive('dict_getitem') +dict_setitem = Primitive('dict_setitem') tuple_div = Primitive("tuple_div") tuple_len = Primitive("tuple_len") tuple_reversed = Primitive("tuple_reversed") diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index 6200d4e163d..a88a2d83222 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -18,6 +18,7 @@ import pytest from mindspore import Tensor from mindspore import context +from mindspore import dtype as mstype from mindspore.nn import Cell from ....mindspore_test_framework.mindspore_test import mindspore_test @@ -79,7 +80,102 @@ class NetWorkReduceToScalar(Cell): return ret +class TensorAssignWithBoolTensorIndex(Cell): + def __init__(self): + super(TensorAssignWithBoolTensorIndex, self).__init__() + self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64) + + def construct(self, a, b, c, u_tensor, _scalar): + a[c] = u_scalar + a[b] = u_tensor + z = a + self.t + return z + + +class TensorAssignWithBoolTensorIndexError(Cell): + def __init__(self): + super(TensorAssignWithBoolTensorIndexError, self).__init__() + + def construct(self, a, b, c, u_tensor): + a[b][c] = u_tensor + return a + + +class TensorAssignWithBoolTensorIndex2(Cell): + def __init__(self): + super(TensorAssignWithBoolTensorIndex2, self).__init__() + self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64) + + def construct(self, a, u_tensor, _scalar): + a[a>8] = u_tensor + a[a>=6] = u_scalar + a[a<3] = u_scalar + a[a<=5] = u_tensor + a[a==5] = u_scalar + z = a + self.t + return z + + +class TensorAssignWithBoolTensorIndex2Error(Cell): + def __init__(self): + super(TensorAssignWithBoolTensorIndex2Error, self).__init__() + + def construct(self, a, u_tensor): + a[a>8][a>5] = u_tensor + return a + + +a = np.random.uniform(1,10,[2,3]) +b = a > 5 +c = a < 3 +Ta = Tensor(a) +Tb = Tensor(b) +Tc = Tensor(c) +Td = Tensor([True, True]) +u_tensor = Tensor([1]) +u_tensor_error = Tensor([1, 2]) +u_scalar = 5 + + +def test_tensor_assign_bool_index(): + net1 = TensorAssignWithBoolTensorIndex() + net2 = TensorAssignWithBoolTensorIndex2() + + net1(Ta, Tb, Tc, u_tensor, u_scalar) + with pytest.raises(ValueError): + net1(Ta, Td, Tc, u_tensor, u_scalar) + with pytest.raises(ValueError): + net1(Ta, u_tensor, Tc, u_tensor, u_scalar) + with pytest.raises(ValueError): + net1(Ta, Tb, Td, u_tensor, u_scalar) + with pytest.raises(ValueError): + net1(Ta, Tb, Ta, u_tensor, u_scalar) + with pytest.raises(ValueError): + net1(Ta, Tb, Tc, u_tensor_error, u_scalar) + #net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar) + with pytest.raises(ValueError): + net2(Ta, u_tensor_error, u_scalar) + net3 = TensorAssignWithBoolTensorIndexError() + with pytest.raises(AttributeError): + net3(Ta, Tb, Tc, u_tensor) + with pytest.raises(AttributeError): + net3(Ta, Tb, Tc, u_scalar) + net4 = TensorAssignWithBoolTensorIndex2Error() + with pytest.raises(AttributeError): + net4(Ta, u_tensor) + with pytest.raises(AttributeError): + net4(Ta, u_scalar) + + test_cases = [ + ('TensorAssignWithBoolTensorIndex', { + 'block': TensorAssignWithBoolTensorIndex(), + 'desc_inputs': [Ta, Tb, Tc, u_tensor, u_scalar], + }), + ('TensorAssignWithBoolTensorIndex2', { + 'block': TensorAssignWithBoolTensorIndex2(), + 'desc_inputs': [Ta, u_tensor, u_scalar], + }), ('SlicePositive', { 'block': NetWorkSlicePositive(), 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],