forked from mindspore-Ecosystem/mindspore
optimize infer in pynative mode
This commit is contained in:
parent
61639d9020
commit
a09565389c
|
@ -351,7 +351,7 @@ bool ExecuteAction(const ResourcePtr &res) {
|
|||
}
|
||||
auto graph_id = res->results()[kOutput].cast<GraphId>();
|
||||
std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>();
|
||||
std::shared_ptr<compile::MsBackend> msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr);
|
||||
compile::MsBackend *msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr).get();
|
||||
MS_EXCEPTION_IF_NULL(msbc_ptr);
|
||||
compile::VmEvalFuncPtr run =
|
||||
std::make_shared<compile::VmEvalFunc>([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef {
|
||||
|
|
|
@ -205,6 +205,7 @@ Resource::Resource(const py::object &obj)
|
|||
Resource::~Resource() {
|
||||
MS_LOG(DEBUG) << "Resource clear";
|
||||
|
||||
std::unordered_map<std::string, Any>().swap(results_);
|
||||
// If exit normally, these global variables will be cleaned
|
||||
// in Resource::Clean call by MsPipeline::Compile, but if exit with MS_LOGEXCEPTION,
|
||||
// these global variables may not being cleaned, it may
|
||||
|
|
|
@ -54,12 +54,12 @@ struct OpExecInfo {
|
|||
AbstractBasePtr abstract;
|
||||
ValuePtr value = nullptr;
|
||||
|
||||
py::tuple op_inputs;
|
||||
py::tuple inputs_mask;
|
||||
py::list op_inputs;
|
||||
py::dict op_attrs;
|
||||
std::vector<bool> inputs_mask;
|
||||
};
|
||||
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
|
||||
OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args);
|
||||
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
|
||||
|
||||
const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"};
|
||||
} // namespace pynative
|
||||
|
|
|
@ -179,8 +179,10 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
|
|||
if (!has_int && !py::isinstance<py::bool_>(py_args[index]) && py::isinstance<py::int_>(py_args[index])) {
|
||||
has_int = true;
|
||||
}
|
||||
if (py::isinstance<tensor::Tensor>(py_args[index])) {
|
||||
auto arg = py::cast<tensor::TensorPtr>(py_args[index]);
|
||||
|
||||
auto obj = py_args[index];
|
||||
if (py::isinstance<tensor::Tensor>(obj)) {
|
||||
auto arg = py::cast<tensor::TensorPtr>(obj);
|
||||
TypeId arg_type_id = arg->data_type();
|
||||
auto type_priority = prim::type_map.find(arg_type_id);
|
||||
if (type_priority == prim::type_map.end()) {
|
||||
|
@ -230,24 +232,19 @@ py::object DoAutoCast(const py::object &arg, const TypeId &type_id) {
|
|||
|
||||
return RunOp(args)[0];
|
||||
}
|
||||
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args,
|
||||
py::list *const out_args_list) {
|
||||
auto &py_args = *out_args;
|
||||
py::tuple input_mask(args.size());
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
input_mask[i] = py::hasattr(args[i], "__parameter__");
|
||||
py_args[i] = args[i];
|
||||
}
|
||||
|
||||
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExecInfoPtr &op_exec_info) {
|
||||
auto &out_args = op_exec_info->op_inputs;
|
||||
auto signature = prim->signatures();
|
||||
std::vector<SignatureEnumDType> dtypes;
|
||||
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
||||
[](const Signature &sig) { return sig.dtype; });
|
||||
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
|
||||
if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) {
|
||||
return input_mask;
|
||||
return;
|
||||
}
|
||||
auto type_indexes = GetTypeIndex(dtypes);
|
||||
auto dst_type = GetDstType(py_args, type_indexes);
|
||||
auto dst_type = GetDstType(out_args, type_indexes);
|
||||
|
||||
for (size_t i = 0; i < dtypes.size(); ++i) {
|
||||
if (dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
|
||||
|
@ -257,8 +254,10 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu
|
|||
if (it == dst_type.end() || it->second == kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
if (py::isinstance<tensor::Tensor>(py_args[i])) {
|
||||
auto arg = py::cast<tensor::TensorPtr>(py_args[i]);
|
||||
|
||||
auto obj = out_args[i];
|
||||
if (py::isinstance<tensor::Tensor>(obj)) {
|
||||
auto arg = py::cast<tensor::TensorPtr>(obj);
|
||||
if (arg->data_type() == it->second) {
|
||||
continue;
|
||||
}
|
||||
|
@ -267,32 +266,29 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu
|
|||
TypeIdToMsTypeStr(it->second));
|
||||
}
|
||||
}
|
||||
if (!py::isinstance<tensor::Tensor>(py_args[i]) && !py::isinstance<py::int_>(py_args[i]) &&
|
||||
!py::isinstance<py::float_>(py_args[i])) {
|
||||
|
||||
if (!py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj)) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i << "th input is a not support type: "
|
||||
<< py::cast<std::string>(py_args[1].attr("__class__").attr("__name__"))
|
||||
<< ", and the value is " << py::cast<py::str>(py_args[i]) << ".";
|
||||
<< py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is "
|
||||
<< py::cast<py::str>(obj) << ".";
|
||||
}
|
||||
py::object cast_output = DoAutoCast(py_args[i], it->second);
|
||||
(*out_args)[i] = cast_output;
|
||||
(*out_args_list)[i] = cast_output;
|
||||
py::object cast_output = DoAutoCast(out_args[i], it->second);
|
||||
out_args[i] = cast_output;
|
||||
ValuePtr input_value = PyAttrValue(cast_output);
|
||||
}
|
||||
return input_mask;
|
||||
}
|
||||
|
||||
void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) {
|
||||
size_t size = py_args.size();
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
ValuePtr input_value = PyAttrValue(py_args[i]);
|
||||
args_spec_list.emplace_back(
|
||||
abstract::FromValueInside(input_value, !prim->ObjHasAttr("const_value") && input_value->isa<tensor::Tensor>()));
|
||||
}
|
||||
void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info,
|
||||
const abstract::AbstractBasePtrList &args_spec_list) {
|
||||
MS_LOG(DEBUG) << "prim " << prim->name() << "input infer" << mindspore::ToString(args_spec_list);
|
||||
prim->BeginRecordAddAttr();
|
||||
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
|
||||
prim->EndRecordAddAttr();
|
||||
op_exec_info->abstract = infer_res;
|
||||
MS_LOG(DEBUG) << "prim " << prim->name() << "infer result " << op_exec_info->abstract->ToString();
|
||||
}
|
||||
|
||||
OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args) {
|
||||
OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
|
||||
if (args.size() != PY_ARGS_NUM) {
|
||||
MS_LOG(ERROR) << "Three args are needed by RunOp";
|
||||
return nullptr;
|
||||
|
@ -304,26 +300,14 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args)
|
|||
if (!prim->HasPyObj()) {
|
||||
MS_LOG(EXCEPTION) << "pyobj is empty";
|
||||
}
|
||||
|
||||
py::list a = args[PY_INPUTS];
|
||||
size_t input_num = a.size();
|
||||
op_exec_info->op_inputs = py::tuple(input_num);
|
||||
|
||||
op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs, out_args);
|
||||
// use python infer method
|
||||
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
|
||||
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get());
|
||||
}
|
||||
op_exec_info->py_primitive = prim;
|
||||
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
||||
auto inst = PynativeExecutor::GetInstance();
|
||||
if (inst->grad_flag()) {
|
||||
op_exec_info->value = inst->GetForwardValue(op_exec_info);
|
||||
}
|
||||
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
|
||||
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
|
||||
return nullptr;
|
||||
}
|
||||
op_exec_info->op_inputs = args[PY_INPUTS];
|
||||
ConvertInputs(prim, args[PY_INPUTS], op_exec_info);
|
||||
return op_exec_info;
|
||||
}
|
||||
|
||||
|
@ -358,8 +342,9 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
|||
MS_EXCEPTION_IF_NULL(status);
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
|
||||
|
||||
auto &op_inputs = op_exec_info->op_inputs;
|
||||
if (op_exec_info->op_name == "HookBackward") {
|
||||
auto op_inputs = op_exec_info->op_inputs;
|
||||
py::tuple result(op_inputs.size());
|
||||
for (size_t i = 0; i < op_inputs.size(); i++) {
|
||||
py::object input = op_inputs[i];
|
||||
|
@ -375,7 +360,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
|||
}
|
||||
auto primitive = op_exec_info->py_primitive;
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto result = primitive->RunPyComputeFunction(op_exec_info->op_inputs);
|
||||
auto result = primitive->RunPyComputeFunction(op_inputs);
|
||||
if (py::isinstance<py::none>(result)) {
|
||||
MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func";
|
||||
*status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
|
||||
|
@ -456,8 +441,9 @@ void ConvertMultiPyObjectToTensor(const py::object &input_object, const Primitiv
|
|||
if (tuple_inputs.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
|
||||
}
|
||||
if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) {
|
||||
PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors);
|
||||
auto inputs = py::cast<py::tuple>(input_object);
|
||||
if (py::isinstance<tensor::Tensor>(inputs[0])) {
|
||||
PlantTensorTupleToVector(inputs, op_prim, input_tensors);
|
||||
} else {
|
||||
ConvertValueTupleToTensor(input_object, input_tensors);
|
||||
*tensor_mask = kValueNodeTensorMask;
|
||||
|
@ -509,10 +495,6 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *te
|
|||
PrimitivePtr op_prim = op_run_info->py_primitive;
|
||||
MS_EXCEPTION_IF_NULL(op_prim);
|
||||
|
||||
if (op_run_info->op_inputs.size() != op_run_info->inputs_mask.size()) {
|
||||
MS_LOG(EXCEPTION) << "Op input size " << op_run_info->op_inputs.size() << " should be equal to op input mask size "
|
||||
<< op_run_info->inputs_mask.size();
|
||||
}
|
||||
opt::ConstInputToAttrInfoRegister reg;
|
||||
bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®);
|
||||
size_t input_num = op_run_info->op_inputs.size();
|
||||
|
@ -523,7 +505,7 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *te
|
|||
continue;
|
||||
}
|
||||
// convert const and tuple input to tensor
|
||||
int tensor_mask = py::cast<int>(op_run_info->inputs_mask[index]);
|
||||
int tensor_mask = static_cast<int>(op_run_info->inputs_mask[index]);
|
||||
ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask);
|
||||
// mark tensors, data : 0, weight : 1, valuenode: 2
|
||||
std::vector<int> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
|
||||
|
@ -550,7 +532,6 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
|||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms";
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
ms_context->set_enable_pynative_infer(true);
|
||||
std::string device_target = ms_context->device_target();
|
||||
if (device_target != kAscendDevice && device_target != kGPUDevice) {
|
||||
|
@ -573,6 +554,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
|||
py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors);
|
||||
ms_context->set_enable_pynative_infer(false);
|
||||
*status = PYNATIVE_SUCCESS;
|
||||
MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms";
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -626,29 +608,65 @@ ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
CNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) {
|
||||
if (!grad_flag_ || graph_info_map_.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list) {
|
||||
CNodePtr cnode = nullptr;
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
auto prim = op_exec_info->py_primitive;
|
||||
inputs.push_back(NewValueNode(prim));
|
||||
py::tuple op_masks = op_exec_info->inputs_mask;
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
auto node = GetInput(args[i], op_masks[i]);
|
||||
args_spec_list.push_back(node->abstract());
|
||||
|
||||
size_t size = op_exec_info->op_inputs.size();
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto obj = op_exec_info->op_inputs[i];
|
||||
bool op_mask = py::hasattr(obj, "__parameter__");
|
||||
(*op_masks).push_back(op_mask);
|
||||
MS_LOG(DEBUG) << "gen args i " << i << op_exec_info->op_name << " op mask" << op_mask << "grad_flag_" << grad_flag_;
|
||||
|
||||
AnfNodePtr node = nullptr;
|
||||
abstract::AbstractBasePtr abs = nullptr;
|
||||
auto id = GetId(obj);
|
||||
if (node_abs_map_.find(id) != node_abs_map_.end()) {
|
||||
abs = node_abs_map_[id];
|
||||
}
|
||||
if (!graph_info_map_.empty()) {
|
||||
node = GetInput(obj, op_mask);
|
||||
}
|
||||
if (node != nullptr && node->abstract() != nullptr) {
|
||||
abs = node->abstract();
|
||||
}
|
||||
if (abs == nullptr || prim->is_const_value()) {
|
||||
MS_LOG(DEBUG) << "MakeCnode get node no in map" << id;
|
||||
ValuePtr input_value = PyAttrValue(obj);
|
||||
bool broaden = !prim->is_const_value() && input_value->isa<tensor::Tensor>();
|
||||
abs = abstract::FromValueInside(input_value, broaden);
|
||||
node_abs_map_[id] = abs;
|
||||
}
|
||||
(*args_spec_list).push_back(abs);
|
||||
inputs.push_back(node);
|
||||
}
|
||||
|
||||
auto cnode = curr_g_->NewCNode(inputs);
|
||||
MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString(4);
|
||||
py::object out_real = out;
|
||||
if (out.size() == 1) {
|
||||
MS_LOG(DEBUG) << "MakeCnode out size is one.";
|
||||
out_real = out[0];
|
||||
MS_LOG(DEBUG) << "MakeCnode args end";
|
||||
if (grad_flag_) {
|
||||
if (curr_g_ != nullptr) {
|
||||
cnode = curr_g_->NewCNode(inputs);
|
||||
MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString(4);
|
||||
}
|
||||
}
|
||||
|
||||
return cnode;
|
||||
}
|
||||
|
||||
void PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out_real,
|
||||
const AnfNodePtr &cnode) {
|
||||
if (!grad_flag_ || graph_info_map_.empty()) {
|
||||
MS_LOG(DEBUG) << "no graph cnode";
|
||||
return;
|
||||
}
|
||||
|
||||
std::string obj_id = GetId(out_real);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_LOG(DEBUG) << "MakeCnode set obj node id " << cnode->DebugString(4) << "id " << obj_id;
|
||||
|
||||
if (py::isinstance<py::tuple>(out_real)) {
|
||||
auto value = py::cast<py::tuple>(out_real);
|
||||
if (value.size() > 1) {
|
||||
|
@ -659,10 +677,8 @@ CNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py
|
|||
}
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "MakeCnode set node id " << obj_id;
|
||||
set_obj_node_map(curr_g_, obj_id, cnode);
|
||||
set_pyobj(curr_g_, obj_id);
|
||||
return cnode;
|
||||
}
|
||||
|
||||
void PynativeExecutor::SaveOpForwardValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value) {
|
||||
|
@ -694,22 +710,32 @@ void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CN
|
|||
}
|
||||
|
||||
AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
|
||||
auto &out = graph_info_map_[curr_g_].obj_node_map[GetId(obj)];
|
||||
auto id = GetId(obj);
|
||||
auto &out = graph_info_map_[curr_g_].obj_node_map[id];
|
||||
if (out.second.size() == 1 && out.second[0] == -1) {
|
||||
return out.first;
|
||||
}
|
||||
auto node = out.first;
|
||||
MS_LOG(DEBUG) << "output size " << out.second.size() << node->DebugString();
|
||||
auto abs = node->abstract();
|
||||
for (auto &idx : out.second) {
|
||||
std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)};
|
||||
node = curr_g_->NewCNode(tuple_get_item_inputs);
|
||||
if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) {
|
||||
auto prim_abs = dyn_cast<abstract::AbstractTuple>(abs)->elements()[idx];
|
||||
MS_LOG(DEBUG) << "set tuple getitem abs" << prim_abs->ToString();
|
||||
node->set_abstract(prim_abs);
|
||||
}
|
||||
}
|
||||
if (node->abstract() != nullptr) {
|
||||
node_abs_map_[id] = node->abstract();
|
||||
}
|
||||
MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6);
|
||||
node->cast<CNodePtr>()->set_forward(PyAttrValue(obj));
|
||||
return node;
|
||||
}
|
||||
|
||||
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) {
|
||||
py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
|
||||
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
|
||||
mindspore::parse::python_adapter::set_python_env_flag(true);
|
||||
MsBackendPolicy backend_policy;
|
||||
|
@ -739,45 +765,89 @@ py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) {
|
|||
return err_ret;
|
||||
}
|
||||
|
||||
if (op_exec_info->op_name != prim::kPrimMixedPrecisionCast->name()) {
|
||||
auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
|
||||
if (cnode != nullptr) {
|
||||
cnode->set_abstract(op_exec_info->abstract);
|
||||
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << cnode->DebugString();
|
||||
}
|
||||
PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode, result);
|
||||
MS_LOG(DEBUG) << "RunOp end";
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "RunOp end";
|
||||
return result;
|
||||
}
|
||||
|
||||
py::tuple RunOpInner(const py::args &args) {
|
||||
py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
|
||||
MS_LOG(DEBUG) << "RunOp start" << args.size();
|
||||
py::list args_input = args[PY_INPUTS];
|
||||
OpExecInfoPtr op_exec_info = nullptr;
|
||||
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
|
||||
auto name = py::cast<std::string>(args[PY_NAME]);
|
||||
abstract::AbstractBasePtrList args_spec_list;
|
||||
std::vector<bool> op_masks;
|
||||
op_exec_info = GenerateOpExecInfo(args);
|
||||
if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) {
|
||||
return RunOpInner(op_exec_info);
|
||||
}
|
||||
auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, &op_masks, &args_spec_list);
|
||||
bool is_find = false;
|
||||
if (prim_abs_list.find(prim->id()) != prim_abs_list.end()) {
|
||||
auto abs_list = prim_abs_list[prim->id()];
|
||||
MS_LOG(DEBUG) << "match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
|
||||
if (abs_list.find(args_spec_list) != abs_list.end()) {
|
||||
MS_LOG(DEBUG) << "match prim ok" << op_exec_info->op_name;
|
||||
op_exec_info->abstract = abs_list[args_spec_list].abs;
|
||||
prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs);
|
||||
is_find = true;
|
||||
}
|
||||
}
|
||||
|
||||
OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args, &args_input);
|
||||
if (op_exec_info->abstract == nullptr) {
|
||||
// use python infer method
|
||||
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
|
||||
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list);
|
||||
}
|
||||
}
|
||||
|
||||
if (cnode != nullptr) {
|
||||
cnode->set_abstract(op_exec_info->abstract);
|
||||
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << cnode->DebugString();
|
||||
}
|
||||
|
||||
op_exec_info->inputs_mask = op_masks;
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
|
||||
if (op_exec_info->abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "run op infer" << name << op_exec_info->abstract->ToString();
|
||||
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
|
||||
if (!output["value"].is_none()) {
|
||||
py::tuple value_ret(1);
|
||||
value_ret[0] = output["value"];
|
||||
return value_ret;
|
||||
}
|
||||
if (op_exec_info->py_primitive->ObjHasAttr("const_value")) {
|
||||
if (op_exec_info->py_primitive->is_const_value()) {
|
||||
py::tuple value_ret(1);
|
||||
value_ret[0] = "";
|
||||
return value_ret;
|
||||
}
|
||||
}
|
||||
return RunOpInner(op_exec_info, args_input);
|
||||
|
||||
if (!is_find) {
|
||||
// const_value need infer every step
|
||||
auto &out = prim_abs_list[prim->id()];
|
||||
out[args_spec_list].abs = op_exec_info->abstract;
|
||||
out[args_spec_list].attrs = prim->evaluate_added_attrs();
|
||||
MS_LOG(DEBUG) << "set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
|
||||
}
|
||||
|
||||
auto result = RunOpInner(op_exec_info);
|
||||
py::object out_real = result;
|
||||
if (result.size() == 1) {
|
||||
MS_LOG(DEBUG) << "MakeCnode out size is one.";
|
||||
out_real = result[0];
|
||||
}
|
||||
std::string obj_id = GetId(out_real);
|
||||
node_abs_map_[obj_id] = op_exec_info->abstract;
|
||||
PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, out_real, cnode);
|
||||
if (cnode != nullptr) {
|
||||
PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode->cast<CNodePtr>(), result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
py::tuple RunOp(const py::args &args) {
|
||||
try {
|
||||
return RunOpInner(args);
|
||||
return PynativeExecutor::GetInstance()->RunOpInner(args);
|
||||
} catch (const py::error_already_set &ex) {
|
||||
// print function call stack info before release
|
||||
std::ostringstream oss;
|
||||
|
@ -857,11 +927,11 @@ AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::str
|
|||
return node;
|
||||
}
|
||||
|
||||
AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &op_mask) {
|
||||
AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
|
||||
AnfNodePtr node = nullptr;
|
||||
std::string obj_id = GetId(obj);
|
||||
|
||||
if (op_mask != nullptr && py::cast<bool>(op_mask)) {
|
||||
if (op_mask) {
|
||||
MS_LOG(DEBUG) << "Topgraph free parameter";
|
||||
// get the parameter name from parameter object
|
||||
auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name");
|
||||
|
@ -905,8 +975,9 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &o
|
|||
|
||||
auto tuple_size = static_cast<int>(tuple.size());
|
||||
for (int i = 0; i < tuple_size; i++) {
|
||||
args.push_back(GetInput(tuple[i], py::object()));
|
||||
args.push_back(GetInput(tuple[i], false));
|
||||
}
|
||||
|
||||
auto cnode = curr_g_->NewCNode(args);
|
||||
set_obj_node_map(curr_g_, GetId(obj), cnode);
|
||||
node = cnode;
|
||||
|
@ -960,7 +1031,7 @@ void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &o
|
|||
auto tuple_size = static_cast<int>(tuple.size());
|
||||
auto cnode = curr_g_->NewCNode(args);
|
||||
for (int i = 0; i < tuple_size; i++) {
|
||||
args.push_back(GetInput(tuple[i], py::object()));
|
||||
args.push_back(GetInput(tuple[i], false));
|
||||
set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, i);
|
||||
SetTupleOutput(tuple[i], cnode, std::vector<int>{i});
|
||||
}
|
||||
|
@ -1000,7 +1071,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
|
|||
if (curr_g_ != top_g_) {
|
||||
Popp();
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
auto input = GetInput(args[i], py::object());
|
||||
auto input = GetInput(args[i], false);
|
||||
inputs.push_back(input);
|
||||
}
|
||||
auto out_cnode = curr_g_->NewCNode(inputs);
|
||||
|
@ -1156,6 +1227,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
|
|||
curr_g_ = nullptr;
|
||||
graph_info_map_.clear();
|
||||
op_id_map_.clear();
|
||||
// node_abs_map_.clear();
|
||||
std::stack<FuncGraphPtr>().swap(graph_p_);
|
||||
}
|
||||
|
||||
|
|
|
@ -41,12 +41,20 @@ namespace py = pybind11;
|
|||
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
|
||||
using GradOperationPtr = std::shared_ptr<prim::GradOperation>;
|
||||
|
||||
struct PrimAbsInfo {
|
||||
abstract::AbstractBasePtr abs;
|
||||
std::unordered_map<std::string, ValuePtr> attrs;
|
||||
};
|
||||
|
||||
using AbstractListMap = std::unordered_map<abstract::AbstractBasePtrList, PrimAbsInfo,
|
||||
abstract::AbstractBasePtrListHasher, abstract::AbstractBasePtrListEqual>;
|
||||
|
||||
py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
|
||||
|
||||
py::tuple RunOp(const py::args &args);
|
||||
|
||||
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args,
|
||||
py::list *const out_args_list);
|
||||
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args,
|
||||
py::list *const out_args_list);
|
||||
|
||||
void ClearPyNativeSession();
|
||||
|
||||
|
@ -82,7 +90,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void ClearRes();
|
||||
bool grad_flag() { return grad_flag_; }
|
||||
void set_grad_flag(bool flag) { grad_flag_ = flag; }
|
||||
AnfNodePtr GetInput(const py::object &obj, const py::object &op_mask);
|
||||
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
|
||||
AnfNodePtr GetObjNode(const py::object &obj);
|
||||
FuncGraphPtr curr_g() { return curr_g_; }
|
||||
void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); }
|
||||
|
@ -95,11 +103,14 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector<int> index) {
|
||||
graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index);
|
||||
}
|
||||
CNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out);
|
||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode);
|
||||
ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info);
|
||||
void SaveOpForwardValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value);
|
||||
void SaveForwardResult(const CNodePtr &cnode, const py::object &out);
|
||||
void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out);
|
||||
|
||||
py::object Run(const py::tuple &args, const py::object &phase);
|
||||
|
||||
void Pushp();
|
||||
|
@ -108,6 +119,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
size_t arg_size);
|
||||
void SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector<int> idx);
|
||||
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
|
||||
py::tuple RunOpInner(const py::args &args);
|
||||
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info);
|
||||
|
||||
~PynativeExecutor();
|
||||
|
||||
|
@ -123,10 +136,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
|
||||
std::unordered_map<std::string, ValuePtr> op_forward_map_;
|
||||
std::unordered_map<std::string, size_t> op_id_map_;
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
|
||||
std::stack<FuncGraphPtr> graph_p_;
|
||||
FuncGraphPtr top_g_;
|
||||
FuncGraphPtr df_builder_;
|
||||
FuncGraphPtr curr_g_;
|
||||
std::unordered_map<std::string, AbstractListMap> prim_abs_list;
|
||||
};
|
||||
|
||||
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;
|
||||
|
|
|
@ -220,6 +220,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
|
|||
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
|
||||
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
|
||||
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
|
||||
.def("set_is_const_value", &PrimitivePy::set_is_const_value, "Set primitive is const value.")
|
||||
.def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.")
|
||||
.def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.")
|
||||
.def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name.");
|
||||
|
|
|
@ -21,6 +21,31 @@
|
|||
|
||||
namespace mindspore {
|
||||
|
||||
static std::string MakeId() {
|
||||
// Use atomic to make id generator thread safe.
|
||||
static std::atomic<uint64_t> last_id{1};
|
||||
return "P" + std::to_string(last_id.fetch_add(1, std::memory_order_relaxed));
|
||||
}
|
||||
|
||||
Primitive::Primitive(const std::string &name, const bool is_base, const PrimType prim_type)
|
||||
: Named(name),
|
||||
is_base_(is_base),
|
||||
has_signature_(false),
|
||||
prim_type_(prim_type),
|
||||
record_evaluate_add_attr_(false),
|
||||
is_const_value_(false),
|
||||
id_(MakeId()) {}
|
||||
|
||||
Primitive::Primitive(const Primitive &prim)
|
||||
: Named(prim),
|
||||
attrs_(prim.attrs_),
|
||||
instance_name_(prim.instance_name_),
|
||||
is_base_(prim.is_base_),
|
||||
has_signature_(prim.has_signature_),
|
||||
prim_type_(prim.prim_type_),
|
||||
record_evaluate_add_attr_(false),
|
||||
id_(prim.id_) {}
|
||||
|
||||
abstract::AbstractBasePtr Primitive::ToAbstract() {
|
||||
return std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), nullptr);
|
||||
}
|
||||
|
|
|
@ -40,22 +40,8 @@ enum PrimType {
|
|||
|
||||
class Primitive : public Named {
|
||||
public:
|
||||
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn)
|
||||
: Named(name),
|
||||
is_base_(is_base),
|
||||
has_signature_(false),
|
||||
prim_type_(prim_type),
|
||||
record_evaluate_add_attr_(false) {}
|
||||
|
||||
Primitive(const Primitive &prim)
|
||||
: Named(prim),
|
||||
attrs_(prim.attrs_),
|
||||
instance_name_(prim.instance_name_),
|
||||
is_base_(prim.is_base_),
|
||||
has_signature_(prim.has_signature_),
|
||||
prim_type_(prim.prim_type_),
|
||||
record_evaluate_add_attr_(false) {}
|
||||
|
||||
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn);
|
||||
Primitive(const Primitive &prim);
|
||||
MS_DECLARE_PARENT(Primitive, Named);
|
||||
abstract::AbstractBasePtr ToAbstract();
|
||||
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
|
||||
|
@ -91,6 +77,12 @@ class Primitive : public Named {
|
|||
|
||||
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||
const std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() const { return evaluate_added_attrs_; }
|
||||
void set_evaluate_added_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
|
||||
for (auto &attr : attrs) {
|
||||
MS_LOG(INFO) << " set evalu attrl " << name() << attr.first;
|
||||
attrs_[attr.first] = attr.second;
|
||||
}
|
||||
}
|
||||
|
||||
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
|
||||
bool HasAttr() const { return !attrs_.empty(); }
|
||||
|
@ -117,6 +109,9 @@ class Primitive : public Named {
|
|||
bool is_base() const { return is_base_; }
|
||||
virtual BaseRef RunHookFunction(const VectorRef &args) const { MS_LOG(EXCEPTION) << "call a empty function!"; }
|
||||
virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; }
|
||||
void set_is_const_value(bool value) { is_const_value_ = value; }
|
||||
bool is_const_value() const { return is_const_value_; }
|
||||
std::string id() const { return id_; }
|
||||
|
||||
protected:
|
||||
std::unordered_map<std::string, ValuePtr> attrs_;
|
||||
|
@ -128,6 +123,8 @@ class Primitive : public Named {
|
|||
bool has_signature_;
|
||||
PrimType prim_type_;
|
||||
bool record_evaluate_add_attr_;
|
||||
bool is_const_value_;
|
||||
std::string id_{""};
|
||||
};
|
||||
|
||||
inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
|
||||
|
|
|
@ -335,7 +335,7 @@ static void PrintTimeStat(std::ostringstream &oss, const TimeInfoGroup &group, c
|
|||
void MsProfile::Print() {
|
||||
GetProfile()->Print();
|
||||
std::vector<std::string> items = {"substitution.", "renormalize.", "replace.", "match.",
|
||||
"func_graph_cloner_run.", "meta_graph.", "manager."};
|
||||
"func_graph_cloner_run.", "meta_graph.", "manager.", "pynative"};
|
||||
std::vector<TimeInfoGroup> groups(items.size() + 1);
|
||||
const auto &stat = GetSingleton().time_stat_;
|
||||
// group all time infos
|
||||
|
|
|
@ -28,7 +28,7 @@ hastype = Primitive('hastype')
|
|||
cast = P.Cast()
|
||||
dtype = P.DType()
|
||||
isconstant = Primitive('is_constant')
|
||||
isconstant.add_prim_attr('const_value', True)
|
||||
isconstant.set_is_const_value(True)
|
||||
|
||||
|
||||
issubclass_ = P.IsSubClass()
|
||||
|
|
|
@ -1027,7 +1027,7 @@ class InvertPermutation(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init InvertPermutation"""
|
||||
self.const_value = True
|
||||
self.set_is_const_value(True)
|
||||
|
||||
def __infer__(self, x):
|
||||
x_shp = x['shape']
|
||||
|
|
|
@ -352,7 +352,7 @@ def constexpr(fn=None, get_instance=True, name=None):
|
|||
def __init__(self):
|
||||
op_name = name if name else fn.__name__
|
||||
PrimitiveWithInfer.__init__(self, op_name)
|
||||
self.const_value = True
|
||||
self.set_is_const_value(True)
|
||||
|
||||
def infer_value(self, *args):
|
||||
return fn(*args)
|
||||
|
|
|
@ -65,27 +65,7 @@ OpExecInfoPtr ConstructOpExecInfo() {
|
|||
py::none py_none;
|
||||
py::args args = py::make_tuple(conv_obj, op_name, op_inputs);
|
||||
py::list args_input = args[PY_INPUTS];
|
||||
return GenerateOpExecInfo(args, &args_input);
|
||||
}
|
||||
|
||||
TEST_F(TestPynativeExecute, TestRunOpInVM) {
|
||||
py::tuple result;
|
||||
PynativeStatusCode status;
|
||||
auto op_exec_info_ptr = ConstructOpExecInfo();
|
||||
result = pynative::RunOpInVM(op_exec_info_ptr, &status);
|
||||
ASSERT_EQ(status, PYNATIVE_SUCCESS);
|
||||
}
|
||||
|
||||
TEST_F(TestPynativeExecute, TestRunOp) {
|
||||
py::none py_none;
|
||||
auto op_exec_info_ptr = ConstructOpExecInfo();
|
||||
py::tuple outputs = pynative::RunOp(
|
||||
py::make_tuple(op_exec_info_ptr->py_primitive, op_exec_info_ptr->op_name, op_exec_info_ptr->op_inputs));
|
||||
if (outputs.size() == 0) {
|
||||
FAIL();
|
||||
} else {
|
||||
SUCCEED();
|
||||
}
|
||||
return GenerateOpExecInfo(args);
|
||||
}
|
||||
|
||||
TEST_F(TestPynativeExecute, TestCreateContext) {
|
||||
|
|
Loading…
Reference in New Issue