!5913 add count of graphs using the parameter

Merge pull request !5913 from limingqi107/master
This commit is contained in:
mindspore-ci-bot 2020-09-09 14:15:51 +08:00 committed by Gitee
commit f480e48271
3 changed files with 17 additions and 2 deletions

View File

@ -469,6 +469,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
}
TraceManager::EndTrace();
}
new_parameter->IncreaseUsedGraphCount();
graph_inputs->push_back(new_parameter);
valid_inputs->push_back(true);
return new_parameter;
@ -812,6 +813,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
}
TraceManager::EndTrace();
}
new_parameter->IncreaseUsedGraphCount();
return new_parameter;
}

View File

@ -803,11 +803,18 @@ void KernelRuntime::ClearOutputAddress(const std::vector<AnfNodePtr> &inputs,
if (!input_node->isa<Parameter>()) {
continue;
}
auto parameter = input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
parameter->DecreaseUsedGraphCount();
// Only the parameter has no graph used, then clear the output address.
if (parameter->used_graph_count() != 0) {
continue;
}
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(input_node); ++index) {
if (!AnfAlgo::OutputAddrExist(input_node, index)) {
continue;
}
AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get());
AnfAlgo::SetOutputAddr(nullptr, index, input_node.get());
}
}
// clear input value node output address.

View File

@ -282,7 +282,7 @@ class ANode : public AnfNode {
class Parameter : public ANode {
public:
explicit Parameter(const FuncGraphPtr &func_graph)
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr) {}
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), used_graph_count_(0) {}
~Parameter() override = default;
MS_DECLARE_PARENT(Parameter, ANode);
@ -300,6 +300,10 @@ class Parameter : public ANode {
ValuePtr default_param() const { return default_param_; }
ParamInfoPtr param_info() const;
void IncreaseUsedGraphCount() { used_graph_count_++; }
void DecreaseUsedGraphCount() { used_graph_count_--; }
int used_graph_count() const { return used_graph_count_; }
bool operator==(const AnfNode &other) const override {
if (!other.isa<Parameter>()) {
return false;
@ -315,6 +319,8 @@ class Parameter : public ANode {
std::string name_;
bool has_default_;
ValuePtr default_param_;
// The count of graphs using the parameter.
int used_graph_count_;
};
using ParameterPtr = std::shared_ptr<Parameter>;