forked from mindspore-Ecosystem/mindspore
fix api for tensor.item and tensor.itemset
This commit is contained in:
parent
635c2b0adb
commit
8ef467f310
|
@ -344,19 +344,67 @@ class Tensor(Tensor_):
|
||||||
"""
|
"""
|
||||||
Getitem from the Tensor with the index.
|
Getitem from the Tensor with the index.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Tensor.item returns a Tensor scalar instead of a Python scalar.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
index (Union[None, int, tuple(int)]): The index in Tensor. Default: None.
|
index (Union[None, int, tuple(int)]): The index in Tensor. Default: None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor getitem by index whose dtype is int or tuple with int. If index is None,
|
A Tensor scalar, dtype is the same with the original Tensor.
|
||||||
the `item` API can only convert an array of size 1 to a Python scalar.
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the length of the `index` is not euqal to self.ndim.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import numpy as np
|
||||||
|
>>> from mindspore import Tensor
|
||||||
|
>>> x = Tensor(np.array([[1,2,3],[4,5,6]], dtype=np.float32))
|
||||||
|
>>> x = x.item((0,1))
|
||||||
|
>>> print(x)
|
||||||
|
2.0
|
||||||
|
|
||||||
"""
|
"""
|
||||||
output = tensor_operator_registry.get('item')(self, index)
|
output = tensor_operator_registry.get('item')(self, index)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def itemset(self, *args):
|
def itemset(self, *args):
|
||||||
"""Setitem from the Tensor with the index."""
|
"""
|
||||||
|
Insert scalar into a tensor (scalar is cast to tensor’s dtype, if possible).
|
||||||
|
|
||||||
|
There must be at least 1 argument, and define the last argument as item.
|
||||||
|
Then, tensor.itemset(*args) is equivalent to :math:`tensor[args] = item`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (Union[(Number), (int/tuple(int), Number)]): The arguments that
|
||||||
|
specify the index and value. If `args` contain one argument (a scalar),
|
||||||
|
it is only used in case tensor is of size 1. If `args` contain two
|
||||||
|
arguments, the last argument is the value to be set and must be a
|
||||||
|
scalar, the first argument specifies a single tensor element location.
|
||||||
|
It is either an int or a tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new Tensor, with value set by :math:`tensor[args] = item`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the length of the first argument is not euqal to self.ndim.
|
||||||
|
IndexError: If only one argument is provided, and the original Tensor is not scalar.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import numpy as np
|
||||||
|
>>> from mindspore import Tensor
|
||||||
|
>>> x = Tensor(np.array([[1,2,3],[4,5,6]], dtype=np.float32))
|
||||||
|
>>> x = x.itemset((0,1), 4)
|
||||||
|
>>> print(x)
|
||||||
|
[[1. 4. 3.]
|
||||||
|
[4. 5. 6.]]
|
||||||
|
"""
|
||||||
output = tensor_operator_registry.get('itemset')(self, *args)
|
output = tensor_operator_registry.get('itemset')(self, *args)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -1012,7 +1060,6 @@ class Tensor(Tensor_):
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If input arguments have types not specified above.
|
TypeError: If input arguments have types not specified above.
|
||||||
ValueError: If `shape` has entries < 0.
|
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU`` ``CPU``
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
Loading…
Reference in New Issue