forked from mindspore-Ecosystem/mindspore
Enhance ut test cast
This commit is contained in:
parent
ff541f82d0
commit
8d70549e4b
|
@ -134,6 +134,18 @@ AnfNodePtr CreateRealMakeTupleByTupleUnfoldInput(const FuncGraphPtr &func_graph,
|
|||
KernelBuildInfoPtr real_make_tuple_build_info = AnfAlgo::GetSelectKernelBuildInfo(real_make_tuple);
|
||||
MS_EXCEPTION_IF_NULL(real_make_tuple_build_info);
|
||||
real_make_tuple_build_info->SetInputsKernelObjectType({KernelObjectType::TUPLE_UNFOLD});
|
||||
|
||||
// Extend tuple_unfold inputs.
|
||||
abstract::AbstractTuplePtr tuple_unfold_abs =
|
||||
node_with_tuple_unfold_output->abstract()->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_unfold_abs);
|
||||
auto builder = AnfAlgo::GetSelectKernelBuildInfo(real_make_tuple);
|
||||
MS_EXCEPTION_IF_NULL(builder);
|
||||
std::vector<std::string> inputs_format{tuple_unfold_abs->size(), builder->GetInputFormat(kIndex0)};
|
||||
std::vector<TypeId> inputs_type{tuple_unfold_abs->size(), builder->GetInputDeviceType(kIndex0)};
|
||||
builder->SetInputsFormat(inputs_format);
|
||||
builder->SetInputsDeviceType(inputs_type);
|
||||
|
||||
return real_make_tuple;
|
||||
}
|
||||
|
||||
|
@ -257,16 +269,6 @@ abstract::AbstractBasePtr GenerateAbsByOpInfer(const PrimitivePtr &primitive, co
|
|||
return abs;
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr GenerateAbsByUserNodeInput(const CNodePtr &user_node, size_t input_index) {
|
||||
MS_EXCEPTION_IF_NULL(user_node);
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(user_node, input_index);
|
||||
auto type_id = AnfAlgo::GetInputDeviceDataType(user_node, input_index);
|
||||
// Defaultly the input is a tensor. Other cases should be handled respectively.
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type_id), shape);
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
return abs;
|
||||
}
|
||||
|
||||
std::string GenerateOutputFormatForNewCNode(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimRealMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimTupleToTensor)) {
|
||||
|
@ -559,7 +561,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTensor(const FuncGraph
|
|||
// Data type of the tensor should be set as an attr of TupleToTensor op.
|
||||
size_t input_index = GetInputNodeIndex(input, node);
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
|
||||
common::AnfAlgo::SetNodeAttr("dtype", TypeIdToType(data_type), tuple_to_tensor);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDType, TypeIdToType(data_type), tuple_to_tensor);
|
||||
|
||||
// Set abstract for TupleToTensor op according to user node's input shape and type.
|
||||
auto abs = GenerateAbsByOpInfer(prim::kPrimTupleToTensor, {input});
|
||||
|
@ -625,7 +627,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTensor(const FuncGraphPtr &f
|
|||
// Data type of the tensor should be set as an attr of TupleToTensor op.
|
||||
size_t input_index = GetInputNodeIndex(input, node);
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
|
||||
common::AnfAlgo::SetNodeAttr("dtype", TypeIdToType(data_type), tuple_to_tensor);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDType, TypeIdToType(data_type), tuple_to_tensor);
|
||||
|
||||
// Set abstract for TupleToTensor op according to user node's input shape and type.
|
||||
auto abs = GenerateAbsByOpInfer(prim::kPrimTupleToTensor, {input});
|
||||
|
|
|
@ -105,9 +105,6 @@ abstract::AbstractBasePtr GenerateAbsByOpInfer(const PrimitivePtr &primitive);
|
|||
|
||||
// Generate abstract, format and object type for newly created node.
|
||||
// They can be generated in multiple ways because new node is not processed by kernel selecting method.
|
||||
|
||||
// Generate abstract according to input type and shape which is already set to the user node.
|
||||
abstract::AbstractBasePtr GenerateAbsByUserNodeInput(const CNodePtr &user_node, size_t input_index);
|
||||
std::string GenerateOutputFormatForNewCNode(const CNodePtr &cnode);
|
||||
void GenerateKernelObjectTypeForNewCNode(const CNodePtr &cnode, std::vector<KernelObjectType> *input_obj_type,
|
||||
std::vector<KernelObjectType> *output_obj_type);
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/pass_manager.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
|
||||
#define private public
|
||||
#define protected public
|
||||
|
@ -45,9 +47,12 @@ class TestInsertTypeTransformOp : public BackendCommon {
|
|||
AnfNodePtr *split2_ptr, AnfNodePtr *addn2_ptr);
|
||||
void SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr, AnfNodePtr *split_ptr,
|
||||
AnfNodePtr *seq_add1_ptr, AnfNodePtr *seq_add2_ptr);
|
||||
void SetTupleUnfoldToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple, AnfNodePtr *reshape);
|
||||
void SetTupleUnfoldToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple, AnfNodePtr *reshape_ptr);
|
||||
void SetTupleToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *shape_ptr);
|
||||
void SetTupleToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *reshape_ptr);
|
||||
void SetTensorToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *seq_add);
|
||||
void SetScalarToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *add_ptr);
|
||||
void SetTensorToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *seq_add_ptr);
|
||||
void SetTensorToScalarKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *scalar_to_tensor_ptr);
|
||||
|
||||
void SetKernelBuildInfo(const AnfNodePtr &node, const std::vector<std::string> &input_formats,
|
||||
const std::vector<TypeId> &input_types, const std::vector<std::string> &output_formats,
|
||||
|
@ -192,6 +197,20 @@ void TestInsertTypeTransformOp::SetTupleUnfoldToTensorKernelBuildInfo(const Func
|
|||
{KernelObjectType::TENSOR}, {KernelObjectType::TENSOR});
|
||||
}
|
||||
|
||||
void TestInsertTypeTransformOp::SetTupleToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *shape_ptr) {
|
||||
auto ret = g->get_return();
|
||||
EXPECT_NE(ret->input(1), nullptr);
|
||||
auto tuple_get_item = ret->input(1)->cast<CNodePtr>();
|
||||
SetKernelBuildInfo(tuple_get_item, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeInt64}, {"NCHW"},
|
||||
{kNumberTypeFloat32}, {KernelObjectType::TUPLE_UNFOLD, KernelObjectType::SCALAR},
|
||||
{KernelObjectType::SCALAR});
|
||||
|
||||
auto shape = tuple_get_item->input(1)->cast<CNodePtr>();
|
||||
*shape_ptr = shape;
|
||||
SetKernelBuildInfo(shape, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeInt64}, {KernelObjectType::SCALAR},
|
||||
{KernelObjectType::TUPLE});
|
||||
}
|
||||
|
||||
void TestInsertTypeTransformOp::SetTupleToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *reshape_ptr) {
|
||||
auto ret = g->get_return();
|
||||
EXPECT_NE(ret->input(1), nullptr);
|
||||
|
@ -209,6 +228,15 @@ void TestInsertTypeTransformOp::SetTupleToTensorKernelBuildInfo(const FuncGraphP
|
|||
{KernelObjectType::TENSOR});
|
||||
}
|
||||
|
||||
void TestInsertTypeTransformOp::SetScalarToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *add_ptr) {
|
||||
auto ret = g->get_return();
|
||||
EXPECT_NE(ret->input(1), nullptr);
|
||||
auto add = ret->input(1)->cast<CNodePtr>();
|
||||
*add_ptr = add;
|
||||
SetKernelBuildInfo(add, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32},
|
||||
{KernelObjectType::TENSOR, KernelObjectType::TENSOR}, {KernelObjectType::TENSOR});
|
||||
}
|
||||
|
||||
void TestInsertTypeTransformOp::SetTensorToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *seq_add_ptr) {
|
||||
auto ret = g->get_return();
|
||||
EXPECT_NE(ret->input(1), nullptr);
|
||||
|
@ -226,6 +254,16 @@ void TestInsertTypeTransformOp::SetTensorToTupleKernelBuildInfo(const FuncGraphP
|
|||
{KernelObjectType::TENSOR});
|
||||
}
|
||||
|
||||
void TestInsertTypeTransformOp::SetTensorToScalarKernelBuildInfo(const FuncGraphPtr &g,
|
||||
AnfNodePtr *scalar_to_tensor_ptr) {
|
||||
auto ret = g->get_return();
|
||||
EXPECT_NE(ret->input(1), nullptr);
|
||||
auto scalar_to_tensor = ret->input(1)->cast<CNodePtr>();
|
||||
*scalar_to_tensor_ptr = scalar_to_tensor;
|
||||
SetKernelBuildInfo(scalar_to_tensor, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32},
|
||||
{KernelObjectType::SCALAR}, {KernelObjectType::TENSOR});
|
||||
}
|
||||
|
||||
void TestInsertTypeTransformOp::SetKernelBuildInfo(
|
||||
const AnfNodePtr &node, const std::vector<std::string> &input_formats, const std::vector<TypeId> &input_types,
|
||||
const std::vector<std::string> &output_formats, const std::vector<TypeId> &output_types,
|
||||
|
@ -318,6 +356,12 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) {
|
|||
optimizer->AddPassManager(pm);
|
||||
optimizer->Optimize(func_graph);
|
||||
|
||||
auto real_make_tuple2 = func_graph->get_return()->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>();
|
||||
ASSERT_TRUE(IsPrimitiveCNode(real_make_tuple2, prim::kPrimRealMakeTuple));
|
||||
ASSERT_TRUE(real_make_tuple2->abstract()->isa<abstract::AbstractTuple>());
|
||||
auto obj_type = AnfAlgo::GetOutputKernelObjectType(real_make_tuple2, 0);
|
||||
ASSERT_TRUE(obj_type == KernelObjectType::TUPLE);
|
||||
|
||||
FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_transform", "after");
|
||||
ASSERT_TRUE(g_after != nullptr);
|
||||
EXPECT_TRUE(CheckEqualGraph(func_graph, g_after));
|
||||
|
@ -348,11 +392,51 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tensor_transform) {
|
|||
optimizer->AddPassManager(pm);
|
||||
optimizer->Optimize(func_graph);
|
||||
|
||||
auto tuple_to_tensor = func_graph->get_return()->input(1)->cast<CNodePtr>()->input(2)->cast<CNodePtr>();
|
||||
ASSERT_TRUE(IsPrimitiveCNode(tuple_to_tensor, prim::kPrimTupleToTensor));
|
||||
auto real_make_tuple = tuple_to_tensor->input(1);
|
||||
ASSERT_TRUE(IsPrimitiveCNode(real_make_tuple, prim::kPrimRealMakeTuple));
|
||||
ASSERT_TRUE(real_make_tuple->abstract()->isa<abstract::AbstractTuple>());
|
||||
auto obj_type = AnfAlgo::GetOutputKernelObjectType(real_make_tuple, 0);
|
||||
ASSERT_TRUE(obj_type == KernelObjectType::TUPLE);
|
||||
|
||||
FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tensor_transform", "after");
|
||||
ASSERT_TRUE(g_after != nullptr);
|
||||
EXPECT_TRUE(CheckEqualGraph(func_graph, g_after));
|
||||
}
|
||||
|
||||
/// Feature: Dynamic shape.
|
||||
/// Description: Test Tuple to TupleUnfold type transforming pass.
|
||||
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python.
|
||||
TEST_F(TestInsertTypeTransformOp, DISABLED_test_tuple_to_tuple_unfold_transform) {
|
||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_to_tuple_unfold_transform", "before");
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
std::vector<int64_t> shp_x{4, 2};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
AbstractBasePtrList args_spec_list{x_abstract};
|
||||
auto func_graph = GetFuncGraph(g, args_spec_list);
|
||||
ASSERT_TRUE(func_graph != nullptr);
|
||||
AnfNodePtr shape;
|
||||
SetTupleToTupleUnfoldKernelBuildInfo(func_graph, &shape);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::InsertTypeTransformOp>());
|
||||
optimizer->AddPassManager(pm);
|
||||
optimizer->Optimize(func_graph);
|
||||
|
||||
auto ret = func_graph->get_return();
|
||||
auto real_tuple_get_item = ret->input(1)->cast<CNodePtr>();
|
||||
ASSERT_TRUE(IsPrimitiveCNode(real_tuple_get_item, prim::kPrimRealTupleGetItem));
|
||||
ASSERT_TRUE(real_tuple_get_item->abstract()->isa<abstract::AbstractScalar>());
|
||||
auto obj_type = AnfAlgo::GetOutputKernelObjectType(real_tuple_get_item, 0);
|
||||
ASSERT_TRUE(obj_type == KernelObjectType::SCALAR);
|
||||
|
||||
FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_tuple_to_tuple_unfold_transform", "after");
|
||||
ASSERT_TRUE(g_after != nullptr);
|
||||
EXPECT_TRUE(CheckEqualGraph(func_graph, g_after));
|
||||
}
|
||||
|
||||
/// Feature: Dynamic shape.
|
||||
/// Description: Test Tuple to Tensor type transforming pass.
|
||||
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python.
|
||||
|
@ -377,11 +461,52 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_to_tensor_transform) {
|
|||
optimizer->AddPassManager(pm);
|
||||
optimizer->Optimize(func_graph);
|
||||
|
||||
auto tuple_to_tensor = func_graph->get_return()->input(1)->cast<CNodePtr>()->input(2)->cast<CNodePtr>();
|
||||
ASSERT_TRUE(IsPrimitiveCNode(tuple_to_tensor, prim::kPrimTupleToTensor));
|
||||
auto dtype = common::AnfAlgo::GetNodeAttr<TypePtr>(tuple_to_tensor, kAttrDType);
|
||||
ASSERT_TRUE(dtype->type_id() == kNumberTypeFloat32);
|
||||
ASSERT_TRUE(tuple_to_tensor->abstract()->isa<abstract::AbstractTensor>());
|
||||
auto obj_type = AnfAlgo::GetOutputKernelObjectType(tuple_to_tensor, 0);
|
||||
ASSERT_TRUE(obj_type == KernelObjectType::TENSOR);
|
||||
|
||||
FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_tuple_to_tensor_transform", "after");
|
||||
ASSERT_TRUE(g_after != nullptr);
|
||||
EXPECT_TRUE(CheckEqualGraph(func_graph, g_after));
|
||||
}
|
||||
|
||||
/// Feature: Dynamic shape.
|
||||
/// Description: Test Scalar to Tensor type transforming pass.
|
||||
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python.
|
||||
TEST_F(TestInsertTypeTransformOp, DISABLED_test_scalar_to_tensor_transform) {
|
||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_scalar_to_tensor_transform", "before");
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
auto x_abstract = std::make_shared<abstract::AbstractScalar>(3);
|
||||
auto y_abstract = std::make_shared<abstract::AbstractScalar>(4);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, y_abstract};
|
||||
auto func_graph = GetFuncGraph(g, args_spec_list);
|
||||
ASSERT_TRUE(func_graph != nullptr);
|
||||
AnfNodePtr add;
|
||||
SetScalarToTensorKernelBuildInfo(func_graph, &add);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::InsertTypeTransformOp>());
|
||||
optimizer->AddPassManager(pm);
|
||||
optimizer->Optimize(func_graph);
|
||||
|
||||
auto scalar_to_tensor1 = func_graph->get_return()->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>();
|
||||
ASSERT_TRUE(IsPrimitiveCNode(scalar_to_tensor1, prim::kPrimScalarToTensor));
|
||||
auto dtype = common::AnfAlgo::GetNodeAttr<TypePtr>(scalar_to_tensor1, kAttrDType);
|
||||
ASSERT_TRUE(dtype->type_id() == kNumberTypeInt64);
|
||||
ASSERT_TRUE(scalar_to_tensor1->abstract()->isa<abstract::AbstractTensor>());
|
||||
auto obj_type = AnfAlgo::GetOutputKernelObjectType(scalar_to_tensor1, 0);
|
||||
ASSERT_TRUE(obj_type == KernelObjectType::TENSOR);
|
||||
|
||||
FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_scalar_to_tensor_transform", "after");
|
||||
ASSERT_TRUE(g_after != nullptr);
|
||||
EXPECT_TRUE(CheckEqualGraph(func_graph, g_after));
|
||||
}
|
||||
|
||||
/// Feature: Dynamic shape.
|
||||
/// Description: Test Tensor to Tuple type transforming pass.
|
||||
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python.
|
||||
|
@ -404,9 +529,46 @@ TEST_F(TestInsertTypeTransformOp, test_tensor_to_tuple_transform) {
|
|||
optimizer->AddPassManager(pm);
|
||||
optimizer->Optimize(func_graph);
|
||||
|
||||
auto tensor_to_tuple1 = func_graph->get_return()->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>();
|
||||
ASSERT_TRUE(IsPrimitiveCNode(tensor_to_tuple1, prim::kPrimTensorToTuple));
|
||||
ASSERT_TRUE(tensor_to_tuple1->abstract()->isa<abstract::AbstractTuple>());
|
||||
auto obj_type = AnfAlgo::GetOutputKernelObjectType(tensor_to_tuple1, 0);
|
||||
ASSERT_TRUE(obj_type == KernelObjectType::TUPLE);
|
||||
|
||||
FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_tensor_to_tuple_transform", "after");
|
||||
ASSERT_TRUE(g_after != nullptr);
|
||||
EXPECT_TRUE(CheckEqualGraph(func_graph, g_after));
|
||||
}
|
||||
|
||||
/// Feature: Dynamic shape.
|
||||
/// Description: Test Tensor to Scalar type transforming pass.
|
||||
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python.
|
||||
TEST_F(TestInsertTypeTransformOp, DISABLED_test_tensor_to_scalar_transform) {
|
||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tensor_to_scalar_transform", "before");
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
std::vector<int64_t> shp_x{4};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
AbstractBasePtrList args_spec_list{x_abstract};
|
||||
auto func_graph = GetFuncGraph(g, args_spec_list);
|
||||
ASSERT_TRUE(func_graph != nullptr);
|
||||
AnfNodePtr scalar_to_tensor;
|
||||
SetTensorToScalarKernelBuildInfo(func_graph, &scalar_to_tensor);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::InsertTypeTransformOp>());
|
||||
optimizer->AddPassManager(pm);
|
||||
optimizer->Optimize(func_graph);
|
||||
|
||||
auto tensor_to_scalar = func_graph->get_return()->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>();
|
||||
ASSERT_TRUE(IsPrimitiveCNode(tensor_to_scalar, prim::kPrimTensorToScalar));
|
||||
ASSERT_TRUE(tensor_to_scalar->abstract()->isa<abstract::AbstractScalar>());
|
||||
auto obj_type = AnfAlgo::GetOutputKernelObjectType(tensor_to_scalar, 0);
|
||||
ASSERT_TRUE(obj_type == KernelObjectType::SCALAR);
|
||||
|
||||
FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_tensor_to_scalar_transform", "after");
|
||||
ASSERT_TRUE(g_after != nullptr);
|
||||
EXPECT_TRUE(CheckEqualGraph(func_graph, g_after));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -127,6 +127,31 @@ def test_tuple_unfold_to_tensor_transform(tag):
|
|||
return fns[tag]
|
||||
|
||||
|
||||
def test_tuple_to_tuple_unfold_transform(tag):
|
||||
"""
|
||||
Feature: Dynamic shape.
|
||||
Description: Test Tuple to TupleUnfold transforming pass.
|
||||
Expectation: The 'after' graph is identical to the graph after this pass.
|
||||
"""
|
||||
fns = FnDict()
|
||||
shape = P.Shape()
|
||||
real_tuple_get_item = Primitive('RealTupleGetItem')
|
||||
|
||||
@fns
|
||||
def before(x):
|
||||
res = shape(x)
|
||||
res = res[0]
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(x):
|
||||
res = shape(x, y)
|
||||
res = real_tuple_get_item(res, 0)
|
||||
return res
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_tuple_to_tensor_transform(tag):
|
||||
"""
|
||||
Feature: Dynamic shape.
|
||||
|
@ -151,6 +176,31 @@ def test_tuple_to_tensor_transform(tag):
|
|||
return fns[tag]
|
||||
|
||||
|
||||
def test_scalar_to_tensor_transform(tag):
|
||||
"""
|
||||
Feature: Dynamic shape.
|
||||
Description: Test Scalar to Tensor transforming pass.
|
||||
Expectation: The 'after' graph is identical to the graph after this pass.
|
||||
"""
|
||||
fns = FnDict()
|
||||
add = P.Add()
|
||||
scalar_to_tensor = Primitive('ScalarToTensor')
|
||||
|
||||
@fns
|
||||
def before(x, y):
|
||||
res = add(x, y)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(x, y):
|
||||
x = scalar_to_tensor(x)
|
||||
y = scalar_to_tensor(y)
|
||||
res = add(x, y)
|
||||
return res
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_tensor_to_tuple_transform(tag):
|
||||
"""
|
||||
Feature: Dynamic shape.
|
||||
|
@ -175,3 +225,27 @@ def test_tensor_to_tuple_transform(tag):
|
|||
return res
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_tensor_to_scalar_transform(tag):
|
||||
"""
|
||||
Feature: Dynamic shape.
|
||||
Description: Test Tensor to Scalar transforming pass.
|
||||
Expectation: The 'after' graph is identical to the graph after this pass.
|
||||
"""
|
||||
fns = FnDict()
|
||||
scalar_to_tensor = P.ScalarToTensor()
|
||||
tensor_to_scalar = Primitive('TensorToScalar')
|
||||
|
||||
@fns
|
||||
def before(x):
|
||||
res = scalar_to_tensor(x)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(x):
|
||||
res = tensor_to_scalar(x)
|
||||
res = scalar_to_tensor(res)
|
||||
return res
|
||||
|
||||
return fns[tag]
|
||||
|
|
Loading…
Reference in New Issue