forked from mindspore-Ecosystem/mindspore
!16137 [PyNative]Solve Opmask Leak Problem
From: @chenyijie6 Reviewed-by: @wilfchen,@limingqi107 Signed-off-by: @limingqi107
This commit is contained in:
commit
84a545f0c2
|
@ -709,28 +709,9 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) {
|
|||
return op_exec_info;
|
||||
}
|
||||
|
||||
bool ForwardExecutor::FindOpMask(py::object obj, std::vector<int64_t> *op_masks, const std::string &id) {
|
||||
bool op_mask = false;
|
||||
auto temp = op_mask_map_.find(id);
|
||||
if (temp != op_mask_map_.end()) {
|
||||
op_mask = temp->second;
|
||||
(*op_masks).emplace_back(op_mask);
|
||||
} else {
|
||||
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
||||
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
|
||||
if (meta_tensor) {
|
||||
op_mask = meta_tensor->is_parameter();
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Gen args op_mask " << op_mask;
|
||||
op_mask_map_[id] = op_mask;
|
||||
(*op_masks).emplace_back(op_mask);
|
||||
}
|
||||
return op_mask;
|
||||
}
|
||||
|
||||
void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
|
||||
std::vector<AnfNodePtr> *inputs, abstract::AbstractBasePtrList *args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(op_masks);
|
||||
auto prim = op_exec_info->py_primitive;
|
||||
for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
|
||||
abstract::AbstractBasePtr abs = nullptr;
|
||||
|
@ -740,8 +721,16 @@ void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector
|
|||
if (it != node_abs_map_.end()) {
|
||||
abs = it->second;
|
||||
}
|
||||
// Find the opmask of input obj
|
||||
bool op_mask = FindOpMask(obj, op_masks, id);
|
||||
|
||||
bool op_mask = false;
|
||||
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
||||
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
|
||||
if (meta_tensor) {
|
||||
op_mask = meta_tensor->is_parameter();
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Gen args i " << i << op_mask;
|
||||
(*op_masks).emplace_back(op_mask);
|
||||
|
||||
// Construct grad graph
|
||||
if (grad()->need_construct_graph()) {
|
||||
|
@ -1542,7 +1531,6 @@ void ForwardExecutor::ClearRes() {
|
|||
prim_abs_list_.clear();
|
||||
node_abs_map_.clear();
|
||||
cast_struct_map_.clear();
|
||||
op_mask_map_.clear();
|
||||
cell_op_index_with_tensor_id_.clear();
|
||||
cell_tensor_id_with_tensor_.clear();
|
||||
}
|
||||
|
|
|
@ -406,7 +406,6 @@ class ForwardExecutor {
|
|||
PynativeStatusCode *status);
|
||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
bool FindOpMask(py::object obj, std::vector<int64_t> *op_masks, const std::string &id);
|
||||
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
|
||||
|
@ -435,8 +434,6 @@ class ForwardExecutor {
|
|||
std::unordered_map<std::string, TensorIdWithTensor> cell_tensor_id_with_tensor_;
|
||||
// Used to cache cast struct
|
||||
std::unordered_map<std::string, OpExecInfoPtr> cast_struct_map_;
|
||||
// Used to cache op_mask
|
||||
std::unordered_map<std::string, int64_t> op_mask_map_;
|
||||
};
|
||||
|
||||
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
||||
|
|
Loading…
Reference in New Issue