!5913 add count of graphs using the parameter
Merge pull request !5913 from limingqi107/master
This commit is contained in:
commit
f480e48271
|
@ -469,6 +469,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
|
|||
}
|
||||
TraceManager::EndTrace();
|
||||
}
|
||||
new_parameter->IncreaseUsedGraphCount();
|
||||
graph_inputs->push_back(new_parameter);
|
||||
valid_inputs->push_back(true);
|
||||
return new_parameter;
|
||||
|
@ -812,6 +813,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
|
|||
}
|
||||
TraceManager::EndTrace();
|
||||
}
|
||||
new_parameter->IncreaseUsedGraphCount();
|
||||
|
||||
return new_parameter;
|
||||
}
|
||||
|
|
|
@ -803,11 +803,18 @@ void KernelRuntime::ClearOutputAddress(const std::vector<AnfNodePtr> &inputs,
|
|||
if (!input_node->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
auto parameter = input_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
parameter->DecreaseUsedGraphCount();
|
||||
// Only the parameter has no graph used, then clear the output address.
|
||||
if (parameter->used_graph_count() != 0) {
|
||||
continue;
|
||||
}
|
||||
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(input_node); ++index) {
|
||||
if (!AnfAlgo::OutputAddrExist(input_node, index)) {
|
||||
continue;
|
||||
}
|
||||
AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get());
|
||||
AnfAlgo::SetOutputAddr(nullptr, index, input_node.get());
|
||||
}
|
||||
}
|
||||
// clear input value node output address.
|
||||
|
|
|
@ -282,7 +282,7 @@ class ANode : public AnfNode {
|
|||
class Parameter : public ANode {
|
||||
public:
|
||||
explicit Parameter(const FuncGraphPtr &func_graph)
|
||||
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr) {}
|
||||
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), used_graph_count_(0) {}
|
||||
~Parameter() override = default;
|
||||
MS_DECLARE_PARENT(Parameter, ANode);
|
||||
|
||||
|
@ -300,6 +300,10 @@ class Parameter : public ANode {
|
|||
ValuePtr default_param() const { return default_param_; }
|
||||
ParamInfoPtr param_info() const;
|
||||
|
||||
void IncreaseUsedGraphCount() { used_graph_count_++; }
|
||||
void DecreaseUsedGraphCount() { used_graph_count_--; }
|
||||
int used_graph_count() const { return used_graph_count_; }
|
||||
|
||||
bool operator==(const AnfNode &other) const override {
|
||||
if (!other.isa<Parameter>()) {
|
||||
return false;
|
||||
|
@ -315,6 +319,8 @@ class Parameter : public ANode {
|
|||
std::string name_;
|
||||
bool has_default_;
|
||||
ValuePtr default_param_;
|
||||
// The count of graphs using the parameter.
|
||||
int used_graph_count_;
|
||||
};
|
||||
using ParameterPtr = std::shared_ptr<Parameter>;
|
||||
|
||||
|
|
Loading…
Reference in New Issue