Reset RealMakeTuple objtypes

This commit is contained in:
ZPaC 2023-02-07 19:27:32 +08:00
parent f40e159dd1
commit 50f90fb8df
3 changed files with 29 additions and 30 deletions

View File

@ -139,6 +139,17 @@ AnfNodePtr CreateRealMakeTupleByMakeTuple(const FuncGraphPtr &func_graph, const
real_make_tuple->set_abstract(make_tuple_node->abstract());
SetKernelInfoForNewCNode(real_make_tuple);
// RealMakeTuple's inputs must be scalar. To avoid failing to select kernel, we must override RealMakeTuple's
// KernelObjectTypes, which is created from MakeTuple.
KernelBuildInfoPtr real_make_tuple_build_info = AnfAlgo::GetSelectKernelBuildInfo(real_make_tuple);
MS_EXCEPTION_IF_NULL(real_make_tuple_build_info);
auto inputs_obj_types = real_make_tuple_build_info->GetAllInputKernelObjectTypes();
auto new_obj_types = inputs_obj_types;
std::transform(new_obj_types.begin(), new_obj_types.end(), new_obj_types.begin(),
[](const auto &obj_type) { return KernelObjectType::SCALAR; });
real_make_tuple_build_info->SetInputsKernelObjectType(new_obj_types);
MS_LOG(DEBUG) << "Override RealMakeTuple input kernel object types from " << inputs_obj_types << " " << new_obj_types;
return real_make_tuple;
}

View File

@ -45,8 +45,8 @@ class TestInsertTypeTransformOp : public BackendCommon {
public:
void SetTupleUnfoldToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *split1_ptr, AnfNodePtr *addn1_ptr,
AnfNodePtr *split2_ptr, AnfNodePtr *addn2_ptr);
void SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr, AnfNodePtr *split_ptr,
AnfNodePtr *seq_add1_ptr, AnfNodePtr *seq_add2_ptr);
void SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr,
AnfNodePtr *seq_add1_ptr);
void SetTupleUnfoldToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple, AnfNodePtr *reshape_ptr);
void SetTupleToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *shape_ptr);
void SetTupleToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *reshape_ptr);
@ -126,23 +126,10 @@ void TestInsertTypeTransformOp::SetTupleUnfoldToTupleUnfoldKernelBuildInfo(
}
void TestInsertTypeTransformOp::SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr,
AnfNodePtr *split_ptr, AnfNodePtr *seq_add1_ptr,
AnfNodePtr *seq_add2_ptr) {
AnfNodePtr *seq_add1_ptr) {
auto ret = g->get_return();
EXPECT_NE(ret->input(1), nullptr);
auto seq_add2 = ret->input(1)->cast<CNodePtr>();
*seq_add2_ptr = seq_add2;
SetKernelBuildInfo(seq_add2, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"},
{kNumberTypeFloat32}, {KernelObjectType::TUPLE, KernelObjectType::TUPLE},
{KernelObjectType::TENSOR});
auto split = seq_add2->input(1)->cast<CNodePtr>();
*split_ptr = split;
MS_LOG(INFO) << "split is " << split->fullname_with_scope();
SetKernelBuildInfo(split, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32},
{KernelObjectType::TENSOR}, {KernelObjectType::TUPLE_UNFOLD});
auto seq_add1 = split->input(1)->cast<CNodePtr>();
auto seq_add1 = ret->input(1)->cast<CNodePtr>();
*seq_add1_ptr = seq_add1;
SetKernelBuildInfo(seq_add1, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"},
{kNumberTypeFloat32}, {KernelObjectType::TUPLE, KernelObjectType::TUPLE},
@ -336,9 +323,9 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_unfold_transform) {
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) {
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_transform", "before");
ASSERT_TRUE(g != nullptr);
std::vector<int64_t> shp_1{2, 4};
std::vector<int64_t> shp_1{};
auto abstract_1 = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_1);
std::vector<int64_t> shp_2{2, 4};
std::vector<int64_t> shp_2{};
auto abstract_2 = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_2);
std::vector<int64_t> shp_x{1, 3};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
@ -347,8 +334,8 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) {
AbstractBasePtrList args_spec_list{abstract_1, abstract_2, x_tuple_abs};
auto func_graph = GetFuncGraph(g, args_spec_list);
ASSERT_TRUE(func_graph != nullptr);
AnfNodePtr make_tuple, split, seq_add1, seq_add2;
SetTupleUnfoldToTupleKernelBuildInfo(func_graph, &make_tuple, &split, &seq_add1, &seq_add2);
AnfNodePtr make_tuple, seq_add1;
SetTupleUnfoldToTupleKernelBuildInfo(func_graph, &make_tuple, &seq_add1);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
@ -374,9 +361,9 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) {
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tensor_transform) {
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tensor_transform", "before");
ASSERT_TRUE(g != nullptr);
std::vector<int64_t> shp_x{4};
std::vector<int64_t> shp_x{};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
std::vector<int64_t> shp_y{2};
std::vector<int64_t> shp_y{};
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
std::vector<int64_t> shp_z{2, 4};
auto z_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_z);

View File

@ -18,6 +18,7 @@ from mindspore.common.tensor import Tensor
from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
from mindspore.ops.operations import _sequence_ops as seq
tuple_get_item = Primitive(Constants.kTupleGetItem)
@ -76,25 +77,22 @@ def test_tuple_unfold_to_tuple_transform(tag):
"""
fns = FnDict()
# Need to change AddN to SequenceAdd in later version. This case is just used to cover this pattern.
seq_add1 = P.AddN()
seq_add2 = P.AddN()
seq_add1 = seq.SequenceAdd()
tensor_to_scalar = Primitive('TensorToScalar')
real_make_tuple = Primitive('RealMakeTuple')
@fns
def before(input_1, input_2, x):
res = make_tuple(input_1, input_2)
res = seq_add1(res, x)
res = split1(res)
res = seq_add2(res, x)
return res
@fns
def after(input_1, input_2, x):
input_1 = tensor_to_scalar(input_1)
input_2 = tensor_to_scalar(input_2)
res = real_make_tuple(input_1, input_2)
res = seq_add1(res, x)
res = split1(res)
res = real_make_tuple(tuple_get_item(res, 0), tuple_get_item(res, 1))
res = seq_add2(res, x)
return res
return fns[tag]
@ -110,6 +108,7 @@ def test_tuple_unfold_to_tensor_transform(tag):
reshape = P.Reshape()
real_make_tuple = Primitive('RealMakeTuple')
tuple_to_tensor = Primitive('TupleToTensor')
tensor_to_scalar = Primitive('TensorToScalar')
@fns
def before(input_1, input_2, input_3):
@ -119,6 +118,8 @@ def test_tuple_unfold_to_tensor_transform(tag):
@fns
def after(input_1, input_2, input_3):
input_1 = tensor_to_scalar(input_1)
input_2 = tensor_to_scalar(input_2)
res = real_make_tuple(input_1, input_2)
res = tuple_to_tensor(res)
res = reshape(input_3, res)