forked from mindspore-Ecosystem/mindspore
Enable cell to call bool in graph
This commit is contained in:
parent
bd2141ce00
commit
b1a4e60359
|
@ -2302,7 +2302,11 @@ def bool_func(*data):
|
||||||
return data != 0
|
return data != 0
|
||||||
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.")
|
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.")
|
||||||
if not F.isconstant(data):
|
if not F.isconstant(data):
|
||||||
return len(data) != 0
|
if hasattr(data, "__bool__"):
|
||||||
|
return data.__bool__()
|
||||||
|
if hasattr(data, "__len__"):
|
||||||
|
return len(data) != 0
|
||||||
|
return True
|
||||||
return cast_to_bool(data)
|
return cast_to_bool(data)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -151,6 +151,9 @@ class Cell(Cell_):
|
||||||
self.__dict__ = dict_
|
self.__dict__ = dict_
|
||||||
self._attr_synced = False
|
self._attr_synced = False
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _cell_tag(self):
|
def _cell_tag(self):
|
||||||
# `<class 'xxxxxxx'>` to `xxxxxxx`
|
# `<class 'xxxxxxx'>` to `xxxxxxx`
|
||||||
|
|
|
@ -233,6 +233,9 @@ class SequentialCell(Cell):
|
||||||
self._cells = temp_dict
|
self._cells = temp_dict
|
||||||
self.cell_list = list(self._cells.values())
|
self.cell_list = list(self._cells.values())
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return len(self._cells) != 0
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._cells)
|
return len(self._cells)
|
||||||
|
|
||||||
|
@ -390,6 +393,9 @@ class CellList(_CellListBase, Cell):
|
||||||
temp_dict[str(idx)] = cell
|
temp_dict[str(idx)] = cell
|
||||||
self._cells = temp_dict
|
self._cells = temp_dict
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return len(self._cells) != 0
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._cells)
|
return len(self._cells)
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
""" test graph fallback buildin python function bool"""
|
""" test graph fallback buildin python function bool"""
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import mindspore.nn as nn
|
||||||
from mindspore import jit, context, Tensor
|
from mindspore import jit, context, Tensor
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
@ -199,3 +199,41 @@ def test_fallback_bool_with_input_tensor2():
|
||||||
return bool(x)
|
return bool(x)
|
||||||
|
|
||||||
assert foo()
|
assert foo()
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_bool_for_cell_object():
|
||||||
|
"""
|
||||||
|
Feature: Bool function.
|
||||||
|
Description: Test bool() for cell object input
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
@jit
|
||||||
|
def foo():
|
||||||
|
net = Net()
|
||||||
|
return bool(net)
|
||||||
|
|
||||||
|
assert foo()
|
||||||
|
|
||||||
|
|
||||||
|
def test_bool_for_cell_object_2():
|
||||||
|
"""
|
||||||
|
Feature: Bool function.
|
||||||
|
Description: Test bool() for cell object input
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
@jit
|
||||||
|
def foo():
|
||||||
|
net = Net()
|
||||||
|
if net:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
assert foo()
|
||||||
|
|
Loading…
Reference in New Issue