Enable cell to call bool in graph

This commit is contained in:
liangzhibo 2023-01-05 09:49:24 +08:00
parent bd2141ce00
commit b1a4e60359
4 changed files with 53 additions and 2 deletions

View File

@ -2302,7 +2302,11 @@ def bool_func(*data):
return data != 0 return data != 0
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.") const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.")
if not F.isconstant(data): if not F.isconstant(data):
return len(data) != 0 if hasattr(data, "__bool__"):
return data.__bool__()
if hasattr(data, "__len__"):
return len(data) != 0
return True
return cast_to_bool(data) return cast_to_bool(data)

View File

@ -151,6 +151,9 @@ class Cell(Cell_):
self.__dict__ = dict_ self.__dict__ = dict_
self._attr_synced = False self._attr_synced = False
def __bool__(self):
return True
@property @property
def _cell_tag(self): def _cell_tag(self):
# `<class 'xxxxxxx'>` to `xxxxxxx` # `<class 'xxxxxxx'>` to `xxxxxxx`

View File

@ -233,6 +233,9 @@ class SequentialCell(Cell):
self._cells = temp_dict self._cells = temp_dict
self.cell_list = list(self._cells.values()) self.cell_list = list(self._cells.values())
def __bool__(self):
return len(self._cells) != 0
def __len__(self): def __len__(self):
return len(self._cells) return len(self._cells)
@ -390,6 +393,9 @@ class CellList(_CellListBase, Cell):
temp_dict[str(idx)] = cell temp_dict[str(idx)] = cell
self._cells = temp_dict self._cells = temp_dict
def __bool__(self):
return len(self._cells) != 0
def __len__(self): def __len__(self):
return len(self._cells) return len(self._cells)

View File

@ -15,8 +15,8 @@
""" test graph fallback buildin python function bool""" """ test graph fallback buildin python function bool"""
import pytest import pytest
import numpy as np import numpy as np
import mindspore.nn as nn
from mindspore import jit, context, Tensor from mindspore import jit, context, Tensor
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
@ -199,3 +199,41 @@ def test_fallback_bool_with_input_tensor2():
return bool(x) return bool(x)
assert foo() assert foo()
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self):
return 1
def test_bool_for_cell_object():
"""
Feature: Bool function.
Description: Test bool() for cell object input
Expectation: No exception.
"""
@jit
def foo():
net = Net()
return bool(net)
assert foo()
def test_bool_for_cell_object_2():
"""
Feature: Bool function.
Description: Test bool() for cell object input
Expectation: No exception.
"""
@jit
def foo():
net = Net()
if net:
return True
return False
assert foo()