Pybind11::object && PrimitivePy recycle optimize

This commit is contained in:
zhangzhaoju 2021-04-13 16:33:57 +08:00
parent fe1d6e5a78
commit a1db9c4959
16 changed files with 257 additions and 250 deletions

View File

@ -1997,20 +1997,25 @@ class IrParser {
// restore python function of PrimitivePy from serialized file
py::object py_obj = LoadObject(lexer_.GetTokenText());
PrimitivePyPtr ptr = nullptr;
py::object py_adapter = py_obj;
static auto len = strlen("PrimitivePy::");
bool cloned = false;
if (py::hasattr(py_obj, "__setattr_flag__") && py::hasattr(py_obj, "_clone")) {
auto clone_fn = py_obj.attr("_clone");
py::object new_obj = clone_fn();
ptr = new_obj.cast<PrimitivePyPtr>();
if (ptr == nullptr) {
MS_LOG(EXCEPTION) << "Cast to type 'PrimitivePyPtr' error";
}
} else {
auto len = strlen("PrimitivePy::");
if (id.size() < len) {
return TOK_ERROR;
}
ptr = std::make_shared<PrimitivePy>(id.substr(len), py_obj);
py_adapter = clone_fn();
cloned = true;
} else if (id.size() < len) {
return TOK_ERROR;
}
auto prim_adapter = py_adapter.cast<PrimitivePyAdapterPtr>();
MS_EXCEPTION_IF_NULL(prim_adapter);
if (!cloned) {
prim_adapter->set_name(id.substr(len));
}
PrimitivePyPtr ptr = prim_adapter->attached_primitive();
if (ptr == nullptr) {
ptr = std::make_shared<PrimitivePy>(py_adapter, prim_adapter);
prim_adapter->set_attached_primitive(ptr);
}
*val_ptr = ptr;

View File

@ -366,8 +366,7 @@ FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::Res
auto func_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> outputs;
auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut", py::object());
auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut");
bprop_cut->CopyHookFunction(prim);
auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));

View File

@ -157,12 +157,10 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<Pattern, std::shared_ptr<Pattern>>(*m, "Pattern").def(py::init<>());
(void)py::class_<OneOf, std::shared_ptr<OneOf>, Pattern>(*m, "OneOf_").def(py::init<vector<PatternPtr>>());
(void)py::class_<Prim, std::shared_ptr<Prim>, Pattern>(*m, "Prim_", py::dynamic_attr())
.def(py::init<vector<PrimitivePyPtr>, string>())
.def(py::init<vector<string>, string>());
.def(py::init<vector<py::object>, string>());
(void)py::class_<Call, std::shared_ptr<Call>, Pattern>(*m, "Call_")
.def(py::init<PatternPtr, vector<PatternPtr>>())
.def(py::init<PrimitivePyPtr, vector<PatternPtr>>())
.def(py::init<string, vector<PatternPtr>>());
.def(py::init<py::object, vector<PatternPtr>>());
(void)py::class_<NoneOf, std::shared_ptr<NoneOf>, Pattern>(*m, "NoneOf_").def(py::init<vector<PatternPtr>>());
(void)py::class_<Any, std::shared_ptr<Any>, Pattern>(*m, "Any").def(py::init<>());
(void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")

View File

@ -87,16 +87,18 @@ class Prim : public Pattern {
public:
Prim() { unique_name_ = std::to_string(g_id_++); }
~Prim() = default;
Prim(vector<PrimitivePyPtr> prims, string name) : primitives_(prims), name_(name) {
Prim(vector<py::object> prim_objs, string name) : name_(name) {
unique_name_ = std::to_string(g_id_++) + "Prim_" + name;
// Default using the first prim to build target
matched_prim_ = primitives_[0];
}
Prim(vector<string> types, string name) : types_(types), name_(name) {
unique_name_ = std::to_string(g_id_++) + "Prim_" + name;
// Make primitives_
for (auto &iter : types) {
primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr)));
for (auto &prim_obj : prim_objs) {
if (py::isinstance<PrimitivePyAdapter>(prim_obj)) {
auto prim_adapter = prim_obj.cast<PrimitivePyAdapterPtr>();
primitives_.push_back(std::make_shared<PrimitivePy>(prim_obj, prim_adapter));
} else if (py::isinstance<py::str>(prim_obj)) {
std::string prim_name = prim_obj.cast<py::str>();
primitives_.push_back(std::make_shared<PrimitivePy>(prim_name));
} else {
MS_LOG(EXCEPTION) << "Parameter of Prim::__init__ must be Primitive_ type or Prim name, please check input.";
}
}
// Default using the first prim to build target
matched_prim_ = primitives_[0];
@ -111,7 +113,6 @@ class Prim : public Pattern {
}
private:
vector<string> types_;
vector<PrimitivePyPtr> primitives_;
string name_;
PrimitivePyPtr matched_prim_{nullptr};
@ -127,16 +128,19 @@ class Call : public Pattern {
unique_name_ = std::to_string(g_id_++) + "Call_" + prim_pattern->unique_name();
inputs_ = inputs;
}
Call(PrimitivePyPtr prim, vector<PatternPtr> inputs) {
prim_ = prim;
Call(py::object prim_obj, vector<PatternPtr> inputs) {
if (py::isinstance<PrimitivePyAdapter>(prim_obj)) {
auto prim_adapter = prim_obj.cast<PrimitivePyAdapterPtr>();
prim_ = std::make_shared<PrimitivePy>(prim_obj, prim_adapter);
} else if (py::isinstance<py::str>(prim_obj)) {
std::string prim_name = prim_obj.cast<py::str>();
prim_ = std::make_shared<PrimitivePy>(prim_name);
} else {
MS_LOG(EXCEPTION) << "Parameter of Call::__init__ must be Primitive_ type or Prim name, please check input.";
}
unique_name_ = std::to_string(g_id_++) + "Call_" + prim_->ToString();
inputs_ = inputs;
}
Call(string prim_str, vector<PatternPtr> inputs) {
prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr));
unique_name_ = std::to_string(g_id_++) + "CallStr_" + prim_->ToString();
inputs_ = inputs;
}
MS_DECLARE_PARENT(Call, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override;
PrimitivePtr prim_value() { return prim_; }

View File

@ -119,7 +119,7 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
auto bprop_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> outputs;
auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut");
fake_bprop->set_hook(bprop_func);
(void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
outputs.push_back(NewValueNode(fake_bprop));
@ -236,16 +236,21 @@ ValuePtr ConvertPrimitive(const py::object &obj, bool use_signature = false) {
// desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
}
auto primitive = obj.cast<PrimitivePyPtr>();
py::object adapter_obj = obj;
if (py::hasattr(obj, "__setattr_flag__")) {
if (py::hasattr(obj, "_clone")) {
auto clone_fn = obj.attr("_clone");
adapter_obj = clone_fn();
}
}
auto prim_adapter = adapter_obj.cast<PrimitivePyAdapterPtr>();
MS_EXCEPTION_IF_NULL(prim_adapter);
auto primitive = prim_adapter->attached_primitive();
if (primitive == nullptr) {
MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null";
return nullptr;
}
if (py::hasattr(obj, "__setattr_flag__") && py::hasattr(obj, "_clone")) {
auto clone_fn = obj.attr("_clone");
py::object new_obj = clone_fn();
primitive = new_obj.cast<PrimitivePyPtr>();
primitive = std::make_shared<PrimitivePy>(adapter_obj, prim_adapter);
prim_adapter->set_attached_primitive(primitive);
}
if (use_signature) {
return std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive);
}

View File

@ -371,7 +371,6 @@ void ExecutorPy::DelNetRes(const std::string &id) {
void ExecutorPy::ClearRes() {
MS_LOG(INFO) << "Clean executor resource!";
Resource::mem_cleaner().ClearPrimitivePyPythonObj();
#ifdef ENABLE_DUMP_IR
mindspore::RDR::ClearAll();
#endif
@ -384,7 +383,7 @@ ExecutorPy::~ExecutorPy() {
}
void ExecutorPy::GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weight_node,
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> *fake_quant_table) {
std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> *fake_quant_table) {
std::string weight_name;
auto x = root_node->input(1);
if (IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
@ -437,15 +436,15 @@ void ExecutorPy::GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weig
return;
}
auto quant_op = quant_op_value->cast<PrimitivePyPtr>();
(*fake_quant_table)[weight_name] = std::make_pair(quant_op, fakequant_min_node_name);
(*fake_quant_table)[weight_name] = std::make_pair(quant_op->adapter(), fakequant_min_node_name);
}
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchInfoForQuantExport(
std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> ExecutorPy::FetchInfoForQuantExport(
const std::string &phase_s) {
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!";
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> fake_quant_table;
std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> fake_quant_table;
auto filter = [](const AnfNodePtr &node) {
return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul) ||
IsPrimitiveCNode(node, prim::kPrimDepthwiseConv2dNative));
@ -472,7 +471,6 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
}
GetWeightInfo(root_node, weight_node, &fake_quant_table);
}
return fake_quant_table;
}
@ -1162,9 +1160,6 @@ void StartUpProfiling() {
}
void InitPipeline() {
// If previous pipeline exit with exception, memory cleaner's flags maybe unpredictable, so init when a new pipeline
// start.
pipeline::Resource::mem_cleaner().Init();
// set python env flag
mindspore::parse::python_adapter::set_python_env_flag(true);
// Startup profiling before open tsd

View File

@ -105,13 +105,14 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
static void DebugTerminate(bool val) { debugger_terminate_ = val; }
void TerminateDebugger();
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> FetchInfoForQuantExport(const std::string &phase_s);
std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> FetchInfoForQuantExport(
const std::string &phase_s);
private:
ExecutorPy();
void ConvertObjectToTensors(const py::dict &dict, std::map<std::string, tensor::TensorPtr> *tensors);
void GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weight_node,
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> *fake_quant_table);
std::map<std::string, std::pair<PrimitivePyAdapterPtr, std::string>> *fake_quant_table);
void GetGeBackendPolicy() const;
// filter some pipeline actions according to phase, e.g. when exporting onnx, it is no need to execute actions after
// 'validate' stage

View File

@ -308,91 +308,5 @@ void Resource::Clean() {
is_cleaned_ = true;
}
void MemoryCleaner::Init() {
pynative_in_construct_process_ = false;
pynative_in_end_graph_process_ = false;
pynative_released_history_.clear();
pynative_new_primtives_squence_.clear();
}
MemoryCleaner Resource::mem_cleaner_ = MemoryCleaner();
void MemoryCleaner::RecordPrimitivePy(PrimitivePy *prim) {
if (prim == nullptr) {
return;
}
all_primitives_[prim] = true;
}
void MemoryCleaner::ReleasePrimitivePyObj(PrimitivePy *prim) {
if (prim == nullptr) {
return;
}
auto it = all_primitives_.find(prim);
if (it == all_primitives_.end()) {
return;
}
// If flag is false,the pointer hased been released, so it can't be visited.
if (!it->second) {
return;
}
all_primitives_[prim] = false;
prim->SetPyObj(py::none());
}
void MemoryCleaner::ClearPrimitivePyPythonObj() {
for (auto &it : all_primitives_) {
if (it.second) {
it.first->SetPyObj(py::none());
}
}
all_primitives_.clear();
}
void MemoryCleaner::RecordPynativeShortLifePrimitivePy(PrimitivePy *prim) {
if (prim == nullptr) {
return;
}
if (pynative_short_life_primitives_.find(prim) != pynative_short_life_primitives_.end()) {
return;
}
MS_LOG(DEBUG) << "Record pynative tmp primitive:" << prim->ToString();
pynative_short_life_primitives_.insert(prim);
pynative_new_primtives_squence_.push_back(prim->ToString());
}
void MemoryCleaner::ErasePynativeShortLifePrimitivePy(PrimitivePy *prim) {
if (prim == nullptr) {
return;
}
if (pynative_short_life_primitives_.find(prim) == pynative_short_life_primitives_.end()) {
return;
}
pynative_short_life_primitives_.erase(prim);
MS_LOG(DEBUG) << "Erase pynative tmp primitive:" << prim->ToString();
}
void MemoryCleaner::ClearPynativeShortLifePrimitivePy() {
// If the primitives name sequence never been released before, keep the primtives alive
if (std::find(pynative_released_history_.begin(), pynative_released_history_.end(),
pynative_new_primtives_squence_) == pynative_released_history_.end()) {
pynative_released_history_.push_back(pynative_new_primtives_squence_);
} else {
for (auto &primitive : pynative_short_life_primitives_) {
ReleasePrimitivePyObj(primitive);
}
}
pynative_short_life_primitives_.clear();
pynative_new_primtives_squence_.clear();
}
void MemoryCleaner::EnterPynativeConstructProcess() { pynative_in_construct_process_ = true; }
void MemoryCleaner::LeavePynativeConstructProcess() {
pynative_in_construct_process_ = false;
ClearPynativeShortLifePrimitivePy();
}
bool MemoryCleaner::IsInPynativeConstructProcess() const { return pynative_in_construct_process_; }
void MemoryCleaner::EnterPynativeEndGraphProcess() { pynative_in_end_graph_process_ = true; }
void MemoryCleaner::LeavePynativeEndGraphProcess() { pynative_in_end_graph_process_ = false; }
bool MemoryCleaner::IsInPynativeEndGraphProcess() const { return pynative_in_end_graph_process_; }
} // namespace pipeline
} // namespace mindspore

View File

@ -53,39 +53,6 @@ BuiltInTypeMap &GetMethodMap();
BuiltInTypeMap &GetAttrMap();
class MemoryCleaner {
public:
MemoryCleaner() = default;
~MemoryCleaner() = default;
void Init();
void RecordPrimitivePy(PrimitivePy *prim);
void ReleasePrimitivePyObj(PrimitivePy *prim);
void ClearPrimitivePyPythonObj();
void RecordPynativeShortLifePrimitivePy(PrimitivePy *prim);
void ErasePynativeShortLifePrimitivePy(PrimitivePy *prim);
void ClearPynativeShortLifePrimitivePy();
void EnterPynativeConstructProcess();
void LeavePynativeConstructProcess();
bool IsInPynativeConstructProcess() const;
void EnterPynativeEndGraphProcess();
void LeavePynativeEndGraphProcess();
bool IsInPynativeEndGraphProcess() const;
private:
std::unordered_map<PrimitivePy *, bool> all_primitives_;
// PrimitivePy objects that created in pynative construct process.These primitives should be released after construct
// finished.
std::unordered_set<PrimitivePy *> pynative_short_life_primitives_;
// Sequence of primtive names in one construct process.
std::vector<std::string> pynative_new_primtives_squence_;
std::vector<std::vector<std::string>> pynative_released_history_;
bool pynative_in_construct_process_{false};
bool pynative_in_end_graph_process_{false};
};
class Resource : public ResourceBase {
public:
explicit Resource(const py::object &obj = py::none());
@ -118,7 +85,6 @@ class Resource : public ResourceBase {
// ExecutorPy::Compile() can be called multiple times, so cache
// should be cleared.
void Clean();
static MemoryCleaner &mem_cleaner() { return mem_cleaner_; }
private:
abstract::AnalysisEnginePtr engine_;
@ -128,8 +94,6 @@ class Resource : public ResourceBase {
bool is_cleaned_;
bool gpu_loopsink_flag_{false};
int64_t gpu_loopsink_size_{1};
// Used to handle mem leak objects.
static MemoryCleaner mem_cleaner_;
};
using ResourcePtr = std::shared_ptr<pipeline::Resource>;

View File

@ -56,7 +56,6 @@ struct OpExecInfo {
AbstractBasePtr abstract;
py::list op_inputs;
py::dict op_attrs;
std::vector<int64_t> inputs_mask;
bool is_dynamic_shape = false;
std::string next_op_name = "";

View File

@ -689,13 +689,18 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) {
}
grad()->op_index_map()[op_name]++;
}
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
MS_EXCEPTION_IF_NULL(prim);
auto adapter = py::cast<PrimitivePyAdapterPtr>(args[PY_PRIM]);
MS_EXCEPTION_IF_NULL(adapter);
auto prim = adapter->attached_primitive();
if (prim == nullptr) {
prim = std::make_shared<PrimitivePy>(args[PY_PRIM], adapter);
adapter->set_attached_primitive(prim);
}
if (!prim->HasPyObj()) {
MS_LOG(EXCEPTION) << "Pyobj is empty";
}
op_exec_info->py_primitive = prim;
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
op_exec_info->op_inputs = args[PY_INPUTS];
return op_exec_info;
}
@ -3264,10 +3269,7 @@ void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
MS_LOG(DEBUG) << "Enter end graph process.";
py::object *ret = nullptr;
auto &mem_cleaner = pipeline::Resource::mem_cleaner();
mem_cleaner.EnterPynativeEndGraphProcess();
PynativeExecutorTry(grad_executor()->LinkGraph, ret, cell, out, args);
mem_cleaner.LeavePynativeEndGraphProcess();
MS_LOG(DEBUG) << "Leave end graph process.";
}
@ -3289,7 +3291,6 @@ void PynativeExecutor::EnterConstruct(const py::object &cell) {
return;
}
py_top_cell_ = cell.ptr();
pipeline::Resource::mem_cleaner().EnterPynativeConstructProcess();
MS_LOG(DEBUG) << "Enter construct process.";
}
@ -3298,7 +3299,6 @@ void PynativeExecutor::LeaveConstruct(const py::object &cell) {
return;
}
py_top_cell_ = nullptr;
pipeline::Resource::mem_cleaner().LeavePynativeConstructProcess();
MS_LOG(DEBUG) << "Leave construct process.";
}

View File

@ -18,6 +18,7 @@
#include <mutex>
#include <map>
#include <utility>
#include "ir/signature.h"
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/parse/python_adapter.h"
@ -57,22 +58,21 @@ void SyncData(const py::object &arg) {
} // namespace
std::map<std::string, py::object> PrimitivePy::hook_grad_;
PrimitivePy::PrimitivePy(const py::str &name, const py::object &python_obj)
: Primitive(name, false), python_obj_(python_obj), signatures_() {
auto &mem_cleaner = pipeline::Resource::mem_cleaner();
mem_cleaner.RecordPrimitivePy(this);
MS_LOG(DEBUG) << "New primitive:" << name;
if (mem_cleaner.IsInPynativeConstructProcess() && !mem_cleaner.IsInPynativeEndGraphProcess()) {
mem_cleaner.RecordPynativeShortLifePrimitivePy(this);
}
PrimitivePy::PrimitivePy(const std::string &name) : Primitive(name, false), python_obj_(py::none()) {}
PrimitivePy::PrimitivePy(const py::object &python_obj, const PrimitivePyAdapterPtr &adapter)
: Primitive(adapter->name_, false), python_obj_(python_obj), adapter_(adapter) {
MS_LOG(DEBUG) << "New primitive:" << adapter->name_;
set_signatures(adapter->signatures_);
Primitive::SetAttrs(adapter->attrs_);
Primitive::set_prim_type(adapter->prim_type_);
Primitive::set_const_prim(adapter->is_const_prim_);
Primitive::set_const_input_indexes(adapter->const_input_indexes_);
set_hook(adapter->hook_);
set_instance_name(adapter->instance_name_);
}
PrimitivePy::~PrimitivePy() {
// Erase primitive here to set released flag false, to avoid calling released pointer when clear primitives in
// resource.
pipeline::Resource::mem_cleaner().ReleasePrimitivePyObj(this);
MS_LOG(DEBUG) << "Release:" << ToString();
}
void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; }
PrimitivePy::~PrimitivePy() { MS_LOG(DEBUG) << "Release:" << ToString(); }
void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
signatures_ = signatures;
set_has_signature(!signatures.empty());
@ -272,29 +272,6 @@ py::function PrimitivePy::GetComputeFunction() const {
return vm_fn;
}
void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
std::string attr_name = name;
ValuePtr converted_ret = nullptr;
if (py::isinstance<py::module>(obj)) {
MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module";
}
bool converted = parse::ConvertData(obj, &converted_ret);
if (!converted) {
MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
}
if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) {
attr_name = kOpAttrNameReplaceMap[attr_name];
}
const std::string &prim_name = this->name();
CheckAndConvertUtils::ConvertAttrValueToInt(prim_name, attr_name, &converted_ret);
(void)this->AddAttr(attr_name, converted_ret);
}
void PrimitivePy::DelPyAttr(const py::str &name) {
std::string attr_name = name;
(void)this->DelAttr(attr_name);
}
py::dict PrimitivePy::GetAttrDict() {
py::dict attr_dict;
for (auto &attr : attrs_) {
@ -338,9 +315,11 @@ bool PrimitivePy::HasComputeFunction() const {
PrimitivePtr PrimitivePy::Clone() {
auto clone_fn = python_obj_.attr("_clone");
py::object new_obj = clone_fn();
auto cloned_prim = new_obj.cast<PrimitivePyPtr>();
return cloned_prim;
py::object obj_adapter = clone_fn();
auto prim_adapter = obj_adapter.cast<PrimitivePyAdapterPtr>();
auto prim = std::make_shared<PrimitivePy>(obj_adapter, prim_adapter);
prim_adapter->set_attached_primitive(prim);
return prim;
}
py::dict PrimitivePy::RunInfer(const py::tuple &args) {
@ -379,6 +358,113 @@ py::object PrimitivePy::RunInferValue(const py::tuple &args) {
return infer_value(*args);
}
PrimitivePyAdapter::PrimitivePyAdapter(const py::str &name) : name_(name) {}
void PrimitivePyAdapter::AddPyAttr(const py::str &name, const py::object &obj) {
std::string attr_name = name;
ValuePtr converted_ret = nullptr;
if (py::isinstance<py::module>(obj)) {
MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module";
}
bool converted = parse::ConvertData(obj, &converted_ret);
if (!converted) {
MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj));
}
if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) {
attr_name = kOpAttrNameReplaceMap[attr_name];
}
CheckAndConvertUtils::ConvertAttrValueToInt(name_, name, &converted_ret);
auto prim = attached_primitive_.lock();
if (prim != nullptr) {
prim->AddAttr(attr_name, converted_ret);
} else {
attrs_[attr_name] = converted_ret;
}
}
void PrimitivePyAdapter::DelPyAttr(const py::str &name) {
std::string attr_name = name;
auto prim = attached_primitive_.lock();
if (prim != nullptr) {
prim->DelAttr(attr_name);
} else {
attrs_.erase(attr_name);
}
}
py::dict PrimitivePyAdapter::GetAttrDict() {
auto prim = attached_primitive_.lock();
if (prim != nullptr) {
return prim->GetAttrDict();
}
py::dict attr_dict;
for (auto &attr : attrs_) {
attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second);
}
return attr_dict;
}
void PrimitivePyAdapter::set_prim_type(const PrimType t) {
auto prim = attached_primitive_.lock();
if (prim != nullptr) {
prim->set_prim_type(t);
} else {
prim_type_ = t;
}
}
void PrimitivePyAdapter::set_const_prim(bool is_const_prim) {
auto prim = attached_primitive_.lock();
if (prim != nullptr) {
prim->set_const_prim(is_const_prim);
} else {
is_const_prim_ = is_const_prim;
}
}
void PrimitivePyAdapter::set_const_input_indexes(const std::vector<size_t> &const_input_indexes) {
auto prim = attached_primitive_.lock();
if (prim != nullptr) {
prim->set_const_input_indexes(const_input_indexes);
} else {
const_input_indexes_ = const_input_indexes;
}
}
void PrimitivePyAdapter::set_signatures(const std::vector<Signature> &signatures) {
auto prim = attached_primitive_.lock();
if (prim != nullptr) {
prim->set_signatures(signatures);
} else {
signatures_ = signatures;
}
}
void PrimitivePyAdapter::set_hook(const py::function &hook) {
auto prim = attached_primitive_.lock();
if (prim != nullptr) {
prim->set_hook(hook);
} else {
hook_ = hook;
}
}
void PrimitivePyAdapter::set_instance_name(const std::string &s) {
auto prim = attached_primitive_.lock();
if (prim != nullptr) {
prim->set_instance_name(s);
} else {
instance_name_ = s;
}
}
void PrimitivePyAdapter::set_attached_primitive(const PrimitivePyPtr &prim) {
if (attached_primitive_.lock() != nullptr) {
MS_LOG(EXCEPTION) << "PrimitivePyAdapter can't attach to multi Primitive.";
}
MS_EXCEPTION_IF_NULL(prim);
attached_primitive_ = prim;
}
REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
.value("unknown", PrimType::kPrimTypeUnknown)
@ -386,18 +472,20 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
.value("py_infer_shape", PrimType::kPrimTypePyInferShape)
.value("user_custom", PrimType::kPrimTypeUserCustom)
.value("py_infer_check", PrimType::kPrimTypePyInferCheck);
(void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_")
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
.def(py::init<py::str &, py::object>())
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
.def("del_attr", &PrimitivePy::DelPyAttr, "del primitive attr")
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
.def("set_const_prim", &PrimitivePy::set_const_prim, "Set primitive is const.")
.def("set_const_input_indexes", &PrimitivePy::set_const_input_indexes,
(void)py::class_<PrimitivePyAdapter, std::shared_ptr<PrimitivePyAdapter>>(*m, "Primitive_")
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePyAdapter::parse_info_)
.def(py::init<py::str &>())
.def("add_attr", &PrimitivePyAdapter::AddPyAttr, "add primitive attr")
.def("del_attr", &PrimitivePyAdapter::DelPyAttr, "del primitive attr")
.def("get_attr_dict", &PrimitivePyAdapter::GetAttrDict, "get primitive attr")
.def("set_prim_type", &PrimitivePyAdapter::set_prim_type, "Set primitive type.")
.def("set_const_prim", &PrimitivePyAdapter::set_const_prim, "Set primitive is const.")
.def("set_const_input_indexes", &PrimitivePyAdapter::set_const_input_indexes,
"Set primitive const input indexes.")
.def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.")
.def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.")
.def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name.");
.def("set_signatures", &PrimitivePyAdapter::set_signatures,
"Set primitive inputs signature.")
.def("register_hook", &PrimitivePyAdapter::set_hook, "Set primitive hook function.")
.def("set_instance_name", &PrimitivePyAdapter::set_instance_name,
"Set primitive instance name.");
}));
} // namespace mindspore

View File

@ -34,9 +34,18 @@
namespace py = pybind11;
namespace mindspore {
class PrimitivePy;
using PrimitivePyPtr = std::shared_ptr<PrimitivePy>;
using PrimitivePyWeakPtr = std::weak_ptr<PrimitivePy>;
class PrimitivePyAdapter;
using PrimitivePyAdapterPtr = std::shared_ptr<PrimitivePyAdapter>;
class PrimitivePy : public Primitive {
public:
PrimitivePy(const py::str &name, const py::object &python_obj);
explicit PrimitivePy(const std::string &name);
PrimitivePy(const py::object &python_obj, const PrimitivePyAdapterPtr &adapter);
~PrimitivePy() override;
MS_DECLARE_PARENT(PrimitivePy, Primitive);
py::function GetBpropFunction();
@ -47,10 +56,6 @@ class PrimitivePy : public Primitive {
void CopyHookFunction(const PrimitivePtr &primitive) override;
void AddPyAttr(const py::str &name, const py::object &obj);
void DelPyAttr(const py::str &name);
py::dict GetAttrDict();
void set_hook(const py::function &hook) { hook_ = hook; }
py::function hook() const { return hook_; }
@ -61,13 +66,13 @@ class PrimitivePy : public Primitive {
bool HasComputeFunction() const;
const bool parse_info_ = true;
const py::object &GetPyObj() const { return python_obj_; }
void SetPyObj(const py::object &obj);
py::dict RunInfer(const py::tuple &args);
void RunCheck(const py::tuple &args);
py::object RunInferValue(const py::tuple &args);
bool ObjHasAttr(const char *attr_name) { return py::hasattr(python_obj_, attr_name); }
bool HasPyObj() { return python_obj_.operator bool(); }
PrimitivePtr Clone() override;
PrimitivePyAdapterPtr adapter() const { return adapter_; }
bool is_tuple_input_ = false;
private:
@ -75,11 +80,41 @@ class PrimitivePy : public Primitive {
void ConvertCTensorToPyTensor(const py::tuple &input_args, py::tuple *convert_args) const;
void CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out) const;
py::object python_obj_;
PrimitivePyAdapterPtr adapter_;
py::function hook_;
std::vector<Signature> signatures_;
static std::map<std::string, py::object> hook_grad_;
};
using PrimitivePyPtr = std::shared_ptr<PrimitivePy>;
class PrimitivePyAdapter {
public:
explicit PrimitivePyAdapter(const py::str &name);
~PrimitivePyAdapter() = default;
void AddPyAttr(const py::str &name, const py::object &obj);
void DelPyAttr(const py::str &name);
py::dict GetAttrDict();
void set_prim_type(const PrimType t);
void set_const_prim(bool is_const_prim);
void set_const_input_indexes(const std::vector<size_t> &const_input_indexes);
void set_signatures(const std::vector<Signature> &signatures);
void set_hook(const py::function &hook);
void set_instance_name(const std::string &s);
void set_attached_primitive(const PrimitivePyPtr &prim);
PrimitivePyPtr attached_primitive() { return attached_primitive_.lock(); }
void set_name(const std::string &name) { name_ = name; }
const bool parse_info_ = true;
private:
friend PrimitivePy;
std::string name_;
PrimitivePyWeakPtr attached_primitive_;
std::unordered_map<std::string, ValuePtr> attrs_;
PrimType prim_type_{kPrimTypeBuiltIn};
bool is_const_prim_{false};
std::vector<size_t> const_input_indexes_;
std::vector<Signature> signatures_;
py::function hook_;
std::string instance_name_;
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_PY_H_

View File

@ -99,7 +99,7 @@ class Primitive : public Named {
}
void set_prim_type(const PrimType t) { prim_type_ = t; }
virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); }
void set_instance_name(const std::string s) { instance_name_ = s; }
void set_instance_name(const std::string &s) { instance_name_ = s; }
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; }
bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; }
bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }

View File

@ -50,7 +50,7 @@ class Primitive(Primitive_):
self.attrs = {}
self.init_attrs = {"name": name}
self._update_parameter = False
Primitive_.__init__(self, name, self)
Primitive_.__init__(self, name)
if hasattr(self.__class__, '__mindspore_signature__'):
out = self._fill_signature(self.__class__.__mindspore_signature__)
self.set_signatures(out)

View File

@ -94,13 +94,13 @@ TEST_F(TestCompileSegmentRunner, test_if) {
TEST_F(TestCompileSegmentRunner, test_RunOperation1) {
VectorRef args({1});
auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimIdentity->name()), py::none()), args);
auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimIdentity->name())), args);
ASSERT_EQ(py::cast<int>(BaseRefToPyData(res)), 1);
}
TEST_F(TestCompileSegmentRunner, test_RunOperation2) {
VectorRef args({1, 2});
auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimScalarGt->name()), py::none()), args);
auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimScalarGt->name())), args);
ASSERT_EQ(py::cast<bool>(BaseRefToPyData(res)), false);
}
} // namespace compile