forked from OSSInnovation/mindspore
!5962 fix mobilenetv2 loss error in pynative mode
Merge pull request !5962 from chujinjin/fix_mobilenetv2_loss_error
This commit is contained in:
commit
31fa3edb96
|
@ -767,9 +767,14 @@ void PynativeExecutor::SaveOpForwardValue(const std::string &id, const ValuePtr
|
|||
if (iter != op_forward_map_.end()) {
|
||||
return;
|
||||
}
|
||||
op_forward_map_[id] = value;
|
||||
auto tuple_info_iter = obj_to_forward_id_tuple_info_.find(id);
|
||||
ValuePtr temp_value = value;
|
||||
if (tuple_info_iter != obj_to_forward_id_tuple_info_.end()) {
|
||||
temp_value = tuple_info_iter->second;
|
||||
}
|
||||
op_forward_map_[id] = temp_value;
|
||||
MS_LOG(DEBUG) << "Save op forward value: "
|
||||
<< "(" << id << "), " << value;
|
||||
<< "(" << id << "), " << temp_value;
|
||||
}
|
||||
|
||||
void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) {
|
||||
|
@ -799,6 +804,14 @@ void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CN
|
|||
cnode->set_forward(value, op_id);
|
||||
++op_id_map_[id];
|
||||
auto out_id = GetId(out_real);
|
||||
if (py::isinstance<py::tuple>(out_real)) {
|
||||
auto tuple_item = py::cast<py::tuple>(out_real);
|
||||
for (size_t i = 0; i < tuple_item.size(); i++) {
|
||||
auto tuple_item_id = GetId(tuple_item[i]);
|
||||
obj_to_forward_id_[tuple_item_id] = op_id;
|
||||
}
|
||||
obj_to_forward_id_tuple_info_[op_id] = value;
|
||||
}
|
||||
obj_to_forward_id_[out_id] = op_id;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -155,6 +155,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
std::unordered_map<std::string, ValuePtr> op_forward_map_;
|
||||
std::unordered_map<std::string, size_t> op_id_map_;
|
||||
std::unordered_map<std::string, std::string> obj_to_forward_id_;
|
||||
std::unordered_map<std::string, ValuePtr> obj_to_forward_id_tuple_info_;
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
|
||||
std::unordered_map<std::string, FuncGraphPtr> df_builder_map_;
|
||||
// the stack that records the context of graph created, the bottom is the top graph
|
||||
|
|
Loading…
Reference in New Issue