forked from mindspore-Ecosystem/mindspore
move master code to r1.10
This commit is contained in:
parent
79b3d3c8b0
commit
c5023912db
|
@ -1103,8 +1103,8 @@ bool IrExportBuilder::SetTypeToAttributeProto_irs(const ValuePtr &value, mind_ir
|
|||
} else if (value->isa<UInt>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
|
||||
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
auto uint_value = value->cast<FloatPtr>();
|
||||
auto data_type = GetMindirDataBitsFloatType(uint_value->nbits());
|
||||
auto uint_value = value->cast<UIntPtr>();
|
||||
auto data_type = GetMindirDataBitsUIntType(uint_value->nbits());
|
||||
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -1849,7 +1849,11 @@ def bool_func(*data):
|
|||
return data != 0
|
||||
const_utils.raise_value_error("The truth value of an array with several elements is ambiguous.")
|
||||
if not F.isconstant(data):
|
||||
return len(data) != 0
|
||||
if hasattr(data, "__bool__"):
|
||||
return data.__bool__()
|
||||
if hasattr(data, "__len__"):
|
||||
return len(data) != 0
|
||||
return True
|
||||
return cast_to_bool(data)
|
||||
|
||||
|
||||
|
|
|
@ -179,6 +179,9 @@ class Cell(Cell_):
|
|||
self.__dict__ = dict_
|
||||
self._attr_synced = False
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def _cell_tag(self):
|
||||
# `<class 'xxxxxxx'>` to `xxxxxxx`
|
||||
|
|
|
@ -233,6 +233,9 @@ class SequentialCell(Cell):
|
|||
self._cells = temp_dict
|
||||
self.cell_list = list(self._cells.values())
|
||||
|
||||
def __bool__(self):
|
||||
return len(self._cells) != 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self._cells)
|
||||
|
||||
|
@ -355,6 +358,9 @@ class CellList(_CellListBase, Cell):
|
|||
temp_dict[str(idx)] = cell
|
||||
self._cells = temp_dict
|
||||
|
||||
def __bool__(self):
|
||||
return len(self._cells) != 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self._cells)
|
||||
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
""" test graph fallback buildin python function bool"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import ms_function, context, Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
|
@ -185,3 +185,38 @@ def test_fallback_bool_with_type_input():
|
|||
return bool(int)
|
||||
|
||||
assert foo()
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self):
|
||||
return 1
|
||||
|
||||
|
||||
def test_bool_for_cell_object():
|
||||
"""
|
||||
Feature: Bool function.
|
||||
Description: Test bool() for cell object input
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
net = Net()
|
||||
return bool(net)
|
||||
|
||||
assert foo()
|
||||
|
||||
|
||||
def test_bool_for_cell_object_2():
|
||||
"""
|
||||
Feature: Bool function.
|
||||
Description: Test bool() for cell object input
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
net = Net()
|
||||
if net:
|
||||
return True
|
||||
return False
|
||||
|
||||
assert foo()
|
||||
|
|
Loading…
Reference in New Issue