Erase input replace
Signed-off-by: zjun <zhangjun0@huawei.com>
This commit is contained in:
parent
62b39b5d7e
commit
82653f31f7
|
@ -781,13 +781,19 @@ bool KPynativeCellImpl::BackPropagateOneCNodeWithFPropFuncGraph(const CNodePtr &
|
|||
CNodePtr bprop_cnode;
|
||||
if (by_value) {
|
||||
AnfNodePtrList args_node_list;
|
||||
(void)std::transform(adjoint->op_args().begin(), adjoint->op_args().end(), std::back_inserter(args_node_list),
|
||||
[](const ValuePtr &value) {
|
||||
auto v_node = NewValueNode(value);
|
||||
v_node->set_abstract(value->ToAbstract()->Broaden());
|
||||
return v_node;
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < adjoint->op_args().size(); ++i) {
|
||||
auto input_node = cnode->input(i + 1);
|
||||
if (input_node->isa<Parameter>()) {
|
||||
bool is_weight = input_node->cast<ParameterPtr>()->has_default();
|
||||
if (!is_weight || need_grad_weights_.find(input_node) != need_grad_weights_.end()) {
|
||||
args_node_list.push_back(input_node);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
auto v_node = NewValueNode(adjoint->op_args()[i]);
|
||||
v_node->set_abstract(adjoint->op_args()[i]->ToAbstract()->Broaden());
|
||||
args_node_list.push_back(v_node);
|
||||
}
|
||||
bprop_cnode = GetBPropFromFProp(fprop_fg, args_node_list);
|
||||
} else {
|
||||
const auto &k_node_list = BuildKNodeListFromPrimalCNode(cnode, adjoint);
|
||||
|
|
|
@ -1490,10 +1490,6 @@ void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_e
|
|||
MS_LOG(DEBUG) << "Current op info: " << op_info;
|
||||
|
||||
std::vector<tensor::TensorPtr> all_op_tensors;
|
||||
// Get input tensors
|
||||
for (size_t i = 0; i < op_exec_info->op_inputs.size(); ++i) {
|
||||
TensorValueToTensor(parse::data_converter::PyDataToValue(op_exec_info->op_inputs[i]), &all_op_tensors);
|
||||
}
|
||||
// Get output tensors
|
||||
TensorValueToTensor(parse::data_converter::PyDataToValue(out_real), &all_op_tensors);
|
||||
// Save all tensors info of current op
|
||||
|
@ -1941,6 +1937,8 @@ void GradExecutor::HandleInputArgsForTopCell(const py::args &args, bool is_bprop
|
|||
// Convert input args to parameters for top cell graph in construct.
|
||||
std::vector<ValuePtr> input_param_values;
|
||||
py::args only_tensors = FilterTensorArgs(args);
|
||||
auto df_builder = GetDfbuilder(top_cell_->cell_id());
|
||||
MS_EXCEPTION_IF_NULL(df_builder);
|
||||
for (size_t i = 0; i < only_tensors.size(); ++i) {
|
||||
auto new_param = curr_g_->add_parameter();
|
||||
auto param_i = only_tensors[i];
|
||||
|
@ -1954,6 +1952,7 @@ void GradExecutor::HandleInputArgsForTopCell(const py::args &args, bool is_bprop
|
|||
SetTupleArgsToGraphInfoMap(curr_g_, param_i, new_param, true);
|
||||
SetNodeMapInGraphInfoMap(curr_g_, param_i_id, new_param);
|
||||
SetParamNodeMapInGraphInfoMap(curr_g_, param_i_id, new_param);
|
||||
SetParamNodeMapInGraphInfoMap(df_builder, param_i_id, new_param);
|
||||
}
|
||||
top_cell()->set_k_pynative_cell_ptr(ad::GradPynativeCellBegin(curr_g_->parameters(), input_param_values));
|
||||
}
|
||||
|
@ -1988,6 +1987,7 @@ void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py
|
|||
auto graph_info_cg = std::make_shared<GraphInfo>(cell_id);
|
||||
top_cell()->graph_info_map()[curr_g_] = graph_info_cg;
|
||||
auto df_builder = GetDfbuilder(cell_id);
|
||||
MS_EXCEPTION_IF_NULL(df_builder);
|
||||
auto graph_info_df = std::make_shared<GraphInfo>(cell_id);
|
||||
top_cell()->graph_info_map()[df_builder] = graph_info_df;
|
||||
HandleInputArgsForTopCell(args, false);
|
||||
|
@ -2619,28 +2619,41 @@ void GradExecutor::MakeNestedCnode(const py::object &cell, const std::string &ce
|
|||
SwitchTopcell();
|
||||
auto second_df_builder = GetDfbuilder(top_cell()->cell_id());
|
||||
MS_EXCEPTION_IF_NULL(second_df_builder);
|
||||
auto second_graph_info = top_cell()->graph_info_map()[second_df_builder];
|
||||
auto second_graph_info = top_cell()->graph_info_map().at(second_df_builder);
|
||||
MS_EXCEPTION_IF_NULL(second_graph_info);
|
||||
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(first_grad_fg)};
|
||||
for (size_t i = 0; i < forward_args.size(); ++i) {
|
||||
inputs.emplace_back(GetInput(forward_args[i], false));
|
||||
}
|
||||
|
||||
ValuePtrList weights_args;
|
||||
std::unordered_set<std::string> params_set;
|
||||
std::unordered_set<std::string> params_weights_set;
|
||||
std::unordered_set<std::string> params_inputs_set;
|
||||
for (const auto &sec : second_graph_info->params) {
|
||||
if (sec.second->has_default()) {
|
||||
params_set.emplace(sec.first);
|
||||
params_weights_set.emplace(sec.first);
|
||||
} else {
|
||||
params_inputs_set.insert(sec.first);
|
||||
}
|
||||
}
|
||||
auto manager = Manage({first_grad_fg}, false);
|
||||
// Replace inputs param
|
||||
for (size_t i = 0; i < forward_args.size(); ++i) {
|
||||
const auto &id = GetId(forward_args[i]);
|
||||
if (params_inputs_set.count(id)) {
|
||||
// Can find in second graph
|
||||
const auto &input_param_second = second_graph_info->params.at(id);
|
||||
manager->Replace(first_graph_info->params.at(id), input_param_second);
|
||||
inputs.emplace_back(input_param_second);
|
||||
} else {
|
||||
inputs.emplace_back(GetInput(forward_args[i], false));
|
||||
}
|
||||
}
|
||||
|
||||
// Replace weights param
|
||||
ValuePtrList weights_args;
|
||||
for (const auto &fir : first_graph_info->params) {
|
||||
if (!fir.second->has_default()) {
|
||||
continue;
|
||||
}
|
||||
// No this weight param, need add to second graph
|
||||
if (!params_set.count(fir.first)) {
|
||||
// Second graph no this weight param, need add to second graph
|
||||
if (!params_weights_set.count(fir.first)) {
|
||||
SetParamNodeMapInGraphInfoMap(second_df_builder, fir.first, fir.second);
|
||||
inputs.emplace_back(fir.second);
|
||||
weights_args.emplace_back(fir.second->default_param());
|
||||
|
|
Loading…
Reference in New Issue