Add tuple to tensor pattern

This commit is contained in:
ZPaC 2022-12-17 17:51:34 +08:00
parent f4622c3f38
commit bf3db898b8
5 changed files with 110 additions and 6 deletions

View File

@ -260,6 +260,9 @@ InsertTypeTransformOp::InsertTypeTransformOp(bool multigraph)
kTypePairToProcessFunc[{KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TENSOR}] =
std::bind(&InsertTypeTransformOp::ProcessTupleUnfoldToTensor, this, std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3, std::placeholders::_4);
kTypePairToProcessFunc[{KernelObjectType::TUPLE, KernelObjectType::TENSOR}] =
std::bind(&InsertTypeTransformOp::ProcessTupleToTensor, this, std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3, std::placeholders::_4);
}
const AnfNodePtr InsertTypeTransformOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
@ -414,5 +417,27 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTensor(const FuncGraph
return {tuple_to_tensor};
}
AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
const CNodePtr &node, bool *new_prim) {
MS_EXCEPTION_IF_NULL(input);
MS_EXCEPTION_IF_NULL(node);
// Simply insert TupleToTensor op between 'input' and 'node'.
auto prim = NewValueNode(prim::kPrimTupleToTensor);
MS_EXCEPTION_IF_NULL(prim);
AnfNodePtrList inputs = {prim, input};
CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(tuple_to_tensor);
// Set abstract for TupleToTensor op according to user node's input shape and type.
size_t input_index = GetInputNodeIndex(input, node);
auto abs = GenerateAbsByUserNodeInput(node, input_index);
MS_EXCEPTION_IF_NULL(abs);
MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString();
tuple_to_tensor->set_abstract(abs);
SetKernelInfoForNewCNode(tuple_to_tensor);
return {tuple_to_tensor};
}
} // namespace opt
} // namespace mindspore

View File

@ -127,6 +127,10 @@ class BACKEND_EXPORT InsertTypeTransformOp : public PatternProcessPass {
// Convert TupleUnfold output to Tensor. Firstly insert TupleToTensor op. Then transform TupleUnfold to Tuple.
AnfNodePtrList ProcessTupleUnfoldToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
const CNodePtr &node, bool *new_prim);
// Convert Tuple output to Tensor. Simply add TupleToTensor op.
AnfNodePtrList ProcessTupleToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const CNodePtr &node,
bool *new_prim);
};
} // namespace opt
} // namespace mindspore

View File

@ -460,6 +460,7 @@ constexpr auto kRealTupleGetItem = "RealTupleGetItem";
constexpr auto kRealMakeList = "RealMakeList";
constexpr auto kRealListGetItem = "RealListGetItem";
constexpr auto kTupleSetItem = "TupleSetItem";
constexpr auto kTupleAdd = "TupleAdd";
GVAR_DEF(PrimitivePtr, kPrimExtractGlimpse, std::make_shared<Primitive>(kExtractGlimpse));
//
@ -1479,7 +1480,6 @@ GVAR_DEF(PrimitivePtr, kPrimRaise, std::make_shared<Primitive>("raise"));
GVAR_DEF(PrimitivePtr, kPrimJoinedStr, std::make_shared<Primitive>("joinedstr"));
GVAR_DEF(PrimitivePtr, kPrimMakeTuple, std::make_shared<Primitive>(kMakeTuple));
GVAR_DEF(PrimitivePtr, kPrimRealMakeTuple, std::make_shared<Primitive>(kRealMakeTuple));
GVAR_DEF(PrimitivePtr, kPrimMakeSlice, std::make_shared<Primitive>("make_slice"));
GVAR_DEF(PrimitivePtr, kPrimTupleGetItem, std::make_shared<Primitive>(kTupleGetItem));
GVAR_DEF(PrimitivePtr, kPrimSliceGetItem, std::make_shared<Primitive>(kSliceGetItem));
@ -1681,7 +1681,6 @@ GVAR_DEF(PrimitivePtr, kPrimTupleReversed, std::make_shared<Primitive>("tuple_re
GVAR_DEF(PrimitivePtr, kPrimReducedShape, std::make_shared<Primitive>("reduced_shape"));
GVAR_DEF(PrimitivePtr, kPrimTupleDiv, std::make_shared<Primitive>("tuple_div"));
GVAR_DEF(PrimitivePtr, kPrimTupleToArray, std::make_shared<Primitive>("tuple_to_array"));
GVAR_DEF(PrimitivePtr, kPrimTupleToTensor, std::make_shared<Primitive>(kTupleToTensor));
GVAR_DEF(PrimitivePtr, kPrimShapeMul, std::make_shared<Primitive>("shape_mul"));
GVAR_DEF(PrimitivePtr, kPrimTupleEqual, std::make_shared<Primitive>("tuple_equal"));
GVAR_DEF(PrimitivePtr, kPrimListEqual, std::make_shared<Primitive>("list_equal"));
@ -1730,6 +1729,11 @@ GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferSample, std::make_shared<Primit
GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferDestroy, std::make_shared<Primitive>("ReservoirReplayBufferDestroy"));
GVAR_DEF(PrimitivePtr, kPrimOCRDetectionPreHandle, std::make_shared<Primitive>("OCRDetectionPreHandle"));
// Real tuple and list ops.
GVAR_DEF(PrimitivePtr, kPrimTupleToTensor, std::make_shared<Primitive>(kTupleToTensor));
GVAR_DEF(PrimitivePtr, kPrimRealMakeTuple, std::make_shared<Primitive>(kRealMakeTuple));
GVAR_DEF(PrimitivePtr, kPrimTupleAdd, std::make_shared<Primitive>(kTupleAdd));
// AdamApplyOne
GVAR_DEF(PrimitivePtr, kPrimAdamApplyOne, std::make_shared<Primitive>("AdamApplyOne"));
GVAR_DEF(PrimitivePtr, kPrimAdamApplyOneAssign, std::make_shared<Primitive>("AdamApplyOneAssign"));

View File

@ -46,6 +46,7 @@ class TestInsertTypeTransformOp : public BackendCommon {
void SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr, AnfNodePtr *split_ptr,
AnfNodePtr *tuple_add1_ptr, AnfNodePtr *tuple_add2_ptr);
void SetTupleUnfoldToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple, AnfNodePtr *reshape);
void SetTupleToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *reshape_ptr);
void SetKernelBuildInfo(const AnfNodePtr &node, const std::vector<std::string> &input_formats,
const std::vector<TypeId> &input_types, const std::vector<std::string> &output_formats,
@ -165,7 +166,7 @@ void TestInsertTypeTransformOp::SetTupleUnfoldToTensorKernelBuildInfo(const Func
{kNumberTypeFloat32}, {KernelObjectType::TENSOR, KernelObjectType::TENSOR},
{KernelObjectType::TENSOR});
auto input_node3 = reshape->input(0);
auto input_node3 = reshape->input(1);
SetKernelBuildInfo(input_node3, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32},
{KernelObjectType::TENSOR}, {KernelObjectType::TENSOR});
@ -184,6 +185,23 @@ void TestInsertTypeTransformOp::SetTupleUnfoldToTensorKernelBuildInfo(const Func
{KernelObjectType::TENSOR}, {KernelObjectType::TENSOR});
}
void TestInsertTypeTransformOp::SetTupleToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *reshape_ptr) {
auto ret = g->get_return();
EXPECT_NE(ret->input(1), nullptr);
auto reshape = ret->input(1)->cast<CNodePtr>();
*reshape_ptr = reshape;
SetKernelBuildInfo(reshape, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"},
{kNumberTypeFloat32}, {KernelObjectType::TENSOR, KernelObjectType::TENSOR},
{KernelObjectType::TENSOR});
auto input2 = reshape->input(2);
SetKernelBuildInfo(input2, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32}, {KernelObjectType::TENSOR},
{KernelObjectType::TUPLE});
auto input1 = reshape->input(1);
SetKernelBuildInfo(input1, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32}, {KernelObjectType::TENSOR},
{KernelObjectType::TENSOR});
}
void TestInsertTypeTransformOp::SetKernelBuildInfo(
const AnfNodePtr &node, const std::vector<std::string> &input_formats, const std::vector<TypeId> &input_types,
const std::vector<std::string> &output_formats, const std::vector<TypeId> &output_types,
@ -219,7 +237,8 @@ void TestInsertTypeTransformOp::CheckOutputKernelInfo(const AnfNodePtr &node, si
/// Feature: Dynamic shape.
/// Description: Test TupleUnfold to TupleUnfold type transforming pass.
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python.
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph
/// expressed by python.
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_unfold_transform) {
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_unfold_transform", "before");
ASSERT_TRUE(g != nullptr);
@ -250,7 +269,8 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_unfold_transform) {
/// Feature: Dynamic shape.
/// Description: Test TupleUnfold to Tuple type transforming pass.
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python.
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph
/// expressed by python.
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) {
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_transform", "before");
ASSERT_TRUE(g != nullptr);
@ -277,7 +297,8 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) {
/// Feature: Dynamic shape.
/// Description: Test TupleUnfold to Tensor type transforming pass.
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python.
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph
/// expressed by python.
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tensor_transform) {
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tensor_transform", "before");
ASSERT_TRUE(g != nullptr);
@ -303,5 +324,31 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tensor_transform) {
ASSERT_TRUE(g_after != nullptr);
EXPECT_TRUE(CheckEqualGraph(func_graph, g_after));
}
/// Description: Test Tuple to Tensor type transforming pass.
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python.
TEST_F(TestInsertTypeTransformOp, test_tuple_to_tensor_transform) {
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_to_tensor_transform", "before");
ASSERT_TRUE(g != nullptr);
std::vector<int64_t> shp_x{4, 2};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
std::vector<int64_t> shp_y{2};
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
AbstractBasePtrList args_spec_list{x_abstract, y_abstract};
auto func_graph = GetFuncGraph(g, args_spec_list);
ASSERT_TRUE(func_graph != nullptr);
AnfNodePtr reshape;
SetTupleToTensorKernelBuildInfo(func_graph, &reshape);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::InsertTypeTransformOp>());
optimizer->AddPassManager(pm);
optimizer->Optimize(func_graph);
FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_tuple_to_tensor_transform", "after");
ASSERT_TRUE(g_after != nullptr);
EXPECT_TRUE(CheckEqualGraph(func_graph, g_after));
}
} // namespace opt
} // namespace mindspore

View File

@ -124,3 +124,27 @@ def test_tuple_unfold_to_tensor_transform(tag):
return res
return fns[tag]
def test_tuple_to_tensor_transform(tag):
"""
Feature: Dynamic shape.
Description: Test Tuple to Tensor transforming pass.
Expectation: The 'after' graph is identical to the graph after this pass.
"""
fns = FnDict()
reshape = P.Reshape()
tuple_to_tensor = Primitive('TupleToTensor')
@fns
def before(x, y):
res = reshape(x, y)
return res
@fns
def after(x, y):
res = tuple_to_tensor(y)
res = reshape(x, res)
return res
return fns[tag]