!47525 Enable bool to use cell object as input

Merge pull request !47525 from LiangZhibo/bool
This commit is contained in:
i-robot 2023-01-09 08:12:03 +00:00 committed by Gitee
commit ff541f82d0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 53 additions and 2 deletions

View File

@ -2302,7 +2302,11 @@ def bool_func(*data):
return data != 0
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.")
if not F.isconstant(data):
if hasattr(data, "__bool__"):
return data.__bool__()
if hasattr(data, "__len__"):
return len(data) != 0
return True
return cast_to_bool(data)

View File

@ -151,6 +151,9 @@ class Cell(Cell_):
self.__dict__ = dict_
self._attr_synced = False
def __bool__(self):
return True
@property
def _cell_tag(self):
# `<class 'xxxxxxx'>` to `xxxxxxx`

View File

@ -233,6 +233,9 @@ class SequentialCell(Cell):
self._cells = temp_dict
self.cell_list = list(self._cells.values())
def __bool__(self):
return len(self._cells) != 0
def __len__(self):
return len(self._cells)
@ -390,6 +393,9 @@ class CellList(_CellListBase, Cell):
temp_dict[str(idx)] = cell
self._cells = temp_dict
def __bool__(self):
return len(self._cells) != 0
def __len__(self):
return len(self._cells)

View File

@ -15,8 +15,8 @@
""" test graph fallback buildin python function bool"""
import pytest
import numpy as np
import mindspore.nn as nn
from mindspore import jit, context, Tensor
context.set_context(mode=context.GRAPH_MODE)
@ -199,3 +199,41 @@ def test_fallback_bool_with_input_tensor2():
return bool(x)
assert foo()
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self):
return 1
def test_bool_for_cell_object():
"""
Feature: Bool function.
Description: Test bool() for cell object input
Expectation: No exception.
"""
@jit
def foo():
net = Net()
return bool(net)
assert foo()
def test_bool_for_cell_object_2():
"""
Feature: Bool function.
Description: Test bool() for cell object input
Expectation: No exception.
"""
@jit
def foo():
net = Net()
if net:
return True
return False
assert foo()