forked from mindspore-Ecosystem/mindspore
!4000 improve interface '__bool__' for tensor
Merge pull request !4000 from zhangbuxue/improve_bool_for_tensor
This commit is contained in:
commit
8dec74908a
|
@ -17,7 +17,6 @@
|
|||
"""standard_method"""
|
||||
from dataclasses import dataclass
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||
from ...ops import functional as F
|
||||
from ...ops import operations as P
|
||||
from ...ops.primitive import constexpr
|
||||
|
@ -206,13 +205,11 @@ def const_tensor_to_bool(x):
|
|||
if x is None:
|
||||
raise ValueError("Only constant tensor bool can be converted to bool")
|
||||
x = x.asnumpy()
|
||||
if x.shape not in ((), (1,)):
|
||||
raise ValueError("The truth value of an array with several elements is ambiguous.")
|
||||
if x.shape == ():
|
||||
value = bool(x)
|
||||
else:
|
||||
value = bool(x[0])
|
||||
return value
|
||||
return bool(x)
|
||||
if x.shape == (1,):
|
||||
return bool(x[0])
|
||||
raise ValueError("The truth value of an array with several elements is ambiguous.")
|
||||
|
||||
|
||||
def tensor_bool(x):
|
||||
|
@ -349,6 +346,3 @@ def list_append(self_, item):
|
|||
def to_array(x):
|
||||
"""Implementation of `to_array`."""
|
||||
return x.__ms_to_array__()
|
||||
|
||||
|
||||
tensor_operator_registry.register('__bool__', tensor_bool)
|
||||
|
|
|
@ -73,7 +73,7 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
|
|||
namespace {
|
||||
bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) {
|
||||
MS_LOG(DEBUG) << "Converting python tuple";
|
||||
py::tuple tuple = obj.cast<py::tuple>();
|
||||
auto tuple = obj.cast<py::tuple>();
|
||||
std::vector<ValuePtr> value_list;
|
||||
for (size_t it = 0; it < tuple.size(); ++it) {
|
||||
ValuePtr out = nullptr;
|
||||
|
@ -91,7 +91,7 @@ bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signatur
|
|||
bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) {
|
||||
MS_LOG(DEBUG) << "Converting python list";
|
||||
|
||||
py::list list = obj.cast<py::list>();
|
||||
auto list = obj.cast<py::list>();
|
||||
std::vector<ValuePtr> value_list;
|
||||
for (size_t it = 0; it < list.size(); ++it) {
|
||||
ValuePtr out = nullptr;
|
||||
|
@ -124,7 +124,7 @@ bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signa
|
|||
bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) {
|
||||
MS_LOG(DEBUG) << "Converting python dict";
|
||||
|
||||
py::dict dict_values = obj.cast<py::dict>();
|
||||
auto dict_values = obj.cast<py::dict>();
|
||||
std::vector<std::pair<std::string, ValuePtr>> key_values;
|
||||
for (auto item : dict_values) {
|
||||
if (!py::isinstance<py::str>(item.first)) {
|
||||
|
@ -208,7 +208,7 @@ bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_
|
|||
bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
|
||||
MS_LOG(DEBUG) << "Converting slice object";
|
||||
|
||||
py::slice slice_obj = obj.cast<py::slice>();
|
||||
auto slice_obj = obj.cast<py::slice>();
|
||||
auto convert_func = [obj](std::string attr) -> ValuePtr {
|
||||
auto py_attr = py::getattr(obj, attr.c_str());
|
||||
if (py::isinstance<py::none>(py_attr)) {
|
||||
|
@ -335,7 +335,7 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
|
|||
} else if (py::isinstance<MetaTensor>(obj)) {
|
||||
converted = obj.cast<MetaTensorPtr>();
|
||||
} else if (py::isinstance<EnvInstance>(obj)) {
|
||||
std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>();
|
||||
auto env = obj.cast<std::shared_ptr<EnvInstance>>();
|
||||
converted = env;
|
||||
} else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) {
|
||||
converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
|
||||
|
@ -374,7 +374,7 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
|
|||
|
||||
data_converter::MakeProperNameToFuncGraph(func_graph, obj_id);
|
||||
data_converter::CacheObjectValue(obj_id, func_graph);
|
||||
if (obj_key != "") {
|
||||
if (!obj_key.empty()) {
|
||||
MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString();
|
||||
data_converter::SetObjGraphValue(obj_key, func_graph);
|
||||
}
|
||||
|
@ -440,7 +440,7 @@ bool IsCellInstance(const py::object &obj) {
|
|||
py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms) {
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
py::object obj;
|
||||
if (params.size() == 0) {
|
||||
if (params.empty()) {
|
||||
obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type);
|
||||
} else {
|
||||
obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params);
|
||||
|
@ -499,7 +499,7 @@ ClassPtr ParseDataClass(const py::object &cls_obj) {
|
|||
ClassAttrVector attributes;
|
||||
py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj);
|
||||
for (auto &item : names) {
|
||||
TypePtr type_value = item.second.cast<TypePtr>();
|
||||
auto type_value = item.second.cast<TypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(type_value);
|
||||
MS_LOG(DEBUG) << "(Name: " << py::cast<std::string>(item.first) << ", type: " << type_value->ToString() << ")";
|
||||
attributes.push_back(std::make_pair(py::cast<std::string>(item.first), type_value));
|
||||
|
@ -508,8 +508,8 @@ ClassPtr ParseDataClass(const py::object &cls_obj) {
|
|||
std::unordered_map<std::string, ValuePtr> methods_map;
|
||||
py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj);
|
||||
for (auto &item : methods) {
|
||||
std::string fun_name = item.first.cast<std::string>();
|
||||
py::object obj = py::cast<py::object>(item.second);
|
||||
auto fun_name = item.first.cast<std::string>();
|
||||
auto obj = py::cast<py::object>(item.second);
|
||||
std::shared_ptr<PyObjectWrapper> method_obj = std::make_shared<PyObjectWrapper>(obj, fun_name);
|
||||
methods_map[fun_name] = method_obj;
|
||||
}
|
||||
|
|
|
@ -108,8 +108,12 @@ class Tensor(Tensor_):
|
|||
return out
|
||||
|
||||
def __bool__(self):
|
||||
out = tensor_operator_registry.get('__bool__')(self)
|
||||
return out
|
||||
data = self.asnumpy()
|
||||
if data.shape == ():
|
||||
return bool(data)
|
||||
if data.shape == (1,):
|
||||
return bool(data[0])
|
||||
raise ValueError("The truth value of an array with several elements is ambiguous.")
|
||||
|
||||
def __pos__(self):
|
||||
return self
|
||||
|
|
|
@ -35,7 +35,6 @@ def test_dtype_and_shape_as_attr():
|
|||
dtype = x.dtype
|
||||
return shape, dtype
|
||||
|
||||
|
||||
net = Net()
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
ret = net(x)
|
||||
|
@ -55,7 +54,6 @@ def test_dtype_and_shape_as_attr_to_new_tensor():
|
|||
y = self.fill(dtype, shape, self.value)
|
||||
return y
|
||||
|
||||
|
||||
net = Net(2.2)
|
||||
x = Tensor(np.ones([1, 2, 3], np.float32))
|
||||
ret = net(x)
|
||||
|
@ -71,7 +69,6 @@ def test_type_not_have_the_attr():
|
|||
shape = x.shapes
|
||||
return shape
|
||||
|
||||
|
||||
net = Net()
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
|
@ -88,7 +85,6 @@ def test_type_not_have_the_method():
|
|||
shape = x.dtypes()
|
||||
return shape
|
||||
|
||||
|
||||
net = Net()
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
|
|
Loading…
Reference in New Issue