!597 Transform const to variable when having assign in export process

Merge pull request !597 from ghzl/trans-const-to-variable-in-assign
This commit is contained in:
mindspore-ci-bot 2020-04-26 11:29:46 +08:00 committed by Gitee
commit 7c7d95acf9
1 changed files with 23 additions and 0 deletions

23
mindspore/ccsrc/transform/convert.cc Executable file → Normal file
View File

@ -1155,6 +1155,9 @@ void DfGraphConvertor::SetOpControlInput(const AnfNodePtr node) {
}
}
const std::vector<std::string> trans_var_list = {prim::kPrimAssign->name(), string(kNameAssignAdd),
string(kNameAssignSub)};
void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
OperatorPtr src = Convert(node);
auto &inputs = node->inputs();
@ -1167,6 +1170,26 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
if (IsValueNode<None>(pred)) {
continue;
}
// transform "Const" op to "Variable" op when the next node is "Assign" op.
std::string c_name = GetCNodeFuncName(node);
auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name);
if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) {
std::string name = std::static_pointer_cast<Parameter>(pred)->name();
auto op_itor = op_cache_.find(pred.get());
if (op_itor == op_cache_.end()) {
MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << ".";
}
if (op_itor->second != nullptr &&
(op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") &&
vars_.find(name) != vars_.end()) {
auto variable = std::make_shared<Variable>(name);
auto desc = vars_[name]->GetOutputDesc("y");
(void)variable->update_output_desc_y(desc);
MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << ".";
op_itor->second = variable; // replace parameter with variable
vars_[name] = variable;
}
}
// find in out_hadnle_cache_ first
auto it = out_handle_cache_.find(pred.get());
if (it != out_handle_cache_.end()) {