keep consistent in Graph mode and PyNative mode for 'isinstance'

This commit is contained in:
buxue 2021-01-04 18:56:48 +08:00
parent f31dfa129a
commit 6d395c1d3f
14 changed files with 185 additions and 67 deletions

View File

@ -17,14 +17,14 @@
"""Resources for ast tree parse."""
import ast
import math
from mindspore import RowTensor, SparseTensor
from mindspore.ops.composite import multitype_ops
from mindspore.ops import functional as F, composite as C
from mindspore.ops.composite import multitype_ops
from . import standard_method as M
from . import trope as T
from .namespace import CellNamespace
# namespace define
functional_ns = CellNamespace('mindspore.ops.functional')
composite_ns = CellNamespace('mindspore.ops.composite')
@ -109,7 +109,7 @@ convert_object_map = {
# system function
T.len: M.ms_len,
T.bool: M.bool_,
T.bool_: M.bool_,
T.map: C.Map(),
T.partial: F.partial,
T.zip: C.zip_operation,

View File

@ -16,13 +16,15 @@
# ============================================================================
"""standard_method"""
from dataclasses import dataclass
from mindspore.common import dtype as mstype
from mindspore import Tensor
from mindspore import dtype as mstype
from ...ops import functional as F
from ...ops import operations as P
from ...ops.primitive import constexpr
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
zeros_like, ones_like
from ...ops.composite.base import _append
from ...ops.primitive import constexpr
__all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
@ -219,9 +221,23 @@ def while_cond(x):
@constexpr
def check_type_same(x_type, base_type):
"""Check x_type is same as base_type."""
if mstype.issubclass_(x_type, base_type):
return True
return False
pytype_to_mstype = {
bool: mstype.Bool,
int: mstype.Int,
float: mstype.Float,
str: mstype.String,
list: mstype.List,
tuple: mstype.Tuple,
Tensor: mstype.tensor_type
}
try:
if isinstance(base_type, (tuple, list)):
target_type = tuple(pytype_to_mstype[i] for i in base_type)
else:
target_type = pytype_to_mstype[base_type]
return isinstance(x_type, target_type)
except KeyError:
raise TypeError(f"The type '{base_type}' is not supported for 'isinstance'")
@constexpr
@ -235,7 +251,7 @@ def check_is_tensor(x):
@constexpr
def check_is_tuple_or_list_or_tensor(x, op_name, arg_name):
"""check whether x is list or tuple or tensor."""
if isinstance(x, (mstype.list_type, mstype.tuple_type, mstype.tensor_type)):
if isinstance(x, (mstype.List, mstype.Tuple, mstype.tensor_type)):
return True
raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list or tensor, but got {x}.")

View File

@ -95,3 +95,7 @@ def not_contains(x): # pragma: no cover
def while_cond(x): # pragma: no cover
"""Not in function."""
raise RuntimeError('This operation is not meant to be called directly.')
def bool_(x): # pragma: no cover
"""judge true function."""
raise RuntimeError('This operation is not meant to be called directly.')

View File

@ -116,7 +116,7 @@ const char NAMED_PRIMITIVE_NEXT[] = "next";
const char NAMED_PRIMITIVE_GETITEM[] = "getitem";
const char NAMED_PRIMITIVE_SETITEM[] = "setitem";
const char NAMED_PRIMITIVE_HASNEXT[] = "hasnext";
const char NAMED_PRIMITIVE_BOOL[] = "bool"; // bool: P.identity
const char NAMED_PRIMITIVE_BOOL[] = "bool_"; // bool: P.identity
const char NAMED_PRIMITIVE_MAKETUPLE[] = "make_tuple";
const char NAMED_PRIMITIVE_MAKELIST[] = "make_list";
const char NAMED_PRIMITIVE_MAKESLICE[] = "make_slice";

View File

@ -109,6 +109,7 @@ class ClassType : public PyObjectWrapper {
MS_DECLARE_PARENT(ClassType, PyObjectWrapper);
abstract::AbstractBasePtr ToAbstract() override;
};
using ClassTypePtr = std::shared_ptr<ClassType>;
// SymbolResolver class for resolving symbol extracted from AnfNode.
class SymbolResolver {

View File

@ -280,24 +280,20 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
py::tuple max_shape_tuple(len);
auto dic = py::dict();
bool dyn_shape = false;
bool is_build_value = true;
bool dyn_value = false;
for (size_t i = 0; i < len; i++) {
auto arg = arg_tuple->elements()[i];
py::dict out = ConvertAbstractToPython(arg);
shape_tuple[i] = out[ATTR_SHAPE];
dtype_tuple[i] = out[ATTR_DTYPE];
value_tuple[i] = out[ATTR_VALUE];
// Elements in tuple is tensor shape value.
if (out.contains(py::str(ATTR_MIN_VALUE)) && out.contains(py::str(ATTR_MAX_VALUE))) {
value_tuple[i] = out[ATTR_VALUE];
min_value_tuple[i] = out[ATTR_MIN_VALUE];
max_value_tuple[i] = out[ATTR_MAX_VALUE];
is_build_value = false;
} else {
value_tuple[i] = BuildValue(arg->BuildValue());
min_value_tuple[i] = value_tuple[i];
max_value_tuple[i] = value_tuple[i];
dyn_value = true;
}
// Elements in tuple is tensor, which shape is dynamic.
@ -305,21 +301,21 @@ py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
min_shape_tuple[i] = out[ATTR_MIN_SHAPE];
max_shape_tuple[i] = out[ATTR_MAX_SHAPE];
dyn_shape = true;
} else {
min_shape_tuple[i] = out[ATTR_SHAPE];
max_shape_tuple[i] = out[ATTR_SHAPE];
}
}
dic[ATTR_SHAPE] = shape_tuple;
dic[ATTR_DTYPE] = dtype_tuple;
if (is_build_value) {
dic[ATTR_VALUE] = BuildValue(arg_tuple->BuildValue());
if (arg_tuple->BuildValue()->isa<AnyValue>()) {
dic[ATTR_VALUE] = py::none();
} else {
dic[ATTR_VALUE] = value_tuple;
}
if (dyn_value) {
dic[ATTR_MIN_VALUE] = min_value_tuple;
dic[ATTR_MAX_VALUE] = max_value_tuple;
}
if (dyn_shape) {
dic[ATTR_MIN_SHAPE] = min_shape_tuple;
dic[ATTR_MAX_SHAPE] = max_shape_tuple;
@ -333,6 +329,7 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
size_t len = arg_list->size();
py::list shape_list(len);
py::list dtype_list(len);
py::list value_list(len);
py::list min_shape_list(len);
py::list max_shape_list(len);
auto dic = py::dict();
@ -342,27 +339,29 @@ py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
py::dict out = ConvertAbstractToPython(arg_list->elements()[i]);
shape_list[i] = out[ATTR_SHAPE];
dtype_list[i] = out[ATTR_DTYPE];
value_list[i] = out[ATTR_VALUE];
// Elements in list is tensor, which shape is dynamic.
if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) {
min_shape_list[i] = out[ATTR_MIN_SHAPE];
max_shape_list[i] = out[ATTR_MAX_SHAPE];
dyn_shape = true;
} else {
min_shape_list[i] = out[ATTR_SHAPE];
max_shape_list[i] = out[ATTR_SHAPE];
}
}
dic[ATTR_SHAPE] = shape_list;
dic[ATTR_DTYPE] = dtype_list;
if (arg_list->BuildValue()->isa<AnyValue>()) {
dic[ATTR_VALUE] = py::none();
} else {
dic[ATTR_VALUE] = value_list;
}
if (dyn_shape) {
dic[ATTR_MIN_SHAPE] = min_shape_list;
dic[ATTR_MAX_SHAPE] = max_shape_list;
}
dic[ATTR_SHAPE] = shape_list;
dic[ATTR_DTYPE] = dtype_list;
dic[ATTR_VALUE] = BuildValue(arg_list->BuildValue());
return dic;
}
} // end anonymous namespace
@ -428,6 +427,16 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic[ATTR_SHAPE] = py::none();
dic[ATTR_DTYPE] = abs_base->BuildType();
dic[ATTR_VALUE] = py::none();
if (abs_base->isa<PartialAbstractClosure>()) {
AbstractBasePtrList args = abs_base->cast<PartialAbstractClosurePtr>()->args();
if (!args.empty()) {
auto value = args[0]->BuildValue()->cast<parse::ClassTypePtr>();
if (value != nullptr) {
dic[ATTR_DTYPE] = std::make_shared<TypeType>();
dic[ATTR_VALUE] = value->obj();
}
}
}
} else if (abs_base->isa<AbstractUndetermined>()) {
auto arg = dyn_cast<AbstractUndetermined>(abs_base);
dic[ATTR_SHAPE] = py::none();

View File

@ -390,8 +390,8 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
}
return PyList2DynamicShapeTensor(shape_obj, type_obj, output);
} else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
py::tuple shape_tuple = shape_obj.cast<py::tuple>();
py::tuple typeid_tuple = type_obj.cast<py::tuple>();
auto shape_tuple = shape_obj.cast<py::tuple>();
auto typeid_tuple = type_obj.cast<py::tuple>();
AbstractBasePtrList ptr_list;
for (size_t it = 0; it < shape_tuple.size(); ++it) {
auto tensor_it = PyListDtype2AbstractTensor(shape_tuple[it], typeid_tuple[it]);
@ -400,8 +400,8 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
return tuple;
} else if (py::isinstance<py::list>(shape_obj) && py::isinstance<py::list>(type_obj)) {
py::list shape_list = shape_obj.cast<py::list>();
py::list typeid_list = type_obj.cast<py::list>();
auto shape_list = shape_obj.cast<py::list>();
auto typeid_list = type_obj.cast<py::list>();
AbstractBasePtrList ptr_list;
for (size_t it = 0; it < shape_list.size(); ++it) {
auto tensor_it = PyListDtype2AbstractTensor(shape_list[it], typeid_list[it]);

View File

@ -78,36 +78,39 @@ single = float32
float64 = typing.Float(64)
double = float64
number = typing.Number()
int_ = typing.Int()
uint = typing.UInt()
float_ = typing.Float()
number = typing.Number()
string = typing.String()
list_ = typing.List()
tuple_ = typing.Tuple()
tensor = typing.TensorType()
function = typing.Function()
function_type = typing.Function
symbolic_key = typing.SymbolicKeyType()
env_type = typing.EnvType()
env_type_type = typing.EnvType
type_type = typing.TypeType()
type_none = typing.TypeNone()
type_bool = typing.Bool()
string = typing.String()
type_refkey = typing.RefKeyType()
tensor_type = typing.TensorType
anything_type = typing.TypeAnything
slice_type = typing.Slice
ellipsis_type = typing.TypeEllipsis
list_type = typing.List
tuple_type = typing.Tuple
tensor = typing.TensorType()
index_slices = typing.RowTensorType()
sparse_tensor = typing.SparseTensorType()
undetermined = typing.UndeterminedType()
function = typing.Function()
symbolic_key = typing.SymbolicKeyType()
env_type = typing.EnvType()
type_type = typing.TypeType()
type_refkey = typing.RefKeyType()
Int = typing.Int
bool_type = typing.Bool
Float = typing.Float
Bool = typing.Bool
String = typing.String
List = typing.List
Tuple = typing.Tuple
Slice = typing.Slice
function_type = typing.Function
Ellipsis_ = typing.TypeEllipsis
none_type = typing.TypeNone
env_type_type = typing.EnvType
tensor_type = typing.TensorType
anything_type = typing.TypeAnything
number_type = (int8,
int16,

View File

@ -86,6 +86,8 @@ class PrimitiveAbstractClosure : public AbstractFuncAtom {
std::string ToString() const override { return "Prim: " + prim_->name(); }
ValuePtr RealBuildValue() const override { return prim_; }
private:
PrimitivePtr prim_;
// store it as weak_ptr to break reference cycle.
@ -183,6 +185,7 @@ class PartialAbstractClosure : public AbstractFuncAtom {
AbstractFunctionPtr fn() { return fn_; }
AbstractBasePtrList args() { return args_spec_list_; }
ValuePtr RealBuildValue() const override { return fn_->BuildValue(); }
AnfNodePtr node() { return node_.lock(); }
void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); }
AbstractFunctionPtr Copy() const override {
@ -199,6 +202,7 @@ class PartialAbstractClosure : public AbstractFuncAtom {
// The CNode which this PartialAbstractClosure evaluated from.
AnfNodeWeakPtr node_;
};
using PartialAbstractClosurePtr = std::shared_ptr<PartialAbstractClosure>;
class JTransformedAbstractClosure : public AbstractFuncAtom {
public:

View File

@ -339,13 +339,13 @@ def _cpu_not_support(name):
@constexpr
def _check_is_tuple(obj):
"""Check whether obj is a tuple"""
return isinstance(obj, mstype.tuple_type)
return isinstance(obj, mstype.Tuple)
@constexpr
def _check_is_list(obj):
"""Check whether obj is a list"""
return isinstance(obj, mstype.list_type)
return isinstance(obj, mstype.List)
@constexpr

View File

@ -148,7 +148,7 @@ def _expand_data_dims_with_bool(data, tuple_index, op_name):
bool_positions, tuple_index_without_bool = (), ()
for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)):
bool_type_tag = const_utils.judge_index_type(index_type, mstype.type_bool)
bool_type_tag = const_utils.judge_index_type(index_type, mstype.bool_)
if bool_type_tag:
if index:
tuple_index_without_bool += (const_utils.make_tensor([0], mstype.int64),)
@ -653,6 +653,6 @@ def tensor_in_sequence(x, y):
"""Assigns whether a sequence contains the given tensor"""
result = const_utils.scalar_to_tensor(False)
for i in y:
if isinstance(i, mstype.tensor) and x.shape == i.shape and x.dtype == i.dtype:
if isinstance(i, Tensor) and x.shape == i.shape and x.dtype == i.dtype:
result = F.logical_or(F.equal(x, i).all(), result)
return result

View File

@ -171,15 +171,15 @@ def get_pos_of_indexes_types(indexes_types, op_name):
slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \
sequence_positions = [], [], [], [], [], [], []
for i, index_type in enumerate(indexes_types):
if isinstance(index_type, mstype.slice_type):
if isinstance(index_type, mstype.Slice):
slice_positions.append(i)
elif isinstance(index_type, mstype.ellipsis_type):
elif isinstance(index_type, mstype.Ellipsis_):
ellipsis_positions.append(i)
elif isinstance(index_type, mstype.none_type):
none_positions.append(i)
elif isinstance(index_type, mstype.Int):
int_positions.append(i)
elif isinstance(index_type, mstype.bool_type):
elif isinstance(index_type, mstype.Bool):
bool_positions.append(i)
elif isinstance(index_type, mstype.tensor_type):
tensor_positions.append(i)
@ -341,7 +341,7 @@ def tuple_index_int_cnt(types, op_name):
def tuple_index_type_cnt(types, op_name):
"""count the tensor type of types which contains the tuple elements' type."""
tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types)
basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.ellipsis_type, mstype.slice_type)) for ele in types)
basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types)
if tensor_cnt == len(types):
return ALL_TENSOR
if basic_cnt == len(types):
@ -614,7 +614,7 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape, indexes_types, t
indexes_info[pos] = tensor_indexes_shapes[tensor_count]
index_tensors_info[pos] = tensor_indexes_shapes[tensor_count]
tensor_count += 1
elif isinstance(index_type, mstype.slice_type):
elif isinstance(index_type, mstype.Slice):
slice_obj = slice(slice_indexes[slice_count].start,
slice_indexes[slice_count].stop,
slice_indexes[slice_count].step)
@ -680,7 +680,7 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te
return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info)
@ constexpr
@constexpr
def scalar_in_sequence(x, y):
"""Determine whether the scalar in the sequence."""
if x is None:

View File

@ -63,7 +63,7 @@ TEST_F(TestData, test_build_value) {
// BuildValue(AbstractFunction) should return kAnyValue.
AbstractBasePtr abs_f1 = FromValue(prim::kPrimReturn, false);
ValuePtr abs_f1_built = abs_f1->BuildValue();
ASSERT_EQ(abs_f1_built, kAnyValue);
ASSERT_EQ(abs_f1_built, prim::kPrimReturn);
FuncGraphPtr fg1 = std::make_shared<FuncGraph>();
AbstractBasePtr abs_fg1 = FromValue(fg1, false);
@ -74,17 +74,20 @@ TEST_F(TestData, test_build_value) {
AbstractBasePtr abs_f2 = FromValue(prim::kPrimScalarAdd, false);
AbstractBasePtr abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_f1, abs_f2}));
ValuePtr func_tuple_built = abs_func_tuple->BuildValue();
ASSERT_EQ(func_tuple_built, kAnyValue);
ASSERT_EQ(*func_tuple_built,
ValueTuple(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
// BuildValue(List(AbstractFunction)) should return kAnyValue;
AbstractBasePtr abs_func_list = std::make_shared<AbstractList>(AbstractBasePtrList({abs_f1, abs_f2}));
ValuePtr func_list_built = abs_func_list->BuildValue();
ASSERT_EQ(func_list_built, kAnyValue);
ASSERT_EQ(*func_list_built,
ValueList(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
// BuildValue(Tuple(AnyAbstractBase, AbstractFunction)) should return kAnyValue
abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({base1, abs_f2}));
func_tuple_built = abs_func_tuple->BuildValue();
ASSERT_EQ(func_tuple_built, kAnyValue);
ASSERT_EQ(*func_tuple_built,
ValueTuple(std::vector<ValuePtr>{std::make_shared<Int64Imm>(1), prim::kPrimScalarAdd}));
}
TEST_F(TestData, test_build_type) {

View File

@ -0,0 +1,78 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test instance"""
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
def test_isinstance():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.int_member = 1
self.float_member = 1.0
self.bool_member = True
self.string_member = "abcd"
self.tensor_member = Tensor(np.arange(4))
self.tuple_member = (1, 1.0, True, "abcd", self.tensor_member)
self.list_member = list(self.tuple_member)
def construct(self, x, y):
is_int = isinstance(self.int_member, int)
is_float = isinstance(self.float_member, float)
is_bool = isinstance(self.bool_member, bool)
is_string = isinstance(self.string_member, str)
is_tensor_const = isinstance(self.tensor_member, Tensor)
is_tensor_var = isinstance(x, Tensor)
is_tuple_const = isinstance(self.tuple_member, tuple)
is_tuple_var = isinstance((x, 1, 1.0, y), tuple)
is_list_const = isinstance(self.list_member, list)
is_list_var = isinstance([x, 1, 1.0, y], list)
is_list_or_tensor = isinstance([x, y], (Tensor, list))
is_int_or_float_or_tensor_or_tuple = isinstance(x, (Tensor, tuple, int, float))
float_is_int = isinstance(self.float_member, int)
bool_is_string = isinstance(self.bool_member, str)
tensor_is_tuple = isinstance(x, tuple)
tuple_is_list = isinstance(self.tuple_member, list)
return is_int, is_float, is_bool, is_string, is_tensor_const, is_tensor_var, \
is_tuple_const, is_tuple_var, is_list_const, is_list_var, \
is_int_or_float_or_tensor_or_tuple, is_list_or_tensor, \
float_is_int, bool_is_string, tensor_is_tuple, tuple_is_list
net = Net()
x = Tensor(np.arange(4))
y = Tensor(np.arange(5))
assert net(x, y) == (True,) * 12 + (False,) * 4
def test_isinstance_not_supported():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = (11, 22, 33, 44)
def construct(self):
return isinstance(self.value, None)
net = Net()
with pytest.raises(TypeError) as err:
net()
assert "The type 'None' is not supported for 'isinstance'" in str(err.value)