fix-bug-avoid-multi-attr-value-be-eliminated-in-pynative-mode

This commit is contained in:
lvliang 2020-08-10 09:07:32 +08:00
parent 13a66805b3
commit e1a3c39fac
3 changed files with 14 additions and 3 deletions

View File

@ -62,7 +62,14 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::unordered_map<AnfNodePtr, AnfNodePtr> transed_nodes;
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
auto real_input = AnfAlgo::GetTupleGetItemRealInput(cnode);
MS_EXCEPTION_IF_NULL(real_input);
if (!real_input->isa<Parameter>() && !real_input->isa<ValueNode>()) {
return nullptr;
}
}
if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
return nullptr;
}
bool cnode_input_changed = false;

View File

@ -41,6 +41,9 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
}
// Prim Eliminate (identity)
MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x);
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
return nullptr;
}
// ConstantDuplicateMul
auto const_dup_lambda = [&node, &x, &const_, &const_2]() -> AnfNodePtr {

View File

@ -393,9 +393,7 @@ bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_i
ValuePtr value = parse::data_converter::PyDataToValue(input_object);
MS_EXCEPTION_IF_NULL(value);
auto input_name = input_names_vec[input_index];
op_prim->BeginRecordAddAttr();
op_prim->AddAttr(input_name, value);
op_prim->EndRecordAddAttr();
return true;
}
return false;
@ -499,6 +497,8 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *te
opt::ConstInputToAttrInfoRegister reg;
bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, &reg);
op_prim->BeginRecordAddAttr();
size_t input_num = op_run_info->op_inputs.size();
for (size_t index = 0; index < input_num; ++index) {
// convert const input to attr
@ -513,6 +513,7 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *te
std::vector<int> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
}
op_prim->EndRecordAddAttr();
}
void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors) {