forked from mindspore-Ecosystem/mindspore
Link control arrow for first input of updatestate.
This commit is contained in:
parent
f58355e67f
commit
9b6d9ce967
|
@ -1535,6 +1535,24 @@ void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor,
|
|||
}
|
||||
|
||||
namespace {
|
||||
void FetchRealInputByNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *inputs) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(inputs);
|
||||
if (!node->isa<CNode>()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
for (const auto &input : cnode->inputs()) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
FetchRealInputByNode(input, inputs);
|
||||
}
|
||||
} else {
|
||||
(void)inputs->emplace_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
void GetRealDependInputByUpdateState(const AnfNodePtr &update_state, std::vector<AnfNodePtr> *real_depend_inputs) {
|
||||
MS_EXCEPTION_IF_NULL(update_state);
|
||||
MS_EXCEPTION_IF_NULL(real_depend_inputs);
|
||||
|
@ -1549,18 +1567,23 @@ void GetRealDependInputByUpdateState(const AnfNodePtr &update_state, std::vector
|
|||
MS_EXCEPTION_IF_NULL(u_input);
|
||||
|
||||
bool is_u_input_valid = true;
|
||||
std::vector<AnfNodePtr> real_inputs;
|
||||
for (size_t i = kUpdateStateRealInput; i < inputs.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(inputs[i]);
|
||||
(void)real_depend_inputs->emplace_back(inputs[i]);
|
||||
FetchRealInputByNode(inputs[i], &real_inputs);
|
||||
}
|
||||
for (size_t i = 0; i < real_inputs.size(); ++i) {
|
||||
const auto &input = real_inputs[i];
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
(void)real_depend_inputs->emplace_back(input);
|
||||
|
||||
// Check the u input of update state.
|
||||
if (inputs[i]->isa<CNode>()) {
|
||||
const auto &input_cnode = inputs[i]->cast<CNodePtr>();
|
||||
if (input->isa<CNode>()) {
|
||||
const auto &input_cnode = input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
if (std::find(input_cnode->inputs().begin(), input_cnode->inputs().end(), u_input) !=
|
||||
input_cnode->inputs().end()) {
|
||||
MS_LOG(DEBUG) << "U input node:" << u_input->DebugString() << " of update state:" << update_state->DebugString()
|
||||
<< " is input of update state input node:" << inputs[i]->DebugString();
|
||||
<< " is input of update state input node:" << input->DebugString();
|
||||
is_u_input_valid = false;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue