forked from mindspore-Ecosystem/mindspore
!48536 Optimizer reusing infer op
Merge pull request !48536 from ZPaC/optimize-reusing-infer-op
This commit is contained in:
commit
9a2cdb8dd9
|
@ -615,7 +615,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTensor(const FuncGraph
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
// Use TupleToTensor op as the input of this node. Then TupleUnfoldToTuple pattern will be matched.
|
||||
auto prim = NewValueNode(prim::kPrimTupleToTensor);
|
||||
auto prim = NewValueNode(std::make_shared<Primitive>(prim::kTupleToTensor));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
AnfNodePtrList inputs = {prim, input};
|
||||
CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs);
|
||||
|
@ -624,10 +624,20 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTensor(const FuncGraph
|
|||
// Data type of the tensor should be set as an attr of TupleToTensor op.
|
||||
size_t input_index = GetInputNodeIndex(input, node);
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
|
||||
// There might be nested tuples, we need to find one step further to get element's data type.
|
||||
if (data_type == kObjectTypeTuple) {
|
||||
auto seq_abs = input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(seq_abs);
|
||||
if (!seq_abs->isa<abstract::AbstractSequence>()) {
|
||||
MS_LOG(EXCEPTION) << "Input " << input->DebugString() << " is not tuple output";
|
||||
}
|
||||
data_type = seq_abs->cast<abstract::AbstractSequencePtr>()->ElementsType()[kIndex0]->type_id();
|
||||
MS_LOG(DEBUG) << "Input " << input->DebugString() << " real data type is " << data_type;
|
||||
}
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDType, TypeIdToType(data_type), tuple_to_tensor);
|
||||
|
||||
// Set abstract for TupleToTensor op according to user node's input shape and type.
|
||||
auto abs = GenerateAbsByOpInfer(prim::kPrimTupleToTensor, {input});
|
||||
auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tuple_to_tensor), {input});
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString();
|
||||
tuple_to_tensor->set_abstract(abs);
|
||||
|
@ -690,7 +700,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTensor(const FuncGraphPtr &f
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
// Simply insert TupleToTensor op between 'input' and 'node'.
|
||||
auto prim = NewValueNode(prim::kPrimTupleToTensor);
|
||||
auto prim = NewValueNode(std::make_shared<Primitive>(prim::kTupleToTensor));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
AnfNodePtrList inputs = {prim, input};
|
||||
CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs);
|
||||
|
@ -699,10 +709,20 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTensor(const FuncGraphPtr &f
|
|||
// Data type of the tensor should be set as an attr of TupleToTensor op.
|
||||
size_t input_index = GetInputNodeIndex(input, node);
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
|
||||
// There might be nested tuples, we need to find one step further to get element's data type.
|
||||
if (data_type == kObjectTypeTuple) {
|
||||
auto seq_abs = input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(seq_abs);
|
||||
if (!seq_abs->isa<abstract::AbstractSequence>()) {
|
||||
MS_LOG(EXCEPTION) << "Input " << input->DebugString() << " is not tuple output";
|
||||
}
|
||||
data_type = seq_abs->cast<abstract::AbstractSequencePtr>()->ElementsType()[kIndex0]->type_id();
|
||||
MS_LOG(DEBUG) << "Input " << input->DebugString() << " real data type is " << data_type;
|
||||
}
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDType, TypeIdToType(data_type), tuple_to_tensor);
|
||||
|
||||
// Set abstract for TupleToTensor op according to user node's input shape and type.
|
||||
auto abs = GenerateAbsByOpInfer(prim::kPrimTupleToTensor, {input});
|
||||
auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tuple_to_tensor), {input});
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString();
|
||||
tuple_to_tensor->set_abstract(abs);
|
||||
|
@ -718,7 +738,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessScalarToTensor(const FuncGraphPtr &
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
// Simply insert ScalarToTensor op between 'input' and 'node'.
|
||||
auto prim = NewValueNode(prim::kPrimScalarToTensor);
|
||||
auto prim = NewValueNode(std::make_shared<Primitive>(prim::kScalarToTensor));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
AnfNodePtrList inputs = {prim, input};
|
||||
CNodePtr scalar_to_tensor = func_graph->NewCNode(inputs);
|
||||
|
@ -729,7 +749,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessScalarToTensor(const FuncGraphPtr &
|
|||
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
|
||||
common::AnfAlgo::SetNodeAttr("dtype", TypeIdToType(data_type), scalar_to_tensor);
|
||||
|
||||
auto abs = GenerateAbsByOpInfer(prim::kPrimScalarToTensor, {input});
|
||||
auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(scalar_to_tensor), {input});
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(DEBUG) << "Abstract for ScalarToTensor op is " << abs->ToString();
|
||||
scalar_to_tensor->set_abstract(abs);
|
||||
|
@ -745,13 +765,13 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTensorToTuple(const FuncGraphPtr &f
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
// Create TensorToTuple op.
|
||||
auto prim = NewValueNode(prim::kPrimTensorToTuple);
|
||||
auto prim = NewValueNode(std::make_shared<Primitive>(prim::kTensorToTuple));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
AnfNodePtrList inputs = {prim, input};
|
||||
CNodePtr tensor_to_tuple = func_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(tensor_to_tuple);
|
||||
|
||||
auto abs = GenerateAbsByOpInfer(prim::kPrimTensorToTuple, {input});
|
||||
auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tensor_to_tuple), {input});
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(DEBUG) << "Abstract for TensorToTuple op is " << abs->ToString();
|
||||
tensor_to_tuple->set_abstract(abs);
|
||||
|
@ -767,13 +787,13 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTensorToScalar(const FuncGraphPtr &
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
// Create TensorToScalar op.
|
||||
auto prim = NewValueNode(prim::kPrimTensorToScalar);
|
||||
auto prim = NewValueNode(std::make_shared<Primitive>(prim::kTensorToScalar));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
AnfNodePtrList inputs = {prim, input};
|
||||
CNodePtr tensor_to_scalar = func_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(tensor_to_scalar);
|
||||
|
||||
auto abs = GenerateAbsByOpInfer(prim::kPrimTensorToScalar, {input});
|
||||
auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tensor_to_scalar), {input});
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(DEBUG) << "Abstract for TensorToScalar op is " << abs->ToString();
|
||||
tensor_to_scalar->set_abstract(abs);
|
||||
|
|
Loading…
Reference in New Issue