!48152 [API] Add ops.is_tensor

Merge pull request !48152 from shaojunsong/feature/is_tensor
This commit is contained in:
i-robot 2023-01-31 02:47:29 +00:00 committed by Gitee
commit fd6338be75
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 88 additions and 0 deletions

View File

@ -549,6 +549,7 @@ Array操作
:nosignatures: :nosignatures:
:template: classtemplate.rst :template: classtemplate.rst
mindspore.ops.is_tensor
mindspore.ops.scalar_cast mindspore.ops.scalar_cast
mindspore.ops.scalar_to_tensor mindspore.ops.scalar_to_tensor
mindspore.ops.tuple_to_array mindspore.ops.tuple_to_array

View File

@ -0,0 +1,12 @@
mindspore.ops.is_tensor
========================
.. py:function:: mindspore.ops.is_tensor(obj)
判断输入对象是否为 :class:`mindspore.Tensor`
参数:
- **obj** (Object) - 输入对象。
返回:
Bool如果 `obj`:class:`mindspore.Tensor`返回True否则返回False。

View File

@ -549,6 +549,7 @@ Type Cast
:nosignatures: :nosignatures:
:template: classtemplate.rst :template: classtemplate.rst
mindspore.ops.is_tensor
mindspore.ops.scalar_cast mindspore.ops.scalar_cast
mindspore.ops.scalar_to_tensor mindspore.ops.scalar_to_tensor
mindspore.ops.tuple_to_array mindspore.ops.tuple_to_array

View File

@ -82,6 +82,7 @@ from .array_func import (
gather_d, gather_d,
gather_elements, gather_elements,
gather_nd, gather_nd,
is_tensor,
scalar_cast, scalar_cast,
masked_fill, masked_fill,
narrow, narrow,

View File

@ -4502,6 +4502,28 @@ def population_count(input_x):
############################## ##############################
def is_tensor(obj):
r"""
Check whether the input object is a :class:`mindspore.Tensor` .
Args:
obj (Object): input object.
Returns:
Bool. Return True if `obj` is a Tensor, otherwise, return False.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore import Tensor, ops
>>> a = Tensor([1.9, 2.2, 3.1])
>>> ops.is_tensor(a)
True
"""
return isinstance(obj, Tensor)
def scalar_cast(input_x, input_y): def scalar_cast(input_x, input_y):
""" """
Casts the input scalar to another type. Casts the input scalar to another type.
@ -6424,6 +6446,7 @@ __all__ = [
'stack', 'stack',
'unbind', 'unbind',
'unstack', 'unstack',
'is_tensor',
'scalar_cast', 'scalar_cast',
'scalar_to_tensor', 'scalar_to_tensor',
'space_to_batch_nd', 'space_to_batch_nd',

View File

@ -0,0 +1,50 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
class Net(nn.Cell):
def construct(self, x):
return ops.is_tensor(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_is_tensor(mode):
"""
Feature: ops.is_tensor
Description: Verify the result of ops.is_tensor
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
a = Tensor([1, 2])
output1 = net(a)
assert output1
b = [1, 2]
output2 = net(b)
assert not output2