forked from mindspore-Ecosystem/mindspore
refactor bool op parsing to be consistent with pynative mode
add testcase of st
This commit is contained in:
parent
c7b7af6c3a
commit
f30418991c
|
@ -737,8 +737,7 @@ AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object
|
|||
return block->func_graph()->NewCNode({op_node, left_node, right_node});
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list,
|
||||
const py::object &op) {
|
||||
AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) {
|
||||
// if there is only one bool op now
|
||||
if (value_list.size() == 1) {
|
||||
AnfNodePtr first_node = ParseExprNode(block, value_list[0]);
|
||||
|
@ -749,11 +748,41 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p
|
|||
for (size_t i = 1; i < value_list.size(); i++) {
|
||||
rest.append(value_list[i]);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
TraceManager::DebugTrace(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
|
||||
FunctionBlockPtr true_block = MakeFunctionBlock(*this);
|
||||
TraceManager::EndTrace();
|
||||
TraceManager::DebugTrace(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
|
||||
FunctionBlockPtr false_block = MakeFunctionBlock(*this);
|
||||
TraceManager::EndTrace();
|
||||
MakeConditionBlocks(block, true_block, false_block);
|
||||
FunctionBlockPtr b1, b2;
|
||||
|
||||
AnfNodePtr first_node = ParseExprNode(block, first);
|
||||
AnfNodePtr rest_node = ProcessBoolOpValueList(block, rest, op);
|
||||
auto op_node = block->MakeResolveAstOp(op);
|
||||
return block->func_graph()->NewCNode({op_node, first_node, rest_node});
|
||||
// if it is and, we need to process the rest nodes;
|
||||
// if it is or, we continue to next
|
||||
if (mode == AST_SUB_TYPE_AND) {
|
||||
b1 = true_block;
|
||||
b2 = false_block;
|
||||
} else if (mode == AST_SUB_TYPE_OR) {
|
||||
b2 = true_block;
|
||||
b1 = false_block;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Not supported mode: " << mode;
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodePtr test_node = ParseExprNode(block, first);
|
||||
AnfNodePtr rest_node = ProcessBoolOpValueList(b1, rest, mode);
|
||||
b1->func_graph()->set_output(rest_node);
|
||||
b2->func_graph()->set_output(test_node);
|
||||
|
||||
auto cond_node = block->ForceToBoolNode(test_node);
|
||||
auto switch_app =
|
||||
block->func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), cond_node, NewValueNode(true_block->func_graph()),
|
||||
NewValueNode(false_block->func_graph())});
|
||||
|
||||
std::vector<AnfNodePtr> call_graph_nodes{switch_app};
|
||||
auto switch_app_call = block->func_graph()->NewCNode(call_graph_nodes);
|
||||
return switch_app_call;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -761,8 +790,13 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p
|
|||
AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast BoolOp";
|
||||
py::object op_node = python_adapter::GetPyObjAttr(node, "op");
|
||||
AstSubType op_type = ast_->GetOpType(op_node);
|
||||
if (op_type == AST_SUB_TYPE_UNKNOWN) {
|
||||
MS_LOG(WARNING) << "ProcessBoolOp, got unkown op type";
|
||||
return nullptr;
|
||||
}
|
||||
py::list op_values = python_adapter::GetPyObjAttr(node, "values");
|
||||
return ProcessBoolOpValueList(block, op_values, op_node);
|
||||
return ProcessBoolOpValueList(block, op_values, op_type);
|
||||
}
|
||||
|
||||
// Process a function def
|
||||
|
|
|
@ -206,7 +206,7 @@ class Parser {
|
|||
void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
|
||||
|
||||
// process a bool operation value list
|
||||
AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, const py::object &op);
|
||||
AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode);
|
||||
|
||||
CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node,
|
||||
const AnfNodePtr &op_iter);
|
||||
|
|
|
@ -45,7 +45,9 @@ def _logical_not_tensor(x):
|
|||
Returns:
|
||||
Tensor, Return logical not operation result of x.
|
||||
"""
|
||||
return F.logical_not(x)
|
||||
if F.isconstant(x):
|
||||
return F.bool_not(x.__bool__())
|
||||
return F.logical_not(x.__bool__())
|
||||
|
||||
|
||||
@logical_not.register("Tuple")
|
||||
|
|
|
@ -61,8 +61,7 @@ class ControlSimpleIfWithAssign(nn.Cell):
|
|||
|
||||
|
||||
class ControlIfinIf(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
"""pass"""
|
||||
|
||||
def construct(self, x, y):
|
||||
if x > y:
|
||||
|
@ -151,6 +150,40 @@ class ControlMixedWhileIf(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
class AndOperation(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reduce_sum = op.ReduceSum()
|
||||
|
||||
def construct(self, x, y):
|
||||
x_sum = self.reduce_sum(x)
|
||||
y_sum = self.reduce_sum(y)
|
||||
out = x_sum and y_sum
|
||||
return out
|
||||
|
||||
|
||||
class OrOperation(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reduce_sum = op.ReduceSum()
|
||||
|
||||
def construct(self, x, y):
|
||||
x_sum = self.reduce_sum(x)
|
||||
y_sum = self.reduce_sum(y)
|
||||
out = x_sum or y_sum
|
||||
return out
|
||||
|
||||
|
||||
class NotOperation(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reduce_sum = op.ReduceSum()
|
||||
|
||||
def construct(self, x):
|
||||
x_sum = self.reduce_sum(x)
|
||||
return not x_sum
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -248,3 +281,27 @@ def test_mixed_while_if():
|
|||
output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4)
|
||||
expect = np.array(3318).astype(np.int32)
|
||||
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_and_or_operation():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
x = np.array([0, 1]).astype(np.float32)
|
||||
y = np.array([0, 0]).astype(np.float32)
|
||||
net = AndOperation()
|
||||
output = net(Tensor(x), Tensor(y))
|
||||
expect = np.sum(x) and np.sum(y)
|
||||
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
net = OrOperation()
|
||||
output = net(Tensor(x), Tensor(y))
|
||||
expect = np.sum(x) or np.sum(y)
|
||||
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
net = NotOperation()
|
||||
output = net(Tensor(x))
|
||||
expect = not np.sum(x)
|
||||
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
|
||||
|
|
|
@ -103,15 +103,15 @@ class LogicalTensorOpsNet(nn.Cell):
|
|||
self.const_true = Tensor(True, dtype=mstype.bool_)
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = x and y and (y or self.const_true) and (not self.const_true)
|
||||
ret = x and y and (y or self.const_true) and (not y)
|
||||
return ret
|
||||
|
||||
|
||||
test_case_ops = [
|
||||
('CompareOpsNet', {
|
||||
'block': ComparisonOpsNet(),
|
||||
'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32),
|
||||
Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}),
|
||||
'desc_inputs': [Tensor(1.0, dtype=mstype.float32),
|
||||
Tensor(1.0, dtype=mstype.float32)]}),
|
||||
('MathOpsNet', {
|
||||
'block': MathOpsNet(),
|
||||
'desc_inputs': [Tensor(np.ones([6, 9, 10]), dtype=mstype.float32),
|
||||
|
@ -126,8 +126,8 @@ test_case_ops = [
|
|||
Tensor(np.zeros([6, 9, 10]), dtype=mstype.float32)]}),
|
||||
('LogicalTensorOps', {
|
||||
'block': LogicalTensorOpsNet(),
|
||||
'desc_inputs': [Tensor(np.ones([6, 9, 10]).astype(np.bool_), dtype=mstype.bool_),
|
||||
Tensor(np.zeros([6, 9, 10]).astype(np.bool_), dtype=mstype.bool_)]}),
|
||||
'desc_inputs': [Tensor(True, dtype=mstype.bool_),
|
||||
Tensor(False, dtype=mstype.bool_)]}),
|
||||
]
|
||||
|
||||
test_case_lists = [test_case_ops]
|
||||
|
|
|
@ -41,10 +41,12 @@ def vm_impl_tensor_add(self):
|
|||
# pylint: disable=used-before-assignment
|
||||
@vm_impl_getters.register(P.LogicalNot)
|
||||
def vm_impl_logical_not(self):
|
||||
x = x.asnumpy()
|
||||
out = vm.logical_not(x)
|
||||
return Tensor(out)
|
||||
def vm_impl(x):
|
||||
x = x.asnumpy()
|
||||
out = vm.logical_not(x)
|
||||
return Tensor(out)
|
||||
|
||||
return vm_impl
|
||||
|
||||
@vm_impl_getters.register(P.MatMul)
|
||||
def vm_impl_mat_mul(self):
|
||||
|
|
Loading…
Reference in New Issue