!48536 Optimizer reusing infer op

Merge pull request !48536 from ZPaC/optimize-reusing-infer-op
This commit is contained in:
i-robot 2023-02-08 07:13:41 +00:00 committed by Gitee
commit 9a2cdb8dd9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 30 additions and 10 deletions

View File

@ -615,7 +615,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTensor(const FuncGraph
MS_EXCEPTION_IF_NULL(node);
// Use TupleToTensor op as the input of this node. Then TupleUnfoldToTuple pattern will be matched.
auto prim = NewValueNode(prim::kPrimTupleToTensor);
auto prim = NewValueNode(std::make_shared<Primitive>(prim::kTupleToTensor));
MS_EXCEPTION_IF_NULL(prim);
AnfNodePtrList inputs = {prim, input};
CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs);
@ -624,10 +624,20 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTensor(const FuncGraph
// Data type of the tensor should be set as an attr of TupleToTensor op.
size_t input_index = GetInputNodeIndex(input, node);
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
// There might be nested tuples, we need to find one step further to get element's data type.
if (data_type == kObjectTypeTuple) {
auto seq_abs = input->abstract();
MS_EXCEPTION_IF_NULL(seq_abs);
if (!seq_abs->isa<abstract::AbstractSequence>()) {
MS_LOG(EXCEPTION) << "Input " << input->DebugString() << " is not tuple output";
}
data_type = seq_abs->cast<abstract::AbstractSequencePtr>()->ElementsType()[kIndex0]->type_id();
MS_LOG(DEBUG) << "Input " << input->DebugString() << " real data type is " << data_type;
}
common::AnfAlgo::SetNodeAttr(kAttrDType, TypeIdToType(data_type), tuple_to_tensor);
// Set abstract for TupleToTensor op according to user node's input shape and type.
auto abs = GenerateAbsByOpInfer(prim::kPrimTupleToTensor, {input});
auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tuple_to_tensor), {input});
MS_EXCEPTION_IF_NULL(abs);
MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString();
tuple_to_tensor->set_abstract(abs);
@ -690,7 +700,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTensor(const FuncGraphPtr &f
MS_EXCEPTION_IF_NULL(node);
// Simply insert TupleToTensor op between 'input' and 'node'.
auto prim = NewValueNode(prim::kPrimTupleToTensor);
auto prim = NewValueNode(std::make_shared<Primitive>(prim::kTupleToTensor));
MS_EXCEPTION_IF_NULL(prim);
AnfNodePtrList inputs = {prim, input};
CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs);
@ -699,10 +709,20 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTensor(const FuncGraphPtr &f
// Data type of the tensor should be set as an attr of TupleToTensor op.
size_t input_index = GetInputNodeIndex(input, node);
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
// There might be nested tuples, we need to find one step further to get element's data type.
if (data_type == kObjectTypeTuple) {
auto seq_abs = input->abstract();
MS_EXCEPTION_IF_NULL(seq_abs);
if (!seq_abs->isa<abstract::AbstractSequence>()) {
MS_LOG(EXCEPTION) << "Input " << input->DebugString() << " is not tuple output";
}
data_type = seq_abs->cast<abstract::AbstractSequencePtr>()->ElementsType()[kIndex0]->type_id();
MS_LOG(DEBUG) << "Input " << input->DebugString() << " real data type is " << data_type;
}
common::AnfAlgo::SetNodeAttr(kAttrDType, TypeIdToType(data_type), tuple_to_tensor);
// Set abstract for TupleToTensor op according to user node's input shape and type.
auto abs = GenerateAbsByOpInfer(prim::kPrimTupleToTensor, {input});
auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tuple_to_tensor), {input});
MS_EXCEPTION_IF_NULL(abs);
MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString();
tuple_to_tensor->set_abstract(abs);
@ -718,7 +738,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessScalarToTensor(const FuncGraphPtr &
MS_EXCEPTION_IF_NULL(node);
// Simply insert ScalarToTensor op between 'input' and 'node'.
auto prim = NewValueNode(prim::kPrimScalarToTensor);
auto prim = NewValueNode(std::make_shared<Primitive>(prim::kScalarToTensor));
MS_EXCEPTION_IF_NULL(prim);
AnfNodePtrList inputs = {prim, input};
CNodePtr scalar_to_tensor = func_graph->NewCNode(inputs);
@ -729,7 +749,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessScalarToTensor(const FuncGraphPtr &
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
common::AnfAlgo::SetNodeAttr("dtype", TypeIdToType(data_type), scalar_to_tensor);
auto abs = GenerateAbsByOpInfer(prim::kPrimScalarToTensor, {input});
auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(scalar_to_tensor), {input});
MS_EXCEPTION_IF_NULL(abs);
MS_LOG(DEBUG) << "Abstract for ScalarToTensor op is " << abs->ToString();
scalar_to_tensor->set_abstract(abs);
@ -745,13 +765,13 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTensorToTuple(const FuncGraphPtr &f
MS_EXCEPTION_IF_NULL(node);
// Create TensorToTuple op.
auto prim = NewValueNode(prim::kPrimTensorToTuple);
auto prim = NewValueNode(std::make_shared<Primitive>(prim::kTensorToTuple));
MS_EXCEPTION_IF_NULL(prim);
AnfNodePtrList inputs = {prim, input};
CNodePtr tensor_to_tuple = func_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(tensor_to_tuple);
auto abs = GenerateAbsByOpInfer(prim::kPrimTensorToTuple, {input});
auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tensor_to_tuple), {input});
MS_EXCEPTION_IF_NULL(abs);
MS_LOG(DEBUG) << "Abstract for TensorToTuple op is " << abs->ToString();
tensor_to_tuple->set_abstract(abs);
@ -767,13 +787,13 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTensorToScalar(const FuncGraphPtr &
MS_EXCEPTION_IF_NULL(node);
// Create TensorToScalar op.
auto prim = NewValueNode(prim::kPrimTensorToScalar);
auto prim = NewValueNode(std::make_shared<Primitive>(prim::kTensorToScalar));
MS_EXCEPTION_IF_NULL(prim);
AnfNodePtrList inputs = {prim, input};
CNodePtr tensor_to_scalar = func_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(tensor_to_scalar);
auto abs = GenerateAbsByOpInfer(prim::kPrimTensorToScalar, {input});
auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tensor_to_scalar), {input});
MS_EXCEPTION_IF_NULL(abs);
MS_LOG(DEBUG) << "Abstract for TensorToScalar op is " << abs->ToString();
tensor_to_scalar->set_abstract(abs);