diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc index 69c9d4392ed..676b256d9a2 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pynative/pynative_execute.cc @@ -110,7 +110,40 @@ py::object GetTupleObj(const py::object &obj) { return obj_tuple; } -py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) { +std::map> GetTypeIndex(const std::vector &dtypes) { + std::map> type_indexes; + for (size_t i = 0; i < dtypes.size(); ++i) { + auto it = type_indexes.find(dtypes[i]); + if (it == type_indexes.end()) { + (void)type_indexes.insert(std::make_pair(dtypes[i], std::vector{i})); + } else { + it->second.push_back(i); + } + } + return type_indexes; +} + +std::map GetDstType(const py::tuple &py_args, + const std::map> &type_indexes) { + std::map dst_type; + for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) { + auto type = it->first; + auto indexes = it->second; + if (indexes.size() < 2) { + continue; + } + size_t m_index = indexes[0]; + for (size_t i = 1; i < indexes.size(); ++i) { + if (py::isinstance(py_args[indexes[i]])) { + m_index = indexes[i]; + } + } + (void)dst_type.insert(std::make_pair(type, m_index)); + } + return dst_type; +} + +py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args) { auto &py_args = *out_args; py::tuple input_mask(args.size()); for (size_t i = 0; i < args.size(); ++i) { @@ -129,30 +162,8 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu if (dtypes.size() == 0 || static_cast(dtypes.size()) == empty_dtype_count) { return input_mask; } - std::map> type_indexs; - for (size_t i = 0; i < dtypes.size(); ++i) { - auto it = type_indexs.find(dtypes[i]); - if (it == type_indexs.end()) { - (void)type_indexs.insert(std::make_pair(dtypes[i], std::vector{i})); - } else { - it->second.push_back(i); - } - } - std::map dst_type; - for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) { - auto type = it->first; - auto indexs = it->second; - if (indexs.size() < 2) { - continue; - } - size_t m_index = indexs[0]; - for (size_t i = 1; i < indexs.size(); ++i) { - if (py::isinstance(py_args[indexs[i]])) { - m_index = indexs[i]; - } - } - (void)dst_type.insert(std::make_pair(type, m_index)); - } + auto type_indexes = GetTypeIndex(dtypes); + auto dst_type = GetDstType(py_args, type_indexes); for (size_t i = 0; i < py_args.size(); ++i) { auto it = dst_type.find(dtypes[i]); if (it != dst_type.end() && it->second != i && @@ -542,28 +553,7 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { return curr_g_->NewCNode(tuple_get_item_inputs); } -py::tuple RunOp(const py::args &args) { - MS_LOG(DEBUG) << "RunOp start" << args.size(); - py::object result; - // returns a null py::tuple on error - py::tuple err_ret(0); - PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE; - - OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args); - MS_EXCEPTION_IF_NULL(op_exec_info); - if (op_exec_info->abstract != nullptr) { - py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); - if (!output["value"].is_none()) { - py::tuple value_ret(1); - value_ret[0] = output["value"]; - return value_ret; - } - if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) { - py::tuple value_ret(1); - value_ret[0] = ""; - return value_ret; - } - } +py::tuple RunOp(const OpExecInfoPtr &op_exec_info, const py::args &args) { MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name; mindspore::parse::python_adapter::set_python_env_flag(true); MsBackendPolicy backend_policy; @@ -584,7 +574,10 @@ py::tuple RunOp(const py::args &args) { if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) { backend_policy = kMsBackendVmOnly; } - result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status); + PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE; + // returns a null py::tuple on error + py::tuple err_ret(0); + py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status); if (status != PYNATIVE_SUCCESS) { MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name; return err_ret; @@ -599,6 +592,26 @@ py::tuple RunOp(const py::args &args) { return result; } +py::tuple RunOp(const py::args &args) { + MS_LOG(DEBUG) << "RunOp start" << args.size(); + OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args); + MS_EXCEPTION_IF_NULL(op_exec_info); + if (op_exec_info->abstract != nullptr) { + py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); + if (!output["value"].is_none()) { + py::tuple value_ret(1); + value_ret[0] = output["value"]; + return value_ret; + } + if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) { + py::tuple value_ret(1); + value_ret[0] = ""; + return value_ret; + } + } + return RunOp(op_exec_info, args); +} + void ClearPyNativeSession() { session = nullptr; } PynativeExecutor::~PynativeExecutor() { ClearRes(); } @@ -732,7 +745,11 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c return; } } + EndGraphByOutId(out_id, cell, out, args); +} +void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, + const py::args &args) { AnfNodePtr output_node; if (graph_info_map_[curr_g_].param_map.count(out_id)) { output_node = graph_info_map_[curr_g_].param_map[out_id]; @@ -776,27 +793,7 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c } } -void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, - const py::args &args) { - MS_LOG(INFO) << "GradNet start" << args.size(); - - std::size_t size = args.size(); - auto cell_id = GetId(cell); - if (graph_map_.count(cell_id) != 0) { - MS_LOG(DEBUG) << "GradNet already compiled"; - return; - } - MS_LOG(DEBUG) << "GradNet first compiled"; - std::vector new_params; - for (size_t i = 0; i < size; i++) { - ParameterPtr p = std::make_shared(df_builder_); - new_params.push_back(p); - } - MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size(); - new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end()); - df_builder_->set_parameters(new_params); - resource_->manager()->SetParameters(df_builder_, new_params); - +std::vector PynativeExecutor::GetWeightsArgs(const py::object &weights) { std::vector w_args; if (py::hasattr(weights, "__parameter_tuple__")) { auto tuple = weights.cast(); @@ -821,12 +818,12 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c } else { MS_LOG(EXCEPTION) << "training not paramter_tuple"; } - MS_EXCEPTION_IF_NULL(resource_->func_graph()); - auto g = GradGraph(resource_->func_graph(), grad, w_args, size); - resource_->set_func_graph(g); + return w_args; +} - // get the parameters items and add the value to args_spec +abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args) { abstract::AbstractBasePtrList args_spec; + std::size_t size = args.size(); for (std::size_t i = 0; i < size; i++) { ValuePtr converted = nullptr; bool succ = parse::ConvertData(args[i], &converted); @@ -852,6 +849,38 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c param_node->set_abstract(ptr); } } + + return args_spec; +} + +void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, + const py::args &args) { + MS_LOG(INFO) << "GradNet start" << args.size(); + + std::size_t size = args.size(); + auto cell_id = GetId(cell); + if (graph_map_.count(cell_id) != 0) { + MS_LOG(DEBUG) << "GradNet already compiled"; + return; + } + MS_LOG(DEBUG) << "GradNet first compiled"; + std::vector new_params; + for (size_t i = 0; i < size; i++) { + ParameterPtr p = std::make_shared(df_builder_); + new_params.push_back(p); + } + MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size(); + new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end()); + df_builder_->set_parameters(new_params); + resource_->manager()->SetParameters(df_builder_, new_params); + + std::vector w_args = GetWeightsArgs(weights); + MS_EXCEPTION_IF_NULL(resource_->func_graph()); + auto g = GradGraph(resource_->func_graph(), grad, w_args, size); + resource_->set_func_graph(g); + + // get the parameters items and add the value to args_spec + abstract::AbstractBasePtrList args_spec = GetArgsSpec(args); MS_LOG(DEBUG) << "Args_spec size" << args_spec.size(); resource_->set_args_spec(args_spec); diff --git a/mindspore/ccsrc/pynative/pynative_execute.h b/mindspore/ccsrc/pynative/pynative_execute.h index a0e8b448f4d..d0247b33d9e 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pynative/pynative_execute.h @@ -44,7 +44,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat py::tuple RunOp(const py::args &args); -py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args); +py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args); void ClearPyNativeSession(); @@ -67,6 +67,9 @@ class PynativeExecutor : public std::enable_shared_from_this { } void NewGraph(const py::object &cell, const py::args &args); void EndGraph(const py::object &cell, const py::object &out, const py::args &args); + void EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, const py::args &args); + std::vector GetWeightsArgs(const py::object &weights); + abstract::AbstractBasePtrList GetArgsSpec(const py::args &args); void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); void Clear(const std::string &flag = ""); void Clean();