forked from mindspore-Ecosystem/mindspore
!17167 Refresh prim_abs_lsit_ key, avoid memory leak
From: @zhangzhaoju Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
a0b14b52e4
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_ABS_CACHE_H
|
||||
#define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_ABS_CACHE_H
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore::pynative {
|
||||
struct AbsCacheKey {
|
||||
std::string prim_name_;
|
||||
size_t prim_hash_value_;
|
||||
std::unordered_map<std::string, ValuePtr> prim_attrs_;
|
||||
};
|
||||
|
||||
struct AbsCacheKeyHasher {
|
||||
size_t operator()(const AbsCacheKey &key) const { return key.prim_hash_value_; }
|
||||
};
|
||||
|
||||
struct AbsCacheKeyEqual {
|
||||
bool operator()(const AbsCacheKey &lk, const AbsCacheKey &rk) const {
|
||||
if (lk.prim_attrs_.size() != rk.prim_attrs_.size()) {
|
||||
return false;
|
||||
}
|
||||
if (lk.prim_name_ != rk.prim_name_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto all = std::all_of(lk.prim_attrs_.begin(), lk.prim_attrs_.end(),
|
||||
[&rk](const std::pair<std::string, ValuePtr> &item) -> bool {
|
||||
auto iter = rk.prim_attrs_.find(item.first);
|
||||
if (iter == rk.prim_attrs_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (item.second == iter->second) {
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(item.second);
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
return *item.second == *iter->second;
|
||||
});
|
||||
return all;
|
||||
}
|
||||
};
|
||||
|
||||
struct PrimAbsInfo {
|
||||
abstract::AbstractBasePtr abs;
|
||||
bool is_dynamic_shape = false;
|
||||
std::unordered_map<std::string, ValuePtr> attrs;
|
||||
};
|
||||
using AbstractListMap = std::unordered_map<abstract::AbstractBasePtrList, PrimAbsInfo,
|
||||
abstract::AbstractBasePtrListHasher, abstract::AbstractBasePtrListEqual>;
|
||||
using PrimAbsCache = std::unordered_map<AbsCacheKey, AbstractListMap, AbsCacheKeyHasher, AbsCacheKeyEqual>;
|
||||
} // namespace mindspore::pynative
|
||||
#endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_ABS_CACHE_H
|
|
@ -883,7 +883,8 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
|
|||
auto op_name = op_exec_info->op_name;
|
||||
auto prim = op_exec_info->py_primitive;
|
||||
|
||||
auto temp = prim_abs_list_.find(prim);
|
||||
AbsCacheKey key{prim->name(), prim->Hash(), prim->attrs()};
|
||||
auto temp = prim_abs_list_.find(key);
|
||||
if (temp != prim_abs_list_.end()) {
|
||||
MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list);
|
||||
auto iter = temp->second.find(args_spec_list);
|
||||
|
@ -931,7 +932,8 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
|
|||
|
||||
// Add output abstract info into cache, the const value needs to infer evert step
|
||||
if (!prim_cache_hit && !op_exec_info->is_dynamic_shape) {
|
||||
auto &out = prim_abs_list_[prim];
|
||||
AbsCacheKey key{prim->name(), prim->Hash(), prim->attrs()};
|
||||
auto &out = prim_abs_list_[key];
|
||||
out[args_spec_list].abs = op_exec_info->abstract;
|
||||
out[args_spec_list].attrs = prim->evaluate_added_attrs();
|
||||
}
|
||||
|
@ -1475,7 +1477,7 @@ void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_e
|
|||
}
|
||||
}
|
||||
|
||||
void GradExecutor::SaveForwardTensorInfoInBpropGraph(const ResourcePtr &resource) {
|
||||
void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
// Get all tensors id belong to forward op
|
||||
std::unordered_set<std::string> forward_op_tensor_id;
|
||||
|
@ -1851,7 +1853,7 @@ FuncGraphPtr GradExecutor::GetDfbuilder(const std::string &cell_id) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
ResourcePtr GradExecutor::GetResource(const std::string &cell_id) {
|
||||
pipeline::ResourcePtr GradExecutor::GetResource(const std::string &cell_id) {
|
||||
// If top graph hold
|
||||
for (auto it = top_cell_list_.rbegin(); it != top_cell_list_.rend(); ++it) {
|
||||
if (cell_id.find((*it)->cell_id()) != std::string::npos) {
|
||||
|
@ -2210,7 +2212,7 @@ std::string GradExecutor::GetGradCellId(bool has_sens, const py::object &cell, c
|
|||
return cell_id;
|
||||
}
|
||||
|
||||
void GradExecutor::GradNetInner(py::object *ret, const GradOperationPtr &grad, const py::object &cell,
|
||||
void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell,
|
||||
const py::object &weights, const py::args &args) {
|
||||
MS_EXCEPTION_IF_NULL(grad);
|
||||
auto size = args.size();
|
||||
|
@ -2335,7 +2337,7 @@ abstract::AbstractBasePtrList GradExecutor::GetArgsSpec(const py::args &args, co
|
|||
return args_spec;
|
||||
}
|
||||
|
||||
FuncGraphPtr GradExecutor::GetBpropGraph(const GradOperationPtr &grad, const py::object &cell,
|
||||
FuncGraphPtr GradExecutor::GetBpropGraph(const prim::GradOperationPtr &grad, const py::object &cell,
|
||||
const std::vector<AnfNodePtr> &weights, size_t arg_size,
|
||||
const py::args &args) {
|
||||
bool build_formal_param = false;
|
||||
|
@ -2361,7 +2363,7 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const GradOperationPtr &grad, const py:
|
|||
(void)GetArgsSpec(args, bprop_graph);
|
||||
|
||||
// Do opt for final bprop graph
|
||||
ResourcePtr resource = std::make_shared<pipeline::Resource>();
|
||||
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
|
||||
resource->set_func_graph(bprop_graph);
|
||||
auto manager = resource->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -2512,7 +2514,7 @@ void GradExecutor::SwitchTopcell() {
|
|||
}
|
||||
|
||||
void GradExecutor::MakeNestedCnode(const py::object &cell, const std::string &cell_id, const py::args &forward_args,
|
||||
const ResourcePtr &resource, const py::object &out) {
|
||||
const pipeline::ResourcePtr &resource, const py::object &out) {
|
||||
if (cell_stack_.empty()) {
|
||||
MS_LOG(DEBUG) << "No nested grad find";
|
||||
return;
|
||||
|
@ -2528,7 +2530,7 @@ void GradExecutor::MakeNestedCnode(const py::object &cell, const std::string &ce
|
|||
DumpGraphIR("first_grad_fg.ir", first_grad_fg);
|
||||
|
||||
auto out_id = GetId(out);
|
||||
ResourcePtr r = std::make_shared<pipeline::Resource>();
|
||||
pipeline::ResourcePtr r = std::make_shared<pipeline::Resource>();
|
||||
r->manager()->AddFuncGraph(first_grad_fg);
|
||||
FuncGraphPtr second_grad_fg = ad::Grad(first_grad_fg, r);
|
||||
DumpGraphIR("second_grad_fg.ir", second_grad_fg);
|
||||
|
@ -2628,7 +2630,7 @@ void GradExecutor::GradMsFunction(const py::object &out, const py::args &args) {
|
|||
if (iter == ms_function_grad_cache.end()) {
|
||||
ms_func_graph = BasicClone(executor->GetFuncGraph(graph_phase()));
|
||||
MS_EXCEPTION_IF_NULL(ms_func_graph);
|
||||
ResourcePtr res = std::make_shared<pipeline::Resource>();
|
||||
pipeline::ResourcePtr res = std::make_shared<pipeline::Resource>();
|
||||
res->set_func_graph(ms_func_graph);
|
||||
res->manager()->AddFuncGraph(ms_func_graph, true);
|
||||
fprop_g = ad::Grad(ms_func_graph, res, true);
|
||||
|
@ -2751,7 +2753,7 @@ void PynativeExecutor::GradMsFunction(const py::object &out, const py::args &arg
|
|||
grad_executor()->GradMsFunction(out, args);
|
||||
}
|
||||
|
||||
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||
void PynativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||
const py::args &args) {
|
||||
py::object *ret = nullptr;
|
||||
PynativeExecutorTry(grad_executor()->GradGraph, ret, grad, cell, weights, args);
|
||||
|
|
|
@ -38,21 +38,11 @@
|
|||
#include "pipeline/jit/resource.h"
|
||||
#include "frontend/optimizer/ad/kpynative.h"
|
||||
#include "frontend/operator/composite/composite.h"
|
||||
#include "pipeline/pynative/pynative_abs_cache.h"
|
||||
|
||||
namespace mindspore::pynative {
|
||||
namespace py = pybind11;
|
||||
using cell_id = std::string;
|
||||
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
|
||||
using GradOperationPtr = std::shared_ptr<prim::GradOperation>;
|
||||
|
||||
struct PrimAbsInfo {
|
||||
abstract::AbstractBasePtr abs;
|
||||
bool is_dynamic_shape = false;
|
||||
std::unordered_map<std::string, ValuePtr> attrs;
|
||||
};
|
||||
|
||||
using AbstractListMap = std::unordered_map<abstract::AbstractBasePtrList, PrimAbsInfo,
|
||||
abstract::AbstractBasePtrListHasher, abstract::AbstractBasePtrListEqual>;
|
||||
using CellId = std::string;
|
||||
using MsFunctionGradCache = std::unordered_map<std::string, std::pair<FuncGraphPtr, FuncGraphPtr>>;
|
||||
using OpInfoWithTensorId = std::unordered_map<std::string, std::vector<std::string>>;
|
||||
using TensorIdWithTensorObject = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>;
|
||||
|
@ -73,7 +63,7 @@ class TopCellInfo {
|
|||
public:
|
||||
TopCellInfo() = default;
|
||||
~TopCellInfo() = default;
|
||||
TopCellInfo(bool topest, size_t grad_order, ResourcePtr r, FuncGraphPtr df, std::string cellid)
|
||||
TopCellInfo(bool topest, size_t grad_order, pipeline::ResourcePtr r, FuncGraphPtr df, std::string cellid)
|
||||
: is_topest_(topest),
|
||||
grad_order_(grad_order),
|
||||
resource_(std::move(r)),
|
||||
|
@ -95,7 +85,7 @@ class TopCellInfo {
|
|||
void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
|
||||
bool forward_already_run() const { return forward_already_run_; }
|
||||
void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
|
||||
ResourcePtr resource() { return resource_; }
|
||||
pipeline::ResourcePtr resource() { return resource_; }
|
||||
FuncGraphPtr df_builder() { return df_builder_; }
|
||||
size_t op_num() const { return op_num_; }
|
||||
void set_op_num(size_t op_num) { op_num_ = op_num; }
|
||||
|
@ -129,7 +119,7 @@ class TopCellInfo {
|
|||
bool need_compile_graph_{false};
|
||||
size_t op_num_{0};
|
||||
size_t grad_order_{0};
|
||||
ResourcePtr resource_{nullptr};
|
||||
pipeline::ResourcePtr resource_{nullptr};
|
||||
FuncGraphPtr df_builder_{nullptr};
|
||||
ad::KPynativeCellPtr k_pynative_cell_ptr_{nullptr};
|
||||
std::string cell_id_;
|
||||
|
@ -167,7 +157,8 @@ class GradExecutor {
|
|||
EndGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2),
|
||||
std::forward<decltype(PH3)>(PH3), std::forward<decltype(PH4)>(PH4));
|
||||
};
|
||||
std::function<void(py::object *, const GradOperationPtr &, const py::object &, const py::object &, const py::args &)>
|
||||
std::function<void(py::object *, const prim::GradOperationPtr &, const py::object &, const py::object &,
|
||||
const py::args &)>
|
||||
GradGraph = [this](auto &&PH1, auto &&PH2, auto &&PH3, auto &&PH4, auto &&PH5) {
|
||||
GradNetInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), std::forward<decltype(PH3)>(PH3),
|
||||
std::forward<decltype(PH4)>(PH4), std::forward<decltype(PH5)>(PH5));
|
||||
|
@ -199,7 +190,7 @@ class GradExecutor {
|
|||
const OpExecInfoPtr &op_exec_info, ValuePtrList *input_values,
|
||||
CNodePtr *ms_function_cnode);
|
||||
void UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_exec_info, const py::object &out_real);
|
||||
void SaveForwardTensorInfoInBpropGraph(const ResourcePtr &resource);
|
||||
void SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource);
|
||||
py::object CheckGraph(const py::object &cell, const py::args &args);
|
||||
void RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args);
|
||||
void EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell);
|
||||
|
@ -216,14 +207,14 @@ class GradExecutor {
|
|||
void SwitchTopcell();
|
||||
size_t GetHighOrderStackSize() const { return high_order_stack_.size(); }
|
||||
void MakeNestedCnode(const py::object &cell, const std::string &cell_id, const py::args &forward_args,
|
||||
const ResourcePtr &resource, const py::object &out);
|
||||
const pipeline::ResourcePtr &resource, const py::object &out);
|
||||
void PushCellStack(const std::string &cell_id);
|
||||
void PopCellStack();
|
||||
void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell);
|
||||
TopCellInfoPtr PopHighOrderGraphStack();
|
||||
|
||||
FuncGraphPtr GetDfbuilder(const std::string &cell_id = "");
|
||||
ResourcePtr GetResource(const std::string &cell_id = "");
|
||||
pipeline::ResourcePtr GetResource(const std::string &cell_id = "");
|
||||
bool IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id);
|
||||
bool IsBpropGraph(const std::string &cell_id);
|
||||
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compiled);
|
||||
|
@ -234,9 +225,9 @@ class GradExecutor {
|
|||
void EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args);
|
||||
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args,
|
||||
py::args *forward_args = nullptr);
|
||||
void GradNetInner(py::object *ret, const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||
const py::args &args);
|
||||
FuncGraphPtr GetBpropGraph(const GradOperationPtr &grad, const py::object &cell,
|
||||
void GradNetInner(py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell,
|
||||
const py::object &weights, const py::args &args);
|
||||
FuncGraphPtr GetBpropGraph(const prim::GradOperationPtr &grad, const py::object &cell,
|
||||
const std::vector<AnfNodePtr> &weights, size_t arg_size, const py::args &args);
|
||||
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
|
||||
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &bprop_graph);
|
||||
|
@ -285,7 +276,7 @@ class GradExecutor {
|
|||
// Use vector for keep order
|
||||
std::vector<TopCellInfoPtr> top_cell_list_;
|
||||
// Record all top cell which has been ran
|
||||
std::map<cell_id, TopCellInfoPtr> already_run_top_cell_;
|
||||
std::map<CellId, TopCellInfoPtr> already_run_top_cell_;
|
||||
// Use vector for keep order
|
||||
ForwardExecutorWeakPtr forward_executor_;
|
||||
};
|
||||
|
@ -336,7 +327,7 @@ class ForwardExecutor {
|
|||
|
||||
private:
|
||||
GradExecutorWeakPtr grad_executor_;
|
||||
std::unordered_map<PrimitivePtr, AbstractListMap, PrimitiveHasher, PrimitiveTotalEqual> prim_abs_list_;
|
||||
PrimAbsCache prim_abs_list_;
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
|
||||
// Used to cache cast struct
|
||||
std::unordered_map<std::string, OpExecInfoPtr> cast_struct_map_;
|
||||
|
@ -368,7 +359,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void GradMsFunction(const py::object &out, const py::args &args);
|
||||
void NewGraph(const py::object &cell, const py::args &args);
|
||||
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
|
||||
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
|
||||
void GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||
const py::args &args);
|
||||
py::object CheckGraph(const py::object &cell, const py::args &args);
|
||||
py::object CheckAlreadyRun(const py::object &cell, const py::args &args);
|
||||
py::object Run(const py::object &cell, const py::tuple &args);
|
||||
|
|
Loading…
Reference in New Issue