forked from mindspore-Ecosystem/mindspore
!48152 [API] Add ops.is_tensor
Merge pull request !48152 from shaojunsong/feature/is_tensor
This commit is contained in:
commit
fd6338be75
|
@ -549,6 +549,7 @@ Array操作
|
||||||
:nosignatures:
|
:nosignatures:
|
||||||
:template: classtemplate.rst
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
mindspore.ops.is_tensor
|
||||||
mindspore.ops.scalar_cast
|
mindspore.ops.scalar_cast
|
||||||
mindspore.ops.scalar_to_tensor
|
mindspore.ops.scalar_to_tensor
|
||||||
mindspore.ops.tuple_to_array
|
mindspore.ops.tuple_to_array
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
mindspore.ops.is_tensor
|
||||||
|
========================
|
||||||
|
|
||||||
|
.. py:function:: mindspore.ops.is_tensor(obj)
|
||||||
|
|
||||||
|
判断输入对象是否为 :class:`mindspore.Tensor`。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **obj** (Object) - 输入对象。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Bool,如果 `obj` 是 :class:`mindspore.Tensor`,返回True,否则返回False。
|
|
@ -549,6 +549,7 @@ Type Cast
|
||||||
:nosignatures:
|
:nosignatures:
|
||||||
:template: classtemplate.rst
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
mindspore.ops.is_tensor
|
||||||
mindspore.ops.scalar_cast
|
mindspore.ops.scalar_cast
|
||||||
mindspore.ops.scalar_to_tensor
|
mindspore.ops.scalar_to_tensor
|
||||||
mindspore.ops.tuple_to_array
|
mindspore.ops.tuple_to_array
|
||||||
|
|
|
@ -82,6 +82,7 @@ from .array_func import (
|
||||||
gather_d,
|
gather_d,
|
||||||
gather_elements,
|
gather_elements,
|
||||||
gather_nd,
|
gather_nd,
|
||||||
|
is_tensor,
|
||||||
scalar_cast,
|
scalar_cast,
|
||||||
masked_fill,
|
masked_fill,
|
||||||
narrow,
|
narrow,
|
||||||
|
|
|
@ -4502,6 +4502,28 @@ def population_count(input_x):
|
||||||
##############################
|
##############################
|
||||||
|
|
||||||
|
|
||||||
|
def is_tensor(obj):
|
||||||
|
r"""
|
||||||
|
Check whether the input object is a :class:`mindspore.Tensor` .
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj (Object): input object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Bool. Return True if `obj` is a Tensor, otherwise, return False.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from mindspore import Tensor, ops
|
||||||
|
>>> a = Tensor([1.9, 2.2, 3.1])
|
||||||
|
>>> ops.is_tensor(a)
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
return isinstance(obj, Tensor)
|
||||||
|
|
||||||
|
|
||||||
def scalar_cast(input_x, input_y):
|
def scalar_cast(input_x, input_y):
|
||||||
"""
|
"""
|
||||||
Casts the input scalar to another type.
|
Casts the input scalar to another type.
|
||||||
|
@ -6424,6 +6446,7 @@ __all__ = [
|
||||||
'stack',
|
'stack',
|
||||||
'unbind',
|
'unbind',
|
||||||
'unstack',
|
'unstack',
|
||||||
|
'is_tensor',
|
||||||
'scalar_cast',
|
'scalar_cast',
|
||||||
'scalar_to_tensor',
|
'scalar_to_tensor',
|
||||||
'space_to_batch_nd',
|
'space_to_batch_nd',
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.ops as ops
|
||||||
|
from mindspore import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def construct(self, x):
|
||||||
|
return ops.is_tensor(x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_cpu
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||||
|
def test_is_tensor(mode):
|
||||||
|
"""
|
||||||
|
Feature: ops.is_tensor
|
||||||
|
Description: Verify the result of ops.is_tensor
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
ms.set_context(mode=mode)
|
||||||
|
net = Net()
|
||||||
|
a = Tensor([1, 2])
|
||||||
|
output1 = net(a)
|
||||||
|
assert output1
|
||||||
|
b = [1, 2]
|
||||||
|
output2 = net(b)
|
||||||
|
assert not output2
|
Loading…
Reference in New Issue