Merge pull request !20521 from zjun/fix_pclint_mas
This commit is contained in:
i-robot 2021-07-21 01:12:12 +00:00 committed by Gitee
commit 7c8ac0a90b
2 changed files with 23 additions and 22 deletions

View File

@ -906,16 +906,16 @@ AnfNodePtr ForwardExecutor::ConstructForwardGraph(const OpExecInfoPtr &op_exec_i
inputs.emplace_back(NewValueNode(prim));
for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
const auto &obj = op_exec_info->op_inputs[i];
int64_t op_mask = false;
bool op_mask = false;
tensor::MetaTensorPtr meta_tensor = nullptr;
if (py::isinstance<tensor::MetaTensor>(obj)) {
meta_tensor = obj.cast<tensor::MetaTensorPtr>();
if (meta_tensor) {
op_mask = static_cast<int64_t>(meta_tensor->is_parameter());
op_mask = meta_tensor->is_parameter();
}
}
MS_LOG(DEBUG) << "Args i " << i << ", op mask " << op_mask;
op_masks.emplace_back(op_mask);
op_masks.emplace_back(static_cast<int64_t>(op_mask));
// Construct grad graph
if (grad()->need_construct_graph()) {
@ -1280,18 +1280,17 @@ AnfNodePtr GradExecutor::GetObjNode(const py::object &obj, const std::string &ob
} else {
out_obj = PyAttrValue(obj);
}
for (auto &idx : out.second) {
idx = static_cast<size_t>(idx);
for (const auto idx : out.second) {
std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)};
node = curr_g_->NewCNode(tuple_get_item_inputs);
if (out_obj->isa<ValueTuple>()) {
node->add_input_value(out_obj, "");
node->add_input_value(MakeValue(idx), "");
out_obj = (*out_obj->cast<ValueTuplePtr>())[idx];
out_obj = (*out_obj->cast<ValueTuplePtr>())[static_cast<size_t>(idx)];
node->set_forward(out_obj, "");
}
if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) {
auto prim_abs = dyn_cast<abstract::AbstractTuple>(abs)->elements()[idx];
auto prim_abs = dyn_cast<abstract::AbstractTuple>(abs)->elements()[static_cast<size_t>(idx)];
MS_LOG(DEBUG) << "Set tuple getitem abs " << prim_abs->ToString();
node->set_abstract(prim_abs);
}
@ -1664,7 +1663,6 @@ py::object ForwardExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, Pynativ
py::tuple result(op_inputs.size());
for (size_t i = 0; i < op_inputs.size(); i++) {
py::object input = op_inputs[i];
auto input_obj_id = GetId(input);
auto tensor = py::cast<tensor::TensorPtr>(input);
MS_EXCEPTION_IF_NULL(tensor);
if (op_exec_info->op_name == "HookBackward") {
@ -1844,7 +1842,7 @@ bool GradExecutor::IsNestedGrad() const {
return grad_order_ > 1;
}
bool GradExecutor::IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id) {
bool GradExecutor::IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id) const {
// just compare obj_id, ignore args id
return l_cell_id.compare(0, PTR_LEN, r_cell_id, 0, PTR_LEN) == 0;
}
@ -2000,6 +1998,7 @@ void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py
}
void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const py::args &args) {
MS_EXCEPTION_IF_NULL(ret);
auto cell_id = GetCellId(cell, args);
MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id;
if (top_cell_ != nullptr && cell_stack_.empty()) {
@ -2104,14 +2103,14 @@ void GradExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const p
}
}
void GradExecutor::CreateMakeTupleNodeForMultiOut(const std::string &cell_id, const FuncGraphPtr &curr_g,
const py::object &out, const std::string &out_id) {
void GradExecutor::CreateMakeTupleNodeForMultiOut(const FuncGraphPtr &curr_g, const py::object &out,
const std::string &out_id) {
MS_EXCEPTION_IF_NULL(curr_g);
auto out_tuple = out.cast<py::tuple>();
// get input node and value
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
ValuePtrList input_args;
std::vector<int> value_index;
std::vector<size_t> value_index;
for (size_t i = 0; i < out_tuple.size(); i++) {
auto v = parse::data_converter::PyDataToValue(out_tuple[i]);
// Graph have no define for grad
@ -2141,6 +2140,7 @@ void GradExecutor::CreateMakeTupleNodeForMultiOut(const std::string &cell_id, co
}
void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args) {
MS_EXCEPTION_IF_NULL(ret);
const auto &cell_id = GetCellId(cell, args);
MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id;
if (cell_stack_.empty()) {
@ -2162,7 +2162,7 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const
MS_EXCEPTION_IF_NULL(graph_info);
if (graph_info->node_map.find(out_id) == graph_info->node_map.end()) {
if (py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out)) {
CreateMakeTupleNodeForMultiOut(cell_id, curr_g_, out, out_id);
CreateMakeTupleNodeForMultiOut(curr_g_, out, out_id);
} else {
MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id;
MakeValueNode(out, out_id);
@ -2268,6 +2268,7 @@ std::string GradExecutor::GetGradCellId(bool has_sens, const py::object &cell, c
void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell,
const py::object &weights, const py::args &args) {
MS_EXCEPTION_IF_NULL(ret);
MS_EXCEPTION_IF_NULL(grad);
auto size = args.size();
const auto &cell_id = GetGradCellId(grad->sens_param(), cell, args);
@ -2746,6 +2747,7 @@ void GradExecutor::GradMsFunction(const py::object &out, const py::args &args) {
}
void GradExecutor::ClearGrad(const py::object &cell, const py::args &args) {
MS_LOG(DEBUG) << "Clear top cell grad resource " << GetCellId(cell, args);
if (grad_order_ > 0) {
--grad_order_;
}
@ -2829,8 +2831,8 @@ void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
MS_LOG(DEBUG) << "Grad flag is false";
return;
}
py::object *ret = nullptr;
PynativeExecutorTry(grad_executor()->InitGraph, ret, cell, args);
py::object ret;
PynativeExecutorTry(grad_executor()->InitGraph, &ret, cell, args);
}
void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
@ -2839,8 +2841,8 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
return;
}
MS_LOG(DEBUG) << "Enter end graph process.";
py::object *ret = nullptr;
PynativeExecutorTry(grad_executor()->LinkGraph, ret, cell, out, args);
py::object ret;
PynativeExecutorTry(grad_executor()->LinkGraph, &ret, cell, out, args);
MS_LOG(DEBUG) << "Leave end graph process.";
}
@ -2850,8 +2852,8 @@ void PynativeExecutor::GradMsFunction(const py::object &out, const py::args &arg
void PynativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args) {
py::object *ret = nullptr;
PynativeExecutorTry(grad_executor()->GradGraph, ret, grad, cell, weights, args);
py::object ret;
PynativeExecutorTry(grad_executor()->GradGraph, &ret, grad, cell, weights, args);
}
void PynativeExecutor::Sync() {

View File

@ -221,7 +221,7 @@ class GradExecutor {
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compiled);
// Manage resource when run grad process.
bool IsBpropGraph(const std::string &cell_id);
bool IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id);
bool IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id) const;
void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
void NewGraphInner(py::object *ret, const py::object &cell, const py::args &args);
void EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args);
@ -253,8 +253,7 @@ class GradExecutor {
const std::vector<int64_t> &index) {
top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, index);
}
void CreateMakeTupleNodeForMultiOut(const std::string &cell_id, const FuncGraphPtr &curr_g, const py::object &out,
const std::string &out_id);
void CreateMakeTupleNodeForMultiOut(const FuncGraphPtr &curr_g, const py::object &out, const std::string &out_id);
private:
bool grad_flag_{false};