Link control arrow for first input of updatestate.

This commit is contained in:
gaoyong10 2022-09-26 17:28:15 +08:00
parent f58355e67f
commit 9b6d9ce967
1 changed files with 28 additions and 5 deletions

View File

@ -1535,6 +1535,24 @@ void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor,
}
namespace {
void FetchRealInputByNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *inputs) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(inputs);
if (!node->isa<CNode>()) {
return;
}
if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
const auto &cnode = node->cast<CNodePtr>();
for (const auto &input : cnode->inputs()) {
MS_EXCEPTION_IF_NULL(input);
FetchRealInputByNode(input, inputs);
}
} else {
(void)inputs->emplace_back(node);
}
}
void GetRealDependInputByUpdateState(const AnfNodePtr &update_state, std::vector<AnfNodePtr> *real_depend_inputs) {
MS_EXCEPTION_IF_NULL(update_state);
MS_EXCEPTION_IF_NULL(real_depend_inputs);
@ -1549,18 +1567,23 @@ void GetRealDependInputByUpdateState(const AnfNodePtr &update_state, std::vector
MS_EXCEPTION_IF_NULL(u_input);
bool is_u_input_valid = true;
std::vector<AnfNodePtr> real_inputs;
for (size_t i = kUpdateStateRealInput; i < inputs.size(); ++i) {
MS_EXCEPTION_IF_NULL(inputs[i]);
(void)real_depend_inputs->emplace_back(inputs[i]);
FetchRealInputByNode(inputs[i], &real_inputs);
}
for (size_t i = 0; i < real_inputs.size(); ++i) {
const auto &input = real_inputs[i];
MS_EXCEPTION_IF_NULL(input);
(void)real_depend_inputs->emplace_back(input);
// Check the u input of update state.
if (inputs[i]->isa<CNode>()) {
const auto &input_cnode = inputs[i]->cast<CNodePtr>();
if (input->isa<CNode>()) {
const auto &input_cnode = input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
if (std::find(input_cnode->inputs().begin(), input_cnode->inputs().end(), u_input) !=
input_cnode->inputs().end()) {
MS_LOG(DEBUG) << "U input node:" << u_input->DebugString() << " of update state:" << update_state->DebugString()
<< " is input of update state input node:" << inputs[i]->DebugString();
<< " is input of update state input node:" << input->DebugString();
is_u_input_valid = false;
}
}