forked from mindspore-Ecosystem/mindspore
Pybind11::object && PrimitivePy recycle optimize
This commit is contained in:
parent
fe1d6e5a78
commit
a1db9c4959
|
@ -1997,20 +1997,25 @@ class IrParser {
|
|||
|
||||
// restore python function of PrimitivePy from serialized file
|
||||
py::object py_obj = LoadObject(lexer_.GetTokenText());
|
||||
PrimitivePyPtr ptr = nullptr;
|
||||
py::object py_adapter = py_obj;
|
||||
static auto len = strlen("PrimitivePy::");
|
||||
bool cloned = false;
|
||||
if (py::hasattr(py_obj, "__setattr_flag__") && py::hasattr(py_obj, "_clone")) {
|
||||
auto clone_fn = py_obj.attr("_clone");
|
||||
py::object new_obj = clone_fn();
|
||||
ptr = new_obj.cast<PrimitivePyPtr>();
|
||||
if (ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cast to type 'PrimitivePyPtr' error";
|
||||
}
|
||||
} else {
|
||||
auto len = strlen("PrimitivePy::");
|
||||
if (id.size() < len) {
|
||||
return TOK_ERROR;
|
||||
}
|
||||
ptr = std::make_shared<PrimitivePy>(id.substr(len), py_obj);
|
||||
py_adapter = clone_fn();
|
||||
cloned = true;
|
||||
} else if (id.size() < len) {
|
||||
return TOK_ERROR;
|
||||
}
|
||||
auto prim_adapter = py_adapter.cast<PrimitivePyAdapterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_adapter);
|
||||
if (!cloned) {
|
||||
prim_adapter->set_name(id.substr(len));
|
||||
}
|
||||
PrimitivePyPtr ptr = prim_adapter->attached_primitive();
|
||||
if (ptr == nullptr) {
|
||||
ptr = std::make_shared<PrimitivePy>(py_adapter, prim_adapter);
|
||||
prim_adapter->set_attached_primitive(ptr);
|
||||
}
|
||||
*val_ptr = ptr;
|
||||
|
||||
|
|
|
@ -366,8 +366,7 @@ FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::Res
|
|||
|
||||
auto func_graph = std::make_shared<FuncGraph>();
|
||||
std::vector<AnfNodePtr> outputs;
|
||||
|
||||
auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut", py::object());
|
||||
auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut");
|
||||
bprop_cut->CopyHookFunction(prim);
|
||||
|
||||
auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
|
||||
|
|
|
@ -157,12 +157,10 @@ REGISTER_PYBIND_DEFINE(
|
|||
(void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>());
|
||||
(void)py::class_<OneOf, std::shared_ptr<OneOf>, Pattern>(*m, "OneOf_").def(py::init<vector<PatternPtr>>());
|
||||
(void)py::class_<Prim, std::shared_ptr<Prim>, Pattern>(*m, "Prim_", py::dynamic_attr())
|
||||
.def(py::init<vector<PrimitivePyPtr>, string>())
|
||||
.def(py::init<vector<string>, string>());
|
||||
.def(py::init<vector<py::object>, string>());
|
||||
(void)py::class_<Call, std::shared_ptr<Call>, Pattern>(*m, "Call_")
|
||||
.def(py::init<PatternPtr, vector<PatternPtr>>())
|
||||
.def(py::init<PrimitivePyPtr, vector<PatternPtr>>())
|
||||
.def(py::init<string, vector<PatternPtr>>());
|
||||
.def(py::init<py::object, vector<PatternPtr>>());
|
||||
(void)py::class_<NoneOf, std::shared_ptr<NoneOf>, Pattern>(*m, "NoneOf_").def(py::init<vector<PatternPtr>>());
|
||||
(void)py::class_<Any, std::shared_ptr<Any>, Pattern>(*m, "Any").def(py::init<>());
|
||||
(void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")
|
||||
|
|
|
@ -87,16 +87,18 @@ class Prim : public Pattern {
|
|||
public:
|
||||
Prim() { unique_name_ = std::to_string(g_id_++); }
|
||||
~Prim() = default;
|
||||
Prim(vector<PrimitivePyPtr> prims, string name) : primitives_(prims), name_(name) {
|
||||
Prim(vector<py::object> prim_objs, string name) : name_(name) {
|
||||
unique_name_ = std::to_string(g_id_++) + "Prim_" + name;
|
||||
// Default using the first prim to build target
|
||||
matched_prim_ = primitives_[0];
|
||||
}
|
||||
Prim(vector<string> types, string name) : types_(types), name_(name) {
|
||||
unique_name_ = std::to_string(g_id_++) + "Prim_" + name;
|
||||
// Make primitives_
|
||||
for (auto &iter : types) {
|
||||
primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr)));
|
||||
for (auto &prim_obj : prim_objs) {
|
||||
if (py::isinstance<PrimitivePyAdapter>(prim_obj)) {
|
||||
auto prim_adapter = prim_obj.cast<PrimitivePyAdapterPtr>();
|
||||
primitives_.push_back(std::make_shared<PrimitivePy>(prim_obj, prim_adapter));
|
||||
} else if (py::isinstance<py::str>(prim_obj)) {
|
||||
std::string prim_name = prim_obj.cast<py::str>();
|
||||
primitives_.push_back(std::make_shared<PrimitivePy>(prim_name));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Parameter of Prim::__init__ must be Primitive_ type or Prim name, please check input.";
|
||||
}
|
||||
}
|
||||
// Default using the first prim to build target
|
||||
matched_prim_ = primitives_[0];
|
||||
|
@ -111,7 +113,6 @@ class Prim : public Pattern {
|
|||
}
|
||||
|
||||
private:
|
||||
vector<string> types_;
|
||||
vector<PrimitivePyPtr> primitives_;
|
||||
string name_;
|
||||
PrimitivePyPtr matched_prim_{nullptr};
|
||||
|
@ -127,16 +128,19 @@ class Call : public Pattern {
|
|||
unique_name_ = std::to_string(g_id_++) + "Call_" + prim_pattern->unique_name();
|
||||
inputs_ = inputs;
|
||||
}
|
||||
Call(PrimitivePyPtr prim, vector<PatternPtr> inputs) {
|
||||
prim_ = prim;
|
||||
Call(py::object prim_obj, vector<PatternPtr> inputs) {
|
||||
if (py::isinstance<PrimitivePyAdapter>(prim_obj)) {
|
||||
auto prim_adapter = prim_obj.cast<PrimitivePyAdapterPtr>();
|
||||
prim_ = std::make_shared<PrimitivePy>(prim_obj, prim_adapter);
|
||||
} else if (py::isinstance<py::str>(prim_obj)) {
|
||||
std::string prim_name = prim_obj.cast<py::str>();
|
||||
prim_ = std::make_shared<PrimitivePy>(prim_name);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Parameter of Call::__init__ must be Primitive_ type or Prim name, please check input.";
|
||||
}
|
||||
unique_name_ = std::to_string(g_id_++) + "Call_" + prim_->ToString();
|
||||
inputs_ = inputs;
|
||||
}
|
||||
Call(string prim_str, vector<PatternPtr> inputs) {
|
||||
prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr));
|
||||
unique_name_ = std::to_string(g_id_++) + "CallStr_" + prim_->ToString();
|
||||
inputs_ = inputs;
|
||||
}
|
||||
MS_DECLARE_PARENT(Call, Pattern);
|
||||
MatchResultPtr match(const AnfNodePtr &node) override;
|
||||
PrimitivePtr prim_value() { return prim_; }
|
||||
|
|
|
@ -119,7 +119,7 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
|
|||
auto bprop_graph = std::make_shared<FuncGraph>();
|
||||
std::vector<AnfNodePtr> outputs;
|
||||
|
||||
auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
|
||||
auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut");
|
||||
fake_bprop->set_hook(bprop_func);
|
||||
(void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
|
||||
outputs.push_back(NewValueNode(fake_bprop));
|
||||
|
@ -236,16 +236,21 @@ ValuePtr ConvertPrimitive(const py::object &obj, bool use_signature = false) {
|
|||
// desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
|
||||
return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
|
||||
}
|
||||
auto primitive = obj.cast<PrimitivePyPtr>();
|
||||
py::object adapter_obj = obj;
|
||||
if (py::hasattr(obj, "__setattr_flag__")) {
|
||||
if (py::hasattr(obj, "_clone")) {
|
||||
auto clone_fn = obj.attr("_clone");
|
||||
adapter_obj = clone_fn();
|
||||
}
|
||||
}
|
||||
auto prim_adapter = adapter_obj.cast<PrimitivePyAdapterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_adapter);
|
||||
auto primitive = prim_adapter->attached_primitive();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null";
|
||||
return nullptr;
|
||||
}
|
||||
if (py::hasattr(obj, "__setattr_flag__") && py::hasattr(obj, "_clone")) {
|
||||
auto clone_fn = obj.attr("_clone");
|
||||
py::object new_obj = clone_fn();
|
||||
primitive = new_obj.cast<PrimitivePyPtr>();
|
||||
primitive = std::make_shared<PrimitivePy>(adapter_obj, prim_adapter);
|
||||
prim_adapter->set_attached_primitive(primitive);
|
||||
}
|
||||
|
||||
if (use_signature) {
|
||||
return std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive);
|
||||
}
|
||||
|
|
|
@ -371,7 +371,6 @@ void ExecutorPy::DelNetRes(const std::string &id) {
|
|||
|
||||
void ExecutorPy::ClearRes() {
|
||||
MS_LOG(INFO) << "Clean executor resource!";
|
||||
Resource::mem_cleaner().ClearPrimitivePyPythonObj();
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
mindspore::RDR::ClearAll();
|
||||
#endif
|
||||
|
@ -384,7 +383,7 @@ ExecutorPy::~ExecutorPy() {
|
|||
}
|
||||
|
||||
void ExecutorPy::GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weight_node,
|
||||
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> *fake_quant_table) {
|
||||
std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> *fake_quant_table) {
|
||||
std::string weight_name;
|
||||
auto x = root_node->input(1);
|
||||
if (IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
|
||||
|
@ -437,15 +436,15 @@ void ExecutorPy::GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weig
|
|||
return;
|
||||
}
|
||||
auto quant_op = quant_op_value->cast<PrimitivePyPtr>();
|
||||
(*fake_quant_table)[weight_name] = std::make_pair(quant_op, fakequant_min_node_name);
|
||||
(*fake_quant_table)[weight_name] = std::make_pair(quant_op->adapter(), fakequant_min_node_name);
|
||||
}
|
||||
|
||||
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchInfoForQuantExport(
|
||||
std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> ExecutorPy::FetchInfoForQuantExport(
|
||||
const std::string &phase_s) {
|
||||
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!";
|
||||
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> fake_quant_table;
|
||||
std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> fake_quant_table;
|
||||
auto filter = [](const AnfNodePtr &node) {
|
||||
return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimDepthwiseConv2dNative));
|
||||
|
@ -472,7 +471,6 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
|
|||
}
|
||||
GetWeightInfo(root_node, weight_node, &fake_quant_table);
|
||||
}
|
||||
|
||||
return fake_quant_table;
|
||||
}
|
||||
|
||||
|
@ -1162,9 +1160,6 @@ void StartUpProfiling() {
|
|||
}
|
||||
|
||||
void InitPipeline() {
|
||||
// If previous pipeline exit with exception, memory cleaner's flags maybe unpredictable, so init when a new pipeline
|
||||
// start.
|
||||
pipeline::Resource::mem_cleaner().Init();
|
||||
// set python env flag
|
||||
mindspore::parse::python_adapter::set_python_env_flag(true);
|
||||
// Startup profiling before open tsd
|
||||
|
|
|
@ -105,13 +105,14 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
|
|||
static void DebugTerminate(bool val) { debugger_terminate_ = val; }
|
||||
void TerminateDebugger();
|
||||
|
||||
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> FetchInfoForQuantExport(const std::string &phase_s);
|
||||
std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> FetchInfoForQuantExport(
|
||||
const std::string &phase_s);
|
||||
|
||||
private:
|
||||
ExecutorPy();
|
||||
void ConvertObjectToTensors(const py::dict &dict, std::map<std::string, tensor::TensorPtr> *tensors);
|
||||
void GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weight_node,
|
||||
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> *fake_quant_table);
|
||||
std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> *fake_quant_table);
|
||||
void GetGeBackendPolicy() const;
|
||||
// filter some pipeline actions according to phase, e.g. when exporting onnx, it is no need to execute actions after
|
||||
// 'validate' stage
|
||||
|
|
|
@ -308,91 +308,5 @@ void Resource::Clean() {
|
|||
is_cleaned_ = true;
|
||||
}
|
||||
|
||||
void MemoryCleaner::Init() {
|
||||
pynative_in_construct_process_ = false;
|
||||
pynative_in_end_graph_process_ = false;
|
||||
pynative_released_history_.clear();
|
||||
pynative_new_primtives_squence_.clear();
|
||||
}
|
||||
|
||||
MemoryCleaner Resource::mem_cleaner_ = MemoryCleaner();
|
||||
void MemoryCleaner::RecordPrimitivePy(PrimitivePy *prim) {
|
||||
if (prim == nullptr) {
|
||||
return;
|
||||
}
|
||||
all_primitives_[prim] = true;
|
||||
}
|
||||
|
||||
void MemoryCleaner::ReleasePrimitivePyObj(PrimitivePy *prim) {
|
||||
if (prim == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto it = all_primitives_.find(prim);
|
||||
if (it == all_primitives_.end()) {
|
||||
return;
|
||||
}
|
||||
// If flag is false,the pointer hased been released, so it can't be visited.
|
||||
if (!it->second) {
|
||||
return;
|
||||
}
|
||||
all_primitives_[prim] = false;
|
||||
prim->SetPyObj(py::none());
|
||||
}
|
||||
|
||||
void MemoryCleaner::ClearPrimitivePyPythonObj() {
|
||||
for (auto &it : all_primitives_) {
|
||||
if (it.second) {
|
||||
it.first->SetPyObj(py::none());
|
||||
}
|
||||
}
|
||||
all_primitives_.clear();
|
||||
}
|
||||
|
||||
void MemoryCleaner::RecordPynativeShortLifePrimitivePy(PrimitivePy *prim) {
|
||||
if (prim == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (pynative_short_life_primitives_.find(prim) != pynative_short_life_primitives_.end()) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Record pynative tmp primitive:" << prim->ToString();
|
||||
pynative_short_life_primitives_.insert(prim);
|
||||
pynative_new_primtives_squence_.push_back(prim->ToString());
|
||||
}
|
||||
|
||||
void MemoryCleaner::ErasePynativeShortLifePrimitivePy(PrimitivePy *prim) {
|
||||
if (prim == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (pynative_short_life_primitives_.find(prim) == pynative_short_life_primitives_.end()) {
|
||||
return;
|
||||
}
|
||||
pynative_short_life_primitives_.erase(prim);
|
||||
MS_LOG(DEBUG) << "Erase pynative tmp primitive:" << prim->ToString();
|
||||
}
|
||||
|
||||
void MemoryCleaner::ClearPynativeShortLifePrimitivePy() {
|
||||
// If the primitives name sequence never been released before, keep the primtives alive
|
||||
if (std::find(pynative_released_history_.begin(), pynative_released_history_.end(),
|
||||
pynative_new_primtives_squence_) == pynative_released_history_.end()) {
|
||||
pynative_released_history_.push_back(pynative_new_primtives_squence_);
|
||||
} else {
|
||||
for (auto &primitive : pynative_short_life_primitives_) {
|
||||
ReleasePrimitivePyObj(primitive);
|
||||
}
|
||||
}
|
||||
pynative_short_life_primitives_.clear();
|
||||
pynative_new_primtives_squence_.clear();
|
||||
}
|
||||
|
||||
void MemoryCleaner::EnterPynativeConstructProcess() { pynative_in_construct_process_ = true; }
|
||||
void MemoryCleaner::LeavePynativeConstructProcess() {
|
||||
pynative_in_construct_process_ = false;
|
||||
ClearPynativeShortLifePrimitivePy();
|
||||
}
|
||||
bool MemoryCleaner::IsInPynativeConstructProcess() const { return pynative_in_construct_process_; }
|
||||
void MemoryCleaner::EnterPynativeEndGraphProcess() { pynative_in_end_graph_process_ = true; }
|
||||
void MemoryCleaner::LeavePynativeEndGraphProcess() { pynative_in_end_graph_process_ = false; }
|
||||
bool MemoryCleaner::IsInPynativeEndGraphProcess() const { return pynative_in_end_graph_process_; }
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -53,39 +53,6 @@ BuiltInTypeMap &GetMethodMap();
|
|||
|
||||
BuiltInTypeMap &GetAttrMap();
|
||||
|
||||
class MemoryCleaner {
|
||||
public:
|
||||
MemoryCleaner() = default;
|
||||
~MemoryCleaner() = default;
|
||||
void Init();
|
||||
|
||||
void RecordPrimitivePy(PrimitivePy *prim);
|
||||
void ReleasePrimitivePyObj(PrimitivePy *prim);
|
||||
void ClearPrimitivePyPythonObj();
|
||||
|
||||
void RecordPynativeShortLifePrimitivePy(PrimitivePy *prim);
|
||||
void ErasePynativeShortLifePrimitivePy(PrimitivePy *prim);
|
||||
void ClearPynativeShortLifePrimitivePy();
|
||||
|
||||
void EnterPynativeConstructProcess();
|
||||
void LeavePynativeConstructProcess();
|
||||
bool IsInPynativeConstructProcess() const;
|
||||
void EnterPynativeEndGraphProcess();
|
||||
void LeavePynativeEndGraphProcess();
|
||||
bool IsInPynativeEndGraphProcess() const;
|
||||
|
||||
private:
|
||||
std::unordered_map<PrimitivePy *, bool> all_primitives_;
|
||||
// PrimitivePy objects that created in pynative construct process.These primitives should be released after construct
|
||||
// finished.
|
||||
std::unordered_set<PrimitivePy *> pynative_short_life_primitives_;
|
||||
// Sequence of primtive names in one construct process.
|
||||
std::vector<std::string> pynative_new_primtives_squence_;
|
||||
std::vector<std::vector<std::string>> pynative_released_history_;
|
||||
bool pynative_in_construct_process_{false};
|
||||
bool pynative_in_end_graph_process_{false};
|
||||
};
|
||||
|
||||
class Resource : public ResourceBase {
|
||||
public:
|
||||
explicit Resource(const py::object &obj = py::none());
|
||||
|
@ -118,7 +85,6 @@ class Resource : public ResourceBase {
|
|||
// ExecutorPy::Compile() can be called multiple times, so cache
|
||||
// should be cleared.
|
||||
void Clean();
|
||||
static MemoryCleaner &mem_cleaner() { return mem_cleaner_; }
|
||||
|
||||
private:
|
||||
abstract::AnalysisEnginePtr engine_;
|
||||
|
@ -128,8 +94,6 @@ class Resource : public ResourceBase {
|
|||
bool is_cleaned_;
|
||||
bool gpu_loopsink_flag_{false};
|
||||
int64_t gpu_loopsink_size_{1};
|
||||
// Used to handle mem leak objects.
|
||||
static MemoryCleaner mem_cleaner_;
|
||||
};
|
||||
|
||||
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
|
||||
|
|
|
@ -56,7 +56,6 @@ struct OpExecInfo {
|
|||
AbstractBasePtr abstract;
|
||||
|
||||
py::list op_inputs;
|
||||
py::dict op_attrs;
|
||||
std::vector<int64_t> inputs_mask;
|
||||
bool is_dynamic_shape = false;
|
||||
std::string next_op_name = "";
|
||||
|
|
|
@ -689,13 +689,18 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) {
|
|||
}
|
||||
grad()->op_index_map()[op_name]++;
|
||||
}
|
||||
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto adapter = py::cast<PrimitivePyAdapterPtr>(args[PY_PRIM]);
|
||||
MS_EXCEPTION_IF_NULL(adapter);
|
||||
auto prim = adapter->attached_primitive();
|
||||
if (prim == nullptr) {
|
||||
prim = std::make_shared<PrimitivePy>(args[PY_PRIM], adapter);
|
||||
adapter->set_attached_primitive(prim);
|
||||
}
|
||||
|
||||
if (!prim->HasPyObj()) {
|
||||
MS_LOG(EXCEPTION) << "Pyobj is empty";
|
||||
}
|
||||
op_exec_info->py_primitive = prim;
|
||||
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
||||
op_exec_info->op_inputs = args[PY_INPUTS];
|
||||
return op_exec_info;
|
||||
}
|
||||
|
@ -3264,10 +3269,7 @@ void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
|
|||
void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
|
||||
MS_LOG(DEBUG) << "Enter end graph process.";
|
||||
py::object *ret = nullptr;
|
||||
auto &mem_cleaner = pipeline::Resource::mem_cleaner();
|
||||
mem_cleaner.EnterPynativeEndGraphProcess();
|
||||
PynativeExecutorTry(grad_executor()->LinkGraph, ret, cell, out, args);
|
||||
mem_cleaner.LeavePynativeEndGraphProcess();
|
||||
MS_LOG(DEBUG) << "Leave end graph process.";
|
||||
}
|
||||
|
||||
|
@ -3289,7 +3291,6 @@ void PynativeExecutor::EnterConstruct(const py::object &cell) {
|
|||
return;
|
||||
}
|
||||
py_top_cell_ = cell.ptr();
|
||||
pipeline::Resource::mem_cleaner().EnterPynativeConstructProcess();
|
||||
MS_LOG(DEBUG) << "Enter construct process.";
|
||||
}
|
||||
|
||||
|
@ -3298,7 +3299,6 @@ void PynativeExecutor::LeaveConstruct(const py::object &cell) {
|
|||
return;
|
||||
}
|
||||
py_top_cell_ = nullptr;
|
||||
pipeline::Resource::mem_cleaner().LeavePynativeConstructProcess();
|
||||
MS_LOG(DEBUG) << "Leave construct process.";
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <mutex>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "ir/signature.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
|
@ -57,22 +58,21 @@ void SyncData(const py::object &arg) {
|
|||
} // namespace
|
||||
std::map<std::string, py::object> PrimitivePy::hook_grad_;
|
||||
|
||||
PrimitivePy::PrimitivePy(const py::str &name, const py::object &python_obj)
|
||||
: Primitive(name, false), python_obj_(python_obj), signatures_() {
|
||||
auto &mem_cleaner = pipeline::Resource::mem_cleaner();
|
||||
mem_cleaner.RecordPrimitivePy(this);
|
||||
MS_LOG(DEBUG) << "New primitive:" << name;
|
||||
if (mem_cleaner.IsInPynativeConstructProcess() && !mem_cleaner.IsInPynativeEndGraphProcess()) {
|
||||
mem_cleaner.RecordPynativeShortLifePrimitivePy(this);
|
||||
}
|
||||
PrimitivePy::PrimitivePy(const std::string &name) : Primitive(name, false), python_obj_(py::none()) {}
|
||||
|
||||
PrimitivePy::PrimitivePy(const py::object &python_obj, const PrimitivePyAdapterPtr &adapter)
|
||||
: Primitive(adapter->name_, false), python_obj_(python_obj), adapter_(adapter) {
|
||||
MS_LOG(DEBUG) << "New primitive:" << adapter->name_;
|
||||
set_signatures(adapter->signatures_);
|
||||
Primitive::SetAttrs(adapter->attrs_);
|
||||
Primitive::set_prim_type(adapter->prim_type_);
|
||||
Primitive::set_const_prim(adapter->is_const_prim_);
|
||||
Primitive::set_const_input_indexes(adapter->const_input_indexes_);
|
||||
set_hook(adapter->hook_);
|
||||
set_instance_name(adapter->instance_name_);
|
||||
}
|
||||
PrimitivePy::~PrimitivePy() {
|
||||
// Erase primitive here to set released flag false, to avoid calling released pointer when clear primitives in
|
||||
// resource.
|
||||
pipeline::Resource::mem_cleaner().ReleasePrimitivePyObj(this);
|
||||
MS_LOG(DEBUG) << "Release:" << ToString();
|
||||
}
|
||||
void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; }
|
||||
PrimitivePy::~PrimitivePy() { MS_LOG(DEBUG) << "Release:" << ToString(); }
|
||||
|
||||
void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
|
||||
signatures_ = signatures;
|
||||
set_has_signature(!signatures.empty());
|
||||
|
@ -272,29 +272,6 @@ py::function PrimitivePy::GetComputeFunction() const {
|
|||
return vm_fn;
|
||||
}
|
||||
|
||||
void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
|
||||
std::string attr_name = name;
|
||||
ValuePtr converted_ret = nullptr;
|
||||
if (py::isinstance<py::module>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module";
|
||||
}
|
||||
bool converted = parse::ConvertData(obj, &converted_ret);
|
||||
if (!converted) {
|
||||
MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
|
||||
}
|
||||
if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) {
|
||||
attr_name = kOpAttrNameReplaceMap[attr_name];
|
||||
}
|
||||
const std::string &prim_name = this->name();
|
||||
CheckAndConvertUtils::ConvertAttrValueToInt(prim_name, attr_name, &converted_ret);
|
||||
(void)this->AddAttr(attr_name, converted_ret);
|
||||
}
|
||||
|
||||
void PrimitivePy::DelPyAttr(const py::str &name) {
|
||||
std::string attr_name = name;
|
||||
(void)this->DelAttr(attr_name);
|
||||
}
|
||||
|
||||
py::dict PrimitivePy::GetAttrDict() {
|
||||
py::dict attr_dict;
|
||||
for (auto &attr : attrs_) {
|
||||
|
@ -338,9 +315,11 @@ bool PrimitivePy::HasComputeFunction() const {
|
|||
|
||||
PrimitivePtr PrimitivePy::Clone() {
|
||||
auto clone_fn = python_obj_.attr("_clone");
|
||||
py::object new_obj = clone_fn();
|
||||
auto cloned_prim = new_obj.cast<PrimitivePyPtr>();
|
||||
return cloned_prim;
|
||||
py::object obj_adapter = clone_fn();
|
||||
auto prim_adapter = obj_adapter.cast<PrimitivePyAdapterPtr>();
|
||||
auto prim = std::make_shared<PrimitivePy>(obj_adapter, prim_adapter);
|
||||
prim_adapter->set_attached_primitive(prim);
|
||||
return prim;
|
||||
}
|
||||
|
||||
py::dict PrimitivePy::RunInfer(const py::tuple &args) {
|
||||
|
@ -379,6 +358,113 @@ py::object PrimitivePy::RunInferValue(const py::tuple &args) {
|
|||
return infer_value(*args);
|
||||
}
|
||||
|
||||
PrimitivePyAdapter::PrimitivePyAdapter(const py::str &name) : name_(name) {}
|
||||
|
||||
void PrimitivePyAdapter::AddPyAttr(const py::str &name, const py::object &obj) {
|
||||
std::string attr_name = name;
|
||||
ValuePtr converted_ret = nullptr;
|
||||
if (py::isinstance<py::module>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module";
|
||||
}
|
||||
bool converted = parse::ConvertData(obj, &converted_ret);
|
||||
if (!converted) {
|
||||
MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
|
||||
}
|
||||
if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) {
|
||||
attr_name = kOpAttrNameReplaceMap[attr_name];
|
||||
}
|
||||
CheckAndConvertUtils::ConvertAttrValueToInt(name_, name, &converted_ret);
|
||||
auto prim = attached_primitive_.lock();
|
||||
if (prim != nullptr) {
|
||||
prim->AddAttr(attr_name, converted_ret);
|
||||
} else {
|
||||
attrs_[attr_name] = converted_ret;
|
||||
}
|
||||
}
|
||||
|
||||
void PrimitivePyAdapter::DelPyAttr(const py::str &name) {
|
||||
std::string attr_name = name;
|
||||
auto prim = attached_primitive_.lock();
|
||||
if (prim != nullptr) {
|
||||
prim->DelAttr(attr_name);
|
||||
} else {
|
||||
attrs_.erase(attr_name);
|
||||
}
|
||||
}
|
||||
|
||||
py::dict PrimitivePyAdapter::GetAttrDict() {
|
||||
auto prim = attached_primitive_.lock();
|
||||
if (prim != nullptr) {
|
||||
return prim->GetAttrDict();
|
||||
}
|
||||
|
||||
py::dict attr_dict;
|
||||
for (auto &attr : attrs_) {
|
||||
attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second);
|
||||
}
|
||||
return attr_dict;
|
||||
}
|
||||
|
||||
void PrimitivePyAdapter::set_prim_type(const PrimType t) {
|
||||
auto prim = attached_primitive_.lock();
|
||||
if (prim != nullptr) {
|
||||
prim->set_prim_type(t);
|
||||
} else {
|
||||
prim_type_ = t;
|
||||
}
|
||||
}
|
||||
void PrimitivePyAdapter::set_const_prim(bool is_const_prim) {
|
||||
auto prim = attached_primitive_.lock();
|
||||
if (prim != nullptr) {
|
||||
prim->set_const_prim(is_const_prim);
|
||||
} else {
|
||||
is_const_prim_ = is_const_prim;
|
||||
}
|
||||
}
|
||||
void PrimitivePyAdapter::set_const_input_indexes(const std::vector<size_t> &const_input_indexes) {
|
||||
auto prim = attached_primitive_.lock();
|
||||
if (prim != nullptr) {
|
||||
prim->set_const_input_indexes(const_input_indexes);
|
||||
} else {
|
||||
const_input_indexes_ = const_input_indexes;
|
||||
}
|
||||
}
|
||||
|
||||
void PrimitivePyAdapter::set_signatures(const std::vector<Signature> &signatures) {
|
||||
auto prim = attached_primitive_.lock();
|
||||
if (prim != nullptr) {
|
||||
prim->set_signatures(signatures);
|
||||
} else {
|
||||
signatures_ = signatures;
|
||||
}
|
||||
}
|
||||
|
||||
void PrimitivePyAdapter::set_hook(const py::function &hook) {
|
||||
auto prim = attached_primitive_.lock();
|
||||
if (prim != nullptr) {
|
||||
prim->set_hook(hook);
|
||||
} else {
|
||||
hook_ = hook;
|
||||
}
|
||||
}
|
||||
|
||||
void PrimitivePyAdapter::set_instance_name(const std::string &s) {
|
||||
auto prim = attached_primitive_.lock();
|
||||
if (prim != nullptr) {
|
||||
prim->set_instance_name(s);
|
||||
} else {
|
||||
instance_name_ = s;
|
||||
}
|
||||
}
|
||||
|
||||
void PrimitivePyAdapter::set_attached_primitive(const PrimitivePyPtr &prim) {
|
||||
if (attached_primitive_.lock() != nullptr) {
|
||||
MS_LOG(EXCEPTION) << "PrimitivePyAdapter can't attach to multi Primitive.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
attached_primitive_ = prim;
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
|
||||
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
|
||||
.value("unknown", PrimType::kPrimTypeUnknown)
|
||||
|
@ -386,18 +472,20 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
|
|||
.value("py_infer_shape", PrimType::kPrimTypePyInferShape)
|
||||
.value("user_custom", PrimType::kPrimTypeUserCustom)
|
||||
.value("py_infer_check", PrimType::kPrimTypePyInferCheck);
|
||||
(void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_")
|
||||
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
|
||||
.def(py::init<py::str &, py::object>())
|
||||
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
|
||||
.def("del_attr", &PrimitivePy::DelPyAttr, "del primitive attr")
|
||||
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
|
||||
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
|
||||
.def("set_const_prim", &PrimitivePy::set_const_prim, "Set primitive is const.")
|
||||
.def("set_const_input_indexes", &PrimitivePy::set_const_input_indexes,
|
||||
(void)py::class_<PrimitivePyAdapter, std::shared_ptr<PrimitivePyAdapter>>(*m, "Primitive_")
|
||||
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_)
|
||||
.def(py::init<py::str &>())
|
||||
.def("add_attr", &PrimitivePyAdapter::AddPyAttr, "add primitive attr")
|
||||
.def("del_attr", &PrimitivePyAdapter::DelPyAttr, "del primitive attr")
|
||||
.def("get_attr_dict", &PrimitivePyAdapter::GetAttrDict, "get primitive attr")
|
||||
.def("set_prim_type", &PrimitivePyAdapter::set_prim_type, "Set primitive type.")
|
||||
.def("set_const_prim", &PrimitivePyAdapter::set_const_prim, "Set primitive is const.")
|
||||
.def("set_const_input_indexes", &PrimitivePyAdapter::set_const_input_indexes,
|
||||
"Set primitive const input indexes.")
|
||||
.def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.")
|
||||
.def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.")
|
||||
.def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name.");
|
||||
.def("set_signatures", &PrimitivePyAdapter::set_signatures,
|
||||
"Set primitive inputs signature.")
|
||||
.def("register_hook", &PrimitivePyAdapter::set_hook, "Set primitive hook function.")
|
||||
.def("set_instance_name", &PrimitivePyAdapter::set_instance_name,
|
||||
"Set primitive instance name.");
|
||||
}));
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,9 +34,18 @@
|
|||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
|
||||
class PrimitivePy;
|
||||
using PrimitivePyPtr = std::shared_ptr<PrimitivePy>;
|
||||
using PrimitivePyWeakPtr = std::weak_ptr<PrimitivePy>;
|
||||
|
||||
class PrimitivePyAdapter;
|
||||
using PrimitivePyAdapterPtr = std::shared_ptr<PrimitivePyAdapter>;
|
||||
|
||||
class PrimitivePy : public Primitive {
|
||||
public:
|
||||
PrimitivePy(const py::str &name, const py::object &python_obj);
|
||||
explicit PrimitivePy(const std::string &name);
|
||||
PrimitivePy(const py::object &python_obj, const PrimitivePyAdapterPtr &adapter);
|
||||
~PrimitivePy() override;
|
||||
MS_DECLARE_PARENT(PrimitivePy, Primitive);
|
||||
py::function GetBpropFunction();
|
||||
|
@ -47,10 +56,6 @@ class PrimitivePy : public Primitive {
|
|||
|
||||
void CopyHookFunction(const PrimitivePtr &primitive) override;
|
||||
|
||||
void AddPyAttr(const py::str &name, const py::object &obj);
|
||||
|
||||
void DelPyAttr(const py::str &name);
|
||||
|
||||
py::dict GetAttrDict();
|
||||
void set_hook(const py::function &hook) { hook_ = hook; }
|
||||
py::function hook() const { return hook_; }
|
||||
|
@ -61,13 +66,13 @@ class PrimitivePy : public Primitive {
|
|||
bool HasComputeFunction() const;
|
||||
const bool parse_info_ = true;
|
||||
const py::object &GetPyObj() const { return python_obj_; }
|
||||
void SetPyObj(const py::object &obj);
|
||||
py::dict RunInfer(const py::tuple &args);
|
||||
void RunCheck(const py::tuple &args);
|
||||
py::object RunInferValue(const py::tuple &args);
|
||||
bool ObjHasAttr(const char *attr_name) { return py::hasattr(python_obj_, attr_name); }
|
||||
bool HasPyObj() { return python_obj_.operator bool(); }
|
||||
PrimitivePtr Clone() override;
|
||||
PrimitivePyAdapterPtr adapter() const { return adapter_; }
|
||||
bool is_tuple_input_ = false;
|
||||
|
||||
private:
|
||||
|
@ -75,11 +80,41 @@ class PrimitivePy : public Primitive {
|
|||
void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) const;
|
||||
void CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const;
|
||||
py::object python_obj_;
|
||||
PrimitivePyAdapterPtr adapter_;
|
||||
py::function hook_;
|
||||
std::vector<Signature> signatures_;
|
||||
static std::map<std::string, py::object> hook_grad_;
|
||||
};
|
||||
|
||||
using PrimitivePyPtr = std::shared_ptr<PrimitivePy>;
|
||||
class PrimitivePyAdapter {
|
||||
public:
|
||||
explicit PrimitivePyAdapter(const py::str &name);
|
||||
~PrimitivePyAdapter() = default;
|
||||
void AddPyAttr(const py::str &name, const py::object &obj);
|
||||
void DelPyAttr(const py::str &name);
|
||||
py::dict GetAttrDict();
|
||||
void set_prim_type(const PrimType t);
|
||||
void set_const_prim(bool is_const_prim);
|
||||
void set_const_input_indexes(const std::vector<size_t> &const_input_indexes);
|
||||
void set_signatures(const std::vector<Signature> &signatures);
|
||||
void set_hook(const py::function &hook);
|
||||
void set_instance_name(const std::string &s);
|
||||
void set_attached_primitive(const PrimitivePyPtr &prim);
|
||||
PrimitivePyPtr attached_primitive() { return attached_primitive_.lock(); }
|
||||
void set_name(const std::string &name) { name_ = name; }
|
||||
const bool parse_info_ = true;
|
||||
|
||||
private:
|
||||
friend PrimitivePy;
|
||||
std::string name_;
|
||||
PrimitivePyWeakPtr attached_primitive_;
|
||||
std::unordered_map<std::string, ValuePtr> attrs_;
|
||||
PrimType prim_type_{kPrimTypeBuiltIn};
|
||||
bool is_const_prim_{false};
|
||||
std::vector<size_t> const_input_indexes_;
|
||||
std::vector<Signature> signatures_;
|
||||
py::function hook_;
|
||||
std::string instance_name_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_
|
||||
|
|
|
@ -99,7 +99,7 @@ class Primitive : public Named {
|
|||
}
|
||||
void set_prim_type(const PrimType t) { prim_type_ = t; }
|
||||
virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); }
|
||||
void set_instance_name(const std::string s) { instance_name_ = s; }
|
||||
void set_instance_name(const std::string &s) { instance_name_ = s; }
|
||||
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; }
|
||||
bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; }
|
||||
bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }
|
||||
|
|
|
@ -50,7 +50,7 @@ class Primitive(Primitive_):
|
|||
self.attrs = {}
|
||||
self.init_attrs = {"name": name}
|
||||
self._update_parameter = False
|
||||
Primitive_.__init__(self, name, self)
|
||||
Primitive_.__init__(self, name)
|
||||
if hasattr(self.__class__, '__mindspore_signature__'):
|
||||
out = self._fill_signature(self.__class__.__mindspore_signature__)
|
||||
self.set_signatures(out)
|
||||
|
|
|
@ -94,13 +94,13 @@ TEST_F(TestCompileSegmentRunner, test_if) {
|
|||
|
||||
TEST_F(TestCompileSegmentRunner, test_RunOperation1) {
|
||||
VectorRef args({1});
|
||||
auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimIdentity->name()), py::none()), args);
|
||||
auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimIdentity->name())), args);
|
||||
ASSERT_EQ(py::cast<int>(BaseRefToPyData(res)), 1);
|
||||
}
|
||||
|
||||
TEST_F(TestCompileSegmentRunner, test_RunOperation2) {
|
||||
VectorRef args({1, 2});
|
||||
auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimScalarGt->name()), py::none()), args);
|
||||
auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimScalarGt->name())), args);
|
||||
ASSERT_EQ(py::cast<bool>(BaseRefToPyData(res)), false);
|
||||
}
|
||||
} // namespace compile
|
||||
|
|
Loading…
Reference in New Issue