forked from mindspore-Ecosystem/mindspore
fix api for tensor.item and tensor.itemset
This commit is contained in:
parent
635c2b0adb
commit
8ef467f310
|
@ -344,19 +344,67 @@ class Tensor(Tensor_):
|
|||
"""
|
||||
Getitem from the Tensor with the index.
|
||||
|
||||
Note:
|
||||
Tensor.item returns a Tensor scalar instead of a Python scalar.
|
||||
|
||||
Args:
|
||||
index (Union[None, int, tuple(int)]): The index in Tensor. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor getitem by index whose dtype is int or tuple with int. If index is None,
|
||||
the `item` API can only convert an array of size 1 to a Python scalar.
|
||||
A Tensor scalar, dtype is the same with the original Tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: If the length of the `index` is not euqal to self.ndim.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> x = Tensor(np.array([[1,2,3],[4,5,6]], dtype=np.float32))
|
||||
>>> x = x.item((0,1))
|
||||
>>> print(x)
|
||||
2.0
|
||||
|
||||
"""
|
||||
output = tensor_operator_registry.get('item')(self, index)
|
||||
return output
|
||||
|
||||
def itemset(self, *args):
|
||||
"""Setitem from the Tensor with the index."""
|
||||
"""
|
||||
Insert scalar into a tensor (scalar is cast to tensor’s dtype, if possible).
|
||||
|
||||
There must be at least 1 argument, and define the last argument as item.
|
||||
Then, tensor.itemset(*args) is equivalent to :math:`tensor[args] = item`.
|
||||
|
||||
Args:
|
||||
args (Union[(Number), (int/tuple(int), Number)]): The arguments that
|
||||
specify the index and value. If `args` contain one argument (a scalar),
|
||||
it is only used in case tensor is of size 1. If `args` contain two
|
||||
arguments, the last argument is the value to be set and must be a
|
||||
scalar, the first argument specifies a single tensor element location.
|
||||
It is either an int or a tuple.
|
||||
|
||||
Returns:
|
||||
A new Tensor, with value set by :math:`tensor[args] = item`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the length of the first argument is not euqal to self.ndim.
|
||||
IndexError: If only one argument is provided, and the original Tensor is not scalar.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> x = Tensor(np.array([[1,2,3],[4,5,6]], dtype=np.float32))
|
||||
>>> x = x.itemset((0,1), 4)
|
||||
>>> print(x)
|
||||
[[1. 4. 3.]
|
||||
[4. 5. 6.]]
|
||||
"""
|
||||
output = tensor_operator_registry.get('itemset')(self, *args)
|
||||
return output
|
||||
|
||||
|
@ -1012,7 +1060,6 @@ class Tensor(Tensor_):
|
|||
|
||||
Raises:
|
||||
TypeError: If input arguments have types not specified above.
|
||||
ValueError: If `shape` has entries < 0.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
|
Loading…
Reference in New Issue