!44592 remove useless api

Merge pull request !44592 from lianliguang/remove-useless-api
This commit is contained in:
i-robot 2022-11-09 12:50:26 +00:00 committed by Gitee
commit c085dc8bc2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
38 changed files with 301 additions and 406 deletions

View File

@ -267,7 +267,6 @@ Reduction函数
mindspore.ops.maximum
mindspore.ops.minimum
mindspore.ops.ne
mindspore.ops.same_type_shape
线性代数函数
^^^^^^^^^^^^^

View File

@ -314,7 +314,6 @@ Reduction算子
:template: classtemplate.rst
mindspore.ops.ApproximateEqual
mindspore.ops.CheckBprop
mindspore.ops.Equal
mindspore.ops.EqualCount
mindspore.ops.Greater
@ -322,15 +321,12 @@ Reduction算子
mindspore.ops.InTopK
mindspore.ops.IsFinite
mindspore.ops.IsInf
mindspore.ops.IsInstance
mindspore.ops.IsNan
mindspore.ops.IsSubClass
mindspore.ops.Less
mindspore.ops.LessEqual
mindspore.ops.Maximum
mindspore.ops.Minimum
mindspore.ops.NotEqual
mindspore.ops.SameTypeShape
mindspore.ops.TopK
线性代数算子

View File

@ -30,7 +30,6 @@ mindspore
mindspore.dtype
mindspore.dtype_to_nptype
mindspore.issubclass_
mindspore.dtype_to_pytype
mindspore.pytype_to_dtype
mindspore.get_py_obj_dtype

View File

@ -1,13 +0,0 @@
mindspore.issubclass\_
=======================
.. py:function:: mindspore.issubclass_(type_, dtype)
判断 `type_` 是否为 `dtype` 的子类。
参数:
- **type_** (mindspore.dtype) - MindSpore中的目标dtype。
- **dtype** (mindspore.dtype) - dtype的比较对象。
返回:
boolTrue或False。

View File

@ -1,16 +0,0 @@
mindspore.ops.IsInstance
=========================
.. py:class:: mindspore.ops.IsInstance
检查输入对象是否为目标类型的实例。
输入:
- **inst** (Any Object) - 要检查的实例。只允许为常量。
- **type_** (mindspore.dtype) - 目标类型。只允许为常量。
输出:
bool检查结果。
异常:
- **TypeError** - 如果 `type_` 不是一种类型。

View File

@ -1,16 +0,0 @@
mindspore.ops.IsSubClass
=========================
.. py:class:: mindspore.ops.IsSubClass
检查输入类型是否为其他类型的子类。
输入:
- **sub_type** (mindspore.dtype) - 要检查的类型。只允许为常量。
- **type_** (mindspore.dtype) - 目标类型。只允许为常量。
输出:
bool检查结果。
异常:
- **TypeError** - 如果 `sub_type``type_` 不是一种类型。

View File

@ -267,7 +267,6 @@ Comparison Functions
mindspore.ops.maximum
mindspore.ops.minimum
mindspore.ops.ne
mindspore.ops.same_type_shape
Linear Algebraic Functions
^^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@ -314,7 +314,6 @@ Comparison Operator
:template: classtemplate.rst
mindspore.ops.ApproximateEqual
mindspore.ops.CheckBprop
mindspore.ops.Equal
mindspore.ops.EqualCount
mindspore.ops.Greater
@ -322,15 +321,12 @@ Comparison Operator
mindspore.ops.InTopK
mindspore.ops.IsFinite
mindspore.ops.IsInf
mindspore.ops.IsInstance
mindspore.ops.IsNan
mindspore.ops.IsSubClass
mindspore.ops.Less
mindspore.ops.LessEqual
mindspore.ops.Maximum
mindspore.ops.Minimum
mindspore.ops.NotEqual
mindspore.ops.SameTypeShape
mindspore.ops.TopK
Linear Algebraic Operator

View File

@ -137,7 +137,6 @@ DataType
:template: classtemplate.rst
mindspore.dtype_to_nptype
mindspore.issubclass_
mindspore.dtype_to_pytype
mindspore.pytype_to_dtype
mindspore.get_py_obj_dtype

View File

@ -443,7 +443,7 @@ void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check
}
// bprop_fg has been checked in caller
auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops");
auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations._inner_ops");
MS_EXCEPTION_IF_NULL(check_bprop_class);
auto check_bprop =
bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared<StringImm>(prim_to_check))});

View File

@ -35,7 +35,7 @@ from ...ops.composite import tail, MultitypeFuncGraph, env_get, hyper_add, \
from ...ops.composite.multitype_ops import _constexpr_utils as const_utils
from ...ops.composite.multitype_ops import _compile_utils as compile_utils
from ...ops.operations.math_ops import Median
from ...ops.operations._inner_ops import Format
from ...ops.operations._inner_ops import Format, issubclass_
from ...ops.operations import _csr_ops
from ...ops.operations import _map_tensor_ops
from ...ops.primitive import constexpr
@ -2643,7 +2643,7 @@ def ge(x, y):
def while_cond(x):
"""For while condition, if the condition is a tensor, the loop will not be unrolled"""
if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
if issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
is_cond = check_is_tensor_bool_cond(F.shape(x))
if is_cond:
return F.cast(x, mstype.bool_)

View File

@ -20,7 +20,7 @@ from mindspore.common.dtype import Type, int8, byte, int16, short, int32, intc,
uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \
float32, single, float64, double, bool_, float_, list_, tuple_, int_, \
uint, number, tensor, string, type_none, tensor_type, Int, \
complex64, complex128, dtype_to_nptype, issubclass_, _null, _null_type, \
complex64, complex128, dtype_to_nptype, _null, _null_type, \
dtype_to_pytype, pytype_to_dtype, get_py_obj_dtype
from mindspore.common.dump import set_dump
from mindspore.common.parameter import Parameter, ParameterTuple
@ -53,7 +53,7 @@ __all__ = [
"Type", "Int", "_null_type",
"complex64", "complex128",
# __method__ from dtype
"dtype_to_nptype", "issubclass_", "dtype_to_pytype",
"dtype_to_nptype", "dtype_to_pytype",
"pytype_to_dtype", "get_py_obj_dtype"
]

View File

@ -47,7 +47,7 @@ __dtype__ = [
]
__method__ = [
"dtype_to_nptype", "issubclass_", "dtype_to_pytype",
"dtype_to_nptype", "dtype_to_pytype",
"pytype_to_dtype", "get_py_obj_dtype"
]
@ -297,20 +297,6 @@ def _issubclass_(type_, dtype):
return typing.is_subclass(type_, dtype)
def issubclass_(type_, dtype):
"""
Determine whether `type_` is a subclass of `dtype`.
Args:
type_ (:class:`mindspore.dtype`): Target MindSpore dtype.
dtype (:class:`mindspore.dtype`): Compare MindSpore dtype.
Returns:
bool, True or False.
"""
logger.warning("'issubclass_' will be deprecated and removed in a future version.")
return _issubclass_(type_, dtype)
def type_size_in_bytes(dtype):
"""

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from mindspore.nn.cell import Cell
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.primitive import Primitive
from mindspore.common import dtype as mstype
from mindspore.common.api import jit
@ -80,7 +81,7 @@ class Jvp(Cell):
self.first_grad_single_value = _FirstGradSingleValue(fn)
self.first_grad_single_value.add_flags(enable_tuple_grad_first=True)
self.second_grad_op = C.GradOperation(sens_param=True)
self.issubclass_ = P.IsSubClass()
self.issubclass_ = inner.IsSubClass()
self.typeof = Primitive('typeof')
self.make_tuple = Primitive('MakeTuple')
self.tuple_len = Primitive("tuple_len")
@ -122,7 +123,7 @@ class _JvpInner(Cell):
self.first_grad_single_value = _JvpFirstGradSingleValue()
self.first_grad_single_value.add_flags(enable_tuple_grad_first=True)
self.second_grad_op = C.GradOperation(sens_param=True)
self.issubclass_ = P.IsSubClass()
self.issubclass_ = inner.IsSubClass()
self.typeof = Primitive('typeof')
self.make_tuple = Primitive('MakeTuple')
self.tuple_len = Primitive("tuple_len")
@ -182,7 +183,7 @@ class Vjp(Cell):
self.fn = fn
self.grad = C.GradOperation(get_all=True, sens_param=True)
self.grad_single_value = C.GradOperation(sens_param=True)
self.issubclass_ = P.IsSubClass()
self.issubclass_ = inner.IsSubClass()
self.typeof = Primitive('typeof')
self.tuple_len = Primitive("tuple_len")

View File

@ -23,6 +23,7 @@ import mindspore.common.dtype as mstype
import mindspore.ops as ops
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore._checkparam import Rel, Validator as validator
@ -281,7 +282,7 @@ class SSIM(Cell):
def construct(self, img1, img2):
_check_input_dtype(F.dtype(img1), "img1", [mstype.float32, mstype.float16], self.cls_name)
_check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name)
P.SameTypeShape()(img1, img2)
inner.SameTypeShape()(img1, img2)
dtype_max_val = _get_dtype_max(F.dtype(img1))
max_val = F.scalar_cast(self.max_val, F.dtype(img1))
max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val)
@ -390,7 +391,7 @@ class MSSSIM(Cell):
_check_input_4d(F.shape(img2), "img2", self.cls_name)
valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.uint8]
_check_input_dtype(F.dtype(img1), 'img1', valid_type, self.cls_name)
P.SameTypeShape()(img1, img2)
inner.SameTypeShape()(img1, img2)
dtype_max_val = _get_dtype_max(F.dtype(img1))
max_val = F.scalar_cast(self.max_val, F.dtype(img1))
max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val)
@ -467,7 +468,7 @@ class PSNR(Cell):
def construct(self, img1, img2):
_check_input_4d(F.shape(img1), "img1", self.cls_name)
_check_input_4d(F.shape(img2), "img2", self.cls_name)
P.SameTypeShape()(img1, img2)
inner.SameTypeShape()(img1, img2)
dtype_max_val = _get_dtype_max(F.dtype(img1))
max_val = F.scalar_cast(self.max_val, F.dtype(img1))
max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val)

View File

@ -23,6 +23,7 @@ from mindspore import log
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations.nn_ops import MultiMarginLoss as MultiMarginLossOp
from mindspore.ops.operations.nn_ops import MultilabelMarginLoss as MultilabelMarginLossOp
from mindspore.ops.operations.nn_ops import TripletMarginLoss as TripletMarginLossOp
@ -1387,7 +1388,7 @@ class CosineEmbeddingLoss(LossBase):
_check_is_tensor('logits_x1', logits_x1, self.cls_name)
_check_is_tensor('logits_x2', logits_x2, self.cls_name)
_check_is_tensor('labels', labels, self.cls_name)
F.same_type_shape(logits_x1, logits_x2)
inner.same_type_shape_(logits_x1, logits_x2)
_check_reduced_shape_valid(F.shape(logits_x1), F.shape(labels), (1,), self.cls_name, "logits_x1", "labels")
# if labels > 0, 1-cosine(logits_x1, logits_x2)
# else, max(0, cosine(logits_x1, logits_x2)-margin)

View File

@ -16,6 +16,7 @@
from mindspore import context
from mindspore.nn.cell import Cell
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common import dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator as validator
@ -96,8 +97,8 @@ class Bijector(Cell):
self.dtype_base = P.DType()
self.shape_base = P.Shape()
self.fill_base = P.Fill()
self.sametypeshape_base = P.SameTypeShape()
self.issubclass_base = P.IsSubClass()
self.sametypeshape_base = inner.SameTypeShape()
self.issubclass_base = inner.IsSubClass()
@property
def name(self):

View File

@ -15,6 +15,7 @@
"""Utility functions to help distribution class."""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common import dtype as mstype
@ -26,7 +27,7 @@ def exp_generic(input_x):
exp = P.Exp()
cast = P.Cast()
dtype = P.DType()
checktype = P.IsSubClass()
checktype = inner.IsSubClass()
if not checktype(dtype(input_x), mstype.float_):
input_x = cast(input_x, mstype.float32)
@ -48,7 +49,7 @@ def log_generic(input_x):
dtype = P.DType()
shape = P.Shape()
select = P.Select()
checktype = P.IsSubClass()
checktype = inner.IsSubClass()
if not checktype(dtype(input_x), mstype.float_):
input_x = cast(input_x, mstype.float32)

View File

@ -18,6 +18,7 @@ from mindspore import context
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops.functional import stop_gradient
from mindspore.ops.operations import _inner_ops as inner
from mindspore._checkparam import Validator
import mindspore.nn as nn
from mindspore.common import dtype as mstype
@ -149,7 +150,7 @@ class Categorical(Distribution):
self.fill = P.Fill()
self.gather = P.GatherNd()
self.greater = P.Greater()
self.issubclass = P.IsSubClass()
self.issubclass = inner.IsSubClass()
self.less = P.Less()
# when the graph kernel mode is enable
# use Log directly as akg will handle the corner cases

View File

@ -17,6 +17,7 @@ from mindspore import context
from mindspore.ops import operations as P
from mindspore.nn.cell import Cell
from mindspore.ops.primitive import constexpr
from mindspore.ops.operations import _inner_ops as inner
from mindspore._checkparam import Validator as validator
from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\
raise_not_implemented_util
@ -115,7 +116,7 @@ class Distribution(Cell):
self.exp_base = exp_generic
self.fill_base = P.Fill()
self.log_base = log_generic
self.sametypeshape_base = P.SameTypeShape()
self.sametypeshape_base = inner.SameTypeShape()
self.sq_base = P.Square()
self.sqrt_base = P.Sqrt()
self.shape_base = P.Shape()

View File

@ -15,6 +15,7 @@
"""Geometric Distribution"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops import composite as C
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
@ -161,7 +162,7 @@ class Geometric(Distribution):
self.dtypeop = P.DType()
self.fill = P.Fill()
self.floor = P.Floor()
self.issubclass = P.IsSubClass()
self.issubclass = inner.IsSubClass()
self.less = P.Less()
self.pow = P.Pow()
self.select = P.Select()

View File

@ -34,6 +34,7 @@ from mindspore.ops._grad.grad_base import dyn_rank, convert_to_tensor, dyn_inver
dyn_fill
from mindspore.ops._grad.grad_base import sum_grad_reduce_axis
from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs
from ..operations._inner_ops import DynamicBroadcastGradientArgs, IsSubClass
reduce_sum = P.ReduceSum()
unsorted_segment_sum = P.UnsortedSegmentSum()
@ -44,7 +45,7 @@ reshape = P.Reshape()
size_op = P.Size()
invert_permutation = P.InvertPermutation()
logical_and = P.LogicalAnd()
is_sub_class = P.IsSubClass()
is_sub_class = IsSubClass()
@bprop_getters.register(P.Fill)

View File

@ -22,7 +22,7 @@ from mindspore.ops import functional as F
from mindspore.communication import get_rank, get_group_size
from mindspore.parallel._utils import _get_enable_parallel_optimizer, _get_grad_accumulation_shard
from mindspore.ops import operations as P
from mindspore.ops.operations._inner_ops import Send, Receive
from mindspore.ops.operations._inner_ops import Send, Receive, issubclass_
from mindspore.common.sparse_tensor import RowTensorInner
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
from mindspore.ops.operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce,
@ -61,7 +61,7 @@ def get_bprop_all_reduce(self):
elif self.op == ReduceOp.SUM:
def bprop(x, out, dout):
if F.issubclass_(F.typeof(dout), mstype.tensor):
if issubclass_(F.typeof(dout), mstype.tensor):
dx = all_reduce_grad(dout)
else:
indices = all_gather(dout.indices)
@ -71,7 +71,7 @@ def get_bprop_all_reduce(self):
else:
def bprop(x, out, dout):
if F.issubclass_(F.typeof(dout), mstype.tensor):
if issubclass_(F.typeof(dout), mstype.tensor):
dx = all_reduce_grad(dout)
z = equal(x, out)
z = cast(z, dtype(dx))
@ -203,14 +203,14 @@ def get_bprop_mirror_micro_step_operator(self):
real_grad = z
assign_out = dout
if mean_flag:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if issubclass_(F.typeof(dout), mstype.tensor):
z = F.depend(z, dout)
real_grad = all_reduce(z)
real_grad = F.tensor_mul(real_grad, scale)
assign(z, real_grad)
assign_out = z
else:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if issubclass_(F.typeof(dout), mstype.tensor):
z = F.depend(z, dout)
real_grad = all_reduce(z)
assign(z, real_grad)
@ -471,7 +471,7 @@ def get_bprop_mirror_operator(self):
if dev_num == 1:
return (dout,)
if mean_flag:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if issubclass_(F.typeof(dout), mstype.tensor):
dx = all_reduce(dout)
float_one = F.scalar_cast(1.0, F.dtype(dx))
num = F.scalar_cast(dev_num, F.dtype(dx))
@ -484,7 +484,7 @@ def get_bprop_mirror_operator(self):
grad = mul(grad, cast(F.scalar_to_tensor(float_one/num), F.dtype(grad)))
dx = RowTensorInner(indices, grad, dout.dense_shape)
else:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if issubclass_(F.typeof(dout), mstype.tensor):
dx = all_reduce(dout)
else:
indices = all_gather(dout.indices)
@ -522,7 +522,7 @@ def get_bprop_mirror_mini_step_operator(self):
def bprop(x, z, out, dout):
if mean_flag:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if issubclass_(F.typeof(dout), mstype.tensor):
if do_mirror:
z = F.depend(z, F.assign_add(z, dout))
real_grad = all_reduce(z)
@ -535,7 +535,7 @@ def get_bprop_mirror_mini_step_operator(self):
else:
dx = zeros_like(x) # The grad accumulation do not support row tensor now
else:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if issubclass_(F.typeof(dout), mstype.tensor):
if do_mirror:
z = F.depend(z, F.assign_add(z, dout))
real_grad = all_reduce(z)
@ -558,14 +558,14 @@ def get_bprop_virtual_div_operator(self):
dtype = P.DType()
def bprop(x, out, dout):
if F.issubclass_(F.typeof(dout), mstype.tensor):
if F.issubclass_(F.dtype(dout), mstype.bool_) or F.issubclass_(F.dtype(dout), mstype.int32) \
or F.issubclass_(F.dtype(dout), mstype.int16):
if issubclass_(F.typeof(dout), mstype.tensor):
if issubclass_(F.dtype(dout), mstype.bool_) or issubclass_(F.dtype(dout), mstype.int32) \
or issubclass_(F.dtype(dout), mstype.int16):
return (dout,)
dx = op(dout, cast(F.scalar_to_tensor(divisor), dtype(dout)))
return (dx,)
if F.issubclass_(F.typeof(dout), mstype.tuple_):
if issubclass_(F.typeof(dout), mstype.tuple_):
dx = ()
input_nums = F.tuple_len(dout)
for i in range(input_nums):

View File

@ -32,7 +32,7 @@ from mindspore.ops._grad.grad_base import sum_grad_reduce_axis, dyn_fill, dyn_ra
from mindspore.ops._grad.grad_base import dyn_ones, dyn_rank_1d
from mindspore.ops.primitive import constexpr
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs, DynamicBroadcastTo
from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs, DynamicBroadcastTo, IsSubClass
from mindspore.ops._utils.utils import is_shape_unknown, is_dim_unknown
shape_op = P.Shape()
@ -41,7 +41,7 @@ reduce_prod = P.ReduceProd()
reduce_sum = P.ReduceSum()
reshape = P.Reshape()
tile = P.Tile()
is_sub_class = P.IsSubClass()
is_sub_class = IsSubClass()
to_array = P.TupleToArray()
real_div = P.RealDiv()

View File

@ -48,6 +48,7 @@ from mindspore.ops.operations.array_ops import Col2Im
from mindspore.ops.operations.array_ops import StridedSliceV2
from mindspore.ops.operations._grad_ops import StridedSliceV2Grad
from mindspore.ops.operations.random_ops import LogNormalReverse
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops._utils.utils import is_shape_unknown
@ -105,7 +106,7 @@ def get_bprop_masked_select(self):
"""Generate bprop for MaskedFill"""
mul_op = P.Mul()
sum_op = P.ReduceSum()
is_instance_op = P.IsInstance()
is_instance_op = inner.IsInstance()
def bprop(input_data, mask, value, out, dout):
mask = F.cast(mask, mstype.float32)

View File

@ -22,6 +22,7 @@ from mindspore.nn import LGamma
from mindspore.ops import functional as F
from mindspore.ops.functional import broadcast_gradient_args
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations.math_ops import Trace, Bernoulli, Renorm
from mindspore import nn, ops, Tensor
from mindspore.ops.operations.math_ops import Real, Imag, Complex, Angle
@ -155,7 +156,7 @@ def get_bprop_index_lerp(self):
"""Generate bprop for Lerp"""
mul_op = P.Mul()
sub_op = P.Sub()
is_instance_op = P.IsInstance()
is_instance_op = inner.IsInstance()
def bprop(start, end, weight, out, dout):
dout = F.cast(dout, mstype.float32)

View File

@ -21,7 +21,8 @@ from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.composite import base
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, DynamicBroadcastTo, TopTypeof
from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, DynamicBroadcastTo, \
TopTypeof, issubclass_
from mindspore.common import dtype as mstype
from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore.common import Tensor, CSRTensor, COOTensor
@ -505,13 +506,13 @@ def get_slice_stride(slice_index, dim_size):
if step is None:
step = const_utils.make_tensor(1)
if F.issubclass_(F.typeof(start), mstype.number):
if issubclass_(F.typeof(start), mstype.number):
start = const_utils.make_tensor(start)
if F.issubclass_(F.typeof(stop), mstype.number):
if issubclass_(F.typeof(stop), mstype.number):
stop = const_utils.make_tensor(stop)
if F.issubclass_(F.typeof(step), mstype.number):
if issubclass_(F.typeof(step), mstype.number):
step = const_utils.make_tensor(step)
return start, stop, step

View File

@ -183,7 +183,6 @@ from .math_func import (
isclose,
hypot,
heaviside,
same_type_shape,
gcd,
log,
log_matrix_determinant,

View File

@ -20,6 +20,7 @@ from mindspore.common import dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.function.math_func import _check_input_dtype, _check_attr_dtype
from mindspore._c_expression import Tensor as Tensor_
@ -143,9 +144,9 @@ def pinv(x, *, atol=None, rtol=None, hermitian=False):
if rtol is None:
rtol = max(ops.shape(x)) * ops.Eps()(Tensor(1.0, x.dtype))
if not ops.IsInstance()(rtol, mstype.tensor):
if not inner.IsInstance()(rtol, mstype.tensor):
rtol = Tensor(rtol, mstype.float32)
if not ops.IsInstance()(atol, mstype.tensor):
if not inner.IsInstance()(atol, mstype.tensor):
atol = Tensor(atol, mstype.float32)
if not hermitian:

View File

@ -152,7 +152,6 @@ bessel_k1e_ = BesselK1e()
equal_ = P.Equal()
isfinite_ = P.IsFinite()
isnan_ = P.IsNan()
same_type_shape_ = P.SameTypeShape()
maximum_ = P.Maximum()
minimum_ = P.Minimum()
lerp_ = P.Lerp()
@ -3203,36 +3202,6 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
return _nan_to_num(x)
def same_type_shape(input_x, input_y):
"""
Checks whether the data type and shape of two tensors are the same.
Args:
input_x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
input_y (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_S)`.
Returns:
Tensor, the shape of tensor is :math:`(x_1, x_2, ..., x_R)`,
if data type and shape of `input_x` and `input_y` are the same.
Raises:
TypeError: If the data types of `input_x` and `input_y` are not the same.
ValueError: If the shapes of `input_x` and `input_y` are not the same.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
>>> input_y = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
>>> output = ops.same_type_shape(input_x, input_y)
>>> print(output)
[[2. 2.]
[2. 2.]]
"""
return same_type_shape_(input_x, input_y)
def maximum(x, y):
"""
Computes the maximum of input tensors element-wise.
@ -7143,7 +7112,6 @@ __all__ = [
'linspace',
'matrix_solve',
'std',
'same_type_shape',
'maximum',
'minimum',
'median',

View File

@ -20,6 +20,7 @@ from math import pi
import mindspore.ops as ops
from mindspore.ops.primitive import constexpr
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops.operations import nn_ops as NN_OPS
from mindspore.ops.operations import image_ops as IMG
from mindspore.ops._utils import is_shape_unknown
@ -3020,8 +3021,8 @@ def margin_ranking_loss(input1, input2, target, margin=0.0, reduction='mean'):
_check_is_tensor('input2', input2, "margin_ranking_loss")
_check_is_tensor('target', target, "margin_ranking_loss")
maximum = P.Maximum()
ops.same_type_shape(input1, input2)
ops.same_type_shape(target, input1)
inner.same_type_shape_(input1, input2)
inner.same_type_shape_(target, input1)
x = maximum(0, -target * (input1 - input2) + margin)
return _get_loss(x, reduction, "margin_ranking_loss")

View File

@ -38,11 +38,8 @@ cast = P.Cast()
dtype = P.DType()
isconstant = Primitive('is_constant')
isconstant.set_const_prim(True)
issubclass_ = P.IsSubClass()
isinstance_ = P.IsInstance()
merge = P.Merge()
geswitch = P.GeSwitch()
check_bprop = P.CheckBprop()
reduce_sum = P.ReduceSum()
reduce_max = P.ReduceMax()
reduce_min = P.ReduceMin()

View File

@ -32,10 +32,10 @@ from ._ms_kernel import (ms_kernel, kernel)
from .array_ops import (ArgMaxWithValue, ArgMinWithValue, Argmax, Argmin, BatchToSpace, BatchToSpaceND,
BatchToSpaceNDV2, BroadcastTo, Cast, Coalesce, Concat, Cummax, DType, DepthToSpace, Diag,
DiagPart, DynamicShape, EditDistance, EmbeddingLookup, ExpandDims, ExtractVolumePatches,
Eye, Fill, Gather, GatherD, GatherNd, GatherV2, Identity, Im2Col, InvertPermutation, IsInstance,
IsSubClass, LowerBound, Lstsq, MaskedFill, MaskedSelect, Meshgrid, Mvlgamma, Ones, OnesLike,
Eye, Fill, Gather, GatherD, GatherNd, GatherV2, Identity, Im2Col, InvertPermutation,
LowerBound, Lstsq, MaskedFill, MaskedSelect, Meshgrid, Mvlgamma, Ones, OnesLike,
Pack, Padding, ParallelConcat, PopulationCount, Range, Rank, Reshape, ResizeNearestNeighbor,
ReverseSequence, ReverseV2, Rint, SameTypeShape, ScalarToTensor, ScatterAdd,
ReverseSequence, ReverseV2, Rint, ScalarToTensor, ScatterAdd,
ScatterDiv, ScatterMax, ScatterMin, ScatterMul, ScatterNd, ScatterNdAdd, ScatterNdDiv,
ScatterNdMax, ScatterNdMin, ScatterNdSub, ScatterNdUpdate, ScatterNonAliasingAdd, ScatterSub,
ScatterUpdate, SearchSorted, Select, Shape, Size, Slice, Sort, SpaceToBatch, SpaceToBatchND,
@ -98,7 +98,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
ApplyAdamWithAmsgrad)
from .other_ops import (Assign, IOU, BartlettWindow, BlackmanWindow, BoundingBoxDecode, BoundingBoxEncode,
ConfusionMatrix, UpdateState, Load,
CheckValid, Partial, Depend, identity, CheckBprop, Push, Pull, PyFunc, _DynamicLossScale)
CheckValid, Partial, Depend, identity, Push, Pull, PyFunc, _DynamicLossScale)
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, RandomGamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
LogUniformCandidateSampler, TruncatedNormal)
@ -183,8 +183,6 @@ __all__ = [
'Einsum',
'Renorm',
'Cast',
'IsSubClass',
'IsInstance',
'Reshape',
'Squeeze',
'Transpose',
@ -306,8 +304,6 @@ __all__ = [
'TupleToArray',
'GeSwitch',
'Merge',
'SameTypeShape',
'CheckBprop',
'CheckValid',
'BartlettWindow',
'BlackmanWindow',

View File

@ -2178,3 +2178,228 @@ class MixedPrecisionCast(Primitive):
return data
return self.hyper_map(cast_inner, x)
class CheckBprop(PrimitiveWithInfer):
"""
Checks whether the data type and the shape of corresponding elements from tuples x and y are the same.
Args:
prim_to_check (str): The name of the primitive being checked. Default: ''.
Inputs:
- **input_x** (tuple[Tensor]) - The `input_x` contains the outputs of bprop to be checked.
- **input_y** (tuple[Tensor]) - The `input_y` contains the inputs of bprop to check against.
Outputs:
Tuple[Tensor], the `input_x`,
if data type and shape of corresponding elements from `input_x` and `input_y` are the same.
Raises:
TypeError: If `input_x` or `input_y` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.op = ops.CheckBprop()
... def construct(self, x, y):
... return self.op(x, y)
...
>>> net = Net()
>>> input_x = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
>>> input_y = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
>>> output = net(input_x, input_y)
>>> print(output)
(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 2.00000000e+00, 2.00000000e+00],
[ 2.00000000e+00, 2.00000000e+00]]),)
"""
@prim_attr_register
def __init__(self, prim_to_check=""):
"""Initialize CheckBprop"""
self.prim_to_check = prim_to_check
def infer_shape(self, xshapes, yshapes):
"""infer shape"""
tips = f"user defined method 'bprop'"
validator.check_value_type('grads', xshapes, (tuple,), tips)
validator.check_value_type('params', yshapes, (tuple,), tips)
if not len(xshapes) == len(yshapes):
raise ValueError(f"For {tips} the number of return values(gradients) must be equal to "
f"the number of input arguments except 'out' and 'dout', "
f"which is:{len(yshapes)} but got {len(xshapes)}.")
checking_range = len(yshapes)
for i in range(checking_range):
xshape = xshapes[i]
yshape = yshapes[i]
if not xshape or not yshape:
continue
if xshape != yshape:
raise ValueError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) "
f"should have the same shape as the {i}th argument, "
f"which is:{yshape}, but got: {xshape}.")
return xshapes
def infer_dtype(self, xdtypes, ydtypes):
"""infer dtype"""
tips = f"user defined method 'bprop'"
validator.check_value_type('grads', xdtypes, (tuple,), tips)
validator.check_value_type('params', ydtypes, (tuple,), tips)
if not len(xdtypes) == len(ydtypes):
raise ValueError(f"For {tips}, the number of return values(gradients) must be equal to "
f"the number of input arguments except 'out' and 'dout', "
f"which is:{len(ydtypes)} but got {len(xdtypes)}.")
checking_range = len(ydtypes)
for i in range(checking_range):
xdtype = xdtypes[i]
ydtype = ydtypes[i]
if isinstance(xdtype, mstype.anything_type) or isinstance(ydtype, mstype.anything_type):
continue
if isinstance(ydtype, mstype.function_type):
if not isinstance(xdtype, mstype.env_type_type):
raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) type "
f"should be {mstype.env_type_type}, but got {xdtype}.")
if xdtype != ydtype:
raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) "
f"should have the same dtype as the {i}th argument, "
f"which is:{ydtype}, but got: {xdtype}.")
return xdtypes
check_bprop = CheckBprop()
class SameTypeShape(PrimitiveWithInfer):
"""
Checks whether the data type and shape of two tensors are the same.
Refer to :func:`mindspore.ops.same_type_shape` for more detail.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
>>> input_y = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
>>> output = ops.SameTypeShape()(input_x, input_y)
>>> print(output)
[[2. 2.]
[2. 2.]]
"""
@prim_attr_register
def __init__(self):
"""Initialize Same"""
def __call__(self, x, y):
"""run in PyNative mode"""
validator.check_value_type('x', x, Tensor, self.name)
validator.check_value_type('y', y, Tensor, self.name)
validator.check('x dtype', x.dtype, 'y dtype', y.dtype, Rel.EQ, self.name, TypeError)
validator.check('x shape', x.shape, 'y shape', y.shape, Rel.EQ, self.name)
return x
def __infer__(self, x, y):
validator.check_subclass('x', x['dtype'], mstype.tensor, self.name)
validator.check_subclass('y', y['dtype'], mstype.tensor, self.name)
validator.check('x dtype', x['dtype'], 'y dtype', y['dtype'], Rel.EQ, self.name, TypeError)
validator.check('x shape', x['shape'], 'y shape', y['shape'], Rel.EQ, self.name)
return x
same_type_shape_ = SameTypeShape()
class IsSubClass(PrimitiveWithInfer):
"""
Checks whether this type is a sub-class of another type.
Inputs:
- **sub_type** (mindspore.dtype) - The type to be checked. Only constant value is allowed.
- **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
Outputs:
bool, the check result.
Raises:
TypeError: If `sub_type` or `type_` is not a Type.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> output = ops.IsSubClass()(mindspore.int32, mindspore.intc)
>>> print(output)
True
"""
@prim_attr_register
def __init__(self):
pass
def __infer__(self, sub_type, type_):
sub_type_t = sub_type['value']
type_v = type_['value']
validator.check_value_type("sub_type", sub_type_t, [mstype.Type], self.name)
validator.check_value_type("type_", type_v, [mstype.Type], self.name)
value = mstype._issubclass_(sub_type_t, type_v) # pylint: disable=W0212
out = {'shape': (),
'dtype': mstype.type_type,
'value': value}
return out
issubclass_ = IsSubClass()
class IsInstance(PrimitiveWithInfer):
"""
Checks whether an object is an instance of a target type.
Inputs:
- **inst** (Any Object) - The instance to be checked. Only constant value is allowed.
- **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
Outputs:
bool, the check result.
Raises:
TypeError: If `type_` is not a Type.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> inst = 1
>>> output = ops.IsInstance()(inst, mindspore.int32)
>>> print(output)
False
"""
@prim_attr_register
def __init__(self):
pass
def __infer__(self, inst, type_):
sub_type_t = inst['dtype']
type_v = type_['value']
validator.check_value_type("type_", type_v, [mstype.Type], self.name)
if type_v == mstype.list_:
value = isinstance(sub_type_t, list)
elif type_v == mstype.tuple_:
value = isinstance(sub_type_t, tuple)
else:
value = mstype._issubclass_(sub_type_t, type_v) # pylint: disable=W0212
out = {'shape': (),
'dtype': mstype.type_type,
'value': value}
return out

View File

@ -285,44 +285,6 @@ class DType(Primitive):
return x.dtype
class SameTypeShape(PrimitiveWithInfer):
"""
Checks whether the data type and shape of two tensors are the same.
Refer to :func:`mindspore.ops.same_type_shape` for more detail.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
>>> input_y = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
>>> output = ops.SameTypeShape()(input_x, input_y)
>>> print(output)
[[2. 2.]
[2. 2.]]
"""
@prim_attr_register
def __init__(self):
"""Initialize Same"""
def __call__(self, x, y):
"""run in PyNative mode"""
validator.check_value_type('x', x, Tensor, self.name)
validator.check_value_type('y', y, Tensor, self.name)
validator.check('x dtype', x.dtype, 'y dtype', y.dtype, Rel.EQ, self.name, TypeError)
validator.check('x shape', x.shape, 'y shape', y.shape, Rel.EQ, self.name)
return x
def __infer__(self, x, y):
validator.check_subclass('x', x['dtype'], mstype.tensor, self.name)
validator.check_subclass('y', y['dtype'], mstype.tensor, self.name)
validator.check('x dtype', x['dtype'], 'y dtype', y['dtype'], Rel.EQ, self.name, TypeError)
validator.check('x shape', x['shape'], 'y shape', y['shape'], Rel.EQ, self.name)
return x
class CheckNumerics(Primitive):
"""
Checks a tensor for NaN and Inf values.
@ -437,95 +399,6 @@ class Cast(PrimitiveWithInfer):
return out
class IsSubClass(PrimitiveWithInfer):
"""
Checks whether this type is a sub-class of another type.
Inputs:
- **sub_type** (mindspore.dtype) - The type to be checked. Only constant value is allowed.
- **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
Outputs:
bool, the check result.
Raises:
TypeError: If `sub_type` or `type_` is not a Type.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> output = ops.IsSubClass()(mindspore.int32, mindspore.intc)
>>> print(output)
True
"""
@prim_attr_register
def __init__(self):
pass
def __infer__(self, sub_type, type_):
sub_type_t = sub_type['value']
type_v = type_['value']
validator.check_value_type("sub_type", sub_type_t, [mstype.Type], self.name)
validator.check_value_type("type_", type_v, [mstype.Type], self.name)
value = mstype._issubclass_(sub_type_t, type_v) # pylint: disable=W0212
out = {'shape': (),
'dtype': mstype.type_type,
'value': value}
return out
class IsInstance(PrimitiveWithInfer):
"""
Checks whether an object is an instance of a target type.
Inputs:
- **inst** (Any Object) - The instance to be checked. Only constant value is allowed.
- **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
Outputs:
bool, the check result.
Raises:
TypeError: If `type_` is not a Type.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> inst = 1
>>> output = ops.IsInstance()(inst, mindspore.int32)
>>> print(output)
False
"""
@prim_attr_register
def __init__(self):
pass
def __infer__(self, inst, type_):
sub_type_t = inst['dtype']
type_v = type_['value']
validator.check_value_type("type_", type_v, [mstype.Type], self.name)
if type_v == mstype.list_:
value = isinstance(sub_type_t, list)
elif type_v == mstype.tuple_:
value = isinstance(sub_type_t, tuple)
else:
value = mstype._issubclass_(sub_type_t, type_v) # pylint: disable=W0212
out = {'shape': (),
'dtype': mstype.type_type,
'value': value}
return out
class Im2Col(Primitive):
r"""
Extracts sliding local blocks from a batched input tensor.

View File

@ -603,98 +603,6 @@ class UpdateState(Primitive):
return state
class CheckBprop(PrimitiveWithInfer):
"""
Checks whether the data type and the shape of corresponding elements from tuples x and y are the same.
Args:
prim_to_check (str): The name of the primitive being checked. Default: ''.
Inputs:
- **input_x** (tuple[Tensor]) - The `input_x` contains the outputs of bprop to be checked.
- **input_y** (tuple[Tensor]) - The `input_y` contains the inputs of bprop to check against.
Outputs:
Tuple[Tensor], the `input_x`,
if data type and shape of corresponding elements from `input_x` and `input_y` are the same.
Raises:
TypeError: If `input_x` or `input_y` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.op = ops.CheckBprop()
... def construct(self, x, y):
... return self.op(x, y)
...
>>> net = Net()
>>> input_x = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
>>> input_y = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
>>> output = net(input_x, input_y)
>>> print(output)
(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 2.00000000e+00, 2.00000000e+00],
[ 2.00000000e+00, 2.00000000e+00]]),)
"""
@prim_attr_register
def __init__(self, prim_to_check=""):
"""Initialize CheckBprop"""
self.prim_to_check = prim_to_check
def infer_shape(self, xshapes, yshapes):
"""infer shape"""
tips = f"user defined method 'bprop'"
validator.check_value_type('grads', xshapes, (tuple,), tips)
validator.check_value_type('params', yshapes, (tuple,), tips)
if not len(xshapes) == len(yshapes):
raise ValueError(f"For {tips} the number of return values(gradients) must be equal to "
f"the number of input arguments except 'out' and 'dout', "
f"which is:{len(yshapes)} but got {len(xshapes)}.")
checking_range = len(yshapes)
for i in range(checking_range):
xshape = xshapes[i]
yshape = yshapes[i]
if not xshape or not yshape:
continue
if xshape != yshape:
raise ValueError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) "
f"should have the same shape as the {i}th argument, "
f"which is:{yshape}, but got: {xshape}.")
return xshapes
def infer_dtype(self, xdtypes, ydtypes):
"""infer dtype"""
tips = f"user defined method 'bprop'"
validator.check_value_type('grads', xdtypes, (tuple,), tips)
validator.check_value_type('params', ydtypes, (tuple,), tips)
if not len(xdtypes) == len(ydtypes):
raise ValueError(f"For {tips}, the number of return values(gradients) must be equal to "
f"the number of input arguments except 'out' and 'dout', "
f"which is:{len(ydtypes)} but got {len(xdtypes)}.")
checking_range = len(ydtypes)
for i in range(checking_range):
xdtype = xdtypes[i]
ydtype = ydtypes[i]
if isinstance(xdtype, mstype.anything_type) or isinstance(ydtype, mstype.anything_type):
continue
if isinstance(ydtype, mstype.function_type):
if not isinstance(xdtype, mstype.env_type_type):
raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) type "
f"should be {mstype.env_type_type}, but got {xdtype}.")
continue
if xdtype != ydtype:
raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) "
f"should have the same dtype as the {i}th argument, "
f"which is:{ydtype}, but got: {xdtype}.")
return xdtypes
class ConfusionMatrix(PrimitiveWithInfer):
r"""
Calculates the confusion matrix from labels and predictions.

View File

@ -19,6 +19,7 @@ import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
@ -36,15 +37,6 @@ class ExpandDimsNet(nn.Cell):
return self.op(x, self.axis)
class IsInstanceNet(nn.Cell):
def __init__(self, inst):
super(IsInstanceNet, self).__init__()
self.inst = inst
self.op = P.IsInstance()
def construct(self, t):
return self.op(self.inst, t)
class ReshapeNet(nn.Cell):
def __init__(self, shape):
@ -86,42 +78,36 @@ raise_set = [
# input x scala, not Tensor
('SameTypeShape0', {
'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
'block': (inner.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input y scala, not Tensor
('SameTypeShape1', {
'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
'block': (inner.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), 5.0],
'skip': ['backward']}),
# type of x and y not match
('SameTypeShape2', {
'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
'block': (inner.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.int32))],
'skip': ['backward']}),
# shape of x and y not match
('SameTypeShape3', {
'block': (P.SameTypeShape(), {'exception': ValueError, 'error_keywords': ['SameTypeShape']}),
'block': (inner.SameTypeShape(), {'exception': ValueError, 'error_keywords': ['SameTypeShape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 3]).astype(np.float32))],
'skip': ['backward']}),
# sub_type is None
('IsSubClass0', {
'block': (P.IsSubClass(), {'exception': TypeError, 'error_keywords': ['IsSubClass']}),
'block': (inner.IsSubClass(), {'exception': TypeError, 'error_keywords': ['IsSubClass']}),
'desc_inputs': [None, mstype.number],
'skip': ['backward']}),
# type_ is None
('IsSubClass1', {
'block': (P.IsSubClass(), {'exception': TypeError, 'error_keywords': ['IsSubClass']}),
'block': (inner.IsSubClass(), {'exception': TypeError, 'error_keywords': ['IsSubClass']}),
'desc_inputs': [mstype.number, None],
'skip': ['backward']}),
# t is not mstype.Type
('IsInstance1', {
'block': (IsInstanceNet(5.0), {'exception': TypeError, 'error_keywords': ['IsInstance']}),
'desc_inputs': [None],
'skip': ['backward']}),
# input x is scalar, not Tensor
('Reshape0', {
'block': (P.Reshape(), {'exception': TypeError, 'error_keywords': ['Reshape']}),

View File

@ -14,11 +14,11 @@
# ============================================================================
""" test_multitype """
import mindspore as ms
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
def test_isinstance():
assert P.IsInstance()([1, 2, 3], ms.list_) is True
assert P.IsInstance()((1, 2, 3), ms.tuple_) is True
assert P.IsInstance()(1.0, ms.float_) is True
assert P.IsInstance()(1, ms.int_) is True
assert inner.IsInstance()([1, 2, 3], ms.list_) is True
assert inner.IsInstance()((1, 2, 3), ms.tuple_) is True
assert inner.IsInstance()(1.0, ms.float_) is True
assert inner.IsInstance()(1, ms.int_) is True