forked from mindspore-Ecosystem/mindspore
!16137 [PyNative]Solve Opmask Leak Problem
From: @chenyijie6 Reviewed-by: @wilfchen,@limingqi107 Signed-off-by: @limingqi107
This commit is contained in:
commit
84a545f0c2
|
@ -709,28 +709,9 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) {
|
||||||
return op_exec_info;
|
return op_exec_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ForwardExecutor::FindOpMask(py::object obj, std::vector<int64_t> *op_masks, const std::string &id) {
|
|
||||||
bool op_mask = false;
|
|
||||||
auto temp = op_mask_map_.find(id);
|
|
||||||
if (temp != op_mask_map_.end()) {
|
|
||||||
op_mask = temp->second;
|
|
||||||
(*op_masks).emplace_back(op_mask);
|
|
||||||
} else {
|
|
||||||
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
|
||||||
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
|
|
||||||
if (meta_tensor) {
|
|
||||||
op_mask = meta_tensor->is_parameter();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "Gen args op_mask " << op_mask;
|
|
||||||
op_mask_map_[id] = op_mask;
|
|
||||||
(*op_masks).emplace_back(op_mask);
|
|
||||||
}
|
|
||||||
return op_mask;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
|
void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
|
||||||
std::vector<AnfNodePtr> *inputs, abstract::AbstractBasePtrList *args_spec_list) {
|
std::vector<AnfNodePtr> *inputs, abstract::AbstractBasePtrList *args_spec_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(op_masks);
|
||||||
auto prim = op_exec_info->py_primitive;
|
auto prim = op_exec_info->py_primitive;
|
||||||
for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
|
for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
|
||||||
abstract::AbstractBasePtr abs = nullptr;
|
abstract::AbstractBasePtr abs = nullptr;
|
||||||
|
@ -740,8 +721,16 @@ void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector
|
||||||
if (it != node_abs_map_.end()) {
|
if (it != node_abs_map_.end()) {
|
||||||
abs = it->second;
|
abs = it->second;
|
||||||
}
|
}
|
||||||
// Find the opmask of input obj
|
|
||||||
bool op_mask = FindOpMask(obj, op_masks, id);
|
bool op_mask = false;
|
||||||
|
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
||||||
|
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
|
||||||
|
if (meta_tensor) {
|
||||||
|
op_mask = meta_tensor->is_parameter();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Gen args i " << i << op_mask;
|
||||||
|
(*op_masks).emplace_back(op_mask);
|
||||||
|
|
||||||
// Construct grad graph
|
// Construct grad graph
|
||||||
if (grad()->need_construct_graph()) {
|
if (grad()->need_construct_graph()) {
|
||||||
|
@ -1542,7 +1531,6 @@ void ForwardExecutor::ClearRes() {
|
||||||
prim_abs_list_.clear();
|
prim_abs_list_.clear();
|
||||||
node_abs_map_.clear();
|
node_abs_map_.clear();
|
||||||
cast_struct_map_.clear();
|
cast_struct_map_.clear();
|
||||||
op_mask_map_.clear();
|
|
||||||
cell_op_index_with_tensor_id_.clear();
|
cell_op_index_with_tensor_id_.clear();
|
||||||
cell_tensor_id_with_tensor_.clear();
|
cell_tensor_id_with_tensor_.clear();
|
||||||
}
|
}
|
||||||
|
|
|
@ -406,7 +406,6 @@ class ForwardExecutor {
|
||||||
PynativeStatusCode *status);
|
PynativeStatusCode *status);
|
||||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
|
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
|
||||||
abstract::AbstractBasePtrList *args_spec_list);
|
abstract::AbstractBasePtrList *args_spec_list);
|
||||||
bool FindOpMask(py::object obj, std::vector<int64_t> *op_masks, const std::string &id);
|
|
||||||
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs,
|
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs,
|
||||||
abstract::AbstractBasePtrList *args_spec_list);
|
abstract::AbstractBasePtrList *args_spec_list);
|
||||||
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
|
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
|
||||||
|
@ -435,8 +434,6 @@ class ForwardExecutor {
|
||||||
std::unordered_map<std::string, TensorIdWithTensor> cell_tensor_id_with_tensor_;
|
std::unordered_map<std::string, TensorIdWithTensor> cell_tensor_id_with_tensor_;
|
||||||
// Used to cache cast struct
|
// Used to cache cast struct
|
||||||
std::unordered_map<std::string, OpExecInfoPtr> cast_struct_map_;
|
std::unordered_map<std::string, OpExecInfoPtr> cast_struct_map_;
|
||||||
// Used to cache op_mask
|
|
||||||
std::unordered_map<std::string, int64_t> op_mask_map_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
||||||
|
|
Loading…
Reference in New Issue