!22352 Support 'no_eliminate' primitive
Merge pull request !22352 from hewei/no_eliminate
This commit is contained in:
commit
a030bd00a2
|
@ -56,10 +56,14 @@ static bool CanBeIsolatedNode(const std::string &var_name, const AnfNodePtr &nod
|
|||
// will be ignored even if func_call() has side effects.
|
||||
return !var_name.empty() && var_name != "_";
|
||||
}
|
||||
// For primitive cnode, only those with side effects can be isolate nodes.
|
||||
// Primitive cnode with side effects can be isolate nodes.
|
||||
auto effect_info = GetPrimEffectInfo(prim);
|
||||
bool has_effects = (effect_info.memory || effect_info.io);
|
||||
return has_effects;
|
||||
if (has_effects) {
|
||||
return true;
|
||||
}
|
||||
// Primitive cnode with 'no_eliminate' flag can be isolate nodes.
|
||||
return GetPrimitiveFlag(prim, ATTR_NO_ELIMINATE);
|
||||
}
|
||||
|
||||
// Write variable records the variable name to corresponding node
|
||||
|
|
|
@ -1075,8 +1075,9 @@ class AutoMonadConverter {
|
|||
~AutoMonadConverter() = default;
|
||||
|
||||
bool Run() {
|
||||
// Handle cnodes if graph has side effects.
|
||||
if (HasSideEffects()) {
|
||||
// Handle cnodes for side effects.
|
||||
const auto &info = func_graph_->GetEffectInfo();
|
||||
if (info.state == EffectInfo::kDetected) {
|
||||
HandleCNodes();
|
||||
}
|
||||
|
||||
|
@ -1091,17 +1092,6 @@ class AutoMonadConverter {
|
|||
// Check if there are side effects from effect info.
|
||||
static bool HasSideEffects(const EffectInfo &info) { return (info.memory || info.io || info.load); }
|
||||
|
||||
// Check if current graph has side effects.
|
||||
bool HasSideEffects() const {
|
||||
const auto &info = func_graph_->GetEffectInfo();
|
||||
if (info.state != EffectInfo::kDetected) {
|
||||
// Effect info should have been set by SideEffectFinder, except unused graphs.
|
||||
MS_LOG(INFO) << "No effect info for unused graph: " << func_graph_->ToString();
|
||||
return false;
|
||||
}
|
||||
return HasSideEffects(info);
|
||||
}
|
||||
|
||||
// Gets effect info for a cnode.
|
||||
const EffectInfo &GetEffectInfo(const CNodePtr &cnode) const {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
@ -1137,15 +1127,60 @@ class AutoMonadConverter {
|
|||
if (info.io) {
|
||||
HandleIoEffects(cnode, update_state);
|
||||
}
|
||||
// If the node has no side effects but 'no_eliminate' flag is set,
|
||||
// we save it to no_eliminate_nodes and handle them late.
|
||||
if (!info.memory && !info.io && IsNoEliminateNode(cnode)) {
|
||||
no_eliminate_nodes_.emplace_back(cnode);
|
||||
}
|
||||
}
|
||||
cnode->SetEffectHandled(true);
|
||||
}
|
||||
// Insert Depend nodes for states if required.
|
||||
// Attach no eliminate nodes to output.
|
||||
HandleNoEliminateNodes();
|
||||
// Attach monad to output if required.
|
||||
if (update_state) {
|
||||
InsertStateDepends();
|
||||
AttachMonadToOutput();
|
||||
}
|
||||
}
|
||||
|
||||
// Return true if the given cnode is primitive cnode with 'no_eliminate' flag.
|
||||
bool IsNoEliminateNode(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr || cnode->size() == 0) {
|
||||
return false;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return GetPrimitiveFlag(prim, ATTR_NO_ELIMINATE);
|
||||
}
|
||||
|
||||
// Attach no eliminate nodes to output.
|
||||
void HandleNoEliminateNodes() {
|
||||
if (no_eliminate_nodes_.empty()) {
|
||||
// Skip if no nodes to be handled.
|
||||
return;
|
||||
}
|
||||
// If only one node, attach it to output directly.
|
||||
if (no_eliminate_nodes_.size() == 1) {
|
||||
AttachToOutput(no_eliminate_nodes_.front());
|
||||
return;
|
||||
}
|
||||
// For multiple nodes, attach them to output by a tuple.
|
||||
std::vector<AnfNodePtr> tuple_inputs;
|
||||
AbstractBasePtrList element_abstracts;
|
||||
tuple_inputs.reserve(no_eliminate_nodes_.size() + 1);
|
||||
element_abstracts.reserve(no_eliminate_nodes_.size());
|
||||
tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
for (auto &node : no_eliminate_nodes_) {
|
||||
tuple_inputs.emplace_back(node);
|
||||
element_abstracts.emplace_back(node->abstract());
|
||||
}
|
||||
auto make_tuple_node = func_graph_->NewCNode(tuple_inputs);
|
||||
make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
|
||||
AttachToOutput(make_tuple_node);
|
||||
}
|
||||
|
||||
// Clean no side effect dependency nodes.
|
||||
// From: output = Depend(output, StopGrad)
|
||||
// return output
|
||||
|
@ -1322,34 +1357,31 @@ class AutoMonadConverter {
|
|||
manager_->AddEdge(cnode, monad);
|
||||
}
|
||||
|
||||
void InsertStateDepends() const {
|
||||
void AttachMonadToOutput() const {
|
||||
if (u_) {
|
||||
// Insert Depend node for UMonad,
|
||||
// Gradient is required for memory side effects.
|
||||
InsertStateDepend(u_);
|
||||
AttachToOutput(u_);
|
||||
}
|
||||
if (io_) {
|
||||
// No gradient required for IO operations.
|
||||
InsertStateDepend(io_);
|
||||
AttachToOutput(io_);
|
||||
}
|
||||
}
|
||||
|
||||
void InsertStateDepend(const AnfNodePtr &state) const {
|
||||
void AttachToOutput(const AnfNodePtr &node) const {
|
||||
auto output = GetGraphOutput();
|
||||
auto depend = NewValueNode(prim::kPrimDepend);
|
||||
// If isolated nodes dependencies exist.
|
||||
if (IsPrimitiveCNode(output, prim::kPrimDepend) &&
|
||||
IsPrimitiveCNode(output->cast<CNodePtr>()->input(2), prim::kPrimStopGradient)) {
|
||||
// Insert state Depend node into isolated Depend node.
|
||||
// Insert new Depend node before isolated Depend node.
|
||||
auto isolated_depend = output->cast<CNodePtr>();
|
||||
auto &orig_output = isolated_depend->input(1);
|
||||
auto state_depend = func_graph_->NewCNode({depend, orig_output, state});
|
||||
auto state_depend = func_graph_->NewCNode({depend, orig_output, node});
|
||||
state_depend->set_abstract(orig_output->abstract());
|
||||
manager_->SetEdge(isolated_depend, 1, state_depend);
|
||||
return;
|
||||
}
|
||||
// Insert Depend node and set it as output, if no isolated nodes.
|
||||
auto depend_cnode = func_graph_->NewCNode({depend, output, state});
|
||||
auto depend_cnode = func_graph_->NewCNode({depend, output, node});
|
||||
depend_cnode->set_abstract(output->abstract());
|
||||
func_graph_->set_output(depend_cnode);
|
||||
}
|
||||
|
@ -1448,6 +1480,9 @@ class AutoMonadConverter {
|
|||
// True if there are side effect cnodes within this func graph.
|
||||
bool has_effect_cnodes_ = false;
|
||||
|
||||
// CNodes that should not be eliminated even it is isolated node.
|
||||
std::vector<CNodePtr> no_eliminate_nodes_;
|
||||
|
||||
// Current memory state node, null if no memory side effects.
|
||||
AnfNodePtr u_;
|
||||
|
||||
|
|
|
@ -47,4 +47,5 @@ const char ATTR_MAX_SHAPE[] = "max_shape";
|
|||
const char ATTR_MIN_VALUE[] = "min_value";
|
||||
const char ATTR_MAX_VALUE[] = "max_value";
|
||||
const char ATTR_NO_BROADEN[] = "no_broaden";
|
||||
const char ATTR_NO_ELIMINATE[] = "no_eliminate";
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,6 +42,7 @@ extern const char ATTR_MAX_SHAPE[];
|
|||
extern const char ATTR_MIN_VALUE[];
|
||||
extern const char ATTR_MAX_VALUE[];
|
||||
extern const char ATTR_NO_BROADEN[];
|
||||
extern const char ATTR_NO_ELIMINATE[];
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_UTILS_FLAGS_H
|
||||
|
|
Loading…
Reference in New Issue