forked from mindspore-Ecosystem/mindspore
Fix codex and codereview.
This commit is contained in:
parent
8867c67d61
commit
e5c0e052e0
|
@ -110,7 +110,40 @@ py::object GetTupleObj(const py::object &obj) {
|
||||||
return obj_tuple;
|
return obj_tuple;
|
||||||
}
|
}
|
||||||
|
|
||||||
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) {
|
std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) {
|
||||||
|
std::map<SignatureEnumDType, std::vector<size_t>> type_indexes;
|
||||||
|
for (size_t i = 0; i < dtypes.size(); ++i) {
|
||||||
|
auto it = type_indexes.find(dtypes[i]);
|
||||||
|
if (it == type_indexes.end()) {
|
||||||
|
(void)type_indexes.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
|
||||||
|
} else {
|
||||||
|
it->second.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return type_indexes;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<SignatureEnumDType, size_t> GetDstType(const py::tuple &py_args,
|
||||||
|
const std::map<SignatureEnumDType, std::vector<size_t>> &type_indexes) {
|
||||||
|
std::map<SignatureEnumDType, size_t> dst_type;
|
||||||
|
for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) {
|
||||||
|
auto type = it->first;
|
||||||
|
auto indexes = it->second;
|
||||||
|
if (indexes.size() < 2) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
size_t m_index = indexes[0];
|
||||||
|
for (size_t i = 1; i < indexes.size(); ++i) {
|
||||||
|
if (py::isinstance<tensor::Tensor>(py_args[indexes[i]])) {
|
||||||
|
m_index = indexes[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(void)dst_type.insert(std::make_pair(type, m_index));
|
||||||
|
}
|
||||||
|
return dst_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args) {
|
||||||
auto &py_args = *out_args;
|
auto &py_args = *out_args;
|
||||||
py::tuple input_mask(args.size());
|
py::tuple input_mask(args.size());
|
||||||
for (size_t i = 0; i < args.size(); ++i) {
|
for (size_t i = 0; i < args.size(); ++i) {
|
||||||
|
@ -129,30 +162,8 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu
|
||||||
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
|
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
|
||||||
return input_mask;
|
return input_mask;
|
||||||
}
|
}
|
||||||
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
|
auto type_indexes = GetTypeIndex(dtypes);
|
||||||
for (size_t i = 0; i < dtypes.size(); ++i) {
|
auto dst_type = GetDstType(py_args, type_indexes);
|
||||||
auto it = type_indexs.find(dtypes[i]);
|
|
||||||
if (it == type_indexs.end()) {
|
|
||||||
(void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
|
|
||||||
} else {
|
|
||||||
it->second.push_back(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::map<SignatureEnumDType, size_t> dst_type;
|
|
||||||
for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) {
|
|
||||||
auto type = it->first;
|
|
||||||
auto indexs = it->second;
|
|
||||||
if (indexs.size() < 2) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
size_t m_index = indexs[0];
|
|
||||||
for (size_t i = 1; i < indexs.size(); ++i) {
|
|
||||||
if (py::isinstance<tensor::Tensor>(py_args[indexs[i]])) {
|
|
||||||
m_index = indexs[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(void)dst_type.insert(std::make_pair(type, m_index));
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < py_args.size(); ++i) {
|
for (size_t i = 0; i < py_args.size(); ++i) {
|
||||||
auto it = dst_type.find(dtypes[i]);
|
auto it = dst_type.find(dtypes[i]);
|
||||||
if (it != dst_type.end() && it->second != i &&
|
if (it != dst_type.end() && it->second != i &&
|
||||||
|
@ -542,28 +553,7 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
|
||||||
return curr_g_->NewCNode(tuple_get_item_inputs);
|
return curr_g_->NewCNode(tuple_get_item_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
py::tuple RunOp(const py::args &args) {
|
py::tuple RunOp(const OpExecInfoPtr &op_exec_info, const py::args &args) {
|
||||||
MS_LOG(DEBUG) << "RunOp start" << args.size();
|
|
||||||
py::object result;
|
|
||||||
// returns a null py::tuple on error
|
|
||||||
py::tuple err_ret(0);
|
|
||||||
PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
|
|
||||||
|
|
||||||
OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args);
|
|
||||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
|
||||||
if (op_exec_info->abstract != nullptr) {
|
|
||||||
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
|
|
||||||
if (!output["value"].is_none()) {
|
|
||||||
py::tuple value_ret(1);
|
|
||||||
value_ret[0] = output["value"];
|
|
||||||
return value_ret;
|
|
||||||
}
|
|
||||||
if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) {
|
|
||||||
py::tuple value_ret(1);
|
|
||||||
value_ret[0] = "";
|
|
||||||
return value_ret;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
|
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
|
||||||
mindspore::parse::python_adapter::set_python_env_flag(true);
|
mindspore::parse::python_adapter::set_python_env_flag(true);
|
||||||
MsBackendPolicy backend_policy;
|
MsBackendPolicy backend_policy;
|
||||||
|
@ -584,7 +574,10 @@ py::tuple RunOp(const py::args &args) {
|
||||||
if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) {
|
if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) {
|
||||||
backend_policy = kMsBackendVmOnly;
|
backend_policy = kMsBackendVmOnly;
|
||||||
}
|
}
|
||||||
result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
|
PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
|
||||||
|
// returns a null py::tuple on error
|
||||||
|
py::tuple err_ret(0);
|
||||||
|
py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
|
||||||
if (status != PYNATIVE_SUCCESS) {
|
if (status != PYNATIVE_SUCCESS) {
|
||||||
MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name;
|
MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name;
|
||||||
return err_ret;
|
return err_ret;
|
||||||
|
@ -599,6 +592,26 @@ py::tuple RunOp(const py::args &args) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
py::tuple RunOp(const py::args &args) {
|
||||||
|
MS_LOG(DEBUG) << "RunOp start" << args.size();
|
||||||
|
OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args);
|
||||||
|
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||||
|
if (op_exec_info->abstract != nullptr) {
|
||||||
|
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
|
||||||
|
if (!output["value"].is_none()) {
|
||||||
|
py::tuple value_ret(1);
|
||||||
|
value_ret[0] = output["value"];
|
||||||
|
return value_ret;
|
||||||
|
}
|
||||||
|
if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) {
|
||||||
|
py::tuple value_ret(1);
|
||||||
|
value_ret[0] = "";
|
||||||
|
return value_ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RunOp(op_exec_info, args);
|
||||||
|
}
|
||||||
|
|
||||||
void ClearPyNativeSession() { session = nullptr; }
|
void ClearPyNativeSession() { session = nullptr; }
|
||||||
|
|
||||||
PynativeExecutor::~PynativeExecutor() { ClearRes(); }
|
PynativeExecutor::~PynativeExecutor() { ClearRes(); }
|
||||||
|
@ -732,7 +745,11 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
EndGraphByOutId(out_id, cell, out, args);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out,
|
||||||
|
const py::args &args) {
|
||||||
AnfNodePtr output_node;
|
AnfNodePtr output_node;
|
||||||
if (graph_info_map_[curr_g_].param_map.count(out_id)) {
|
if (graph_info_map_[curr_g_].param_map.count(out_id)) {
|
||||||
output_node = graph_info_map_[curr_g_].param_map[out_id];
|
output_node = graph_info_map_[curr_g_].param_map[out_id];
|
||||||
|
@ -776,27 +793,7 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weights) {
|
||||||
const py::args &args) {
|
|
||||||
MS_LOG(INFO) << "GradNet start" << args.size();
|
|
||||||
|
|
||||||
std::size_t size = args.size();
|
|
||||||
auto cell_id = GetId(cell);
|
|
||||||
if (graph_map_.count(cell_id) != 0) {
|
|
||||||
MS_LOG(DEBUG) << "GradNet already compiled";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "GradNet first compiled";
|
|
||||||
std::vector<AnfNodePtr> new_params;
|
|
||||||
for (size_t i = 0; i < size; i++) {
|
|
||||||
ParameterPtr p = std::make_shared<Parameter>(df_builder_);
|
|
||||||
new_params.push_back(p);
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size();
|
|
||||||
new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end());
|
|
||||||
df_builder_->set_parameters(new_params);
|
|
||||||
resource_->manager()->SetParameters(df_builder_, new_params);
|
|
||||||
|
|
||||||
std::vector<AnfNodePtr> w_args;
|
std::vector<AnfNodePtr> w_args;
|
||||||
if (py::hasattr(weights, "__parameter_tuple__")) {
|
if (py::hasattr(weights, "__parameter_tuple__")) {
|
||||||
auto tuple = weights.cast<py::tuple>();
|
auto tuple = weights.cast<py::tuple>();
|
||||||
|
@ -821,12 +818,12 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "training not paramter_tuple";
|
MS_LOG(EXCEPTION) << "training not paramter_tuple";
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(resource_->func_graph());
|
return w_args;
|
||||||
auto g = GradGraph(resource_->func_graph(), grad, w_args, size);
|
}
|
||||||
resource_->set_func_graph(g);
|
|
||||||
|
|
||||||
// get the parameters items and add the value to args_spec
|
abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args) {
|
||||||
abstract::AbstractBasePtrList args_spec;
|
abstract::AbstractBasePtrList args_spec;
|
||||||
|
std::size_t size = args.size();
|
||||||
for (std::size_t i = 0; i < size; i++) {
|
for (std::size_t i = 0; i < size; i++) {
|
||||||
ValuePtr converted = nullptr;
|
ValuePtr converted = nullptr;
|
||||||
bool succ = parse::ConvertData(args[i], &converted);
|
bool succ = parse::ConvertData(args[i], &converted);
|
||||||
|
@ -852,6 +849,38 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c
|
||||||
param_node->set_abstract(ptr);
|
param_node->set_abstract(ptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return args_spec;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||||
|
const py::args &args) {
|
||||||
|
MS_LOG(INFO) << "GradNet start" << args.size();
|
||||||
|
|
||||||
|
std::size_t size = args.size();
|
||||||
|
auto cell_id = GetId(cell);
|
||||||
|
if (graph_map_.count(cell_id) != 0) {
|
||||||
|
MS_LOG(DEBUG) << "GradNet already compiled";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "GradNet first compiled";
|
||||||
|
std::vector<AnfNodePtr> new_params;
|
||||||
|
for (size_t i = 0; i < size; i++) {
|
||||||
|
ParameterPtr p = std::make_shared<Parameter>(df_builder_);
|
||||||
|
new_params.push_back(p);
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size();
|
||||||
|
new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end());
|
||||||
|
df_builder_->set_parameters(new_params);
|
||||||
|
resource_->manager()->SetParameters(df_builder_, new_params);
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> w_args = GetWeightsArgs(weights);
|
||||||
|
MS_EXCEPTION_IF_NULL(resource_->func_graph());
|
||||||
|
auto g = GradGraph(resource_->func_graph(), grad, w_args, size);
|
||||||
|
resource_->set_func_graph(g);
|
||||||
|
|
||||||
|
// get the parameters items and add the value to args_spec
|
||||||
|
abstract::AbstractBasePtrList args_spec = GetArgsSpec(args);
|
||||||
MS_LOG(DEBUG) << "Args_spec size" << args_spec.size();
|
MS_LOG(DEBUG) << "Args_spec size" << args_spec.size();
|
||||||
|
|
||||||
resource_->set_args_spec(args_spec);
|
resource_->set_args_spec(args_spec);
|
||||||
|
|
|
@ -44,7 +44,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
||||||
|
|
||||||
py::tuple RunOp(const py::args &args);
|
py::tuple RunOp(const py::args &args);
|
||||||
|
|
||||||
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args);
|
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args);
|
||||||
|
|
||||||
void ClearPyNativeSession();
|
void ClearPyNativeSession();
|
||||||
|
|
||||||
|
@ -67,6 +67,9 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
||||||
}
|
}
|
||||||
void NewGraph(const py::object &cell, const py::args &args);
|
void NewGraph(const py::object &cell, const py::args &args);
|
||||||
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
|
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
|
||||||
|
void EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, const py::args &args);
|
||||||
|
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights);
|
||||||
|
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args);
|
||||||
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
|
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
|
||||||
void Clear(const std::string &flag = "");
|
void Clear(const std::string &flag = "");
|
||||||
void Clean();
|
void Clean();
|
||||||
|
|
Loading…
Reference in New Issue