forked from mindspore-Ecosystem/mindspore
parent
43174475e6
commit
1af8735a19
|
@ -163,14 +163,8 @@ std::string GetId(const py::object &obj) {
|
|||
return prefix;
|
||||
}
|
||||
|
||||
if (py::isinstance<Cell>(obj)) {
|
||||
auto cell = py::cast<CellPtr>(obj);
|
||||
MS_EXCEPTION_IF_NULL(cell);
|
||||
return std::to_string(reinterpret_cast<size_t>(cell.get()));
|
||||
} else {
|
||||
py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
|
||||
return py::cast<std::string>(ret);
|
||||
}
|
||||
py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj);
|
||||
return py::cast<std::string>(ret);
|
||||
}
|
||||
|
||||
std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) {
|
||||
|
@ -623,10 +617,10 @@ void UpdateTensorInfo(const tensor::TensorPtr &new_tensor, const std::vector<ten
|
|||
for (auto &pre_tensor : pre_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(pre_tensor);
|
||||
MS_LOG(DEBUG) << "Replace Old tensor " << pre_tensor.get() << " id " << pre_tensor->id()
|
||||
<< " device_address: " << pre_tensor->device_address()->GetMutablePtr() << " shape and type "
|
||||
<< " device_address: " << pre_tensor->device_address() << " shape and type "
|
||||
<< pre_tensor->GetShapeAndDataTypeInfo() << " with New tensor " << new_tensor.get() << " id "
|
||||
<< new_tensor->id() << " device_address " << new_tensor->device_address()->GetMutablePtr()
|
||||
<< " shape and dtype " << new_tensor->GetShapeAndDataTypeInfo();
|
||||
<< new_tensor->id() << " device_address " << new_tensor->device_address() << " shape and dtype "
|
||||
<< new_tensor->GetShapeAndDataTypeInfo();
|
||||
pre_tensor->set_shape(new_tensor->shape());
|
||||
pre_tensor->set_data_type(new_tensor->data_type());
|
||||
if (device_target != kCPUDevice) {
|
||||
|
@ -2564,7 +2558,9 @@ void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const p
|
|||
MS_LOG(DEBUG) << "Eval run end " << value.ToString();
|
||||
*ret = BaseRefToPyData(value);
|
||||
// Clear device memory resource of top cell when it has been ran.
|
||||
if (top_cell()->is_topest()) {
|
||||
auto has_higher_order = std::any_of(top_cell_list_.begin(), top_cell_list_.end(),
|
||||
[](const TopCellInfoPtr &value) { return !value->is_topest(); });
|
||||
if (top_cell()->is_topest() && !has_higher_order) {
|
||||
top_cell()->ClearDeviceMemory();
|
||||
}
|
||||
// High order
|
||||
|
@ -2607,52 +2603,54 @@ void GradExecutor::MakeNestedCnode(const py::object &cell, const std::string &ce
|
|||
MS_EXCEPTION_IF_NULL(first_grad_fg);
|
||||
DumpGraphIR("first_grad_fg.ir", first_grad_fg);
|
||||
|
||||
auto out_id = GetId(out);
|
||||
auto first_df_builder = GetDfbuilder(cell_id);
|
||||
MS_EXCEPTION_IF_NULL(first_df_builder);
|
||||
auto first_graph_info = top_cell()->graph_info_map().at(first_df_builder);
|
||||
MS_EXCEPTION_IF_NULL(first_graph_info);
|
||||
SwitchTopcell();
|
||||
auto second_df_builder = GetDfbuilder(top_cell()->cell_id());
|
||||
MS_EXCEPTION_IF_NULL(second_df_builder);
|
||||
auto second_graph_info = top_cell()->graph_info_map()[second_df_builder];
|
||||
MS_EXCEPTION_IF_NULL(second_graph_info);
|
||||
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(first_grad_fg)};
|
||||
for (size_t i = 0; i < forward_args.size(); ++i) {
|
||||
inputs.emplace_back(GetInput(forward_args[i], false));
|
||||
}
|
||||
|
||||
ValuePtrList weights_args;
|
||||
auto manager = Manage({first_grad_fg}, false);
|
||||
for (const auto &fir : first_graph_info->params) {
|
||||
auto p = fir.second->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(p);
|
||||
if (!p->has_default()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto &sec : second_graph_info->params) {
|
||||
MS_LOG(DEBUG) << "Param name " << fir.first << " ptr " << fir.second.get();
|
||||
if (fir.second->name() == sec.second->name()) {
|
||||
manager->Replace(fir.second, sec.second);
|
||||
inputs.emplace_back(sec.second);
|
||||
weights_args.emplace_back(sec.second->default_param());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pipeline::ResourcePtr r = std::make_shared<pipeline::Resource>();
|
||||
r->manager()->AddFuncGraph(first_grad_fg);
|
||||
FuncGraphPtr second_grad_fg = ad::Grad(first_grad_fg, r);
|
||||
DumpGraphIR("second_grad_fg.ir", second_grad_fg);
|
||||
r->Clean();
|
||||
|
||||
auto first_df_builder = GetDfbuilder(cell_id);
|
||||
MS_EXCEPTION_IF_NULL(first_df_builder);
|
||||
auto first_graph_info = top_cell()->graph_info_map().at(first_df_builder);
|
||||
MS_EXCEPTION_IF_NULL(first_graph_info);
|
||||
|
||||
SwitchTopcell();
|
||||
|
||||
auto second_df_builder = GetDfbuilder(top_cell()->cell_id());
|
||||
MS_EXCEPTION_IF_NULL(second_df_builder);
|
||||
for (const auto &it : first_graph_info->params) {
|
||||
SetParamNodeMapInGraphInfoMap(second_df_builder, it.first, it.second);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(first_grad_fg)};
|
||||
for (size_t i = 0; i < forward_args.size(); ++i) {
|
||||
inputs.emplace_back(GetInput(forward_args[i], false));
|
||||
}
|
||||
ValuePtrList weights_args;
|
||||
auto first_grad_all_params = first_grad_fg->parameters();
|
||||
for (const auto &it : first_grad_all_params) {
|
||||
auto p = it->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(p);
|
||||
if (!p->has_default()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto &w : first_graph_info->params) {
|
||||
auto param = w.second;
|
||||
if (param->has_default() && param->name() == p->name()) {
|
||||
inputs.emplace_back(param);
|
||||
weights_args.emplace_back(param->default_param());
|
||||
}
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Get pre graph ptr " << curr_g().get();
|
||||
auto cnode = curr_g_->NewCNode(inputs);
|
||||
SetTupleArgsToGraphInfoMap(curr_g_, out, cnode);
|
||||
SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode);
|
||||
auto cnode = curr_g()->NewCNode(inputs);
|
||||
auto out_id = GetId(out);
|
||||
SetTupleArgsToGraphInfoMap(curr_g(), out, cnode);
|
||||
SetNodeMapInGraphInfoMap(curr_g(), out_id, cnode);
|
||||
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString();
|
||||
|
||||
// Get input values
|
||||
ValuePtrList input_args;
|
||||
for (size_t i = 0; i < forward_args.size(); ++i) {
|
||||
auto arg = parse::data_converter::PyDataToValue(forward_args[i]);
|
||||
|
@ -2660,7 +2658,7 @@ void GradExecutor::MakeNestedCnode(const py::object &cell, const std::string &ce
|
|||
input_args.emplace_back(arg);
|
||||
}
|
||||
input_args.insert(input_args.end(), weights_args.begin(), weights_args.end());
|
||||
// Get output value
|
||||
// Get output values
|
||||
py::object new_out;
|
||||
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME) && !py::isinstance<py::tuple>(out)) {
|
||||
new_out = py::make_tuple(out);
|
||||
|
@ -2669,7 +2667,6 @@ void GradExecutor::MakeNestedCnode(const py::object &cell, const std::string &ce
|
|||
}
|
||||
auto out_value = parse::data_converter::PyDataToValue(new_out);
|
||||
MS_EXCEPTION_IF_NULL(out_value);
|
||||
// Add out and dout
|
||||
if (!top_cell()->k_pynative_cell_ptr()->KPynativeWithFProp(cnode, input_args, out_value, second_grad_fg)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to run ad grad for second grad graph " << cnode->ToString();
|
||||
}
|
||||
|
|
|
@ -308,8 +308,6 @@ class ForwardExecutor {
|
|||
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
|
||||
PynativeStatusCode *status);
|
||||
void GetInputsArgsSpec(const OpExecInfoPtr &op_exec_info, abstract::AbstractBasePtrList *args_spec_list);
|
||||
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
|
||||
const abstract::AbstractBasePtr &abs, const std::string &id, size_t index);
|
||||
void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
|
||||
bool *prim_cache_hit);
|
||||
void GetOpOutput(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
|
||||
|
|
Loading…
Reference in New Issue