forked from mindspore-Ecosystem/mindspore
fix-bug-avoid-multi-attr-value-be-eliminated-in-pynative-mode
This commit is contained in:
parent
13a66805b3
commit
e1a3c39fac
|
@ -62,7 +62,14 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> transed_nodes;
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
|
||||
auto real_input = AnfAlgo::GetTupleGetItemRealInput(cnode);
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
if (!real_input->isa<Parameter>() && !real_input->isa<ValueNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
|
||||
return nullptr;
|
||||
}
|
||||
bool cnode_input_changed = false;
|
||||
|
|
|
@ -41,6 +41,9 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
|
|||
}
|
||||
// Prim Eliminate (identity)
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x);
|
||||
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// ConstantDuplicateMul
|
||||
auto const_dup_lambda = [&node, &x, &const_, &const_2]() -> AnfNodePtr {
|
||||
|
|
|
@ -393,9 +393,7 @@ bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_i
|
|||
ValuePtr value = parse::data_converter::PyDataToValue(input_object);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
auto input_name = input_names_vec[input_index];
|
||||
op_prim->BeginRecordAddAttr();
|
||||
op_prim->AddAttr(input_name, value);
|
||||
op_prim->EndRecordAddAttr();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
@ -499,6 +497,8 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *te
|
|||
|
||||
opt::ConstInputToAttrInfoRegister reg;
|
||||
bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®);
|
||||
|
||||
op_prim->BeginRecordAddAttr();
|
||||
size_t input_num = op_run_info->op_inputs.size();
|
||||
for (size_t index = 0; index < input_num; ++index) {
|
||||
// convert const input to attr
|
||||
|
@ -513,6 +513,7 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *te
|
|||
std::vector<int> new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask);
|
||||
tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
|
||||
}
|
||||
op_prim->EndRecordAddAttr();
|
||||
}
|
||||
|
||||
void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors) {
|
||||
|
|
Loading…
Reference in New Issue