refactor bool op parsing to be consistent with pynative mode

add testcase of st
This commit is contained in:
huangdongrun 2020-08-18 20:17:15 +08:00
parent c7b7af6c3a
commit f30418991c
6 changed files with 114 additions and 19 deletions

View File

@ -737,8 +737,7 @@ AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object
return block->func_graph()->NewCNode({op_node, left_node, right_node});
}
AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list,
const py::object &op) {
AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) {
// if there is only one bool op now
if (value_list.size() == 1) {
AnfNodePtr first_node = ParseExprNode(block, value_list[0]);
@ -749,11 +748,41 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p
for (size_t i = 1; i < value_list.size(); i++) {
rest.append(value_list[i]);
}
MS_EXCEPTION_IF_NULL(block);
TraceManager::DebugTrace(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
FunctionBlockPtr true_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
TraceManager::DebugTrace(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
FunctionBlockPtr false_block = MakeFunctionBlock(*this);
TraceManager::EndTrace();
MakeConditionBlocks(block, true_block, false_block);
FunctionBlockPtr b1, b2;
AnfNodePtr first_node = ParseExprNode(block, first);
AnfNodePtr rest_node = ProcessBoolOpValueList(block, rest, op);
auto op_node = block->MakeResolveAstOp(op);
return block->func_graph()->NewCNode({op_node, first_node, rest_node});
// if it is and, we need to process the rest nodes;
// if it is or, we continue to next
if (mode == AST_SUB_TYPE_AND) {
b1 = true_block;
b2 = false_block;
} else if (mode == AST_SUB_TYPE_OR) {
b2 = true_block;
b1 = false_block;
} else {
MS_LOG(ERROR) << "Not supported mode: " << mode;
return nullptr;
}
AnfNodePtr test_node = ParseExprNode(block, first);
AnfNodePtr rest_node = ProcessBoolOpValueList(b1, rest, mode);
b1->func_graph()->set_output(rest_node);
b2->func_graph()->set_output(test_node);
auto cond_node = block->ForceToBoolNode(test_node);
auto switch_app =
block->func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), cond_node, NewValueNode(true_block->func_graph()),
NewValueNode(false_block->func_graph())});
std::vector<AnfNodePtr> call_graph_nodes{switch_app};
auto switch_app_call = block->func_graph()->NewCNode(call_graph_nodes);
return switch_app_call;
}
}
@ -761,8 +790,13 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p
AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast BoolOp";
py::object op_node = python_adapter::GetPyObjAttr(node, "op");
AstSubType op_type = ast_->GetOpType(op_node);
if (op_type == AST_SUB_TYPE_UNKNOWN) {
MS_LOG(WARNING) << "ProcessBoolOp, got unkown op type";
return nullptr;
}
py::list op_values = python_adapter::GetPyObjAttr(node, "values");
return ProcessBoolOpValueList(block, op_values, op_node);
return ProcessBoolOpValueList(block, op_values, op_type);
}
// Process a function def

View File

@ -206,7 +206,7 @@ class Parser {
void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
// process a bool operation value list
AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, const py::object &op);
AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode);
CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node,
const AnfNodePtr &op_iter);

View File

@ -45,7 +45,9 @@ def _logical_not_tensor(x):
Returns:
Tensor, Return logical not operation result of x.
"""
return F.logical_not(x)
if F.isconstant(x):
return F.bool_not(x.__bool__())
return F.logical_not(x.__bool__())
@logical_not.register("Tuple")

View File

@ -61,8 +61,7 @@ class ControlSimpleIfWithAssign(nn.Cell):
class ControlIfinIf(nn.Cell):
def __init__(self):
super().__init__()
"""pass"""
def construct(self, x, y):
if x > y:
@ -151,6 +150,40 @@ class ControlMixedWhileIf(nn.Cell):
return out
class AndOperation(nn.Cell):
def __init__(self):
super().__init__()
self.reduce_sum = op.ReduceSum()
def construct(self, x, y):
x_sum = self.reduce_sum(x)
y_sum = self.reduce_sum(y)
out = x_sum and y_sum
return out
class OrOperation(nn.Cell):
def __init__(self):
super().__init__()
self.reduce_sum = op.ReduceSum()
def construct(self, x, y):
x_sum = self.reduce_sum(x)
y_sum = self.reduce_sum(y)
out = x_sum or y_sum
return out
class NotOperation(nn.Cell):
def __init__(self):
super().__init__()
self.reduce_sum = op.ReduceSum()
def construct(self, x):
x_sum = self.reduce_sum(x)
return not x_sum
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -248,3 +281,27 @@ def test_mixed_while_if():
output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4)
expect = np.array(3318).astype(np.int32)
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_and_or_operation():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
x = np.array([0, 1]).astype(np.float32)
y = np.array([0, 0]).astype(np.float32)
net = AndOperation()
output = net(Tensor(x), Tensor(y))
expect = np.sum(x) and np.sum(y)
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
net = OrOperation()
output = net(Tensor(x), Tensor(y))
expect = np.sum(x) or np.sum(y)
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
net = NotOperation()
output = net(Tensor(x))
expect = not np.sum(x)
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)

View File

@ -103,15 +103,15 @@ class LogicalTensorOpsNet(nn.Cell):
self.const_true = Tensor(True, dtype=mstype.bool_)
def construct(self, x, y):
ret = x and y and (y or self.const_true) and (not self.const_true)
ret = x and y and (y or self.const_true) and (not y)
return ret
test_case_ops = [
('CompareOpsNet', {
'block': ComparisonOpsNet(),
'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32),
Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}),
'desc_inputs': [Tensor(1.0, dtype=mstype.float32),
Tensor(1.0, dtype=mstype.float32)]}),
('MathOpsNet', {
'block': MathOpsNet(),
'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32),
@ -126,8 +126,8 @@ test_case_ops = [
Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}),
('LogicalTensorOps', {
'block': LogicalTensorOpsNet(),
'desc_inputs': [Tensor(np.ones([6, 9, 10]).astype(np.bool_), dtype=mstype.bool_),
Tensor(np.zeros([6, 9, 10]).astype(np.bool_), dtype=mstype.bool_)]}),
'desc_inputs': [Tensor(True, dtype=mstype.bool_),
Tensor(False, dtype=mstype.bool_)]}),
]
test_case_lists = [test_case_ops]

View File

@ -41,10 +41,12 @@ def vm_impl_tensor_add(self):
# pylint: disable=used-before-assignment
@vm_impl_getters.register(P.LogicalNot)
def vm_impl_logical_not(self):
x = x.asnumpy()
out = vm.logical_not(x)
return Tensor(out)
def vm_impl(x):
x = x.asnumpy()
out = vm.logical_not(x)
return Tensor(out)
return vm_impl
@vm_impl_getters.register(P.MatMul)
def vm_impl_mat_mul(self):