improve interface '__bool__' for tensor

This commit is contained in:
buxue 2020-08-05 16:06:57 +08:00
parent 5cf20dc0da
commit ace34525cd
4 changed files with 20 additions and 26 deletions

View File

@ -17,7 +17,6 @@
"""standard_method"""
from dataclasses import dataclass
from mindspore.common import dtype as mstype
from mindspore.common._register_for_tensor import tensor_operator_registry
from ...ops import functional as F
from ...ops import operations as P
from ...ops.primitive import constexpr
@ -206,13 +205,11 @@ def const_tensor_to_bool(x):
if x is None:
raise ValueError("Only constant tensor bool can be converted to bool")
x = x.asnumpy()
if x.shape not in ((), (1,)):
raise ValueError("The truth value of an array with several elements is ambiguous.")
if x.shape == ():
value = bool(x)
else:
value = bool(x[0])
return value
return bool(x)
if x.shape == (1,):
return bool(x[0])
raise ValueError("The truth value of an array with several elements is ambiguous.")
def tensor_bool(x):
@ -349,6 +346,3 @@ def list_append(self_, item):
def to_array(x):
"""Implementation of `to_array`."""
return x.__ms_to_array__()
tensor_operator_registry.register('__bool__', tensor_bool)

View File

@ -73,7 +73,7 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
namespace {
bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) {
MS_LOG(DEBUG) << "Converting python tuple";
py::tuple tuple = obj.cast<py::tuple>();
auto tuple = obj.cast<py::tuple>();
std::vector<ValuePtr> value_list;
for (size_t it = 0; it < tuple.size(); ++it) {
ValuePtr out = nullptr;
@ -91,7 +91,7 @@ bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signatur
bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) {
MS_LOG(DEBUG) << "Converting python list";
py::list list = obj.cast<py::list>();
auto list = obj.cast<py::list>();
std::vector<ValuePtr> value_list;
for (size_t it = 0; it < list.size(); ++it) {
ValuePtr out = nullptr;
@ -124,7 +124,7 @@ bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signa
bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) {
MS_LOG(DEBUG) << "Converting python dict";
py::dict dict_values = obj.cast<py::dict>();
auto dict_values = obj.cast<py::dict>();
std::vector<std::pair<std::string, ValuePtr>> key_values;
for (auto item : dict_values) {
if (!py::isinstance<py::str>(item.first)) {
@ -208,7 +208,7 @@ bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_
bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
MS_LOG(DEBUG) << "Converting slice object";
py::slice slice_obj = obj.cast<py::slice>();
auto slice_obj = obj.cast<py::slice>();
auto convert_func = [obj](std::string attr) -> ValuePtr {
auto py_attr = py::getattr(obj, attr.c_str());
if (py::isinstance<py::none>(py_attr)) {
@ -335,7 +335,7 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
} else if (py::isinstance<MetaTensor>(obj)) {
converted = obj.cast<MetaTensorPtr>();
} else if (py::isinstance<EnvInstance>(obj)) {
std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>();
auto env = obj.cast<std::shared_ptr<EnvInstance>>();
converted = env;
} else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) {
converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
@ -374,7 +374,7 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
data_converter::MakeProperNameToFuncGraph(func_graph, obj_id);
data_converter::CacheObjectValue(obj_id, func_graph);
if (obj_key != "") {
if (!obj_key.empty()) {
MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString();
data_converter::SetObjGraphValue(obj_key, func_graph);
}
@ -440,7 +440,7 @@ bool IsCellInstance(const py::object &obj) {
py::object CreatePythonObject(const py::object &type, const py::tuple &params) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object obj;
if (params.size() == 0) {
if (params.empty()) {
obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type);
} else {
obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params);
@ -499,7 +499,7 @@ ClassPtr ParseDataClass(const py::object &cls_obj) {
ClassAttrVector attributes;
py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj);
for (auto &item : names) {
TypePtr type_value = item.second.cast<TypePtr>();
auto type_value = item.second.cast<TypePtr>();
MS_EXCEPTION_IF_NULL(type_value);
MS_LOG(DEBUG) << "(Name: " << py::cast<std::string>(item.first) << ", type: " << type_value->ToString() << ")";
attributes.push_back(std::make_pair(py::cast<std::string>(item.first), type_value));
@ -508,8 +508,8 @@ ClassPtr ParseDataClass(const py::object &cls_obj) {
std::unordered_map<std::string, ValuePtr> methods_map;
py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj);
for (auto &item : methods) {
std::string fun_name = item.first.cast<std::string>();
py::object obj = py::cast<py::object>(item.second);
auto fun_name = item.first.cast<std::string>();
auto obj = py::cast<py::object>(item.second);
std::shared_ptr<PyObjectWrapper> method_obj = std::make_shared<PyObjectWrapper>(obj, fun_name);
methods_map[fun_name] = method_obj;
}

View File

@ -108,8 +108,12 @@ class Tensor(Tensor_):
return out
def __bool__(self):
out = tensor_operator_registry.get('__bool__')(self)
return out
data = self.asnumpy()
if data.shape == ():
return bool(data)
if data.shape == (1,):
return bool(data[0])
raise ValueError("The truth value of an array with several elements is ambiguous.")
def __pos__(self):
return self

View File

@ -35,7 +35,6 @@ def test_dtype_and_shape_as_attr():
dtype = x.dtype
return shape, dtype
net = Net()
x = Tensor(np.ones([1, 2, 3], np.int32))
ret = net(x)
@ -55,7 +54,6 @@ def test_dtype_and_shape_as_attr_to_new_tensor():
y = self.fill(dtype, shape, self.value)
return y
net = Net(2.2)
x = Tensor(np.ones([1, 2, 3], np.float32))
ret = net(x)
@ -71,7 +69,6 @@ def test_type_not_have_the_attr():
shape = x.shapes
return shape
net = Net()
x = Tensor(np.ones([1, 2, 3], np.int32))
with pytest.raises(RuntimeError) as ex:
@ -88,7 +85,6 @@ def test_type_not_have_the_method():
shape = x.dtypes()
return shape
net = Net()
x = Tensor(np.ones([1, 2, 3], np.int32))
with pytest.raises(RuntimeError) as ex: