forked from mindspore-Ecosystem/mindspore
!18881 To handle Parameter only for top func graph.
Merge pull request !18881 from 张清华/opt
This commit is contained in:
commit
d6474c10ee
|
@ -280,11 +280,11 @@ bool CombineLikeGraphs(const ResourcePtr &res) {
|
|||
auto base_graph = cloner->cloned_func_graph()[fg];
|
||||
MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString();
|
||||
|
||||
if (fg->used_global_parameters().empty() || graphs.size() <= 1 || fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
|
||||
if (fg->paramter_obj_nodes().empty() || graphs.size() <= 1 || fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
|
||||
continue;
|
||||
}
|
||||
auto &cloned_nodes = *cloner->cloned_node();
|
||||
for (auto &fv : fg->used_global_parameters()) {
|
||||
for (auto &fv : fg->paramter_obj_nodes()) {
|
||||
TraceGuard guard(std::make_shared<TraceCombileLikeGraphs>(fv->debug_info()));
|
||||
auto param = base_graph->add_parameter();
|
||||
auto &node_users = res->manager()->node_users()[fv];
|
||||
|
@ -298,10 +298,10 @@ bool CombineLikeGraphs(const ResourcePtr &res) {
|
|||
repl_n->set_input(IntToSize(n.second), param);
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Fg0 used_global_parameters size :" << fg->used_global_parameters().size();
|
||||
MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size();
|
||||
|
||||
for (auto &g : graphs) {
|
||||
auto &fvs = g->used_global_parameters();
|
||||
auto &fvs = g->paramter_obj_nodes();
|
||||
std::vector<AnfNodePtr> new_node_inputs;
|
||||
new_node_inputs.push_back(NewValueNode(base_graph));
|
||||
for (auto &p : g->parameters()) {
|
||||
|
|
|
@ -67,13 +67,13 @@ bool SymbolResolver::Resolve() {
|
|||
}
|
||||
|
||||
namespace {
|
||||
// if any mixed precision flag add a cast node after the parameter node.
|
||||
// If any mixed precision flag add a cast node after the parameter node.
|
||||
// argument obj should be python Parameter object
|
||||
// it will be converted to Parameter node here
|
||||
AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
// parameter object should not be none
|
||||
// Parameter object should not be none
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "Resolve class Parameter error because obj is null.";
|
||||
}
|
||||
|
@ -82,33 +82,38 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
|
|||
MS_LOG(EXCEPTION) << "Resolve class Parameter error: cannot find name attr for obj";
|
||||
}
|
||||
|
||||
// get the parameter name from parameter object
|
||||
// Get the parameter name from parameter object
|
||||
auto name_attr = python_adapter::GetPyObjAttr(obj, "name");
|
||||
if (py::isinstance<py::none>(name_attr)) {
|
||||
MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
|
||||
}
|
||||
|
||||
auto param_name = py::cast<std::string>(name_attr);
|
||||
auto top_graph = Parser::GetTopFuncGraph();
|
||||
// if the parameter node has been created , return it
|
||||
auto top_func_graph = Parser::GetTopFuncGraph();
|
||||
// If the parameter node has been created , return it.
|
||||
AnfNodePtr para_node = nullptr;
|
||||
for (auto const ¶m : top_graph->parameters()) {
|
||||
for (auto const ¶m : top_func_graph->parameters()) {
|
||||
auto param_node = dyn_cast<Parameter>(param);
|
||||
if (param_node != nullptr && param_node->name() == param_name) {
|
||||
para_node = param;
|
||||
MS_LOG(DEBUG) << "Found existing parameter for " << func_graph->ToString()
|
||||
<< ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (para_node == nullptr) {
|
||||
auto node = top_graph->AddWeightParameter(param_name);
|
||||
auto node = top_func_graph->AddWeightParameter(param_name);
|
||||
auto value = py::cast<tensor::MetaTensorPtr>(obj);
|
||||
node->set_default_param(value);
|
||||
// set_abstract for parameter
|
||||
// Set abstract for parameter
|
||||
auto abs = value->ToAbstract();
|
||||
node->set_abstract(abs);
|
||||
para_node = node;
|
||||
MS_LOG(DEBUG) << "Created a new weight parameter for " << func_graph->ToString()
|
||||
<< ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
|
||||
}
|
||||
func_graph->add_used_global_parameters(para_node);
|
||||
func_graph->add_parameter_obj_node(para_node);
|
||||
|
||||
return para_node;
|
||||
}
|
||||
|
||||
|
@ -141,9 +146,9 @@ void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) {
|
|||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
if (param_ptr->has_default()) {
|
||||
param_ptr->set_func_graph(top_graph);
|
||||
func_graph->add_used_global_parameters(param_ptr);
|
||||
func_graph->add_parameter_obj_node(param_ptr);
|
||||
|
||||
// update top_graph
|
||||
// Update top_graph
|
||||
top_graph->add_parameter(param_ptr);
|
||||
size_t hyper_param_count = top_graph->hyper_param_count();
|
||||
top_graph->set_hyper_param_count(hyper_param_count + 1);
|
||||
|
@ -164,7 +169,6 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj,
|
|||
return false;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Add param graph:" << func_graph->ToString() << ", " << param->DebugString();
|
||||
|
||||
output = param;
|
||||
} else if (py::hasattr(obj, "__parameter_tuple__")) {
|
||||
auto tuple = obj.cast<py::tuple>();
|
||||
|
@ -335,7 +339,7 @@ opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &ir
|
|||
opt::OptPassGroupMap map({
|
||||
{"resolve",
|
||||
{
|
||||
// for resolve and getattr primitive;
|
||||
// For resolve and getattr primitive;
|
||||
irpass.resolver_resolve_and_getattr_,
|
||||
}},
|
||||
});
|
||||
|
@ -371,7 +375,7 @@ bool ResolveAll(const FuncGraphManagerPtr &manager) {
|
|||
"called from root graph, so it's not necessary to pass all graphs as roots. "
|
||||
"Please ensure your usage.";
|
||||
}
|
||||
// should not use pipeline::Resource as Resource::Clean will clean some
|
||||
// Should not use pipeline::Resource as Resource::Clean will clean some
|
||||
// global variable such as ScopeManager, it will cause JExpandedGraphs::GetBprop
|
||||
// fail as valid scope has been cleaned.
|
||||
auto res = std::make_shared<pipeline::ResourceBase>();
|
||||
|
|
|
@ -1043,7 +1043,9 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
|
||||
std::string name = refkey->tag();
|
||||
MS_EXCEPTION_IF_NULL(node_conf->node());
|
||||
MS_EXCEPTION_IF_NULL(node_conf->node()->func_graph());
|
||||
if (node_conf->node()->func_graph() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Should not evaluate a ValueNode, node: " << node_conf->node()->DebugString();
|
||||
}
|
||||
const auto &manager = node_conf->node()->func_graph()->manager();
|
||||
auto node = FindParameterNodeByString(manager, name);
|
||||
if (node == nullptr) {
|
||||
|
|
|
@ -56,6 +56,10 @@ FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisConte
|
|||
MS_EXCEPTION_IF_NULL(fg);
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_LOG(DEBUG) << "Specialize topmost function graph: " << context->func_graph()->ToString();
|
||||
if (top_context_ == nullptr) {
|
||||
top_context_ = context;
|
||||
MS_LOG(INFO) << "Specialize set top func graph context: " << context->ToString();
|
||||
}
|
||||
return SpecializeFuncGraph(fg, context);
|
||||
}
|
||||
|
||||
|
@ -108,16 +112,11 @@ FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const Fu
|
|||
|
||||
AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
FuncGraphPtr fg = node->func_graph();
|
||||
|
||||
if (node->isa<ValueNode>()) {
|
||||
return node;
|
||||
}
|
||||
std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
|
||||
while (fg != nullptr && fg != specializer->func_graph_) {
|
||||
specializer = specializer->parent_;
|
||||
MS_EXCEPTION_IF_NULL(specializer);
|
||||
}
|
||||
std::shared_ptr<FuncGraphSpecializer> specializer = GetTopSpecializer(node);
|
||||
|
||||
// If had replicated, just return that.
|
||||
auto iter = specializer->repl_node_->find(node);
|
||||
if (iter != specializer->repl_node_->end()) {
|
||||
|
@ -169,14 +168,7 @@ void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const An
|
|||
}
|
||||
|
||||
AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
FuncGraphPtr fg = node->func_graph();
|
||||
|
||||
std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
|
||||
while (fg != nullptr && fg != specializer->func_graph_) {
|
||||
specializer = specializer->parent_;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraphSpecializer> specializer = GetTopSpecializer(node);
|
||||
MS_EXCEPTION_IF_NULL(specializer->repl_node_);
|
||||
auto iter = specializer->repl_node_->find(node);
|
||||
if (iter != specializer->repl_node_->end()) {
|
||||
|
@ -185,6 +177,40 @@ AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) {
|
|||
return node;
|
||||
}
|
||||
|
||||
// Return itself if node's ValueNode as top,
|
||||
// return the top func graph specializer as top if node's forward Parameter,
|
||||
// or, return the top parent specializer as top.
|
||||
std::shared_ptr<FuncGraphSpecializer> FuncGraphSpecializer::GetTopSpecializer(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
FuncGraphPtr fg = node->func_graph();
|
||||
if (fg == nullptr) { // If ValueNode, return current specializer.
|
||||
MS_LOG(DEBUG) << "Node's a ValueNode, node: " << node->DebugString();
|
||||
return shared_from_this();
|
||||
}
|
||||
std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
|
||||
while (fg != specializer->func_graph_) {
|
||||
if (specializer->parent_ == nullptr && node->isa<Parameter>()) {
|
||||
// If `parent_` is null and forwarded `node` is a Parameter, we'll try to use top func graph as parent.
|
||||
MS_EXCEPTION_IF_NULL(specializer_->top_context());
|
||||
if (specializer_->top_context()->func_graph() == fg) { // `fg` is top func graph.
|
||||
specializer = specializer_->GetFuncGraphSpecializer(specializer_->top_context());
|
||||
MS_LOG(INFO) << "Used top func graph specializer as parent for " << func_graph_->ToString()
|
||||
<< ", node: " << node->DebugString() << ", NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
||||
MS_EXCEPTION_IF_NULL(specializer);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
specializer = specializer->parent_;
|
||||
}
|
||||
if (specializer == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "`specializer` should not be null, node: " << node->DebugString()
|
||||
<< ", NodeInfo: " << trace::GetDebugInfo(node->debug_info()) << ".\n"
|
||||
<< func_graph_->ToString() << " has no parent context? At least not " << fg->ToString();
|
||||
}
|
||||
}
|
||||
return specializer;
|
||||
}
|
||||
|
||||
void FuncGraphSpecializer::Run() {
|
||||
MS_LOG(DEBUG) << "Before run, origin func graph name: " << func_graph_->ToString()
|
||||
<< ", cloned func graph name: " << specialized_func_graph_->ToString()
|
||||
|
@ -205,14 +231,24 @@ void FuncGraphSpecializer::FirstPass() {
|
|||
continue;
|
||||
}
|
||||
if (node->func_graph() != func_graph_) {
|
||||
if (parent_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Parent must not null NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
||||
std::shared_ptr<FuncGraphSpecializer> parent = nullptr;
|
||||
if (parent_ != nullptr) {
|
||||
parent = parent_;
|
||||
} else if (specializer_->top_context()->func_graph() == node->func_graph() && node->isa<Parameter>()) {
|
||||
// If `parent_` is null and forwarded `node` is a Parameter, we'll try to use top func graph as parent.
|
||||
parent = specializer_->GetFuncGraphSpecializer(specializer_->top_context());
|
||||
MS_LOG(INFO) << "Used top func graph specializer as parent for " << func_graph_->ToString()
|
||||
<< ", node: " << node->DebugString() << ", NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
||||
}
|
||||
parent_->AddTodoItem(node);
|
||||
parent_->FirstPass();
|
||||
AnfNodePtr new_node = parent_->GetReplicatedNode(node);
|
||||
if (parent == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Parent must not null, node: " << node->DebugString()
|
||||
<< ", NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
||||
}
|
||||
parent->AddTodoItem(node);
|
||||
parent->FirstPass();
|
||||
AnfNodePtr new_node = parent->GetReplicatedNode(node);
|
||||
if (node->isa<CNode>()) {
|
||||
parent_->ProcessCNode(new_node->cast<CNodePtr>());
|
||||
parent->ProcessCNode(new_node->cast<CNodePtr>());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ class FuncGraphSpecializer;
|
|||
// Specialize a func graph using analyzed abstract values.
|
||||
class ProgramSpecializer {
|
||||
public:
|
||||
explicit ProgramSpecializer(const std::shared_ptr<AnalysisEngine> &engine) : engine_(engine) {
|
||||
explicit ProgramSpecializer(const std::shared_ptr<AnalysisEngine> &engine) : engine_(engine), top_context_(nullptr) {
|
||||
mng_ = engine_->func_graph_manager();
|
||||
}
|
||||
~ProgramSpecializer() = default;
|
||||
|
@ -60,12 +60,15 @@ class ProgramSpecializer {
|
|||
|
||||
std::shared_ptr<AnalysisEngine> engine() { return engine_; }
|
||||
|
||||
AnalysisContextPtr top_context() { return top_context_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<AnalysisEngine> engine_;
|
||||
std::unordered_set<AnfNodePtr> seen_;
|
||||
FuncGraphManagerPtr mng_;
|
||||
std::unordered_map<AnalysisContextPtr, std::shared_ptr<FuncGraphSpecializer>, ContextHasher, ContextEqual>
|
||||
specializations_;
|
||||
AnalysisContextPtr top_context_;
|
||||
};
|
||||
|
||||
class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecializer> {
|
||||
|
@ -78,6 +81,8 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
|
|||
void Run();
|
||||
FuncGraphPtr specialized_func_graph() { return specialized_func_graph_; }
|
||||
|
||||
std::shared_ptr<FuncGraphSpecializer> GetTopSpecializer(const AnfNodePtr &node);
|
||||
|
||||
private:
|
||||
ProgramSpecializer *specializer_;
|
||||
FuncGraphPtr func_graph_;
|
||||
|
|
|
@ -350,8 +350,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
const std::vector<AnfNodePtr> &specialized_parameter_list,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
|
||||
|
||||
const std::vector<AnfNodePtr> &used_global_parameters() const { return used_global_parameters_; }
|
||||
void add_used_global_parameters(const AnfNodePtr &p) { used_global_parameters_.push_back(p); }
|
||||
const std::vector<AnfNodePtr> ¶mter_obj_nodes() const { return paramter_obj_nodes_; }
|
||||
void add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); }
|
||||
|
||||
std::unordered_map<std::string, ValuePtr> attrs_;
|
||||
std::vector<BaseShapePtr> joined_shapes_;
|
||||
|
@ -428,9 +428,7 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
|
||||
// Parameters of this function.
|
||||
std::vector<AnfNodePtr> parameters_;
|
||||
|
||||
// Global parameters used by this function.
|
||||
std::vector<AnfNodePtr> used_global_parameters_;
|
||||
std::vector<AnfNodePtr> paramter_obj_nodes_;
|
||||
|
||||
// Whether there is a *args and **kwargs, and count kwonlyargs'number.
|
||||
bool has_vararg_;
|
||||
|
|
Loading…
Reference in New Issue