Fix codex and codereview.

This commit is contained in:
rick_sanchez 2020-06-19 13:07:27 +08:00
parent 8867c67d61
commit e5c0e052e0
2 changed files with 106 additions and 74 deletions

View File

@ -110,7 +110,40 @@ py::object GetTupleObj(const py::object &obj) {
return obj_tuple;
}
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) {
std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) {
std::map<SignatureEnumDType, std::vector<size_t>> type_indexes;
for (size_t i = 0; i < dtypes.size(); ++i) {
auto it = type_indexes.find(dtypes[i]);
if (it == type_indexes.end()) {
(void)type_indexes.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
} else {
it->second.push_back(i);
}
}
return type_indexes;
}
std::map<SignatureEnumDType, size_t> GetDstType(const py::tuple &py_args,
const std::map<SignatureEnumDType, std::vector<size_t>> &type_indexes) {
std::map<SignatureEnumDType, size_t> dst_type;
for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) {
auto type = it->first;
auto indexes = it->second;
if (indexes.size() < 2) {
continue;
}
size_t m_index = indexes[0];
for (size_t i = 1; i < indexes.size(); ++i) {
if (py::isinstance<tensor::Tensor>(py_args[indexes[i]])) {
m_index = indexes[i];
}
}
(void)dst_type.insert(std::make_pair(type, m_index));
}
return dst_type;
}
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args) {
auto &py_args = *out_args;
py::tuple input_mask(args.size());
for (size_t i = 0; i < args.size(); ++i) {
@ -129,30 +162,8 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
return input_mask;
}
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
for (size_t i = 0; i < dtypes.size(); ++i) {
auto it = type_indexs.find(dtypes[i]);
if (it == type_indexs.end()) {
(void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
} else {
it->second.push_back(i);
}
}
std::map<SignatureEnumDType, size_t> dst_type;
for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) {
auto type = it->first;
auto indexs = it->second;
if (indexs.size() < 2) {
continue;
}
size_t m_index = indexs[0];
for (size_t i = 1; i < indexs.size(); ++i) {
if (py::isinstance<tensor::Tensor>(py_args[indexs[i]])) {
m_index = indexs[i];
}
}
(void)dst_type.insert(std::make_pair(type, m_index));
}
auto type_indexes = GetTypeIndex(dtypes);
auto dst_type = GetDstType(py_args, type_indexes);
for (size_t i = 0; i < py_args.size(); ++i) {
auto it = dst_type.find(dtypes[i]);
if (it != dst_type.end() && it->second != i &&
@ -542,28 +553,7 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
return curr_g_->NewCNode(tuple_get_item_inputs);
}
py::tuple RunOp(const py::args &args) {
MS_LOG(DEBUG) << "RunOp start" << args.size();
py::object result;
// returns a null py::tuple on error
py::tuple err_ret(0);
PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args);
MS_EXCEPTION_IF_NULL(op_exec_info);
if (op_exec_info->abstract != nullptr) {
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
if (!output["value"].is_none()) {
py::tuple value_ret(1);
value_ret[0] = output["value"];
return value_ret;
}
if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) {
py::tuple value_ret(1);
value_ret[0] = "";
return value_ret;
}
}
py::tuple RunOp(const OpExecInfoPtr &op_exec_info, const py::args &args) {
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
mindspore::parse::python_adapter::set_python_env_flag(true);
MsBackendPolicy backend_policy;
@ -584,7 +574,10 @@ py::tuple RunOp(const py::args &args) {
if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) {
backend_policy = kMsBackendVmOnly;
}
result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE;
// returns a null py::tuple on error
py::tuple err_ret(0);
py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status);
if (status != PYNATIVE_SUCCESS) {
MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name;
return err_ret;
@ -599,6 +592,26 @@ py::tuple RunOp(const py::args &args) {
return result;
}
py::tuple RunOp(const py::args &args) {
MS_LOG(DEBUG) << "RunOp start" << args.size();
OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args);
MS_EXCEPTION_IF_NULL(op_exec_info);
if (op_exec_info->abstract != nullptr) {
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
if (!output["value"].is_none()) {
py::tuple value_ret(1);
value_ret[0] = output["value"];
return value_ret;
}
if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) {
py::tuple value_ret(1);
value_ret[0] = "";
return value_ret;
}
}
return RunOp(op_exec_info, args);
}
void ClearPyNativeSession() { session = nullptr; }
PynativeExecutor::~PynativeExecutor() { ClearRes(); }
@ -732,7 +745,11 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
return;
}
}
EndGraphByOutId(out_id, cell, out, args);
}
void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out,
const py::args &args) {
AnfNodePtr output_node;
if (graph_info_map_[curr_g_].param_map.count(out_id)) {
output_node = graph_info_map_[curr_g_].param_map[out_id];
@ -776,27 +793,7 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
}
}
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args) {
MS_LOG(INFO) << "GradNet start" << args.size();
std::size_t size = args.size();
auto cell_id = GetId(cell);
if (graph_map_.count(cell_id) != 0) {
MS_LOG(DEBUG) << "GradNet already compiled";
return;
}
MS_LOG(DEBUG) << "GradNet first compiled";
std::vector<AnfNodePtr> new_params;
for (size_t i = 0; i < size; i++) {
ParameterPtr p = std::make_shared<Parameter>(df_builder_);
new_params.push_back(p);
}
MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size();
new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end());
df_builder_->set_parameters(new_params);
resource_->manager()->SetParameters(df_builder_, new_params);
std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weights) {
std::vector<AnfNodePtr> w_args;
if (py::hasattr(weights, "__parameter_tuple__")) {
auto tuple = weights.cast<py::tuple>();
@ -821,12 +818,12 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c
} else {
MS_LOG(EXCEPTION) << "training not paramter_tuple";
}
MS_EXCEPTION_IF_NULL(resource_->func_graph());
auto g = GradGraph(resource_->func_graph(), grad, w_args, size);
resource_->set_func_graph(g);
return w_args;
}
// get the parameters items and add the value to args_spec
abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args) {
abstract::AbstractBasePtrList args_spec;
std::size_t size = args.size();
for (std::size_t i = 0; i < size; i++) {
ValuePtr converted = nullptr;
bool succ = parse::ConvertData(args[i], &converted);
@ -852,6 +849,38 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c
param_node->set_abstract(ptr);
}
}
return args_spec;
}
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args) {
MS_LOG(INFO) << "GradNet start" << args.size();
std::size_t size = args.size();
auto cell_id = GetId(cell);
if (graph_map_.count(cell_id) != 0) {
MS_LOG(DEBUG) << "GradNet already compiled";
return;
}
MS_LOG(DEBUG) << "GradNet first compiled";
std::vector<AnfNodePtr> new_params;
for (size_t i = 0; i < size; i++) {
ParameterPtr p = std::make_shared<Parameter>(df_builder_);
new_params.push_back(p);
}
MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size();
new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end());
df_builder_->set_parameters(new_params);
resource_->manager()->SetParameters(df_builder_, new_params);
std::vector<AnfNodePtr> w_args = GetWeightsArgs(weights);
MS_EXCEPTION_IF_NULL(resource_->func_graph());
auto g = GradGraph(resource_->func_graph(), grad, w_args, size);
resource_->set_func_graph(g);
// get the parameters items and add the value to args_spec
abstract::AbstractBasePtrList args_spec = GetArgsSpec(args);
MS_LOG(DEBUG) << "Args_spec size" << args_spec.size();
resource_->set_args_spec(args_spec);

View File

@ -44,7 +44,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
py::tuple RunOp(const py::args &args);
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args);
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args);
void ClearPyNativeSession();
@ -67,6 +67,9 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
}
void NewGraph(const py::object &cell, const py::args &args);
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
void EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, const py::args &args);
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights);
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args);
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
void Clear(const std::string &flag = "");
void Clean();