move master code to r1.10

This commit is contained in:
liangzhibo 2023-03-03 09:26:43 +08:00
parent 79b3d3c8b0
commit c5023912db
5 changed files with 52 additions and 4 deletions

View File

@ -1103,8 +1103,8 @@ bool IrExportBuilder::SetTypeToAttributeProto_irs(const ValuePtr &value, mind_ir
} else if (value->isa<UInt>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
auto uint_value = value->cast<FloatPtr>();
auto data_type = GetMindirDataBitsFloatType(uint_value->nbits());
auto uint_value = value->cast<UIntPtr>();
auto data_type = GetMindirDataBitsUIntType(uint_value->nbits());
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
return false;
}

View File

@ -1849,7 +1849,11 @@ def bool_func(*data):
return data != 0
const_utils.raise_value_error("The truth value of an array with several elements is ambiguous.")
if not F.isconstant(data):
return len(data) != 0
if hasattr(data, "__bool__"):
return data.__bool__()
if hasattr(data, "__len__"):
return len(data) != 0
return True
return cast_to_bool(data)

View File

@ -179,6 +179,9 @@ class Cell(Cell_):
self.__dict__ = dict_
self._attr_synced = False
def __bool__(self):
return True
@property
def _cell_tag(self):
# `<class 'xxxxxxx'>` to `xxxxxxx`

View File

@ -233,6 +233,9 @@ class SequentialCell(Cell):
self._cells = temp_dict
self.cell_list = list(self._cells.values())
def __bool__(self):
return len(self._cells) != 0
def __len__(self):
return len(self._cells)
@ -355,6 +358,9 @@ class CellList(_CellListBase, Cell):
temp_dict[str(idx)] = cell
self._cells = temp_dict
def __bool__(self):
return len(self._cells) != 0
def __len__(self):
return len(self._cells)

View File

@ -15,8 +15,8 @@
""" test graph fallback buildin python function bool"""
import pytest
import numpy as np
import mindspore.nn as nn
from mindspore import ms_function, context, Tensor
context.set_context(mode=context.GRAPH_MODE)
@ -185,3 +185,38 @@ def test_fallback_bool_with_type_input():
return bool(int)
assert foo()
class Net(nn.Cell):
def construct(self):
return 1
def test_bool_for_cell_object():
"""
Feature: Bool function.
Description: Test bool() for cell object input
Expectation: No exception.
"""
@ms_function
def foo():
net = Net()
return bool(net)
assert foo()
def test_bool_for_cell_object_2():
"""
Feature: Bool function.
Description: Test bool() for cell object input
Expectation: No exception.
"""
@ms_function
def foo():
net = Net()
if net:
return True
return False
assert foo()