forked from mindspore-Ecosystem/mindspore
!2861 use addparam to replace setparam to reduce overhead
Merge pull request !2861 from xychow/optimize-setparam-to-addparam
This commit is contained in:
commit
056f9f6dc1
|
@ -68,9 +68,7 @@ ParameterPtr FuncGraph::add_parameter() {
|
|||
|
||||
void FuncGraph::add_parameter(const ParameterPtr &p) {
|
||||
if (manager_.lock()) {
|
||||
std::vector<AnfNodePtr> new_params = parameters_;
|
||||
new_params.push_back(p);
|
||||
manager_.lock()->SetParameters(shared_from_base<FuncGraph>(), new_params);
|
||||
manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p);
|
||||
} else {
|
||||
parameters_.push_back(p);
|
||||
}
|
||||
|
@ -82,12 +80,8 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
|
|||
p->set_name(name);
|
||||
p->debug_info()->set_name(name);
|
||||
|
||||
std::vector<AnfNodePtr> new_params = parameters_;
|
||||
// append parameter
|
||||
new_params.push_back(p);
|
||||
|
||||
if (manager_.lock()) {
|
||||
manager_.lock()->SetParameters(shared_from_base<FuncGraph>(), new_params);
|
||||
manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p);
|
||||
} else {
|
||||
parameters_.push_back(p);
|
||||
}
|
||||
|
|
|
@ -158,6 +158,7 @@ class FuncGraph : public FuncGraphBase {
|
|||
const std::vector<AnfNodePtr> ¶meters() const { return parameters_; }
|
||||
virtual ParameterPtr add_parameter();
|
||||
void add_parameter(const ParameterPtr &p);
|
||||
void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); }
|
||||
void set_parameters(const std::vector<AnfNodePtr> ¶ms) { parameters_ = params; }
|
||||
// add a weight parameter with specific name
|
||||
ParameterPtr AddWeightParameter(const std::string &name);
|
||||
|
|
|
@ -420,6 +420,12 @@ void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<A
|
|||
tr.Commit();
|
||||
}
|
||||
|
||||
void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter) {
|
||||
auto tr = Transact();
|
||||
tr.AddParameter(fg, parameter);
|
||||
tr.Commit();
|
||||
}
|
||||
|
||||
bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
|
||||
auto tr = Transact();
|
||||
bool success = tr.Replace(old_node, new_node);
|
||||
|
@ -532,25 +538,37 @@ void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupl
|
|||
for (auto &iter : changes) {
|
||||
auto operation = iter.op;
|
||||
auto args = iter.args;
|
||||
if (operation == Change::kTxSetEdge) {
|
||||
auto edge = args.cast<ArgsOfSetEdge>();
|
||||
auto old_node = edge.root_node->input(edge.index);
|
||||
(*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1;
|
||||
(*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1;
|
||||
(*rms)[old_node] += 1;
|
||||
(*adds)[edge.new_node] += 1;
|
||||
edge.root_node->set_input(edge.index, edge.new_node);
|
||||
} else if (operation == Change::kTxSetParams) {
|
||||
auto param = args.cast<ArgsOfSetParams>();
|
||||
MS_EXCEPTION_IF_NULL(param.func_graph);
|
||||
auto old_parameters = param.func_graph->parameters();
|
||||
for (auto &p : param.params) {
|
||||
(*adds)[p] += 1;
|
||||
}
|
||||
for (auto &p : old_parameters) {
|
||||
(*rms)[p] += 1;
|
||||
}
|
||||
param.func_graph->set_parameters(param.params);
|
||||
switch (operation) {
|
||||
case Change::kTxSetEdge: {
|
||||
auto edge = args.cast<ArgsOfSetEdge>();
|
||||
auto old_node = edge.root_node->input(edge.index);
|
||||
(*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1;
|
||||
(*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1;
|
||||
(*rms)[old_node] += 1;
|
||||
(*adds)[edge.new_node] += 1;
|
||||
edge.root_node->set_input(edge.index, edge.new_node);
|
||||
} break;
|
||||
case Change::kTxSetParams: {
|
||||
auto param = args.cast<ArgsOfSetParams>();
|
||||
MS_EXCEPTION_IF_NULL(param.func_graph);
|
||||
auto old_parameters = param.func_graph->parameters();
|
||||
for (auto &p : param.params) {
|
||||
(*adds)[p] += 1;
|
||||
}
|
||||
for (auto &p : old_parameters) {
|
||||
(*rms)[p] += 1;
|
||||
}
|
||||
param.func_graph->set_parameters(param.params);
|
||||
} break;
|
||||
case Change::kTxAddParam: {
|
||||
auto param = args.cast<ArgsOfAddParam>();
|
||||
MS_EXCEPTION_IF_NULL(param.func_graph);
|
||||
(*adds)[param.param] += 1;
|
||||
auto param_node = param.param->cast<ParameterPtr>();
|
||||
param.func_graph->append_parameter(param_node);
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -599,6 +617,10 @@ void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfN
|
|||
changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params});
|
||||
}
|
||||
|
||||
void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m) {
|
||||
changes_.emplace_back(Change::kTxAddParam, ArgsOfAddParam{fg, param});
|
||||
}
|
||||
|
||||
bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
|
||||
MS_EXCEPTION_IF_NULL(old_node);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
|
|
|
@ -310,6 +310,7 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
|
|||
void KeepRoots(const std::vector<FuncGraphPtr> &roots = {});
|
||||
void RemoveRoots();
|
||||
void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> ¶meters);
|
||||
void AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter);
|
||||
void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false);
|
||||
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
|
||||
void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);
|
||||
|
@ -400,6 +401,7 @@ class FuncGraphTransaction {
|
|||
|
||||
// set parameters of a func graph
|
||||
void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms);
|
||||
void AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m);
|
||||
|
||||
// replace old_node with new_node
|
||||
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
|
||||
|
@ -427,6 +429,18 @@ struct ArgsOfSetParams {
|
|||
}
|
||||
};
|
||||
|
||||
// args for add param
|
||||
struct ArgsOfAddParam {
|
||||
FuncGraphPtr func_graph;
|
||||
AnfNodePtr param;
|
||||
bool operator==(const ArgsOfAddParam &other) const { return &other == this; }
|
||||
|
||||
friend std::ostream &operator<<(std::ostream &os, const ArgsOfAddParam &) {
|
||||
os << "[ArgsOfAddParam]";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
// args for set edge
|
||||
struct ArgsOfSetEdge {
|
||||
CNodePtr root_node;
|
||||
|
@ -441,7 +455,7 @@ struct ArgsOfSetEdge {
|
|||
};
|
||||
|
||||
struct Change {
|
||||
enum OpName { kTxSetParams, kTxSetEdge };
|
||||
enum OpName { kTxSetParams, kTxSetEdge, kTxAddParam };
|
||||
OpName op;
|
||||
Any args;
|
||||
Change(OpName name, const Any ¶) : op(name), args(para) {}
|
||||
|
|
Loading…
Reference in New Issue