fix api for tensor.item and tensor.itemset

This commit is contained in:
yanglf1121 2021-06-08 14:19:50 +08:00
parent 635c2b0adb
commit 8ef467f310
1 changed files with 51 additions and 4 deletions

View File

@ -344,19 +344,67 @@ class Tensor(Tensor_):
"""
Getitem from the Tensor with the index.
Note:
Tensor.item returns a Tensor scalar instead of a Python scalar.
Args:
index (Union[None, int, tuple(int)]): The index in Tensor. Default: None.
Returns:
Tensor getitem by index whose dtype is int or tuple with int. If index is None,
the `item` API can only convert an array of size 1 to a Python scalar.
A Tensor scalar, dtype is the same with the original Tensor.
Raises:
ValueError: If the length of the `index` is not euqal to self.ndim.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.array([[1,2,3],[4,5,6]], dtype=np.float32))
>>> x = x.item((0,1))
>>> print(x)
2.0
"""
output = tensor_operator_registry.get('item')(self, index)
return output
def itemset(self, *args):
"""Setitem from the Tensor with the index."""
"""
Insert scalar into a tensor (scalar is cast to tensors dtype, if possible).
There must be at least 1 argument, and define the last argument as item.
Then, tensor.itemset(*args) is equivalent to :math:`tensor[args] = item`.
Args:
args (Union[(Number), (int/tuple(int), Number)]): The arguments that
specify the index and value. If `args` contain one argument (a scalar),
it is only used in case tensor is of size 1. If `args` contain two
arguments, the last argument is the value to be set and must be a
scalar, the first argument specifies a single tensor element location.
It is either an int or a tuple.
Returns:
A new Tensor, with value set by :math:`tensor[args] = item`.
Raises:
ValueError: If the length of the first argument is not euqal to self.ndim.
IndexError: If only one argument is provided, and the original Tensor is not scalar.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> x = Tensor(np.array([[1,2,3],[4,5,6]], dtype=np.float32))
>>> x = x.itemset((0,1), 4)
>>> print(x)
[[1. 4. 3.]
[4. 5. 6.]]
"""
output = tensor_operator_registry.get('itemset')(self, *args)
return output
@ -1012,7 +1060,6 @@ class Tensor(Tensor_):
Raises:
TypeError: If input arguments have types not specified above.
ValueError: If `shape` has entries < 0.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``