optimize infer in pynative mode

This commit is contained in:
kpy 2020-07-13 14:45:12 +08:00 committed by kuangpeiyu
parent 61639d9020
commit a09565389c
13 changed files with 236 additions and 145 deletions

View File

@ -351,7 +351,7 @@ bool ExecuteAction(const ResourcePtr &res) {
}
auto graph_id = res->results()[kOutput].cast<GraphId>();
std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>();
std::shared_ptr<compile::MsBackend> msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr);
compile::MsBackend *msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr).get();
MS_EXCEPTION_IF_NULL(msbc_ptr);
compile::VmEvalFuncPtr run =
std::make_shared<compile::VmEvalFunc>([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef {

View File

@ -205,6 +205,7 @@ Resource::Resource(const py::object &obj)
Resource::~Resource() {
MS_LOG(DEBUG) << "Resource clear";
std::unordered_map<std::string, Any>().swap(results_);
// If exit normally, these global variables will be cleaned
// in Resource::Clean call by MsPipeline::Compile, but if exit with MS_LOGEXCEPTION,
// these global variables may not being cleaned, it may

View File

@ -54,12 +54,12 @@ struct OpExecInfo {
AbstractBasePtr abstract;
ValuePtr value = nullptr;
py::tuple op_inputs;
py::tuple inputs_mask;
py::list op_inputs;
py::dict op_attrs;
std::vector<bool> inputs_mask;
};
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args);
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"};
} // namespace pynative

View File

@ -179,8 +179,10 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
if (!has_int && !py::isinstance<py::bool_>(py_args[index]) && py::isinstance<py::int_>(py_args[index])) {
has_int = true;
}
if (py::isinstance<tensor::Tensor>(py_args[index])) {
auto arg = py::cast<tensor::TensorPtr>(py_args[index]);
auto obj = py_args[index];
if (py::isinstance<tensor::Tensor>(obj)) {
auto arg = py::cast<tensor::TensorPtr>(obj);
TypeId arg_type_id = arg->data_type();
auto type_priority = prim::type_map.find(arg_type_id);
if (type_priority == prim::type_map.end()) {
@ -230,24 +232,19 @@ py::object DoAutoCast(const py::object &arg, const TypeId &type_id) {
return RunOp(args)[0];
}
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args,
py::list *const out_args_list) {
auto &py_args = *out_args;
py::tuple input_mask(args.size());
for (size_t i = 0; i < args.size(); ++i) {
input_mask[i] = py::hasattr(args[i], "__parameter__");
py_args[i] = args[i];
}
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExecInfoPtr &op_exec_info) {
auto &out_args = op_exec_info->op_inputs;
auto signature = prim->signatures();
std::vector<SignatureEnumDType> dtypes;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
[](const Signature &sig) { return sig.dtype; });
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) {
return input_mask;
return;
}
auto type_indexes = GetTypeIndex(dtypes);
auto dst_type = GetDstType(py_args, type_indexes);
auto dst_type = GetDstType(out_args, type_indexes);
for (size_t i = 0; i < dtypes.size(); ++i) {
if (dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
@ -257,8 +254,10 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu
if (it == dst_type.end() || it->second == kTypeUnknown) {
continue;
}
if (py::isinstance<tensor::Tensor>(py_args[i])) {
auto arg = py::cast<tensor::TensorPtr>(py_args[i]);
auto obj = out_args[i];
if (py::isinstance<tensor::Tensor>(obj)) {
auto arg = py::cast<tensor::TensorPtr>(obj);
if (arg->data_type() == it->second) {
continue;
}
@ -267,32 +266,29 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu
TypeIdToMsTypeStr(it->second));
}
}
if (!py::isinstance<tensor::Tensor>(py_args[i]) && !py::isinstance<py::int_>(py_args[i]) &&
!py::isinstance<py::float_>(py_args[i])) {
if (!py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj)) {
MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i << "th input is a not support type: "
<< py::cast<std::string>(py_args[1].attr("__class__").attr("__name__"))
<< ", and the value is " << py::cast<py::str>(py_args[i]) << ".";
<< py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is "
<< py::cast<py::str>(obj) << ".";
}
py::object cast_output = DoAutoCast(py_args[i], it->second);
(*out_args)[i] = cast_output;
(*out_args_list)[i] = cast_output;
py::object cast_output = DoAutoCast(out_args[i], it->second);
out_args[i] = cast_output;
ValuePtr input_value = PyAttrValue(cast_output);
}
return input_mask;
}
void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) {
size_t size = py_args.size();
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < size; i++) {
ValuePtr input_value = PyAttrValue(py_args[i]);
args_spec_list.emplace_back(
abstract::FromValueInside(input_value, !prim->ObjHasAttr("const_value") && input_value->isa<tensor::Tensor>()));
}
void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info,
const abstract::AbstractBasePtrList &args_spec_list) {
MS_LOG(DEBUG) << "prim " << prim->name() << "input infer" << mindspore::ToString(args_spec_list);
prim->BeginRecordAddAttr();
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
prim->EndRecordAddAttr();
op_exec_info->abstract = infer_res;
MS_LOG(DEBUG) << "prim " << prim->name() << "infer result " << op_exec_info->abstract->ToString();
}
OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args) {
OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
if (args.size() != PY_ARGS_NUM) {
MS_LOG(ERROR) << "Three args are needed by RunOp";
return nullptr;
@ -304,26 +300,14 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args)
if (!prim->HasPyObj()) {
MS_LOG(EXCEPTION) << "pyobj is empty";
}
py::list a = args[PY_INPUTS];
size_t input_num = a.size();
op_exec_info->op_inputs = py::tuple(input_num);
op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs, out_args);
// use python infer method
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get());
}
op_exec_info->py_primitive = prim;
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
auto inst = PynativeExecutor::GetInstance();
if (inst->grad_flag()) {
op_exec_info->value = inst->GetForwardValue(op_exec_info);
}
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
return nullptr;
}
op_exec_info->op_inputs = args[PY_INPUTS];
ConvertInputs(prim, args[PY_INPUTS], op_exec_info);
return op_exec_info;
}
@ -358,8 +342,9 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
MS_EXCEPTION_IF_NULL(status);
MS_EXCEPTION_IF_NULL(op_exec_info);
MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
auto &op_inputs = op_exec_info->op_inputs;
if (op_exec_info->op_name == "HookBackward") {
auto op_inputs = op_exec_info->op_inputs;
py::tuple result(op_inputs.size());
for (size_t i = 0; i < op_inputs.size(); i++) {
py::object input = op_inputs[i];
@ -375,7 +360,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
}
auto primitive = op_exec_info->py_primitive;
MS_EXCEPTION_IF_NULL(primitive);
auto result = primitive->RunPyComputeFunction(op_exec_info->op_inputs);
auto result = primitive->RunPyComputeFunction(op_inputs);
if (py::isinstance<py::none>(result)) {
MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func";
*status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
@ -456,8 +441,9 @@ void ConvertMultiPyObjectToTensor(const py::object &input_object, const Primitiv
if (tuple_inputs.size() == 0) {
MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
}
if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) {
PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors);
auto inputs = py::cast<py::tuple>(input_object);
if (py::isinstance<tensor::Tensor>(inputs[0])) {
PlantTensorTupleToVector(inputs, op_prim, input_tensors);
} else {
ConvertValueTupleToTensor(input_object, input_tensors);
*tensor_mask = kValueNodeTensorMask;
@ -509,10 +495,6 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *te
PrimitivePtr op_prim = op_run_info->py_primitive;
MS_EXCEPTION_IF_NULL(op_prim);
if (op_run_info->op_inputs.size() != op_run_info->inputs_mask.size()) {
MS_LOG(EXCEPTION) << "Op input size " << op_run_info->op_inputs.size() << " should be equal to op input mask size "
<< op_run_info->inputs_mask.size();
}
opt::ConstInputToAttrInfoRegister reg;
bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);
size_t input_num = op_run_info->op_inputs.size();
@ -523,7 +505,7 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *te
continue;
}
// convert const and tuple input to tensor
int tensor_mask = py::cast<int>(op_run_info->inputs_mask[index]);
int tensor_mask = static_cast<int>(op_run_info->inputs_mask[index]);
ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask);
// mark tensors, data : 0, weight : 1, valuenode: 2
std::vector<int> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
@ -550,7 +532,6 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
MS_EXCEPTION_IF_NULL(op_exec_info);
MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms";
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_enable_pynative_infer(true);
std::string device_target = ms_context->device_target();
if (device_target != kAscendDevice && device_target != kGPUDevice) {
@ -573,6 +554,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors);
ms_context->set_enable_pynative_infer(false);
*status = PYNATIVE_SUCCESS;
MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms";
return result;
}
@ -626,29 +608,65 @@ ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) {
return nullptr;
}
CNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) {
if (!grad_flag_ || graph_info_map_.empty()) {
return nullptr;
}
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
abstract::AbstractBasePtrList *args_spec_list) {
CNodePtr cnode = nullptr;
std::vector<AnfNodePtr> inputs;
auto prim = op_exec_info->py_primitive;
inputs.push_back(NewValueNode(prim));
py::tuple op_masks = op_exec_info->inputs_mask;
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < args.size(); i++) {
auto node = GetInput(args[i], op_masks[i]);
args_spec_list.push_back(node->abstract());
size_t size = op_exec_info->op_inputs.size();
for (size_t i = 0; i < size; i++) {
auto obj = op_exec_info->op_inputs[i];
bool op_mask = py::hasattr(obj, "__parameter__");
(*op_masks).push_back(op_mask);
MS_LOG(DEBUG) << "gen args i " << i << op_exec_info->op_name << " op mask" << op_mask << "grad_flag_" << grad_flag_;
AnfNodePtr node = nullptr;
abstract::AbstractBasePtr abs = nullptr;
auto id = GetId(obj);
if (node_abs_map_.find(id) != node_abs_map_.end()) {
abs = node_abs_map_[id];
}
if (!graph_info_map_.empty()) {
node = GetInput(obj, op_mask);
}
if (node != nullptr && node->abstract() != nullptr) {
abs = node->abstract();
}
if (abs == nullptr || prim->is_const_value()) {
MS_LOG(DEBUG) << "MakeCnode get node no in map" << id;
ValuePtr input_value = PyAttrValue(obj);
bool broaden = !prim->is_const_value() && input_value->isa<tensor::Tensor>();
abs = abstract::FromValueInside(input_value, broaden);
node_abs_map_[id] = abs;
}
(*args_spec_list).push_back(abs);
inputs.push_back(node);
}
auto cnode = curr_g_->NewCNode(inputs);
MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString(4);
py::object out_real = out;
if (out.size() == 1) {
MS_LOG(DEBUG) << "MakeCnode out size is one.";
out_real = out[0];
MS_LOG(DEBUG) << "MakeCnode args end";
if (grad_flag_) {
if (curr_g_ != nullptr) {
cnode = curr_g_->NewCNode(inputs);
MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString(4);
}
}
return cnode;
}
void PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out_real,
const AnfNodePtr &cnode) {
if (!grad_flag_ || graph_info_map_.empty()) {
MS_LOG(DEBUG) << "no graph cnode";
return;
}
std::string obj_id = GetId(out_real);
MS_EXCEPTION_IF_NULL(cnode);
MS_LOG(DEBUG) << "MakeCnode set obj node id " << cnode->DebugString(4) << "id " << obj_id;
if (py::isinstance<py::tuple>(out_real)) {
auto value = py::cast<py::tuple>(out_real);
if (value.size() > 1) {
@ -659,10 +677,8 @@ CNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py
}
}
}
MS_LOG(DEBUG) << "MakeCnode set node id " << obj_id;
set_obj_node_map(curr_g_, obj_id, cnode);
set_pyobj(curr_g_, obj_id);
return cnode;
}
void PynativeExecutor::SaveOpForwardValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value) {
@ -694,22 +710,32 @@ void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CN
}
AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
auto &out = graph_info_map_[curr_g_].obj_node_map[GetId(obj)];
auto id = GetId(obj);
auto &out = graph_info_map_[curr_g_].obj_node_map[id];
if (out.second.size() == 1 && out.second[0] == -1) {
return out.first;
}
auto node = out.first;
MS_LOG(DEBUG) << "output size " << out.second.size() << node->DebugString();
auto abs = node->abstract();
for (auto &idx : out.second) {
std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)};
node = curr_g_->NewCNode(tuple_get_item_inputs);
if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) {
auto prim_abs = dyn_cast<abstract::AbstractTuple>(abs)->elements()[idx];
MS_LOG(DEBUG) << "set tuple getitem abs" << prim_abs->ToString();
node->set_abstract(prim_abs);
}
}
if (node->abstract() != nullptr) {
node_abs_map_[id] = node->abstract();
}
MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6);
node->cast<CNodePtr>()->set_forward(PyAttrValue(obj));
return node;
}
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) {
py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
mindspore::parse::python_adapter::set_python_env_flag(true);
MsBackendPolicy backend_policy;
@ -739,45 +765,89 @@ py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) {
return err_ret;
}
if (op_exec_info->op_name != prim::kPrimMixedPrecisionCast->name()) {
auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
if (cnode != nullptr) {
cnode->set_abstract(op_exec_info->abstract);
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << cnode->DebugString();
}
PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode, result);
MS_LOG(DEBUG) << "RunOp end";
}
MS_LOG(DEBUG) << "RunOp end";
return result;
}
py::tuple RunOpInner(const py::args &args) {
py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
MS_LOG(DEBUG) << "RunOp start" << args.size();
py::list args_input = args[PY_INPUTS];
OpExecInfoPtr op_exec_info = nullptr;
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
auto name = py::cast<std::string>(args[PY_NAME]);
abstract::AbstractBasePtrList args_spec_list;
std::vector<bool> op_masks;
op_exec_info = GenerateOpExecInfo(args);
if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) {
return RunOpInner(op_exec_info);
}
auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, &op_masks, &args_spec_list);
bool is_find = false;
if (prim_abs_list.find(prim->id()) != prim_abs_list.end()) {
auto abs_list = prim_abs_list[prim->id()];
MS_LOG(DEBUG) << "match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
if (abs_list.find(args_spec_list) != abs_list.end()) {
MS_LOG(DEBUG) << "match prim ok" << op_exec_info->op_name;
op_exec_info->abstract = abs_list[args_spec_list].abs;
prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs);
is_find = true;
}
}
OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args, &args_input);
if (op_exec_info->abstract == nullptr) {
// use python infer method
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list);
}
}
if (cnode != nullptr) {
cnode->set_abstract(op_exec_info->abstract);
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << cnode->DebugString();
}
op_exec_info->inputs_mask = op_masks;
MS_EXCEPTION_IF_NULL(op_exec_info);
if (op_exec_info->abstract != nullptr) {
MS_LOG(DEBUG) << "run op infer" << name << op_exec_info->abstract->ToString();
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
if (!output["value"].is_none()) {
py::tuple value_ret(1);
value_ret[0] = output["value"];
return value_ret;
}
if (op_exec_info->py_primitive->ObjHasAttr("const_value")) {
if (op_exec_info->py_primitive->is_const_value()) {
py::tuple value_ret(1);
value_ret[0] = "";
return value_ret;
}
}
return RunOpInner(op_exec_info, args_input);
if (!is_find) {
// const_value need infer every step
auto &out = prim_abs_list[prim->id()];
out[args_spec_list].abs = op_exec_info->abstract;
out[args_spec_list].attrs = prim->evaluate_added_attrs();
MS_LOG(DEBUG) << "set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
}
auto result = RunOpInner(op_exec_info);
py::object out_real = result;
if (result.size() == 1) {
MS_LOG(DEBUG) << "MakeCnode out size is one.";
out_real = result[0];
}
std::string obj_id = GetId(out_real);
node_abs_map_[obj_id] = op_exec_info->abstract;
PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, out_real, cnode);
if (cnode != nullptr) {
PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode->cast<CNodePtr>(), result);
}
return result;
}
py::tuple RunOp(const py::args &args) {
try {
return RunOpInner(args);
return PynativeExecutor::GetInstance()->RunOpInner(args);
} catch (const py::error_already_set &ex) {
// print function call stack info before release
std::ostringstream oss;
@ -857,11 +927,11 @@ AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::str
return node;
}
AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &op_mask) {
AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
AnfNodePtr node = nullptr;
std::string obj_id = GetId(obj);
if (op_mask != nullptr && py::cast<bool>(op_mask)) {
if (op_mask) {
MS_LOG(DEBUG) << "Topgraph free parameter";
// get the parameter name from parameter object
auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name");
@ -905,8 +975,9 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &o
auto tuple_size = static_cast<int>(tuple.size());
for (int i = 0; i < tuple_size; i++) {
args.push_back(GetInput(tuple[i], py::object()));
args.push_back(GetInput(tuple[i], false));
}
auto cnode = curr_g_->NewCNode(args);
set_obj_node_map(curr_g_, GetId(obj), cnode);
node = cnode;
@ -960,7 +1031,7 @@ void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &o
auto tuple_size = static_cast<int>(tuple.size());
auto cnode = curr_g_->NewCNode(args);
for (int i = 0; i < tuple_size; i++) {
args.push_back(GetInput(tuple[i], py::object()));
args.push_back(GetInput(tuple[i], false));
set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, i);
SetTupleOutput(tuple[i], cnode, std::vector<int>{i});
}
@ -1000,7 +1071,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
if (curr_g_ != top_g_) {
Popp();
for (size_t i = 0; i < args.size(); i++) {
auto input = GetInput(args[i], py::object());
auto input = GetInput(args[i], false);
inputs.push_back(input);
}
auto out_cnode = curr_g_->NewCNode(inputs);
@ -1156,6 +1227,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
curr_g_ = nullptr;
graph_info_map_.clear();
op_id_map_.clear();
// node_abs_map_.clear();
std::stack<FuncGraphPtr>().swap(graph_p_);
}

View File

@ -41,12 +41,20 @@ namespace py = pybind11;
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
using GradOperationPtr = std::shared_ptr<prim::GradOperation>;
struct PrimAbsInfo {
abstract::AbstractBasePtr abs;
std::unordered_map<std::string, ValuePtr> attrs;
};
using AbstractListMap = std::unordered_map<abstract::AbstractBasePtrList, PrimAbsInfo,
abstract::AbstractBasePtrListHasher, abstract::AbstractBasePtrListEqual>;
py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
py::tuple RunOp(const py::args &args);
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args,
py::list *const out_args_list);
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args,
py::list *const out_args_list);
void ClearPyNativeSession();
@ -82,7 +90,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void ClearRes();
bool grad_flag() { return grad_flag_; }
void set_grad_flag(bool flag) { grad_flag_ = flag; }
AnfNodePtr GetInput(const py::object &obj, const py::object &op_mask);
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
AnfNodePtr GetObjNode(const py::object &obj);
FuncGraphPtr curr_g() { return curr_g_; }
void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); }
@ -95,11 +103,14 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector<int> index) {
graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index);
}
CNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out);
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
abstract::AbstractBasePtrList *args_spec_list);
void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode);
ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info);
void SaveOpForwardValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value);
void SaveForwardResult(const CNodePtr &cnode, const py::object &out);
void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out);
py::object Run(const py::tuple &args, const py::object &phase);
void Pushp();
@ -108,6 +119,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
size_t arg_size);
void SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector<int> idx);
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
py::tuple RunOpInner(const py::args &args);
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info);
~PynativeExecutor();
@ -123,10 +136,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
std::unordered_map<std::string, ValuePtr> op_forward_map_;
std::unordered_map<std::string, size_t> op_id_map_;
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
std::stack<FuncGraphPtr> graph_p_;
FuncGraphPtr top_g_;
FuncGraphPtr df_builder_;
FuncGraphPtr curr_g_;
std::unordered_map<std::string, AbstractListMap> prim_abs_list;
};
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;

View File

@ -220,6 +220,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
.def("set_is_const_value", &PrimitivePy::set_is_const_value, "Set primitive is const value.")
.def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.")
.def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.")
.def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name.");

View File

@ -21,6 +21,31 @@
namespace mindspore {
static std::string MakeId() {
// Use atomic to make id generator thread safe.
static std::atomic<uint64_t> last_id{1};
return "P" + std::to_string(last_id.fetch_add(1, std::memory_order_relaxed));
}
Primitive::Primitive(const std::string &name, const bool is_base, const PrimType prim_type)
: Named(name),
is_base_(is_base),
has_signature_(false),
prim_type_(prim_type),
record_evaluate_add_attr_(false),
is_const_value_(false),
id_(MakeId()) {}
Primitive::Primitive(const Primitive &prim)
: Named(prim),
attrs_(prim.attrs_),
instance_name_(prim.instance_name_),
is_base_(prim.is_base_),
has_signature_(prim.has_signature_),
prim_type_(prim.prim_type_),
record_evaluate_add_attr_(false),
id_(prim.id_) {}
abstract::AbstractBasePtr Primitive::ToAbstract() {
return std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), nullptr);
}

View File

@ -40,22 +40,8 @@ enum PrimType {
class Primitive : public Named {
public:
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn)
: Named(name),
is_base_(is_base),
has_signature_(false),
prim_type_(prim_type),
record_evaluate_add_attr_(false) {}
Primitive(const Primitive &prim)
: Named(prim),
attrs_(prim.attrs_),
instance_name_(prim.instance_name_),
is_base_(prim.is_base_),
has_signature_(prim.has_signature_),
prim_type_(prim.prim_type_),
record_evaluate_add_attr_(false) {}
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn);
Primitive(const Primitive &prim);
MS_DECLARE_PARENT(Primitive, Named);
abstract::AbstractBasePtr ToAbstract();
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
@ -91,6 +77,12 @@ class Primitive : public Named {
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
const std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() const { return evaluate_added_attrs_; }
void set_evaluate_added_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
for (auto &attr : attrs) {
MS_LOG(INFO) << " set evalu attrl " << name() << attr.first;
attrs_[attr.first] = attr.second;
}
}
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
bool HasAttr() const { return !attrs_.empty(); }
@ -117,6 +109,9 @@ class Primitive : public Named {
bool is_base() const { return is_base_; }
virtual BaseRef RunHookFunction(const VectorRef &args) const { MS_LOG(EXCEPTION) << "call a empty function!"; }
virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; }
void set_is_const_value(bool value) { is_const_value_ = value; }
bool is_const_value() const { return is_const_value_; }
std::string id() const { return id_; }
protected:
std::unordered_map<std::string, ValuePtr> attrs_;
@ -128,6 +123,8 @@ class Primitive : public Named {
bool has_signature_;
PrimType prim_type_;
bool record_evaluate_add_attr_;
bool is_const_value_;
std::string id_{""};
};
inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {

View File

@ -335,7 +335,7 @@ static void PrintTimeStat(std::ostringstream &oss, const TimeInfoGroup &group, c
void MsProfile::Print() {
GetProfile()->Print();
std::vector<std::string> items = {"substitution.", "renormalize.", "replace.", "match.",
"func_graph_cloner_run.", "meta_graph.", "manager."};
"func_graph_cloner_run.", "meta_graph.", "manager.", "pynative"};
std::vector<TimeInfoGroup> groups(items.size() + 1);
const auto &stat = GetSingleton().time_stat_;
// group all time infos

View File

@ -28,7 +28,7 @@ hastype = Primitive('hastype')
cast = P.Cast()
dtype = P.DType()
isconstant = Primitive('is_constant')
isconstant.add_prim_attr('const_value', True)
isconstant.set_is_const_value(True)
issubclass_ = P.IsSubClass()

View File

@ -1027,7 +1027,7 @@ class InvertPermutation(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init InvertPermutation"""
self.const_value = True
self.set_is_const_value(True)
def __infer__(self, x):
x_shp = x['shape']

View File

@ -352,7 +352,7 @@ def constexpr(fn=None, get_instance=True, name=None):
def __init__(self):
op_name = name if name else fn.__name__
PrimitiveWithInfer.__init__(self, op_name)
self.const_value = True
self.set_is_const_value(True)
def infer_value(self, *args):
return fn(*args)

View File

@ -65,27 +65,7 @@ OpExecInfoPtr ConstructOpExecInfo() {
py::none py_none;
py::args args = py::make_tuple(conv_obj, op_name, op_inputs);
py::list args_input = args[PY_INPUTS];
return GenerateOpExecInfo(args, &args_input);
}
TEST_F(TestPynativeExecute, TestRunOpInVM) {
py::tuple result;
PynativeStatusCode status;
auto op_exec_info_ptr = ConstructOpExecInfo();
result = pynative::RunOpInVM(op_exec_info_ptr, &status);
ASSERT_EQ(status, PYNATIVE_SUCCESS);
}
TEST_F(TestPynativeExecute, TestRunOp) {
py::none py_none;
auto op_exec_info_ptr = ConstructOpExecInfo();
py::tuple outputs = pynative::RunOp(
py::make_tuple(op_exec_info_ptr->py_primitive, op_exec_info_ptr->op_name, op_exec_info_ptr->op_inputs));
if (outputs.size() == 0) {
FAIL();
} else {
SUCCEED();
}
return GenerateOpExecInfo(args);
}
TEST_F(TestPynativeExecute, TestCreateContext) {