forked from mindspore-Ecosystem/mindspore
commit
7c8ac0a90b
|
@ -906,16 +906,16 @@ AnfNodePtr ForwardExecutor::ConstructForwardGraph(const OpExecInfoPtr &op_exec_i
|
|||
inputs.emplace_back(NewValueNode(prim));
|
||||
for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
|
||||
const auto &obj = op_exec_info->op_inputs[i];
|
||||
int64_t op_mask = false;
|
||||
bool op_mask = false;
|
||||
tensor::MetaTensorPtr meta_tensor = nullptr;
|
||||
if (py::isinstance<tensor::MetaTensor>(obj)) {
|
||||
meta_tensor = obj.cast<tensor::MetaTensorPtr>();
|
||||
if (meta_tensor) {
|
||||
op_mask = static_cast<int64_t>(meta_tensor->is_parameter());
|
||||
op_mask = meta_tensor->is_parameter();
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Args i " << i << ", op mask " << op_mask;
|
||||
op_masks.emplace_back(op_mask);
|
||||
op_masks.emplace_back(static_cast<int64_t>(op_mask));
|
||||
|
||||
// Construct grad graph
|
||||
if (grad()->need_construct_graph()) {
|
||||
|
@ -1280,18 +1280,17 @@ AnfNodePtr GradExecutor::GetObjNode(const py::object &obj, const std::string &ob
|
|||
} else {
|
||||
out_obj = PyAttrValue(obj);
|
||||
}
|
||||
for (auto &idx : out.second) {
|
||||
idx = static_cast<size_t>(idx);
|
||||
for (const auto idx : out.second) {
|
||||
std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)};
|
||||
node = curr_g_->NewCNode(tuple_get_item_inputs);
|
||||
if (out_obj->isa<ValueTuple>()) {
|
||||
node->add_input_value(out_obj, "");
|
||||
node->add_input_value(MakeValue(idx), "");
|
||||
out_obj = (*out_obj->cast<ValueTuplePtr>())[idx];
|
||||
out_obj = (*out_obj->cast<ValueTuplePtr>())[static_cast<size_t>(idx)];
|
||||
node->set_forward(out_obj, "");
|
||||
}
|
||||
if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) {
|
||||
auto prim_abs = dyn_cast<abstract::AbstractTuple>(abs)->elements()[idx];
|
||||
auto prim_abs = dyn_cast<abstract::AbstractTuple>(abs)->elements()[static_cast<size_t>(idx)];
|
||||
MS_LOG(DEBUG) << "Set tuple getitem abs " << prim_abs->ToString();
|
||||
node->set_abstract(prim_abs);
|
||||
}
|
||||
|
@ -1664,7 +1663,6 @@ py::object ForwardExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, Pynativ
|
|||
py::tuple result(op_inputs.size());
|
||||
for (size_t i = 0; i < op_inputs.size(); i++) {
|
||||
py::object input = op_inputs[i];
|
||||
auto input_obj_id = GetId(input);
|
||||
auto tensor = py::cast<tensor::TensorPtr>(input);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
if (op_exec_info->op_name == "HookBackward") {
|
||||
|
@ -1844,7 +1842,7 @@ bool GradExecutor::IsNestedGrad() const {
|
|||
return grad_order_ > 1;
|
||||
}
|
||||
|
||||
bool GradExecutor::IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id) {
|
||||
bool GradExecutor::IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id) const {
|
||||
// just compare obj_id, ignore args id
|
||||
return l_cell_id.compare(0, PTR_LEN, r_cell_id, 0, PTR_LEN) == 0;
|
||||
}
|
||||
|
@ -2000,6 +1998,7 @@ void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py
|
|||
}
|
||||
|
||||
void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const py::args &args) {
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
auto cell_id = GetCellId(cell, args);
|
||||
MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id;
|
||||
if (top_cell_ != nullptr && cell_stack_.empty()) {
|
||||
|
@ -2104,14 +2103,14 @@ void GradExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const p
|
|||
}
|
||||
}
|
||||
|
||||
void GradExecutor::CreateMakeTupleNodeForMultiOut(const std::string &cell_id, const FuncGraphPtr &curr_g,
|
||||
const py::object &out, const std::string &out_id) {
|
||||
void GradExecutor::CreateMakeTupleNodeForMultiOut(const FuncGraphPtr &curr_g, const py::object &out,
|
||||
const std::string &out_id) {
|
||||
MS_EXCEPTION_IF_NULL(curr_g);
|
||||
auto out_tuple = out.cast<py::tuple>();
|
||||
// get input node and value
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
ValuePtrList input_args;
|
||||
std::vector<int> value_index;
|
||||
std::vector<size_t> value_index;
|
||||
for (size_t i = 0; i < out_tuple.size(); i++) {
|
||||
auto v = parse::data_converter::PyDataToValue(out_tuple[i]);
|
||||
// Graph have no define for grad
|
||||
|
@ -2141,6 +2140,7 @@ void GradExecutor::CreateMakeTupleNodeForMultiOut(const std::string &cell_id, co
|
|||
}
|
||||
|
||||
void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args) {
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
const auto &cell_id = GetCellId(cell, args);
|
||||
MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id;
|
||||
if (cell_stack_.empty()) {
|
||||
|
@ -2162,7 +2162,7 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const
|
|||
MS_EXCEPTION_IF_NULL(graph_info);
|
||||
if (graph_info->node_map.find(out_id) == graph_info->node_map.end()) {
|
||||
if (py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out)) {
|
||||
CreateMakeTupleNodeForMultiOut(cell_id, curr_g_, out, out_id);
|
||||
CreateMakeTupleNodeForMultiOut(curr_g_, out, out_id);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id;
|
||||
MakeValueNode(out, out_id);
|
||||
|
@ -2268,6 +2268,7 @@ std::string GradExecutor::GetGradCellId(bool has_sens, const py::object &cell, c
|
|||
|
||||
void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell,
|
||||
const py::object &weights, const py::args &args) {
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
MS_EXCEPTION_IF_NULL(grad);
|
||||
auto size = args.size();
|
||||
const auto &cell_id = GetGradCellId(grad->sens_param(), cell, args);
|
||||
|
@ -2746,6 +2747,7 @@ void GradExecutor::GradMsFunction(const py::object &out, const py::args &args) {
|
|||
}
|
||||
|
||||
void GradExecutor::ClearGrad(const py::object &cell, const py::args &args) {
|
||||
MS_LOG(DEBUG) << "Clear top cell grad resource " << GetCellId(cell, args);
|
||||
if (grad_order_ > 0) {
|
||||
--grad_order_;
|
||||
}
|
||||
|
@ -2829,8 +2831,8 @@ void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
|
|||
MS_LOG(DEBUG) << "Grad flag is false";
|
||||
return;
|
||||
}
|
||||
py::object *ret = nullptr;
|
||||
PynativeExecutorTry(grad_executor()->InitGraph, ret, cell, args);
|
||||
py::object ret;
|
||||
PynativeExecutorTry(grad_executor()->InitGraph, &ret, cell, args);
|
||||
}
|
||||
|
||||
void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
|
||||
|
@ -2839,8 +2841,8 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
|
|||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Enter end graph process.";
|
||||
py::object *ret = nullptr;
|
||||
PynativeExecutorTry(grad_executor()->LinkGraph, ret, cell, out, args);
|
||||
py::object ret;
|
||||
PynativeExecutorTry(grad_executor()->LinkGraph, &ret, cell, out, args);
|
||||
MS_LOG(DEBUG) << "Leave end graph process.";
|
||||
}
|
||||
|
||||
|
@ -2850,8 +2852,8 @@ void PynativeExecutor::GradMsFunction(const py::object &out, const py::args &arg
|
|||
|
||||
void PynativeExecutor::GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||
const py::args &args) {
|
||||
py::object *ret = nullptr;
|
||||
PynativeExecutorTry(grad_executor()->GradGraph, ret, grad, cell, weights, args);
|
||||
py::object ret;
|
||||
PynativeExecutorTry(grad_executor()->GradGraph, &ret, grad, cell, weights, args);
|
||||
}
|
||||
|
||||
void PynativeExecutor::Sync() {
|
||||
|
|
|
@ -221,7 +221,7 @@ class GradExecutor {
|
|||
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compiled);
|
||||
// Manage resource when run grad process.
|
||||
bool IsBpropGraph(const std::string &cell_id);
|
||||
bool IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id);
|
||||
bool IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id) const;
|
||||
void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
|
||||
void NewGraphInner(py::object *ret, const py::object &cell, const py::args &args);
|
||||
void EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args);
|
||||
|
@ -253,8 +253,7 @@ class GradExecutor {
|
|||
const std::vector<int64_t> &index) {
|
||||
top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, index);
|
||||
}
|
||||
void CreateMakeTupleNodeForMultiOut(const std::string &cell_id, const FuncGraphPtr &curr_g, const py::object &out,
|
||||
const std::string &out_id);
|
||||
void CreateMakeTupleNodeForMultiOut(const FuncGraphPtr &curr_g, const py::object &out, const std::string &out_id);
|
||||
|
||||
private:
|
||||
bool grad_flag_{false};
|
||||
|
|
Loading…
Reference in New Issue