forked from mindspore-Ecosystem/mindspore
!47525 Enable bool to use cell object as input
Merge pull request !47525 from LiangZhibo/bool
This commit is contained in:
commit
ff541f82d0
|
@ -2302,7 +2302,11 @@ def bool_func(*data):
|
|||
return data != 0
|
||||
const_utils.raise_value_error("The truth value of an array with more than one element is ambiguous.")
|
||||
if not F.isconstant(data):
|
||||
if hasattr(data, "__bool__"):
|
||||
return data.__bool__()
|
||||
if hasattr(data, "__len__"):
|
||||
return len(data) != 0
|
||||
return True
|
||||
return cast_to_bool(data)
|
||||
|
||||
|
||||
|
|
|
@ -151,6 +151,9 @@ class Cell(Cell_):
|
|||
self.__dict__ = dict_
|
||||
self._attr_synced = False
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def _cell_tag(self):
|
||||
# `<class 'xxxxxxx'>` to `xxxxxxx`
|
||||
|
|
|
@ -233,6 +233,9 @@ class SequentialCell(Cell):
|
|||
self._cells = temp_dict
|
||||
self.cell_list = list(self._cells.values())
|
||||
|
||||
def __bool__(self):
|
||||
return len(self._cells) != 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self._cells)
|
||||
|
||||
|
@ -390,6 +393,9 @@ class CellList(_CellListBase, Cell):
|
|||
temp_dict[str(idx)] = cell
|
||||
self._cells = temp_dict
|
||||
|
||||
def __bool__(self):
|
||||
return len(self._cells) != 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self._cells)
|
||||
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
""" test graph fallback buildin python function bool"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import jit, context, Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
|
@ -199,3 +199,41 @@ def test_fallback_bool_with_input_tensor2():
|
|||
return bool(x)
|
||||
|
||||
assert foo()
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
return 1
|
||||
|
||||
|
||||
def test_bool_for_cell_object():
|
||||
"""
|
||||
Feature: Bool function.
|
||||
Description: Test bool() for cell object input
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
net = Net()
|
||||
return bool(net)
|
||||
|
||||
assert foo()
|
||||
|
||||
|
||||
def test_bool_for_cell_object_2():
|
||||
"""
|
||||
Feature: Bool function.
|
||||
Description: Test bool() for cell object input
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
net = Net()
|
||||
if net:
|
||||
return True
|
||||
return False
|
||||
|
||||
assert foo()
|
||||
|
|
Loading…
Reference in New Issue