forked from mindspore-Ecosystem/mindspore
keep consistent in Graph mode and PyNative mode for 'isinstance'
This commit is contained in:
parent
f31dfa129a
commit
6d395c1d3f
|
@ -17,14 +17,14 @@
|
|||
"""Resources for ast tree parse."""
|
||||
import ast
|
||||
import math
|
||||
|
||||
from mindspore import RowTensor, SparseTensor
|
||||
from mindspore.ops.composite import multitype_ops
|
||||
from mindspore.ops import functional as F, composite as C
|
||||
from mindspore.ops.composite import multitype_ops
|
||||
from . import standard_method as M
|
||||
from . import trope as T
|
||||
from .namespace import CellNamespace
|
||||
|
||||
|
||||
# namespace define
|
||||
functional_ns = CellNamespace('mindspore.ops.functional')
|
||||
composite_ns = CellNamespace('mindspore.ops.composite')
|
||||
|
@ -109,7 +109,7 @@ convert_object_map = {
|
|||
|
||||
# system function
|
||||
T.len: M.ms_len,
|
||||
T.bool: M.bool_,
|
||||
T.bool_: M.bool_,
|
||||
T.map: C.Map(),
|
||||
T.partial: F.partial,
|
||||
T.zip: C.zip_operation,
|
||||
|
|
|
@ -16,13 +16,15 @@
|
|||
# ============================================================================
|
||||
"""standard_method"""
|
||||
from dataclasses import dataclass
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype as mstype
|
||||
from ...ops import functional as F
|
||||
from ...ops import operations as P
|
||||
from ...ops.primitive import constexpr
|
||||
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
|
||||
zeros_like, ones_like
|
||||
from ...ops.composite.base import _append
|
||||
from ...ops.primitive import constexpr
|
||||
|
||||
__all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
|
||||
|
||||
|
@ -219,9 +221,23 @@ def while_cond(x):
|
|||
@constexpr
|
||||
def check_type_same(x_type, base_type):
|
||||
"""Check x_type is same as base_type."""
|
||||
if mstype.issubclass_(x_type, base_type):
|
||||
return True
|
||||
return False
|
||||
pytype_to_mstype = {
|
||||
bool: mstype.Bool,
|
||||
int: mstype.Int,
|
||||
float: mstype.Float,
|
||||
str: mstype.String,
|
||||
list: mstype.List,
|
||||
tuple: mstype.Tuple,
|
||||
Tensor: mstype.tensor_type
|
||||
}
|
||||
try:
|
||||
if isinstance(base_type, (tuple, list)):
|
||||
target_type = tuple(pytype_to_mstype[i] for i in base_type)
|
||||
else:
|
||||
target_type = pytype_to_mstype[base_type]
|
||||
return isinstance(x_type, target_type)
|
||||
except KeyError:
|
||||
raise TypeError(f"The type '{base_type}' is not supported for 'isinstance'")
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -235,7 +251,7 @@ def check_is_tensor(x):
|
|||
@constexpr
|
||||
def check_is_tuple_or_list_or_tensor(x, op_name, arg_name):
|
||||
"""check whether x is list or tuple or tensor."""
|
||||
if isinstance(x, (mstype.list_type, mstype.tuple_type, mstype.tensor_type)):
|
||||
if isinstance(x, (mstype.List, mstype.Tuple, mstype.tensor_type)):
|
||||
return True
|
||||
raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list or tensor, but got {x}.")
|
||||
|
||||
|
|
|
@ -95,3 +95,7 @@ def not_contains(x): # pragma: no cover
|
|||
def while_cond(x): # pragma: no cover
|
||||
"""Not in function."""
|
||||
raise RuntimeError('This operation is not meant to be called directly.')
|
||||
|
||||
def bool_(x): # pragma: no cover
|
||||
"""judge true function."""
|
||||
raise RuntimeError('This operation is not meant to be called directly.')
|
||||
|
|
|
@ -116,7 +116,7 @@ const char NAMED_PRIMITIVE_NEXT[] = "next";
|
|||
const char NAMED_PRIMITIVE_GETITEM[] = "getitem";
|
||||
const char NAMED_PRIMITIVE_SETITEM[] = "setitem";
|
||||
const char NAMED_PRIMITIVE_HASNEXT[] = "hasnext";
|
||||
const char NAMED_PRIMITIVE_BOOL[] = "bool"; // bool: P.identity
|
||||
const char NAMED_PRIMITIVE_BOOL[] = "bool_"; // bool: P.identity
|
||||
const char NAMED_PRIMITIVE_MAKETUPLE[] = "make_tuple";
|
||||
const char NAMED_PRIMITIVE_MAKELIST[] = "make_list";
|
||||
const char NAMED_PRIMITIVE_MAKESLICE[] = "make_slice";
|
||||
|
|
|
@ -109,6 +109,7 @@ class ClassType : public PyObjectWrapper {
|
|||
MS_DECLARE_PARENT(ClassType, PyObjectWrapper);
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
};
|
||||
using ClassTypePtr = std::shared_ptr<ClassType>;
|
||||
|
||||
// SymbolResolver class for resolving symbol extracted from AnfNode.
|
||||
class SymbolResolver {
|
||||
|
|
|
@ -280,24 +280,20 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
|
|||
py::tuple max_shape_tuple(len);
|
||||
auto dic = py::dict();
|
||||
bool dyn_shape = false;
|
||||
bool is_build_value = true;
|
||||
bool dyn_value = false;
|
||||
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
auto arg = arg_tuple->elements()[i];
|
||||
py::dict out = ConvertAbstractToPython(arg);
|
||||
shape_tuple[i] = out[ATTR_SHAPE];
|
||||
dtype_tuple[i] = out[ATTR_DTYPE];
|
||||
value_tuple[i] = out[ATTR_VALUE];
|
||||
|
||||
// Elements in tuple is tensor shape value.
|
||||
if (out.contains(py::str(ATTR_MIN_VALUE)) && out.contains(py::str(ATTR_MAX_VALUE))) {
|
||||
value_tuple[i] = out[ATTR_VALUE];
|
||||
min_value_tuple[i] = out[ATTR_MIN_VALUE];
|
||||
max_value_tuple[i] = out[ATTR_MAX_VALUE];
|
||||
is_build_value = false;
|
||||
} else {
|
||||
value_tuple[i] = BuildValue(arg->BuildValue());
|
||||
min_value_tuple[i] = value_tuple[i];
|
||||
max_value_tuple[i] = value_tuple[i];
|
||||
dyn_value = true;
|
||||
}
|
||||
|
||||
// Elements in tuple is tensor, which shape is dynamic.
|
||||
|
@ -305,21 +301,21 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
|
|||
min_shape_tuple[i] = out[ATTR_MIN_SHAPE];
|
||||
max_shape_tuple[i] = out[ATTR_MAX_SHAPE];
|
||||
dyn_shape = true;
|
||||
} else {
|
||||
min_shape_tuple[i] = out[ATTR_SHAPE];
|
||||
max_shape_tuple[i] = out[ATTR_SHAPE];
|
||||
}
|
||||
}
|
||||
|
||||
dic[ATTR_SHAPE] = shape_tuple;
|
||||
dic[ATTR_DTYPE] = dtype_tuple;
|
||||
if (is_build_value) {
|
||||
dic[ATTR_VALUE] = BuildValue(arg_tuple->BuildValue());
|
||||
if (arg_tuple->BuildValue()->isa<AnyValue>()) {
|
||||
dic[ATTR_VALUE] = py::none();
|
||||
} else {
|
||||
dic[ATTR_VALUE] = value_tuple;
|
||||
}
|
||||
|
||||
if (dyn_value) {
|
||||
dic[ATTR_MIN_VALUE] = min_value_tuple;
|
||||
dic[ATTR_MAX_VALUE] = max_value_tuple;
|
||||
}
|
||||
|
||||
if (dyn_shape) {
|
||||
dic[ATTR_MIN_SHAPE] = min_shape_tuple;
|
||||
dic[ATTR_MAX_SHAPE] = max_shape_tuple;
|
||||
|
@ -333,6 +329,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
|
|||
size_t len = arg_list->size();
|
||||
py::list shape_list(len);
|
||||
py::list dtype_list(len);
|
||||
py::list value_list(len);
|
||||
py::list min_shape_list(len);
|
||||
py::list max_shape_list(len);
|
||||
auto dic = py::dict();
|
||||
|
@ -342,27 +339,29 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
|
|||
py::dict out = ConvertAbstractToPython(arg_list->elements()[i]);
|
||||
shape_list[i] = out[ATTR_SHAPE];
|
||||
dtype_list[i] = out[ATTR_DTYPE];
|
||||
value_list[i] = out[ATTR_VALUE];
|
||||
|
||||
// Elements in list is tensor, which shape is dynamic.
|
||||
if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) {
|
||||
min_shape_list[i] = out[ATTR_MIN_SHAPE];
|
||||
max_shape_list[i] = out[ATTR_MAX_SHAPE];
|
||||
dyn_shape = true;
|
||||
} else {
|
||||
min_shape_list[i] = out[ATTR_SHAPE];
|
||||
max_shape_list[i] = out[ATTR_SHAPE];
|
||||
}
|
||||
}
|
||||
|
||||
dic[ATTR_SHAPE] = shape_list;
|
||||
dic[ATTR_DTYPE] = dtype_list;
|
||||
if (arg_list->BuildValue()->isa<AnyValue>()) {
|
||||
dic[ATTR_VALUE] = py::none();
|
||||
} else {
|
||||
dic[ATTR_VALUE] = value_list;
|
||||
}
|
||||
|
||||
if (dyn_shape) {
|
||||
dic[ATTR_MIN_SHAPE] = min_shape_list;
|
||||
dic[ATTR_MAX_SHAPE] = max_shape_list;
|
||||
}
|
||||
|
||||
dic[ATTR_SHAPE] = shape_list;
|
||||
dic[ATTR_DTYPE] = dtype_list;
|
||||
dic[ATTR_VALUE] = BuildValue(arg_list->BuildValue());
|
||||
|
||||
return dic;
|
||||
}
|
||||
} // end anonymous namespace
|
||||
|
@ -428,6 +427,16 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
|||
dic[ATTR_SHAPE] = py::none();
|
||||
dic[ATTR_DTYPE] = abs_base->BuildType();
|
||||
dic[ATTR_VALUE] = py::none();
|
||||
if (abs_base->isa<PartialAbstractClosure>()) {
|
||||
AbstractBasePtrList args = abs_base->cast<PartialAbstractClosurePtr>()->args();
|
||||
if (!args.empty()) {
|
||||
auto value = args[0]->BuildValue()->cast<parse::ClassTypePtr>();
|
||||
if (value != nullptr) {
|
||||
dic[ATTR_DTYPE] = std::make_shared<TypeType>();
|
||||
dic[ATTR_VALUE] = value->obj();
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (abs_base->isa<AbstractUndetermined>()) {
|
||||
auto arg = dyn_cast<AbstractUndetermined>(abs_base);
|
||||
dic[ATTR_SHAPE] = py::none();
|
||||
|
|
|
@ -390,8 +390,8 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
|
|||
}
|
||||
return PyList2DynamicShapeTensor(shape_obj, type_obj, output);
|
||||
} else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
|
||||
py::tuple shape_tuple = shape_obj.cast<py::tuple>();
|
||||
py::tuple typeid_tuple = type_obj.cast<py::tuple>();
|
||||
auto shape_tuple = shape_obj.cast<py::tuple>();
|
||||
auto typeid_tuple = type_obj.cast<py::tuple>();
|
||||
AbstractBasePtrList ptr_list;
|
||||
for (size_t it = 0; it < shape_tuple.size(); ++it) {
|
||||
auto tensor_it = PyListDtype2AbstractTensor(shape_tuple[it], typeid_tuple[it]);
|
||||
|
@ -400,8 +400,8 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
|
|||
auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
|
||||
return tuple;
|
||||
} else if (py::isinstance<py::list>(shape_obj) && py::isinstance<py::list>(type_obj)) {
|
||||
py::list shape_list = shape_obj.cast<py::list>();
|
||||
py::list typeid_list = type_obj.cast<py::list>();
|
||||
auto shape_list = shape_obj.cast<py::list>();
|
||||
auto typeid_list = type_obj.cast<py::list>();
|
||||
AbstractBasePtrList ptr_list;
|
||||
for (size_t it = 0; it < shape_list.size(); ++it) {
|
||||
auto tensor_it = PyListDtype2AbstractTensor(shape_list[it], typeid_list[it]);
|
||||
|
|
|
@ -78,36 +78,39 @@ single = float32
|
|||
float64 = typing.Float(64)
|
||||
double = float64
|
||||
|
||||
number = typing.Number()
|
||||
int_ = typing.Int()
|
||||
uint = typing.UInt()
|
||||
float_ = typing.Float()
|
||||
number = typing.Number()
|
||||
|
||||
string = typing.String()
|
||||
list_ = typing.List()
|
||||
tuple_ = typing.Tuple()
|
||||
tensor = typing.TensorType()
|
||||
function = typing.Function()
|
||||
function_type = typing.Function
|
||||
symbolic_key = typing.SymbolicKeyType()
|
||||
env_type = typing.EnvType()
|
||||
env_type_type = typing.EnvType
|
||||
type_type = typing.TypeType()
|
||||
type_none = typing.TypeNone()
|
||||
type_bool = typing.Bool()
|
||||
string = typing.String()
|
||||
type_refkey = typing.RefKeyType()
|
||||
tensor_type = typing.TensorType
|
||||
anything_type = typing.TypeAnything
|
||||
slice_type = typing.Slice
|
||||
ellipsis_type = typing.TypeEllipsis
|
||||
list_type = typing.List
|
||||
tuple_type = typing.Tuple
|
||||
|
||||
tensor = typing.TensorType()
|
||||
index_slices = typing.RowTensorType()
|
||||
sparse_tensor = typing.SparseTensorType()
|
||||
undetermined = typing.UndeterminedType()
|
||||
|
||||
function = typing.Function()
|
||||
symbolic_key = typing.SymbolicKeyType()
|
||||
env_type = typing.EnvType()
|
||||
type_type = typing.TypeType()
|
||||
type_refkey = typing.RefKeyType()
|
||||
|
||||
Int = typing.Int
|
||||
bool_type = typing.Bool
|
||||
Float = typing.Float
|
||||
Bool = typing.Bool
|
||||
String = typing.String
|
||||
List = typing.List
|
||||
Tuple = typing.Tuple
|
||||
Slice = typing.Slice
|
||||
function_type = typing.Function
|
||||
Ellipsis_ = typing.TypeEllipsis
|
||||
none_type = typing.TypeNone
|
||||
env_type_type = typing.EnvType
|
||||
tensor_type = typing.TensorType
|
||||
anything_type = typing.TypeAnything
|
||||
|
||||
number_type = (int8,
|
||||
int16,
|
||||
|
|
|
@ -86,6 +86,8 @@ class PrimitiveAbstractClosure : public AbstractFuncAtom {
|
|||
|
||||
std::string ToString() const override { return "Prim: " + prim_->name(); }
|
||||
|
||||
ValuePtr RealBuildValue() const override { return prim_; }
|
||||
|
||||
private:
|
||||
PrimitivePtr prim_;
|
||||
// store it as weak_ptr to break reference cycle.
|
||||
|
@ -183,6 +185,7 @@ class PartialAbstractClosure : public AbstractFuncAtom {
|
|||
|
||||
AbstractFunctionPtr fn() { return fn_; }
|
||||
AbstractBasePtrList args() { return args_spec_list_; }
|
||||
ValuePtr RealBuildValue() const override { return fn_->BuildValue(); }
|
||||
AnfNodePtr node() { return node_.lock(); }
|
||||
void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); }
|
||||
AbstractFunctionPtr Copy() const override {
|
||||
|
@ -199,6 +202,7 @@ class PartialAbstractClosure : public AbstractFuncAtom {
|
|||
// The CNode which this PartialAbstractClosure evaluated from.
|
||||
AnfNodeWeakPtr node_;
|
||||
};
|
||||
using PartialAbstractClosurePtr = std::shared_ptr<PartialAbstractClosure>;
|
||||
|
||||
class JTransformedAbstractClosure : public AbstractFuncAtom {
|
||||
public:
|
||||
|
|
|
@ -339,13 +339,13 @@ def _cpu_not_support(name):
|
|||
@constexpr
|
||||
def _check_is_tuple(obj):
|
||||
"""Check whether obj is a tuple"""
|
||||
return isinstance(obj, mstype.tuple_type)
|
||||
return isinstance(obj, mstype.Tuple)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_is_list(obj):
|
||||
"""Check whether obj is a list"""
|
||||
return isinstance(obj, mstype.list_type)
|
||||
return isinstance(obj, mstype.List)
|
||||
|
||||
|
||||
@constexpr
|
||||
|
|
|
@ -148,7 +148,7 @@ def _expand_data_dims_with_bool(data, tuple_index, op_name):
|
|||
bool_positions, tuple_index_without_bool = (), ()
|
||||
|
||||
for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)):
|
||||
bool_type_tag = const_utils.judge_index_type(index_type, mstype.type_bool)
|
||||
bool_type_tag = const_utils.judge_index_type(index_type, mstype.bool_)
|
||||
if bool_type_tag:
|
||||
if index:
|
||||
tuple_index_without_bool += (const_utils.make_tensor([0], mstype.int64),)
|
||||
|
@ -653,6 +653,6 @@ def tensor_in_sequence(x, y):
|
|||
"""Assigns whether a sequence contains the given tensor"""
|
||||
result = const_utils.scalar_to_tensor(False)
|
||||
for i in y:
|
||||
if isinstance(i, mstype.tensor) and x.shape == i.shape and x.dtype == i.dtype:
|
||||
if isinstance(i, Tensor) and x.shape == i.shape and x.dtype == i.dtype:
|
||||
result = F.logical_or(F.equal(x, i).all(), result)
|
||||
return result
|
||||
|
|
|
@ -171,15 +171,15 @@ def get_pos_of_indexes_types(indexes_types, op_name):
|
|||
slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \
|
||||
sequence_positions = [], [], [], [], [], [], []
|
||||
for i, index_type in enumerate(indexes_types):
|
||||
if isinstance(index_type, mstype.slice_type):
|
||||
if isinstance(index_type, mstype.Slice):
|
||||
slice_positions.append(i)
|
||||
elif isinstance(index_type, mstype.ellipsis_type):
|
||||
elif isinstance(index_type, mstype.Ellipsis_):
|
||||
ellipsis_positions.append(i)
|
||||
elif isinstance(index_type, mstype.none_type):
|
||||
none_positions.append(i)
|
||||
elif isinstance(index_type, mstype.Int):
|
||||
int_positions.append(i)
|
||||
elif isinstance(index_type, mstype.bool_type):
|
||||
elif isinstance(index_type, mstype.Bool):
|
||||
bool_positions.append(i)
|
||||
elif isinstance(index_type, mstype.tensor_type):
|
||||
tensor_positions.append(i)
|
||||
|
@ -341,7 +341,7 @@ def tuple_index_int_cnt(types, op_name):
|
|||
def tuple_index_type_cnt(types, op_name):
|
||||
"""count the tensor type of types which contains the tuple elements' type."""
|
||||
tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types)
|
||||
basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.ellipsis_type, mstype.slice_type)) for ele in types)
|
||||
basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types)
|
||||
if tensor_cnt == len(types):
|
||||
return ALL_TENSOR
|
||||
if basic_cnt == len(types):
|
||||
|
@ -614,7 +614,7 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape, indexes_types, t
|
|||
indexes_info[pos] = tensor_indexes_shapes[tensor_count]
|
||||
index_tensors_info[pos] = tensor_indexes_shapes[tensor_count]
|
||||
tensor_count += 1
|
||||
elif isinstance(index_type, mstype.slice_type):
|
||||
elif isinstance(index_type, mstype.Slice):
|
||||
slice_obj = slice(slice_indexes[slice_count].start,
|
||||
slice_indexes[slice_count].stop,
|
||||
slice_indexes[slice_count].step)
|
||||
|
@ -680,7 +680,7 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te
|
|||
return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info)
|
||||
|
||||
|
||||
@ constexpr
|
||||
@constexpr
|
||||
def scalar_in_sequence(x, y):
|
||||
"""Determine whether the scalar in the sequence."""
|
||||
if x is None:
|
||||
|
|
|
@ -63,7 +63,7 @@ TEST_F(TestData, test_build_value) {
|
|||
// BuildValue(AbstractFunction) should return kAnyValue.
|
||||
AbstractBasePtr abs_f1 = FromValue(prim::kPrimReturn, false);
|
||||
ValuePtr abs_f1_built = abs_f1->BuildValue();
|
||||
ASSERT_EQ(abs_f1_built, kAnyValue);
|
||||
ASSERT_EQ(abs_f1_built, prim::kPrimReturn);
|
||||
|
||||
FuncGraphPtr fg1 = std::make_shared<FuncGraph>();
|
||||
AbstractBasePtr abs_fg1 = FromValue(fg1, false);
|
||||
|
@ -74,17 +74,20 @@ TEST_F(TestData, test_build_value) {
|
|||
AbstractBasePtr abs_f2 = FromValue(prim::kPrimScalarAdd, false);
|
||||
AbstractBasePtr abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_f1, abs_f2}));
|
||||
ValuePtr func_tuple_built = abs_func_tuple->BuildValue();
|
||||
ASSERT_EQ(func_tuple_built, kAnyValue);
|
||||
ASSERT_EQ(*func_tuple_built,
|
||||
ValueTuple(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
|
||||
|
||||
// BuildValue(List(AbstractFunction)) should return kAnyValue;
|
||||
AbstractBasePtr abs_func_list = std::make_shared<AbstractList>(AbstractBasePtrList({abs_f1, abs_f2}));
|
||||
ValuePtr func_list_built = abs_func_list->BuildValue();
|
||||
ASSERT_EQ(func_list_built, kAnyValue);
|
||||
ASSERT_EQ(*func_list_built,
|
||||
ValueList(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
|
||||
|
||||
// BuildValue(Tuple(AnyAbstractBase, AbstractFunction)) should return kAnyValue
|
||||
abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({base1, abs_f2}));
|
||||
func_tuple_built = abs_func_tuple->BuildValue();
|
||||
ASSERT_EQ(func_tuple_built, kAnyValue);
|
||||
ASSERT_EQ(*func_tuple_built,
|
||||
ValueTuple(std::vector<ValuePtr>{std::make_shared<Int64Imm>(1), prim::kPrimScalarAdd}));
|
||||
}
|
||||
|
||||
TEST_F(TestData, test_build_type) {
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test instance"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
|
||||
|
||||
def test_isinstance():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.int_member = 1
|
||||
self.float_member = 1.0
|
||||
self.bool_member = True
|
||||
self.string_member = "abcd"
|
||||
self.tensor_member = Tensor(np.arange(4))
|
||||
self.tuple_member = (1, 1.0, True, "abcd", self.tensor_member)
|
||||
self.list_member = list(self.tuple_member)
|
||||
|
||||
def construct(self, x, y):
|
||||
is_int = isinstance(self.int_member, int)
|
||||
is_float = isinstance(self.float_member, float)
|
||||
is_bool = isinstance(self.bool_member, bool)
|
||||
is_string = isinstance(self.string_member, str)
|
||||
is_tensor_const = isinstance(self.tensor_member, Tensor)
|
||||
is_tensor_var = isinstance(x, Tensor)
|
||||
is_tuple_const = isinstance(self.tuple_member, tuple)
|
||||
is_tuple_var = isinstance((x, 1, 1.0, y), tuple)
|
||||
is_list_const = isinstance(self.list_member, list)
|
||||
is_list_var = isinstance([x, 1, 1.0, y], list)
|
||||
is_list_or_tensor = isinstance([x, y], (Tensor, list))
|
||||
is_int_or_float_or_tensor_or_tuple = isinstance(x, (Tensor, tuple, int, float))
|
||||
float_is_int = isinstance(self.float_member, int)
|
||||
bool_is_string = isinstance(self.bool_member, str)
|
||||
tensor_is_tuple = isinstance(x, tuple)
|
||||
tuple_is_list = isinstance(self.tuple_member, list)
|
||||
return is_int, is_float, is_bool, is_string, is_tensor_const, is_tensor_var, \
|
||||
is_tuple_const, is_tuple_var, is_list_const, is_list_var, \
|
||||
is_int_or_float_or_tensor_or_tuple, is_list_or_tensor, \
|
||||
float_is_int, bool_is_string, tensor_is_tuple, tuple_is_list
|
||||
|
||||
net = Net()
|
||||
x = Tensor(np.arange(4))
|
||||
y = Tensor(np.arange(5))
|
||||
assert net(x, y) == (True,) * 12 + (False,) * 4
|
||||
|
||||
|
||||
def test_isinstance_not_supported():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = (11, 22, 33, 44)
|
||||
|
||||
def construct(self):
|
||||
return isinstance(self.value, None)
|
||||
|
||||
net = Net()
|
||||
with pytest.raises(TypeError) as err:
|
||||
net()
|
||||
assert "The type 'None' is not supported for 'isinstance'" in str(err.value)
|
Loading…
Reference in New Issue