forked from mindspore-Ecosystem/mindspore
!597 Transform const to variable when having assign in export process
Merge pull request !597 from ghzl/trans-const-to-variable-in-assign
This commit is contained in:
commit
7c7d95acf9
|
@ -1155,6 +1155,9 @@ void DfGraphConvertor::SetOpControlInput(const AnfNodePtr node) {
|
|||
}
|
||||
}
|
||||
|
||||
const std::vector<std::string> trans_var_list = {prim::kPrimAssign->name(), string(kNameAssignAdd),
|
||||
string(kNameAssignSub)};
|
||||
|
||||
void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
|
||||
OperatorPtr src = Convert(node);
|
||||
auto &inputs = node->inputs();
|
||||
|
@ -1167,6 +1170,26 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
|
|||
if (IsValueNode<None>(pred)) {
|
||||
continue;
|
||||
}
|
||||
// transform "Const" op to "Variable" op when the next node is "Assign" op.
|
||||
std::string c_name = GetCNodeFuncName(node);
|
||||
auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name);
|
||||
if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) {
|
||||
std::string name = std::static_pointer_cast<Parameter>(pred)->name();
|
||||
auto op_itor = op_cache_.find(pred.get());
|
||||
if (op_itor == op_cache_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << ".";
|
||||
}
|
||||
if (op_itor->second != nullptr &&
|
||||
(op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") &&
|
||||
vars_.find(name) != vars_.end()) {
|
||||
auto variable = std::make_shared<Variable>(name);
|
||||
auto desc = vars_[name]->GetOutputDesc("y");
|
||||
(void)variable->update_output_desc_y(desc);
|
||||
MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << ".";
|
||||
op_itor->second = variable; // replace parameter with variable
|
||||
vars_[name] = variable;
|
||||
}
|
||||
}
|
||||
// find in out_hadnle_cache_ first
|
||||
auto it = out_handle_cache_.find(pred.get());
|
||||
if (it != out_handle_cache_.end()) {
|
||||
|
|
Loading…
Reference in New Issue