!22352 Support 'no_eliminate' primitive

Merge pull request !22352 from hewei/no_eliminate
This commit is contained in:
i-robot 2021-08-26 01:08:02 +00:00 committed by Gitee
commit a030bd00a2
4 changed files with 68 additions and 27 deletions

View File

@ -56,10 +56,14 @@ static bool CanBeIsolatedNode(const std::string &var_name, const AnfNodePtr &nod
// will be ignored even if func_call() has side effects.
return !var_name.empty() && var_name != "_";
}
// For primitive cnode, only those with side effects can be isolate nodes.
// Primitive cnode with side effects can be isolate nodes.
auto effect_info = GetPrimEffectInfo(prim);
bool has_effects = (effect_info.memory || effect_info.io);
return has_effects;
if (has_effects) {
return true;
}
// Primitive cnode with 'no_eliminate' flag can be isolate nodes.
return GetPrimitiveFlag(prim, ATTR_NO_ELIMINATE);
}
// Write variable records the variable name to corresponding node

View File

@ -1075,8 +1075,9 @@ class AutoMonadConverter {
~AutoMonadConverter() = default;
bool Run() {
// Handle cnodes if graph has side effects.
if (HasSideEffects()) {
// Handle cnodes for side effects.
const auto &info = func_graph_->GetEffectInfo();
if (info.state == EffectInfo::kDetected) {
HandleCNodes();
}
@ -1091,17 +1092,6 @@ class AutoMonadConverter {
// Check if there are side effects from effect info.
static bool HasSideEffects(const EffectInfo &info) { return (info.memory || info.io || info.load); }
// Check if current graph has side effects.
bool HasSideEffects() const {
const auto &info = func_graph_->GetEffectInfo();
if (info.state != EffectInfo::kDetected) {
// Effect info should have been set by SideEffectFinder, except unused graphs.
MS_LOG(INFO) << "No effect info for unused graph: " << func_graph_->ToString();
return false;
}
return HasSideEffects(info);
}
// Gets effect info for a cnode.
const EffectInfo &GetEffectInfo(const CNodePtr &cnode) const {
MS_EXCEPTION_IF_NULL(cnode);
@ -1137,15 +1127,60 @@ class AutoMonadConverter {
if (info.io) {
HandleIoEffects(cnode, update_state);
}
// If the node has no side effects but 'no_eliminate' flag is set,
// we save it to no_eliminate_nodes and handle them late.
if (!info.memory && !info.io && IsNoEliminateNode(cnode)) {
no_eliminate_nodes_.emplace_back(cnode);
}
}
cnode->SetEffectHandled(true);
}
// Insert Depend nodes for states if required.
// Attach no eliminate nodes to output.
HandleNoEliminateNodes();
// Attach monad to output if required.
if (update_state) {
InsertStateDepends();
AttachMonadToOutput();
}
}
// Return true if the given cnode is primitive cnode with 'no_eliminate' flag.
bool IsNoEliminateNode(const CNodePtr &cnode) {
if (cnode == nullptr || cnode->size() == 0) {
return false;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim == nullptr) {
return false;
}
return GetPrimitiveFlag(prim, ATTR_NO_ELIMINATE);
}
// Attach no eliminate nodes to output.
void HandleNoEliminateNodes() {
if (no_eliminate_nodes_.empty()) {
// Skip if no nodes to be handled.
return;
}
// If only one node, attach it to output directly.
if (no_eliminate_nodes_.size() == 1) {
AttachToOutput(no_eliminate_nodes_.front());
return;
}
// For multiple nodes, attach them to output by a tuple.
std::vector<AnfNodePtr> tuple_inputs;
AbstractBasePtrList element_abstracts;
tuple_inputs.reserve(no_eliminate_nodes_.size() + 1);
element_abstracts.reserve(no_eliminate_nodes_.size());
tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (auto &node : no_eliminate_nodes_) {
tuple_inputs.emplace_back(node);
element_abstracts.emplace_back(node->abstract());
}
auto make_tuple_node = func_graph_->NewCNode(tuple_inputs);
make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
AttachToOutput(make_tuple_node);
}
// Clean no side effect dependency nodes.
// From: output = Depend(output, StopGrad)
// return output
@ -1322,34 +1357,31 @@ class AutoMonadConverter {
manager_->AddEdge(cnode, monad);
}
void InsertStateDepends() const {
void AttachMonadToOutput() const {
if (u_) {
// Insert Depend node for UMonad,
// Gradient is required for memory side effects.
InsertStateDepend(u_);
AttachToOutput(u_);
}
if (io_) {
// No gradient required for IO operations.
InsertStateDepend(io_);
AttachToOutput(io_);
}
}
void InsertStateDepend(const AnfNodePtr &state) const {
void AttachToOutput(const AnfNodePtr &node) const {
auto output = GetGraphOutput();
auto depend = NewValueNode(prim::kPrimDepend);
// If isolated nodes dependencies exist.
if (IsPrimitiveCNode(output, prim::kPrimDepend) &&
IsPrimitiveCNode(output->cast<CNodePtr>()->input(2), prim::kPrimStopGradient)) {
// Insert state Depend node into isolated Depend node.
// Insert new Depend node before isolated Depend node.
auto isolated_depend = output->cast<CNodePtr>();
auto &orig_output = isolated_depend->input(1);
auto state_depend = func_graph_->NewCNode({depend, orig_output, state});
auto state_depend = func_graph_->NewCNode({depend, orig_output, node});
state_depend->set_abstract(orig_output->abstract());
manager_->SetEdge(isolated_depend, 1, state_depend);
return;
}
// Insert Depend node and set it as output, if no isolated nodes.
auto depend_cnode = func_graph_->NewCNode({depend, output, state});
auto depend_cnode = func_graph_->NewCNode({depend, output, node});
depend_cnode->set_abstract(output->abstract());
func_graph_->set_output(depend_cnode);
}
@ -1448,6 +1480,9 @@ class AutoMonadConverter {
// True if there are side effect cnodes within this func graph.
bool has_effect_cnodes_ = false;
// CNodes that should not be eliminated even it is isolated node.
std::vector<CNodePtr> no_eliminate_nodes_;
// Current memory state node, null if no memory side effects.
AnfNodePtr u_;

View File

@ -47,4 +47,5 @@ const char ATTR_MAX_SHAPE[] = "max_shape";
const char ATTR_MIN_VALUE[] = "min_value";
const char ATTR_MAX_VALUE[] = "max_value";
const char ATTR_NO_BROADEN[] = "no_broaden";
const char ATTR_NO_ELIMINATE[] = "no_eliminate";
} // namespace mindspore

View File

@ -42,6 +42,7 @@ extern const char ATTR_MAX_SHAPE[];
extern const char ATTR_MIN_VALUE[];
extern const char ATTR_MAX_VALUE[];
extern const char ATTR_NO_BROADEN[];
extern const char ATTR_NO_ELIMINATE[];
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_FLAGS_H