forked from mindspore-Ecosystem/mindspore
Reset RealMakeTuple objtypes
This commit is contained in:
parent
f40e159dd1
commit
50f90fb8df
|
@ -139,6 +139,17 @@ AnfNodePtr CreateRealMakeTupleByMakeTuple(const FuncGraphPtr &func_graph, const
|
||||||
real_make_tuple->set_abstract(make_tuple_node->abstract());
|
real_make_tuple->set_abstract(make_tuple_node->abstract());
|
||||||
|
|
||||||
SetKernelInfoForNewCNode(real_make_tuple);
|
SetKernelInfoForNewCNode(real_make_tuple);
|
||||||
|
|
||||||
|
// RealMakeTuple's inputs must be scalar. To avoid failing to select kernel, we must override RealMakeTuple's
|
||||||
|
// KernelObjectTypes, which is created from MakeTuple.
|
||||||
|
KernelBuildInfoPtr real_make_tuple_build_info = AnfAlgo::GetSelectKernelBuildInfo(real_make_tuple);
|
||||||
|
MS_EXCEPTION_IF_NULL(real_make_tuple_build_info);
|
||||||
|
auto inputs_obj_types = real_make_tuple_build_info->GetAllInputKernelObjectTypes();
|
||||||
|
auto new_obj_types = inputs_obj_types;
|
||||||
|
std::transform(new_obj_types.begin(), new_obj_types.end(), new_obj_types.begin(),
|
||||||
|
[](const auto &obj_type) { return KernelObjectType::SCALAR; });
|
||||||
|
real_make_tuple_build_info->SetInputsKernelObjectType(new_obj_types);
|
||||||
|
MS_LOG(DEBUG) << "Override RealMakeTuple input kernel object types from " << inputs_obj_types << " " << new_obj_types;
|
||||||
return real_make_tuple;
|
return real_make_tuple;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -45,8 +45,8 @@ class TestInsertTypeTransformOp : public BackendCommon {
|
||||||
public:
|
public:
|
||||||
void SetTupleUnfoldToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *split1_ptr, AnfNodePtr *addn1_ptr,
|
void SetTupleUnfoldToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *split1_ptr, AnfNodePtr *addn1_ptr,
|
||||||
AnfNodePtr *split2_ptr, AnfNodePtr *addn2_ptr);
|
AnfNodePtr *split2_ptr, AnfNodePtr *addn2_ptr);
|
||||||
void SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr, AnfNodePtr *split_ptr,
|
void SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr,
|
||||||
AnfNodePtr *seq_add1_ptr, AnfNodePtr *seq_add2_ptr);
|
AnfNodePtr *seq_add1_ptr);
|
||||||
void SetTupleUnfoldToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple, AnfNodePtr *reshape_ptr);
|
void SetTupleUnfoldToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple, AnfNodePtr *reshape_ptr);
|
||||||
void SetTupleToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *shape_ptr);
|
void SetTupleToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *shape_ptr);
|
||||||
void SetTupleToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *reshape_ptr);
|
void SetTupleToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *reshape_ptr);
|
||||||
|
@ -126,23 +126,10 @@ void TestInsertTypeTransformOp::SetTupleUnfoldToTupleUnfoldKernelBuildInfo(
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestInsertTypeTransformOp::SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr,
|
void TestInsertTypeTransformOp::SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr,
|
||||||
AnfNodePtr *split_ptr, AnfNodePtr *seq_add1_ptr,
|
AnfNodePtr *seq_add1_ptr) {
|
||||||
AnfNodePtr *seq_add2_ptr) {
|
|
||||||
auto ret = g->get_return();
|
auto ret = g->get_return();
|
||||||
EXPECT_NE(ret->input(1), nullptr);
|
EXPECT_NE(ret->input(1), nullptr);
|
||||||
auto seq_add2 = ret->input(1)->cast<CNodePtr>();
|
auto seq_add1 = ret->input(1)->cast<CNodePtr>();
|
||||||
*seq_add2_ptr = seq_add2;
|
|
||||||
SetKernelBuildInfo(seq_add2, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"},
|
|
||||||
{kNumberTypeFloat32}, {KernelObjectType::TUPLE, KernelObjectType::TUPLE},
|
|
||||||
{KernelObjectType::TENSOR});
|
|
||||||
|
|
||||||
auto split = seq_add2->input(1)->cast<CNodePtr>();
|
|
||||||
*split_ptr = split;
|
|
||||||
MS_LOG(INFO) << "split is " << split->fullname_with_scope();
|
|
||||||
SetKernelBuildInfo(split, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32},
|
|
||||||
{KernelObjectType::TENSOR}, {KernelObjectType::TUPLE_UNFOLD});
|
|
||||||
|
|
||||||
auto seq_add1 = split->input(1)->cast<CNodePtr>();
|
|
||||||
*seq_add1_ptr = seq_add1;
|
*seq_add1_ptr = seq_add1;
|
||||||
SetKernelBuildInfo(seq_add1, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"},
|
SetKernelBuildInfo(seq_add1, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"},
|
||||||
{kNumberTypeFloat32}, {KernelObjectType::TUPLE, KernelObjectType::TUPLE},
|
{kNumberTypeFloat32}, {KernelObjectType::TUPLE, KernelObjectType::TUPLE},
|
||||||
|
@ -336,9 +323,9 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_unfold_transform) {
|
||||||
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) {
|
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) {
|
||||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_transform", "before");
|
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_transform", "before");
|
||||||
ASSERT_TRUE(g != nullptr);
|
ASSERT_TRUE(g != nullptr);
|
||||||
std::vector<int64_t> shp_1{2, 4};
|
std::vector<int64_t> shp_1{};
|
||||||
auto abstract_1 = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_1);
|
auto abstract_1 = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_1);
|
||||||
std::vector<int64_t> shp_2{2, 4};
|
std::vector<int64_t> shp_2{};
|
||||||
auto abstract_2 = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_2);
|
auto abstract_2 = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_2);
|
||||||
std::vector<int64_t> shp_x{1, 3};
|
std::vector<int64_t> shp_x{1, 3};
|
||||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||||
|
@ -347,8 +334,8 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) {
|
||||||
AbstractBasePtrList args_spec_list{abstract_1, abstract_2, x_tuple_abs};
|
AbstractBasePtrList args_spec_list{abstract_1, abstract_2, x_tuple_abs};
|
||||||
auto func_graph = GetFuncGraph(g, args_spec_list);
|
auto func_graph = GetFuncGraph(g, args_spec_list);
|
||||||
ASSERT_TRUE(func_graph != nullptr);
|
ASSERT_TRUE(func_graph != nullptr);
|
||||||
AnfNodePtr make_tuple, split, seq_add1, seq_add2;
|
AnfNodePtr make_tuple, seq_add1;
|
||||||
SetTupleUnfoldToTupleKernelBuildInfo(func_graph, &make_tuple, &split, &seq_add1, &seq_add2);
|
SetTupleUnfoldToTupleKernelBuildInfo(func_graph, &make_tuple, &seq_add1);
|
||||||
|
|
||||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
auto pm = std::make_shared<opt::PassManager>();
|
auto pm = std::make_shared<opt::PassManager>();
|
||||||
|
@ -374,9 +361,9 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) {
|
||||||
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tensor_transform) {
|
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tensor_transform) {
|
||||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tensor_transform", "before");
|
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tensor_transform", "before");
|
||||||
ASSERT_TRUE(g != nullptr);
|
ASSERT_TRUE(g != nullptr);
|
||||||
std::vector<int64_t> shp_x{4};
|
std::vector<int64_t> shp_x{};
|
||||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||||
std::vector<int64_t> shp_y{2};
|
std::vector<int64_t> shp_y{};
|
||||||
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
|
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
|
||||||
std::vector<int64_t> shp_z{2, 4};
|
std::vector<int64_t> shp_z{2, 4};
|
||||||
auto z_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_z);
|
auto z_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_z);
|
||||||
|
|
|
@ -18,6 +18,7 @@ from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops import Primitive
|
from mindspore.ops import Primitive
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import _constants as Constants
|
from mindspore.ops import _constants as Constants
|
||||||
|
from mindspore.ops.operations import _sequence_ops as seq
|
||||||
|
|
||||||
tuple_get_item = Primitive(Constants.kTupleGetItem)
|
tuple_get_item = Primitive(Constants.kTupleGetItem)
|
||||||
|
|
||||||
|
@ -76,25 +77,22 @@ def test_tuple_unfold_to_tuple_transform(tag):
|
||||||
"""
|
"""
|
||||||
fns = FnDict()
|
fns = FnDict()
|
||||||
# Need to change AddN to SequenceAdd in later version. This case is just used to cover this pattern.
|
# Need to change AddN to SequenceAdd in later version. This case is just used to cover this pattern.
|
||||||
seq_add1 = P.AddN()
|
seq_add1 = seq.SequenceAdd()
|
||||||
seq_add2 = P.AddN()
|
tensor_to_scalar = Primitive('TensorToScalar')
|
||||||
real_make_tuple = Primitive('RealMakeTuple')
|
real_make_tuple = Primitive('RealMakeTuple')
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
def before(input_1, input_2, x):
|
def before(input_1, input_2, x):
|
||||||
res = make_tuple(input_1, input_2)
|
res = make_tuple(input_1, input_2)
|
||||||
res = seq_add1(res, x)
|
res = seq_add1(res, x)
|
||||||
res = split1(res)
|
|
||||||
res = seq_add2(res, x)
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
def after(input_1, input_2, x):
|
def after(input_1, input_2, x):
|
||||||
|
input_1 = tensor_to_scalar(input_1)
|
||||||
|
input_2 = tensor_to_scalar(input_2)
|
||||||
res = real_make_tuple(input_1, input_2)
|
res = real_make_tuple(input_1, input_2)
|
||||||
res = seq_add1(res, x)
|
res = seq_add1(res, x)
|
||||||
res = split1(res)
|
|
||||||
res = real_make_tuple(tuple_get_item(res, 0), tuple_get_item(res, 1))
|
|
||||||
res = seq_add2(res, x)
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
return fns[tag]
|
return fns[tag]
|
||||||
|
@ -110,6 +108,7 @@ def test_tuple_unfold_to_tensor_transform(tag):
|
||||||
reshape = P.Reshape()
|
reshape = P.Reshape()
|
||||||
real_make_tuple = Primitive('RealMakeTuple')
|
real_make_tuple = Primitive('RealMakeTuple')
|
||||||
tuple_to_tensor = Primitive('TupleToTensor')
|
tuple_to_tensor = Primitive('TupleToTensor')
|
||||||
|
tensor_to_scalar = Primitive('TensorToScalar')
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
def before(input_1, input_2, input_3):
|
def before(input_1, input_2, input_3):
|
||||||
|
@ -119,6 +118,8 @@ def test_tuple_unfold_to_tensor_transform(tag):
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
def after(input_1, input_2, input_3):
|
def after(input_1, input_2, input_3):
|
||||||
|
input_1 = tensor_to_scalar(input_1)
|
||||||
|
input_2 = tensor_to_scalar(input_2)
|
||||||
res = real_make_tuple(input_1, input_2)
|
res = real_make_tuple(input_1, input_2)
|
||||||
res = tuple_to_tensor(res)
|
res = tuple_to_tensor(res)
|
||||||
res = reshape(input_3, res)
|
res = reshape(input_3, res)
|
||||||
|
|
Loading…
Reference in New Issue