Signed-off-by: zjun <zhangjun0@huawei.com>
This commit is contained in:
zjun 2021-06-25 10:24:43 +08:00
parent 43174475e6
commit 1af8735a19
2 changed files with 48 additions and 53 deletions

View File

@ -163,14 +163,8 @@ std::string GetId(const py::object &obj) {
return prefix;
}
if (py::isinstance<Cell>(obj)) {
auto cell = py::cast<CellPtr>(obj);
MS_EXCEPTION_IF_NULL(cell);
return std::to_string(reinterpret_cast<size_t>(cell.get()));
} else {
py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
return py::cast<std::string>(ret);
}
py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
return py::cast<std::string>(ret);
}
std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) {
@ -623,10 +617,10 @@ void UpdateTensorInfo(const tensor::TensorPtr &new_tensor, const std::vector<ten
for (auto &pre_tensor : pre_tensors) {
MS_EXCEPTION_IF_NULL(pre_tensor);
MS_LOG(DEBUG) << "Replace Old tensor " << pre_tensor.get() << " id " << pre_tensor->id()
<< " device_address: " << pre_tensor->device_address()->GetMutablePtr() << " shape and type "
<< " device_address: " << pre_tensor->device_address() << " shape and type "
<< pre_tensor->GetShapeAndDataTypeInfo() << " with New tensor " << new_tensor.get() << " id "
<< new_tensor->id() << " device_address " << new_tensor->device_address()->GetMutablePtr()
<< " shape and dtype " << new_tensor->GetShapeAndDataTypeInfo();
<< new_tensor->id() << " device_address " << new_tensor->device_address() << " shape and dtype "
<< new_tensor->GetShapeAndDataTypeInfo();
pre_tensor->set_shape(new_tensor->shape());
pre_tensor->set_data_type(new_tensor->data_type());
if (device_target != kCPUDevice) {
@ -2564,7 +2558,9 @@ void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const p
MS_LOG(DEBUG) << "Eval run end " << value.ToString();
*ret = BaseRefToPyData(value);
// Clear device memory resource of top cell when it has been ran.
if (top_cell()->is_topest()) {
auto has_higher_order = std::any_of(top_cell_list_.begin(), top_cell_list_.end(),
[](const TopCellInfoPtr &value) { return !value->is_topest(); });
if (top_cell()->is_topest() && !has_higher_order) {
top_cell()->ClearDeviceMemory();
}
// High order
@ -2607,52 +2603,54 @@ void GradExecutor::MakeNestedCnode(const py::object &cell, const std::string &ce
MS_EXCEPTION_IF_NULL(first_grad_fg);
DumpGraphIR("first_grad_fg.ir", first_grad_fg);
auto out_id = GetId(out);
auto first_df_builder = GetDfbuilder(cell_id);
MS_EXCEPTION_IF_NULL(first_df_builder);
auto first_graph_info = top_cell()->graph_info_map().at(first_df_builder);
MS_EXCEPTION_IF_NULL(first_graph_info);
SwitchTopcell();
auto second_df_builder = GetDfbuilder(top_cell()->cell_id());
MS_EXCEPTION_IF_NULL(second_df_builder);
auto second_graph_info = top_cell()->graph_info_map()[second_df_builder];
MS_EXCEPTION_IF_NULL(second_graph_info);
std::vector<AnfNodePtr> inputs{NewValueNode(first_grad_fg)};
for (size_t i = 0; i < forward_args.size(); ++i) {
inputs.emplace_back(GetInput(forward_args[i], false));
}
ValuePtrList weights_args;
auto manager = Manage({first_grad_fg}, false);
for (const auto &fir : first_graph_info->params) {
auto p = fir.second->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(p);
if (!p->has_default()) {
continue;
}
for (const auto &sec : second_graph_info->params) {
MS_LOG(DEBUG) << "Param name " << fir.first << " ptr " << fir.second.get();
if (fir.second->name() == sec.second->name()) {
manager->Replace(fir.second, sec.second);
inputs.emplace_back(sec.second);
weights_args.emplace_back(sec.second->default_param());
break;
}
}
}
pipeline::ResourcePtr r = std::make_shared<pipeline::Resource>();
r->manager()->AddFuncGraph(first_grad_fg);
FuncGraphPtr second_grad_fg = ad::Grad(first_grad_fg, r);
DumpGraphIR("second_grad_fg.ir", second_grad_fg);
r->Clean();
auto first_df_builder = GetDfbuilder(cell_id);
MS_EXCEPTION_IF_NULL(first_df_builder);
auto first_graph_info = top_cell()->graph_info_map().at(first_df_builder);
MS_EXCEPTION_IF_NULL(first_graph_info);
SwitchTopcell();
auto second_df_builder = GetDfbuilder(top_cell()->cell_id());
MS_EXCEPTION_IF_NULL(second_df_builder);
for (const auto &it : first_graph_info->params) {
SetParamNodeMapInGraphInfoMap(second_df_builder, it.first, it.second);
}
std::vector<AnfNodePtr> inputs{NewValueNode(first_grad_fg)};
for (size_t i = 0; i < forward_args.size(); ++i) {
inputs.emplace_back(GetInput(forward_args[i], false));
}
ValuePtrList weights_args;
auto first_grad_all_params = first_grad_fg->parameters();
for (const auto &it : first_grad_all_params) {
auto p = it->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(p);
if (!p->has_default()) {
continue;
}
for (const auto &w : first_graph_info->params) {
auto param = w.second;
if (param->has_default() && param->name() == p->name()) {
inputs.emplace_back(param);
weights_args.emplace_back(param->default_param());
}
}
}
MS_LOG(DEBUG) << "Get pre graph ptr " << curr_g().get();
auto cnode = curr_g_->NewCNode(inputs);
SetTupleArgsToGraphInfoMap(curr_g_, out, cnode);
SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode);
auto cnode = curr_g()->NewCNode(inputs);
auto out_id = GetId(out);
SetTupleArgsToGraphInfoMap(curr_g(), out, cnode);
SetNodeMapInGraphInfoMap(curr_g(), out_id, cnode);
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString();
// Get input values
ValuePtrList input_args;
for (size_t i = 0; i < forward_args.size(); ++i) {
auto arg = parse::data_converter::PyDataToValue(forward_args[i]);
@ -2660,7 +2658,7 @@ void GradExecutor::MakeNestedCnode(const py::object &cell, const std::string &ce
input_args.emplace_back(arg);
}
input_args.insert(input_args.end(), weights_args.begin(), weights_args.end());
// Get output value
// Get output values
py::object new_out;
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME) && !py::isinstance<py::tuple>(out)) {
new_out = py::make_tuple(out);
@ -2669,7 +2667,6 @@ void GradExecutor::MakeNestedCnode(const py::object &cell, const std::string &ce
}
auto out_value = parse::data_converter::PyDataToValue(new_out);
MS_EXCEPTION_IF_NULL(out_value);
// Add out and dout
if (!top_cell()->k_pynative_cell_ptr()->KPynativeWithFProp(cnode, input_args, out_value, second_grad_fg)) {
MS_LOG(EXCEPTION) << "Failed to run ad grad for second grad graph " << cnode->ToString();
}

View File

@ -308,8 +308,6 @@ class ForwardExecutor {
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
PynativeStatusCode *status);
void GetInputsArgsSpec(const OpExecInfoPtr &op_exec_info, abstract::AbstractBasePtrList *args_spec_list);
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
const abstract::AbstractBasePtr &abs, const std::string &id, size_t index);
void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
bool *prim_cache_hit);
void GetOpOutput(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,