forked from mindspore-Ecosystem/mindspore
Add tuple to tensor pattern
This commit is contained in:
parent
f4622c3f38
commit
bf3db898b8
|
@ -260,6 +260,9 @@ InsertTypeTransformOp::InsertTypeTransformOp(bool multigraph)
|
|||
kTypePairToProcessFunc[{KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TENSOR}] =
|
||||
std::bind(&InsertTypeTransformOp::ProcessTupleUnfoldToTensor, this, std::placeholders::_1, std::placeholders::_2,
|
||||
std::placeholders::_3, std::placeholders::_4);
|
||||
kTypePairToProcessFunc[{KernelObjectType::TUPLE, KernelObjectType::TENSOR}] =
|
||||
std::bind(&InsertTypeTransformOp::ProcessTupleToTensor, this, std::placeholders::_1, std::placeholders::_2,
|
||||
std::placeholders::_3, std::placeholders::_4);
|
||||
}
|
||||
|
||||
const AnfNodePtr InsertTypeTransformOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
|
@ -414,5 +417,27 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTensor(const FuncGraph
|
|||
|
||||
return {tuple_to_tensor};
|
||||
}
|
||||
|
||||
AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
||||
const CNodePtr &node, bool *new_prim) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
// Simply insert TupleToTensor op between 'input' and 'node'.
|
||||
auto prim = NewValueNode(prim::kPrimTupleToTensor);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
AnfNodePtrList inputs = {prim, input};
|
||||
CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(tuple_to_tensor);
|
||||
|
||||
// Set abstract for TupleToTensor op according to user node's input shape and type.
|
||||
size_t input_index = GetInputNodeIndex(input, node);
|
||||
auto abs = GenerateAbsByUserNodeInput(node, input_index);
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString();
|
||||
tuple_to_tensor->set_abstract(abs);
|
||||
SetKernelInfoForNewCNode(tuple_to_tensor);
|
||||
return {tuple_to_tensor};
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -127,6 +127,10 @@ class BACKEND_EXPORT InsertTypeTransformOp : public PatternProcessPass {
|
|||
// Convert TupleUnfold output to Tensor. Firstly insert TupleToTensor op. Then transform TupleUnfold to Tuple.
|
||||
AnfNodePtrList ProcessTupleUnfoldToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
||||
const CNodePtr &node, bool *new_prim);
|
||||
|
||||
// Convert Tuple output to Tensor. Simply add TupleToTensor op.
|
||||
AnfNodePtrList ProcessTupleToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const CNodePtr &node,
|
||||
bool *new_prim);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -460,6 +460,7 @@ constexpr auto kRealTupleGetItem = "RealTupleGetItem";
|
|||
constexpr auto kRealMakeList = "RealMakeList";
|
||||
constexpr auto kRealListGetItem = "RealListGetItem";
|
||||
constexpr auto kTupleSetItem = "TupleSetItem";
|
||||
constexpr auto kTupleAdd = "TupleAdd";
|
||||
|
||||
GVAR_DEF(PrimitivePtr, kPrimExtractGlimpse, std::make_shared<Primitive>(kExtractGlimpse));
|
||||
//
|
||||
|
@ -1479,7 +1480,6 @@ GVAR_DEF(PrimitivePtr, kPrimRaise, std::make_shared<Primitive>("raise"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimJoinedStr, std::make_shared<Primitive>("joinedstr"));
|
||||
|
||||
GVAR_DEF(PrimitivePtr, kPrimMakeTuple, std::make_shared<Primitive>(kMakeTuple));
|
||||
GVAR_DEF(PrimitivePtr, kPrimRealMakeTuple, std::make_shared<Primitive>(kRealMakeTuple));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMakeSlice, std::make_shared<Primitive>("make_slice"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTupleGetItem, std::make_shared<Primitive>(kTupleGetItem));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSliceGetItem, std::make_shared<Primitive>(kSliceGetItem));
|
||||
|
@ -1681,7 +1681,6 @@ GVAR_DEF(PrimitivePtr, kPrimTupleReversed, std::make_shared<Primitive>("tuple_re
|
|||
GVAR_DEF(PrimitivePtr, kPrimReducedShape, std::make_shared<Primitive>("reduced_shape"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTupleDiv, std::make_shared<Primitive>("tuple_div"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTupleToArray, std::make_shared<Primitive>("tuple_to_array"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTupleToTensor, std::make_shared<Primitive>(kTupleToTensor));
|
||||
GVAR_DEF(PrimitivePtr, kPrimShapeMul, std::make_shared<Primitive>("shape_mul"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTupleEqual, std::make_shared<Primitive>("tuple_equal"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimListEqual, std::make_shared<Primitive>("list_equal"));
|
||||
|
@ -1730,6 +1729,11 @@ GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferSample, std::make_shared<Primit
|
|||
GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferDestroy, std::make_shared<Primitive>("ReservoirReplayBufferDestroy"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimOCRDetectionPreHandle, std::make_shared<Primitive>("OCRDetectionPreHandle"));
|
||||
|
||||
// Real tuple and list ops.
|
||||
GVAR_DEF(PrimitivePtr, kPrimTupleToTensor, std::make_shared<Primitive>(kTupleToTensor));
|
||||
GVAR_DEF(PrimitivePtr, kPrimRealMakeTuple, std::make_shared<Primitive>(kRealMakeTuple));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTupleAdd, std::make_shared<Primitive>(kTupleAdd));
|
||||
|
||||
// AdamApplyOne
|
||||
GVAR_DEF(PrimitivePtr, kPrimAdamApplyOne, std::make_shared<Primitive>("AdamApplyOne"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAdamApplyOneAssign, std::make_shared<Primitive>("AdamApplyOneAssign"));
|
||||
|
|
|
@ -46,6 +46,7 @@ class TestInsertTypeTransformOp : public BackendCommon {
|
|||
void SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr, AnfNodePtr *split_ptr,
|
||||
AnfNodePtr *tuple_add1_ptr, AnfNodePtr *tuple_add2_ptr);
|
||||
void SetTupleUnfoldToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple, AnfNodePtr *reshape);
|
||||
void SetTupleToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *reshape_ptr);
|
||||
|
||||
void SetKernelBuildInfo(const AnfNodePtr &node, const std::vector<std::string> &input_formats,
|
||||
const std::vector<TypeId> &input_types, const std::vector<std::string> &output_formats,
|
||||
|
@ -165,7 +166,7 @@ void TestInsertTypeTransformOp::SetTupleUnfoldToTensorKernelBuildInfo(const Func
|
|||
{kNumberTypeFloat32}, {KernelObjectType::TENSOR, KernelObjectType::TENSOR},
|
||||
{KernelObjectType::TENSOR});
|
||||
|
||||
auto input_node3 = reshape->input(0);
|
||||
auto input_node3 = reshape->input(1);
|
||||
SetKernelBuildInfo(input_node3, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32},
|
||||
{KernelObjectType::TENSOR}, {KernelObjectType::TENSOR});
|
||||
|
||||
|
@ -184,6 +185,23 @@ void TestInsertTypeTransformOp::SetTupleUnfoldToTensorKernelBuildInfo(const Func
|
|||
{KernelObjectType::TENSOR}, {KernelObjectType::TENSOR});
|
||||
}
|
||||
|
||||
void TestInsertTypeTransformOp::SetTupleToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *reshape_ptr) {
|
||||
auto ret = g->get_return();
|
||||
EXPECT_NE(ret->input(1), nullptr);
|
||||
auto reshape = ret->input(1)->cast<CNodePtr>();
|
||||
*reshape_ptr = reshape;
|
||||
SetKernelBuildInfo(reshape, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"},
|
||||
{kNumberTypeFloat32}, {KernelObjectType::TENSOR, KernelObjectType::TENSOR},
|
||||
{KernelObjectType::TENSOR});
|
||||
auto input2 = reshape->input(2);
|
||||
SetKernelBuildInfo(input2, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32}, {KernelObjectType::TENSOR},
|
||||
{KernelObjectType::TUPLE});
|
||||
|
||||
auto input1 = reshape->input(1);
|
||||
SetKernelBuildInfo(input1, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32}, {KernelObjectType::TENSOR},
|
||||
{KernelObjectType::TENSOR});
|
||||
}
|
||||
|
||||
void TestInsertTypeTransformOp::SetKernelBuildInfo(
|
||||
const AnfNodePtr &node, const std::vector<std::string> &input_formats, const std::vector<TypeId> &input_types,
|
||||
const std::vector<std::string> &output_formats, const std::vector<TypeId> &output_types,
|
||||
|
@ -219,7 +237,8 @@ void TestInsertTypeTransformOp::CheckOutputKernelInfo(const AnfNodePtr &node, si
|
|||
|
||||
/// Feature: Dynamic shape.
|
||||
/// Description: Test TupleUnfold to TupleUnfold type transforming pass.
|
||||
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python.
|
||||
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph
|
||||
/// expressed by python.
|
||||
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_unfold_transform) {
|
||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_unfold_transform", "before");
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
|
@ -250,7 +269,8 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_unfold_transform) {
|
|||
|
||||
/// Feature: Dynamic shape.
|
||||
/// Description: Test TupleUnfold to Tuple type transforming pass.
|
||||
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python.
|
||||
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph
|
||||
/// expressed by python.
|
||||
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) {
|
||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_transform", "before");
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
|
@ -277,7 +297,8 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) {
|
|||
|
||||
/// Feature: Dynamic shape.
|
||||
/// Description: Test TupleUnfold to Tensor type transforming pass.
|
||||
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python.
|
||||
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph
|
||||
/// expressed by python.
|
||||
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tensor_transform) {
|
||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tensor_transform", "before");
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
|
@ -303,5 +324,31 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tensor_transform) {
|
|||
ASSERT_TRUE(g_after != nullptr);
|
||||
EXPECT_TRUE(CheckEqualGraph(func_graph, g_after));
|
||||
}
|
||||
|
||||
/// Description: Test Tuple to Tensor type transforming pass.
|
||||
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python.
|
||||
TEST_F(TestInsertTypeTransformOp, test_tuple_to_tensor_transform) {
|
||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_to_tensor_transform", "before");
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
std::vector<int64_t> shp_x{4, 2};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
std::vector<int64_t> shp_y{2};
|
||||
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, y_abstract};
|
||||
auto func_graph = GetFuncGraph(g, args_spec_list);
|
||||
ASSERT_TRUE(func_graph != nullptr);
|
||||
AnfNodePtr reshape;
|
||||
SetTupleToTensorKernelBuildInfo(func_graph, &reshape);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::InsertTypeTransformOp>());
|
||||
optimizer->AddPassManager(pm);
|
||||
optimizer->Optimize(func_graph);
|
||||
|
||||
FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_tuple_to_tensor_transform", "after");
|
||||
ASSERT_TRUE(g_after != nullptr);
|
||||
EXPECT_TRUE(CheckEqualGraph(func_graph, g_after));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -124,3 +124,27 @@ def test_tuple_unfold_to_tensor_transform(tag):
|
|||
return res
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_tuple_to_tensor_transform(tag):
|
||||
"""
|
||||
Feature: Dynamic shape.
|
||||
Description: Test Tuple to Tensor transforming pass.
|
||||
Expectation: The 'after' graph is identical to the graph after this pass.
|
||||
"""
|
||||
fns = FnDict()
|
||||
reshape = P.Reshape()
|
||||
tuple_to_tensor = Primitive('TupleToTensor')
|
||||
|
||||
@fns
|
||||
def before(x, y):
|
||||
res = reshape(x, y)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(x, y):
|
||||
res = tuple_to_tensor(y)
|
||||
res = reshape(x, res)
|
||||
return res
|
||||
|
||||
return fns[tag]
|
||||
|
|
Loading…
Reference in New Issue