!18881 To handle Parameter only for top func graph.

Merge pull request !18881 from 张清华/opt
This commit is contained in:
i-robot 2021-06-26 14:38:17 +00:00 committed by Gitee
commit d6474c10ee
6 changed files with 91 additions and 46 deletions

View File

@ -280,11 +280,11 @@ bool CombineLikeGraphs(const ResourcePtr &res) {
auto base_graph = cloner->cloned_func_graph()[fg];
MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString();
if (fg->used_global_parameters().empty() || graphs.size() <= 1 || fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
if (fg->paramter_obj_nodes().empty() || graphs.size() <= 1 || fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
continue;
}
auto &cloned_nodes = *cloner->cloned_node();
for (auto &fv : fg->used_global_parameters()) {
for (auto &fv : fg->paramter_obj_nodes()) {
TraceGuard guard(std::make_shared<TraceCombileLikeGraphs>(fv->debug_info()));
auto param = base_graph->add_parameter();
auto &node_users = res->manager()->node_users()[fv];
@ -298,10 +298,10 @@ bool CombineLikeGraphs(const ResourcePtr &res) {
repl_n->set_input(IntToSize(n.second), param);
}
}
MS_LOG(DEBUG) << "Fg0 used_global_parameters size :" << fg->used_global_parameters().size();
MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size();
for (auto &g : graphs) {
auto &fvs = g->used_global_parameters();
auto &fvs = g->paramter_obj_nodes();
std::vector<AnfNodePtr> new_node_inputs;
new_node_inputs.push_back(NewValueNode(base_graph));
for (auto &p : g->parameters()) {

View File

@ -67,13 +67,13 @@ bool SymbolResolver::Resolve() {
}
namespace {
// if any mixed precision flag add a cast node after the parameter node.
// If any mixed precision flag add a cast node after the parameter node.
// argument obj should be python Parameter object
// it will be converted to Parameter node here
AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) {
MS_EXCEPTION_IF_NULL(func_graph);
// parameter object should not be none
// Parameter object should not be none
if (py::isinstance<py::none>(obj)) {
MS_LOG(EXCEPTION) << "Resolve class Parameter error because obj is null.";
}
@ -82,33 +82,38 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
MS_LOG(EXCEPTION) << "Resolve class Parameter error: cannot find name attr for obj";
}
// get the parameter name from parameter object
// Get the parameter name from parameter object
auto name_attr = python_adapter::GetPyObjAttr(obj, "name");
if (py::isinstance<py::none>(name_attr)) {
MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
}
auto param_name = py::cast<std::string>(name_attr);
auto top_graph = Parser::GetTopFuncGraph();
// if the parameter node has been created , return it
auto top_func_graph = Parser::GetTopFuncGraph();
// If the parameter node has been created , return it.
AnfNodePtr para_node = nullptr;
for (auto const &param : top_graph->parameters()) {
for (auto const &param : top_func_graph->parameters()) {
auto param_node = dyn_cast<Parameter>(param);
if (param_node != nullptr && param_node->name() == param_name) {
para_node = param;
MS_LOG(DEBUG) << "Found existing parameter for " << func_graph->ToString()
<< ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
break;
}
}
if (para_node == nullptr) {
auto node = top_graph->AddWeightParameter(param_name);
auto node = top_func_graph->AddWeightParameter(param_name);
auto value = py::cast<tensor::MetaTensorPtr>(obj);
node->set_default_param(value);
// set_abstract for parameter
// Set abstract for parameter
auto abs = value->ToAbstract();
node->set_abstract(abs);
para_node = node;
MS_LOG(DEBUG) << "Created a new weight parameter for " << func_graph->ToString()
<< ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
}
func_graph->add_used_global_parameters(para_node);
func_graph->add_parameter_obj_node(para_node);
return para_node;
}
@ -141,9 +146,9 @@ void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->has_default()) {
param_ptr->set_func_graph(top_graph);
func_graph->add_used_global_parameters(param_ptr);
func_graph->add_parameter_obj_node(param_ptr);
// update top_graph
// Update top_graph
top_graph->add_parameter(param_ptr);
size_t hyper_param_count = top_graph->hyper_param_count();
top_graph->set_hyper_param_count(hyper_param_count + 1);
@ -164,7 +169,6 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj,
return false;
}
MS_LOG(DEBUG) << "Add param graph:" << func_graph->ToString() << ", " << param->DebugString();
output = param;
} else if (py::hasattr(obj, "__parameter_tuple__")) {
auto tuple = obj.cast<py::tuple>();
@ -335,7 +339,7 @@ opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &ir
opt::OptPassGroupMap map({
{"resolve",
{
// for resolve and getattr primitive;
// For resolve and getattr primitive;
irpass.resolver_resolve_and_getattr_,
}},
});
@ -371,7 +375,7 @@ bool ResolveAll(const FuncGraphManagerPtr &manager) {
"called from root graph, so it's not necessary to pass all graphs as roots. "
"Please ensure your usage.";
}
// should not use pipeline::Resource as Resource::Clean will clean some
// Should not use pipeline::Resource as Resource::Clean will clean some
// global variable such as ScopeManager, it will cause JExpandedGraphs::GetBprop
// fail as valid scope has been cleaned.
auto res = std::make_shared<pipeline::ResourceBase>();

View File

@ -1043,7 +1043,9 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
std::string name = refkey->tag();
MS_EXCEPTION_IF_NULL(node_conf->node());
MS_EXCEPTION_IF_NULL(node_conf->node()->func_graph());
if (node_conf->node()->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "Should not evaluate a ValueNode, node: " << node_conf->node()->DebugString();
}
const auto &manager = node_conf->node()->func_graph()->manager();
auto node = FindParameterNodeByString(manager, name);
if (node == nullptr) {

View File

@ -56,6 +56,10 @@ FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisConte
MS_EXCEPTION_IF_NULL(fg);
MS_EXCEPTION_IF_NULL(context);
MS_LOG(DEBUG) << "Specialize topmost function graph: " << context->func_graph()->ToString();
if (top_context_ == nullptr) {
top_context_ = context;
MS_LOG(INFO) << "Specialize set top func graph context: " << context->ToString();
}
return SpecializeFuncGraph(fg, context);
}
@ -108,16 +112,11 @@ FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const Fu
AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
FuncGraphPtr fg = node->func_graph();
if (node->isa<ValueNode>()) {
return node;
}
std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
while (fg != nullptr && fg != specializer->func_graph_) {
specializer = specializer->parent_;
MS_EXCEPTION_IF_NULL(specializer);
}
std::shared_ptr<FuncGraphSpecializer> specializer = GetTopSpecializer(node);
// If had replicated, just return that.
auto iter = specializer->repl_node_->find(node);
if (iter != specializer->repl_node_->end()) {
@ -169,14 +168,7 @@ void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const An
}
AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
FuncGraphPtr fg = node->func_graph();
std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
while (fg != nullptr && fg != specializer->func_graph_) {
specializer = specializer->parent_;
}
std::shared_ptr<FuncGraphSpecializer> specializer = GetTopSpecializer(node);
MS_EXCEPTION_IF_NULL(specializer->repl_node_);
auto iter = specializer->repl_node_->find(node);
if (iter != specializer->repl_node_->end()) {
@ -185,6 +177,40 @@ AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) {
return node;
}
// Return itself if node's ValueNode as top,
// return the top func graph specializer as top if node's forward Parameter,
// or, return the top parent specializer as top.
std::shared_ptr<FuncGraphSpecializer> FuncGraphSpecializer::GetTopSpecializer(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
FuncGraphPtr fg = node->func_graph();
if (fg == nullptr) { // If ValueNode, return current specializer.
MS_LOG(DEBUG) << "Node's a ValueNode, node: " << node->DebugString();
return shared_from_this();
}
std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
while (fg != specializer->func_graph_) {
if (specializer->parent_ == nullptr && node->isa<Parameter>()) {
// If `parent_` is null and forwarded `node` is a Parameter, we'll try to use top func graph as parent.
MS_EXCEPTION_IF_NULL(specializer_->top_context());
if (specializer_->top_context()->func_graph() == fg) { // `fg` is top func graph.
specializer = specializer_->GetFuncGraphSpecializer(specializer_->top_context());
MS_LOG(INFO) << "Used top func graph specializer as parent for " << func_graph_->ToString()
<< ", node: " << node->DebugString() << ", NodeInfo: " << trace::GetDebugInfo(node->debug_info());
MS_EXCEPTION_IF_NULL(specializer);
break;
}
} else {
specializer = specializer->parent_;
}
if (specializer == nullptr) {
MS_LOG(EXCEPTION) << "`specializer` should not be null, node: " << node->DebugString()
<< ", NodeInfo: " << trace::GetDebugInfo(node->debug_info()) << ".\n"
<< func_graph_->ToString() << " has no parent context? At least not " << fg->ToString();
}
}
return specializer;
}
void FuncGraphSpecializer::Run() {
MS_LOG(DEBUG) << "Before run, origin func graph name: " << func_graph_->ToString()
<< ", cloned func graph name: " << specialized_func_graph_->ToString()
@ -205,14 +231,24 @@ void FuncGraphSpecializer::FirstPass() {
continue;
}
if (node->func_graph() != func_graph_) {
if (parent_ == nullptr) {
MS_LOG(EXCEPTION) << "Parent must not null NodeInfo: " << trace::GetDebugInfo(node->debug_info());
std::shared_ptr<FuncGraphSpecializer> parent = nullptr;
if (parent_ != nullptr) {
parent = parent_;
} else if (specializer_->top_context()->func_graph() == node->func_graph() && node->isa<Parameter>()) {
// If `parent_` is null and forwarded `node` is a Parameter, we'll try to use top func graph as parent.
parent = specializer_->GetFuncGraphSpecializer(specializer_->top_context());
MS_LOG(INFO) << "Used top func graph specializer as parent for " << func_graph_->ToString()
<< ", node: " << node->DebugString() << ", NodeInfo: " << trace::GetDebugInfo(node->debug_info());
}
parent_->AddTodoItem(node);
parent_->FirstPass();
AnfNodePtr new_node = parent_->GetReplicatedNode(node);
if (parent == nullptr) {
MS_LOG(EXCEPTION) << "Parent must not null, node: " << node->DebugString()
<< ", NodeInfo: " << trace::GetDebugInfo(node->debug_info());
}
parent->AddTodoItem(node);
parent->FirstPass();
AnfNodePtr new_node = parent->GetReplicatedNode(node);
if (node->isa<CNode>()) {
parent_->ProcessCNode(new_node->cast<CNodePtr>());
parent->ProcessCNode(new_node->cast<CNodePtr>());
}
continue;
}

View File

@ -45,7 +45,7 @@ class FuncGraphSpecializer;
// Specialize a func graph using analyzed abstract values.
class ProgramSpecializer {
public:
explicit ProgramSpecializer(const std::shared_ptr<AnalysisEngine> &engine) : engine_(engine) {
explicit ProgramSpecializer(const std::shared_ptr<AnalysisEngine> &engine) : engine_(engine), top_context_(nullptr) {
mng_ = engine_->func_graph_manager();
}
~ProgramSpecializer() = default;
@ -60,12 +60,15 @@ class ProgramSpecializer {
std::shared_ptr<AnalysisEngine> engine() { return engine_; }
AnalysisContextPtr top_context() { return top_context_; }
private:
std::shared_ptr<AnalysisEngine> engine_;
std::unordered_set<AnfNodePtr> seen_;
FuncGraphManagerPtr mng_;
std::unordered_map<AnalysisContextPtr, std::shared_ptr<FuncGraphSpecializer>, ContextHasher, ContextEqual>
specializations_;
AnalysisContextPtr top_context_;
};
class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecializer> {
@ -78,6 +81,8 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
void Run();
FuncGraphPtr specialized_func_graph() { return specialized_func_graph_; }
std::shared_ptr<FuncGraphSpecializer> GetTopSpecializer(const AnfNodePtr &node);
private:
ProgramSpecializer *specializer_;
FuncGraphPtr func_graph_;

View File

@ -350,8 +350,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
const std::vector<AnfNodePtr> &specialized_parameter_list,
std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
const std::vector<AnfNodePtr> &used_global_parameters() const { return used_global_parameters_; }
void add_used_global_parameters(const AnfNodePtr &p) { used_global_parameters_.push_back(p); }
const std::vector<AnfNodePtr> &paramter_obj_nodes() const { return paramter_obj_nodes_; }
void add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); }
std::unordered_map<std::string, ValuePtr> attrs_;
std::vector<BaseShapePtr> joined_shapes_;
@ -428,9 +428,7 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
// Parameters of this function.
std::vector<AnfNodePtr> parameters_;
// Global parameters used by this function.
std::vector<AnfNodePtr> used_global_parameters_;
std::vector<AnfNodePtr> paramter_obj_nodes_;
// Whether there is a *args and **kwargs, and count kwonlyargs'number.
bool has_vararg_;