!16137 [PyNative]Solve Opmask Leak Problem

From: @chenyijie6
Reviewed-by: @wilfchen,@limingqi107
Signed-off-by: @limingqi107
This commit is contained in:
mindspore-ci-bot 2021-05-11 14:53:55 +08:00 committed by Gitee
commit 84a545f0c2
2 changed files with 11 additions and 26 deletions

View File

@ -709,28 +709,9 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) {
return op_exec_info; return op_exec_info;
} }
bool ForwardExecutor::FindOpMask(py::object obj, std::vector<int64_t> *op_masks, const std::string &id) {
bool op_mask = false;
auto temp = op_mask_map_.find(id);
if (temp != op_mask_map_.end()) {
op_mask = temp->second;
(*op_masks).emplace_back(op_mask);
} else {
if (py::isinstance<tensor::MetaTensor>(obj)) {
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
if (meta_tensor) {
op_mask = meta_tensor->is_parameter();
}
}
MS_LOG(DEBUG) << "Gen args op_mask " << op_mask;
op_mask_map_[id] = op_mask;
(*op_masks).emplace_back(op_mask);
}
return op_mask;
}
void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
std::vector<AnfNodePtr> *inputs, abstract::AbstractBasePtrList *args_spec_list) { std::vector<AnfNodePtr> *inputs, abstract::AbstractBasePtrList *args_spec_list) {
MS_EXCEPTION_IF_NULL(op_masks);
auto prim = op_exec_info->py_primitive; auto prim = op_exec_info->py_primitive;
for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) { for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
abstract::AbstractBasePtr abs = nullptr; abstract::AbstractBasePtr abs = nullptr;
@ -740,8 +721,16 @@ void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector
if (it != node_abs_map_.end()) { if (it != node_abs_map_.end()) {
abs = it->second; abs = it->second;
} }
// Find the opmask of input obj
bool op_mask = FindOpMask(obj, op_masks, id); bool op_mask = false;
if (py::isinstance<tensor::MetaTensor>(obj)) {
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
if (meta_tensor) {
op_mask = meta_tensor->is_parameter();
}
}
MS_LOG(DEBUG) << "Gen args i " << i << op_mask;
(*op_masks).emplace_back(op_mask);
// Construct grad graph // Construct grad graph
if (grad()->need_construct_graph()) { if (grad()->need_construct_graph()) {
@ -1542,7 +1531,6 @@ void ForwardExecutor::ClearRes() {
prim_abs_list_.clear(); prim_abs_list_.clear();
node_abs_map_.clear(); node_abs_map_.clear();
cast_struct_map_.clear(); cast_struct_map_.clear();
op_mask_map_.clear();
cell_op_index_with_tensor_id_.clear(); cell_op_index_with_tensor_id_.clear();
cell_tensor_id_with_tensor_.clear(); cell_tensor_id_with_tensor_.clear();
} }

View File

@ -406,7 +406,6 @@ class ForwardExecutor {
PynativeStatusCode *status); PynativeStatusCode *status);
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
abstract::AbstractBasePtrList *args_spec_list); abstract::AbstractBasePtrList *args_spec_list);
bool FindOpMask(py::object obj, std::vector<int64_t> *op_masks, const std::string &id);
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs, void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs,
abstract::AbstractBasePtrList *args_spec_list); abstract::AbstractBasePtrList *args_spec_list);
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj, abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
@ -435,8 +434,6 @@ class ForwardExecutor {
std::unordered_map<std::string, TensorIdWithTensor> cell_tensor_id_with_tensor_; std::unordered_map<std::string, TensorIdWithTensor> cell_tensor_id_with_tensor_;
// Used to cache cast struct // Used to cache cast struct
std::unordered_map<std::string, OpExecInfoPtr> cast_struct_map_; std::unordered_map<std::string, OpExecInfoPtr> cast_struct_map_;
// Used to cache op_mask
std::unordered_map<std::string, int64_t> op_mask_map_;
}; };
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {