forked from mindspore-Ecosystem/mindspore
!48152 [API] Add ops.is_tensor
Merge pull request !48152 from shaojunsong/feature/is_tensor
This commit is contained in:
commit
fd6338be75
|
@ -549,6 +549,7 @@ Array操作
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.ops.is_tensor
|
||||
mindspore.ops.scalar_cast
|
||||
mindspore.ops.scalar_to_tensor
|
||||
mindspore.ops.tuple_to_array
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
mindspore.ops.is_tensor
|
||||
========================
|
||||
|
||||
.. py:function:: mindspore.ops.is_tensor(obj)
|
||||
|
||||
判断输入对象是否为 :class:`mindspore.Tensor`。
|
||||
|
||||
参数:
|
||||
- **obj** (Object) - 输入对象。
|
||||
|
||||
返回:
|
||||
Bool,如果 `obj` 是 :class:`mindspore.Tensor`,返回True,否则返回False。
|
|
@ -549,6 +549,7 @@ Type Cast
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.ops.is_tensor
|
||||
mindspore.ops.scalar_cast
|
||||
mindspore.ops.scalar_to_tensor
|
||||
mindspore.ops.tuple_to_array
|
||||
|
|
|
@ -82,6 +82,7 @@ from .array_func import (
|
|||
gather_d,
|
||||
gather_elements,
|
||||
gather_nd,
|
||||
is_tensor,
|
||||
scalar_cast,
|
||||
masked_fill,
|
||||
narrow,
|
||||
|
|
|
@ -4502,6 +4502,28 @@ def population_count(input_x):
|
|||
##############################
|
||||
|
||||
|
||||
def is_tensor(obj):
|
||||
r"""
|
||||
Check whether the input object is a :class:`mindspore.Tensor` .
|
||||
|
||||
Args:
|
||||
obj (Object): input object.
|
||||
|
||||
Returns:
|
||||
Bool. Return True if `obj` is a Tensor, otherwise, return False.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Tensor, ops
|
||||
>>> a = Tensor([1.9, 2.2, 3.1])
|
||||
>>> ops.is_tensor(a)
|
||||
True
|
||||
"""
|
||||
return isinstance(obj, Tensor)
|
||||
|
||||
|
||||
def scalar_cast(input_x, input_y):
|
||||
"""
|
||||
Casts the input scalar to another type.
|
||||
|
@ -6424,6 +6446,7 @@ __all__ = [
|
|||
'stack',
|
||||
'unbind',
|
||||
'unstack',
|
||||
'is_tensor',
|
||||
'scalar_cast',
|
||||
'scalar_to_tensor',
|
||||
'space_to_batch_nd',
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return ops.is_tensor(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_is_tensor(mode):
|
||||
"""
|
||||
Feature: ops.is_tensor
|
||||
Description: Verify the result of ops.is_tensor
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
a = Tensor([1, 2])
|
||||
output1 = net(a)
|
||||
assert output1
|
||||
b = [1, 2]
|
||||
output2 = net(b)
|
||||
assert not output2
|
Loading…
Reference in New Issue