diff --git a/mindspore/python/mindspore/_extends/parse/standard_method.py b/mindspore/python/mindspore/_extends/parse/standard_method.py index 4668c2a18f4..b750ef034e4 100644 --- a/mindspore/python/mindspore/_extends/parse/standard_method.py +++ b/mindspore/python/mindspore/_extends/parse/standard_method.py @@ -2302,7 +2302,11 @@ def bool_func(*data): return data != 0 const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.") if not F.isconstant(data): - return len(data) != 0 + if hasattr(data, "__bool__"): + return data.__bool__() + if hasattr(data, "__len__"): + return len(data) != 0 + return True return cast_to_bool(data) diff --git a/mindspore/python/mindspore/nn/cell.py b/mindspore/python/mindspore/nn/cell.py index d72617b6c62..4897c74d35e 100755 --- a/mindspore/python/mindspore/nn/cell.py +++ b/mindspore/python/mindspore/nn/cell.py @@ -151,6 +151,9 @@ class Cell(Cell_): self.__dict__ = dict_ self._attr_synced = False + def __bool__(self): + return True + @property def _cell_tag(self): # `` to `xxxxxxx` diff --git a/mindspore/python/mindspore/nn/layer/container.py b/mindspore/python/mindspore/nn/layer/container.py index bc8c96aedc2..2073382cb99 100644 --- a/mindspore/python/mindspore/nn/layer/container.py +++ b/mindspore/python/mindspore/nn/layer/container.py @@ -233,6 +233,9 @@ class SequentialCell(Cell): self._cells = temp_dict self.cell_list = list(self._cells.values()) + def __bool__(self): + return len(self._cells) != 0 + def __len__(self): return len(self._cells) @@ -390,6 +393,9 @@ class CellList(_CellListBase, Cell): temp_dict[str(idx)] = cell self._cells = temp_dict + def __bool__(self): + return len(self._cells) != 0 + def __len__(self): return len(self._cells) diff --git a/tests/ut/python/graph_syntax/python_builtin_functions/test_bool.py b/tests/ut/python/graph_syntax/python_builtin_functions/test_bool.py index cc0e8cff9ef..61ca74999e9 100644 --- a/tests/ut/python/graph_syntax/python_builtin_functions/test_bool.py +++ b/tests/ut/python/graph_syntax/python_builtin_functions/test_bool.py @@ -15,8 +15,8 @@ """ test graph fallback buildin python function bool""" import pytest import numpy as np +import mindspore.nn as nn from mindspore import jit, context, Tensor - context.set_context(mode=context.GRAPH_MODE) @@ -199,3 +199,41 @@ def test_fallback_bool_with_input_tensor2(): return bool(x) assert foo() + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self): + return 1 + + +def test_bool_for_cell_object(): + """ + Feature: Bool function. + Description: Test bool() for cell object input + Expectation: No exception. + """ + @jit + def foo(): + net = Net() + return bool(net) + + assert foo() + + +def test_bool_for_cell_object_2(): + """ + Feature: Bool function. + Description: Test bool() for cell object input + Expectation: No exception. + """ + @jit + def foo(): + net = Net() + if net: + return True + return False + + assert foo()