forked from mindspore-Ecosystem/mindspore
Bypass specializing the FuncGraph which is a input of a Partial CNode and it been specialized in ProcessCNode in FirstPass
This commit is contained in:
parent
377d6108f9
commit
784d4b4315
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -58,7 +58,18 @@ const StringImmPtr kDeadNode = std::make_shared<StringImm>(kDeadNodeName);
|
|||
const StringImmPtr kPolyNode = std::make_shared<StringImm>(kPolyNodeName);
|
||||
|
||||
inline bool CanSpecializeValueNode(const AnfNodePtr &node) {
|
||||
if (IsValueNode<FuncGraph>(node) || IsValueNode<MetaFuncGraph>(node) || IsValueNode<Primitive>(node)) {
|
||||
if (IsValueNode<MetaFuncGraph>(node) || IsValueNode<Primitive>(node)) {
|
||||
return true;
|
||||
}
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
if (node->abstract() != nullptr) {
|
||||
auto abs_func = node->abstract()->cast<FuncGraphAbstractClosurePtr>();
|
||||
// If this funcgraph had specialized in ProcessCNode of FirstPass,
|
||||
// then ignore it.
|
||||
if (abs_func != nullptr && abs_func->specialized()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
@ -410,6 +421,8 @@ void FuncGraphSpecializer::FirstPass() {
|
|||
parent->FirstPass();
|
||||
AnfNodePtr new_node = parent->GetReplicatedNode(node);
|
||||
if (new_node->isa<CNode>()) {
|
||||
MS_LOG(INFO) << "ProcessCNode in FirstPass for " << func_graph_->ToString() << ", node: " << node->DebugString()
|
||||
<< ", new_node: " << new_node->DebugString();
|
||||
parent->ProcessCNode(new_node->cast<CNodePtr>());
|
||||
}
|
||||
continue;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -95,3 +95,35 @@ def test_single_if_01():
|
|||
expect1 = Tensor(26, mstype.int32)
|
||||
expect2 = (Tensor(2, mstype.int32), Tensor(2, mstype.int32))
|
||||
control_flow_single_if(SingleIfNet1, x, y, expect1, expect2)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_single_if_any():
|
||||
"""
|
||||
Feature: compile and run control flow with if statement
|
||||
Description: true-branch func graph refer a CNode in construct as free variable.
|
||||
That CNode and the inputs will be specialized before ProcessCNode
|
||||
of true-branch func graph, so it's no need to specialize the inputs
|
||||
of that CNode again if it's a specialized func graph.
|
||||
Expectation: success
|
||||
"""
|
||||
x = Tensor([True, True, False])
|
||||
y = Tensor([False])
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, input1, input2):
|
||||
super().__init__()
|
||||
self.input1 = input1
|
||||
self.input2 = input2
|
||||
|
||||
def construct(self):
|
||||
if self.input1.all() == self.input2:
|
||||
return self.input1.any()
|
||||
return self.input2
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net(x, y)
|
||||
output = net()
|
||||
assert output
|
||||
|
|
Loading…
Reference in New Issue