!17167 Refresh prim_abs_lsit_ key, avoid memory leak

From: @zhangzhaoju
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-06-04 11:29:40 +08:00 committed by Gitee
commit a0b14b52e4
3 changed files with 100 additions and 36 deletions

View File

@ -0,0 +1,70 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_ABS_CACHE_H
#define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_ABS_CACHE_H
#include <string>
#include <utility>
#include <unordered_map>
#include "ir/anf.h"
namespace mindspore::pynative {
struct AbsCacheKey {
std::string prim_name_;
size_t prim_hash_value_;
std::unordered_map<std::string, ValuePtr> prim_attrs_;
};
struct AbsCacheKeyHasher {
size_t operator()(const AbsCacheKey &key) const { return key.prim_hash_value_; }
};
struct AbsCacheKeyEqual {
bool operator()(const AbsCacheKey &lk, const AbsCacheKey &rk) const {
if (lk.prim_attrs_.size() != rk.prim_attrs_.size()) {
return false;
}
if (lk.prim_name_ != rk.prim_name_) {
return false;
}
auto all = std::all_of(lk.prim_attrs_.begin(), lk.prim_attrs_.end(),
[&rk](const std::pair<std::string, ValuePtr> &item) -> bool {
auto iter = rk.prim_attrs_.find(item.first);
if (iter == rk.prim_attrs_.end()) {
return false;
}
if (item.second == iter->second) {
return true;
}
MS_EXCEPTION_IF_NULL(item.second);
MS_EXCEPTION_IF_NULL(iter->second);
return *item.second == *iter->second;
});
return all;
}
};
struct PrimAbsInfo {
abstract::AbstractBasePtr abs;
bool is_dynamic_shape = false;
std::unordered_map<std::string, ValuePtr> attrs;
};
using AbstractListMap = std::unordered_map<abstract::AbstractBasePtrList, PrimAbsInfo,
abstract::AbstractBasePtrListHasher, abstract::AbstractBasePtrListEqual>;
using PrimAbsCache = std::unordered_map<AbsCacheKey, AbstractListMap, AbsCacheKeyHasher, AbsCacheKeyEqual>;
} // namespace mindspore::pynative
#endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_ABS_CACHE_H

View File

@ -883,7 +883,8 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
auto op_name = op_exec_info->op_name;
auto prim = op_exec_info->py_primitive;
auto temp = prim_abs_list_.find(prim);
AbsCacheKey key{prim->name(), prim->Hash(), prim->attrs()};
auto temp = prim_abs_list_.find(key);
if (temp != prim_abs_list_.end()) {
MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list);
auto iter = temp->second.find(args_spec_list);
@ -931,7 +932,8 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
// Add output abstract info into cache, the const value needs to infer evert step
if (!prim_cache_hit && !op_exec_info->is_dynamic_shape) {
auto &out = prim_abs_list_[prim];
AbsCacheKey key{prim->name(), prim->Hash(), prim->attrs()};
auto &out = prim_abs_list_[key];
out[args_spec_list].abs = op_exec_info->abstract;
out[args_spec_list].attrs = prim->evaluate_added_attrs();
}
@ -1475,7 +1477,7 @@ void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_e
}
}
void GradExecutor::SaveForwardTensorInfoInBpropGraph(const ResourcePtr &resource) {
void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
// Get all tensors id belong to forward op
std::unordered_set<std::string> forward_op_tensor_id;
@ -1851,7 +1853,7 @@ FuncGraphPtr GradExecutor::GetDfbuilder(const std::string &cell_id) {
return nullptr;
}
ResourcePtr GradExecutor::GetResource(const std::string &cell_id) {
pipeline::ResourcePtr GradExecutor::GetResource(const std::string &cell_id) {
// If top graph hold
for (auto it = top_cell_list_.rbegin(); it != top_cell_list_.rend(); ++it) {
if (cell_id.find((*it)->cell_id()) != std::string::npos) {
@ -2210,7 +2212,7 @@ std::string GradExecutor::GetGradCellId(bool has_sens, const py::object &cell, c
return cell_id;
}
void GradExecutor::GradNetInner(py::object *ret, const GradOperationPtr &grad, const py::object &cell,
void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell,
const py::object &weights, const py::args &args) {
MS_EXCEPTION_IF_NULL(grad);
auto size = args.size();
@ -2335,7 +2337,7 @@ abstract::AbstractBasePtrList GradExecutor::GetArgsSpec(const py::args &args, co
return args_spec;
}
FuncGraphPtr GradExecutor::GetBpropGraph(const GradOperationPtr &grad, const py::object &cell,
FuncGraphPtr GradExecutor::GetBpropGraph(const prim::GradOperationPtr &grad, const py::object &cell,
const std::vector<AnfNodePtr> &weights, size_t arg_size,
const py::args &args) {
bool build_formal_param = false;
@ -2361,7 +2363,7 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const GradOperationPtr &grad, const py:
(void)GetArgsSpec(args, bprop_graph);
// Do opt for final bprop graph
ResourcePtr resource = std::make_shared<pipeline::Resource>();
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
resource->set_func_graph(bprop_graph);
auto manager = resource->manager();
MS_EXCEPTION_IF_NULL(manager);
@ -2512,7 +2514,7 @@ void GradExecutor::SwitchTopcell() {
}
void GradExecutor::MakeNestedCnode(const py::object &cell, const std::string &cell_id, const py::args &forward_args,
const ResourcePtr &resource, const py::object &out) {
const pipeline::ResourcePtr &resource, const py::object &out) {
if (cell_stack_.empty()) {
MS_LOG(DEBUG) << "No nested grad find";
return;
@ -2528,7 +2530,7 @@ void GradExecutor::MakeNestedCnode(const py::object &cell, const std::string &ce
DumpGraphIR("first_grad_fg.ir", first_grad_fg);
auto out_id = GetId(out);
ResourcePtr r = std::make_shared<pipeline::Resource>();
pipeline::ResourcePtr r = std::make_shared<pipeline::Resource>();
r->manager()->AddFuncGraph(first_grad_fg);
FuncGraphPtr second_grad_fg = ad::Grad(first_grad_fg, r);
DumpGraphIR("second_grad_fg.ir", second_grad_fg);
@ -2628,7 +2630,7 @@ void GradExecutor::GradMsFunction(const py::object &out, const py::args &args) {
if (iter == ms_function_grad_cache.end()) {
ms_func_graph = BasicClone(executor->GetFuncGraph(graph_phase()));
MS_EXCEPTION_IF_NULL(ms_func_graph);
ResourcePtr res = std::make_shared<pipeline::Resource>();
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
res->set_func_graph(ms_func_graph);
res->manager()->AddFuncGraph(ms_func_graph, true);
fprop_g = ad::Grad(ms_func_graph, res, true);
@ -2751,7 +2753,7 @@ void PynativeExecutor::GradMsFunction(const py::object &out, const py::args &arg
grad_executor()->GradMsFunction(out, args);
}
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
void PynativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args) {
py::object *ret = nullptr;
PynativeExecutorTry(grad_executor()->GradGraph, ret, grad, cell, weights, args);

View File

@ -38,21 +38,11 @@
#include "pipeline/jit/resource.h"
#include "frontend/optimizer/ad/kpynative.h"
#include "frontend/operator/composite/composite.h"
#include "pipeline/pynative/pynative_abs_cache.h"
namespace mindspore::pynative {
namespace py = pybind11;
using cell_id = std::string;
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
using GradOperationPtr = std::shared_ptr<prim::GradOperation>;
struct PrimAbsInfo {
abstract::AbstractBasePtr abs;
bool is_dynamic_shape = false;
std::unordered_map<std::string, ValuePtr> attrs;
};
using AbstractListMap = std::unordered_map<abstract::AbstractBasePtrList, PrimAbsInfo,
abstract::AbstractBasePtrListHasher, abstract::AbstractBasePtrListEqual>;
using CellId = std::string;
using MsFunctionGradCache = std::unordered_map<std::string, std::pair<FuncGraphPtr, FuncGraphPtr>>;
using OpInfoWithTensorId = std::unordered_map<std::string, std::vector<std::string>>;
using TensorIdWithTensorObject = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>;
@ -73,7 +63,7 @@ class TopCellInfo {
public:
TopCellInfo() = default;
~TopCellInfo() = default;
TopCellInfo(bool topest, size_t grad_order, ResourcePtr r, FuncGraphPtr df, std::string cellid)
TopCellInfo(bool topest, size_t grad_order, pipeline::ResourcePtr r, FuncGraphPtr df, std::string cellid)
: is_topest_(topest),
grad_order_(grad_order),
resource_(std::move(r)),
@ -95,7 +85,7 @@ class TopCellInfo {
void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
bool forward_already_run() const { return forward_already_run_; }
void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
ResourcePtr resource() { return resource_; }
pipeline::ResourcePtr resource() { return resource_; }
FuncGraphPtr df_builder() { return df_builder_; }
size_t op_num() const { return op_num_; }
void set_op_num(size_t op_num) { op_num_ = op_num; }
@ -129,7 +119,7 @@ class TopCellInfo {
bool need_compile_graph_{false};
size_t op_num_{0};
size_t grad_order_{0};
ResourcePtr resource_{nullptr};
pipeline::ResourcePtr resource_{nullptr};
FuncGraphPtr df_builder_{nullptr};
ad::KPynativeCellPtr k_pynative_cell_ptr_{nullptr};
std::string cell_id_;
@ -167,7 +157,8 @@ class GradExecutor {
EndGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2),
std::forward<decltype(PH3)>(PH3), std::forward<decltype(PH4)>(PH4));
};
std::function<void(py::object *, const GradOperationPtr &, const py::object &, const py::object &, const py::args &)>
std::function<void(py::object *, const prim::GradOperationPtr &, const py::object &, const py::object &,
const py::args &)>
GradGraph = [this](auto &&PH1, auto &&PH2, auto &&PH3, auto &&PH4, auto &&PH5) {
GradNetInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), std::forward<decltype(PH3)>(PH3),
std::forward<decltype(PH4)>(PH4), std::forward<decltype(PH5)>(PH5));
@ -199,7 +190,7 @@ class GradExecutor {
const OpExecInfoPtr &op_exec_info, ValuePtrList *input_values,
CNodePtr *ms_function_cnode);
void UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_exec_info, const py::object &out_real);
void SaveForwardTensorInfoInBpropGraph(const ResourcePtr &resource);
void SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource);
py::object CheckGraph(const py::object &cell, const py::args &args);
void RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args);
void EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell);
@ -216,14 +207,14 @@ class GradExecutor {
void SwitchTopcell();
size_t GetHighOrderStackSize() const { return high_order_stack_.size(); }
void MakeNestedCnode(const py::object &cell, const std::string &cell_id, const py::args &forward_args,
const ResourcePtr &resource, const py::object &out);
const pipeline::ResourcePtr &resource, const py::object &out);
void PushCellStack(const std::string &cell_id);
void PopCellStack();
void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell);
TopCellInfoPtr PopHighOrderGraphStack();
FuncGraphPtr GetDfbuilder(const std::string &cell_id = "");
ResourcePtr GetResource(const std::string &cell_id = "");
pipeline::ResourcePtr GetResource(const std::string &cell_id = "");
bool IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id);
bool IsBpropGraph(const std::string &cell_id);
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compiled);
@ -234,9 +225,9 @@ class GradExecutor {
void EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args);
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args,
py::args *forward_args = nullptr);
void GradNetInner(py::object *ret, const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args);
FuncGraphPtr GetBpropGraph(const GradOperationPtr &grad, const py::object &cell,
void GradNetInner(py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell,
const py::object &weights, const py::args &args);
FuncGraphPtr GetBpropGraph(const prim::GradOperationPtr &grad, const py::object &cell,
const std::vector<AnfNodePtr> &weights, size_t arg_size, const py::args &args);
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &bprop_graph);
@ -285,7 +276,7 @@ class GradExecutor {
// Use vector for keep order
std::vector<TopCellInfoPtr> top_cell_list_;
// Record all top cell which has been ran
std::map<cell_id, TopCellInfoPtr> already_run_top_cell_;
std::map<CellId, TopCellInfoPtr> already_run_top_cell_;
// Use vector for keep order
ForwardExecutorWeakPtr forward_executor_;
};
@ -336,7 +327,7 @@ class ForwardExecutor {
private:
GradExecutorWeakPtr grad_executor_;
std::unordered_map<PrimitivePtr, AbstractListMap, PrimitiveHasher, PrimitiveTotalEqual> prim_abs_list_;
PrimAbsCache prim_abs_list_;
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
// Used to cache cast struct
std::unordered_map<std::string, OpExecInfoPtr> cast_struct_map_;
@ -368,7 +359,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void GradMsFunction(const py::object &out, const py::args &args);
void NewGraph(const py::object &cell, const py::args &args);
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
void GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args);
py::object CheckGraph(const py::object &cell, const py::args &args);
py::object CheckAlreadyRun(const py::object &cell, const py::args &args);
py::object Run(const py::object &cell, const py::tuple &args);