forked from mindspore-Ecosystem/mindspore
!48538 Reset RealMakeTuple objtypes
Merge pull request !48538 from ZPaC/set-real-make-tuple-scalar-input
This commit is contained in:
commit
f0a48a7b6f
|
@ -137,6 +137,17 @@ AnfNodePtr CreateRealMakeTupleByMakeTuple(const FuncGraphPtr &func_graph, const
|
|||
real_make_tuple->set_abstract(make_tuple_node->abstract());
|
||||
|
||||
SetKernelInfoForNewCNode(real_make_tuple);
|
||||
|
||||
// RealMakeTuple's inputs must be scalar. To avoid failing to select kernel, we must override RealMakeTuple's
|
||||
// KernelObjectTypes, which is created from MakeTuple.
|
||||
KernelBuildInfoPtr real_make_tuple_build_info = AnfAlgo::GetSelectKernelBuildInfo(real_make_tuple);
|
||||
MS_EXCEPTION_IF_NULL(real_make_tuple_build_info);
|
||||
auto inputs_obj_types = real_make_tuple_build_info->GetAllInputKernelObjectTypes();
|
||||
auto new_obj_types = inputs_obj_types;
|
||||
std::transform(new_obj_types.begin(), new_obj_types.end(), new_obj_types.begin(),
|
||||
[](const auto &obj_type) { return KernelObjectType::SCALAR; });
|
||||
real_make_tuple_build_info->SetInputsKernelObjectType(new_obj_types);
|
||||
MS_LOG(DEBUG) << "Override RealMakeTuple input kernel object types from " << inputs_obj_types << " " << new_obj_types;
|
||||
return real_make_tuple;
|
||||
}
|
||||
|
||||
|
|
|
@ -45,8 +45,8 @@ class TestInsertTypeTransformOp : public BackendCommon {
|
|||
public:
|
||||
void SetTupleUnfoldToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *split1_ptr, AnfNodePtr *addn1_ptr,
|
||||
AnfNodePtr *split2_ptr, AnfNodePtr *addn2_ptr);
|
||||
void SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr, AnfNodePtr *split_ptr,
|
||||
AnfNodePtr *seq_add1_ptr, AnfNodePtr *seq_add2_ptr);
|
||||
void SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr,
|
||||
AnfNodePtr *seq_add1_ptr);
|
||||
void SetTupleUnfoldToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple, AnfNodePtr *reshape_ptr);
|
||||
void SetTupleToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *shape_ptr);
|
||||
void SetTupleToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *reshape_ptr);
|
||||
|
@ -126,23 +126,10 @@ void TestInsertTypeTransformOp::SetTupleUnfoldToTupleUnfoldKernelBuildInfo(
|
|||
}
|
||||
|
||||
void TestInsertTypeTransformOp::SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr,
|
||||
AnfNodePtr *split_ptr, AnfNodePtr *seq_add1_ptr,
|
||||
AnfNodePtr *seq_add2_ptr) {
|
||||
AnfNodePtr *seq_add1_ptr) {
|
||||
auto ret = g->get_return();
|
||||
EXPECT_NE(ret->input(1), nullptr);
|
||||
auto seq_add2 = ret->input(1)->cast<CNodePtr>();
|
||||
*seq_add2_ptr = seq_add2;
|
||||
SetKernelBuildInfo(seq_add2, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"},
|
||||
{kNumberTypeFloat32}, {KernelObjectType::TUPLE, KernelObjectType::TUPLE},
|
||||
{KernelObjectType::TENSOR});
|
||||
|
||||
auto split = seq_add2->input(1)->cast<CNodePtr>();
|
||||
*split_ptr = split;
|
||||
MS_LOG(INFO) << "split is " << split->fullname_with_scope();
|
||||
SetKernelBuildInfo(split, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32},
|
||||
{KernelObjectType::TENSOR}, {KernelObjectType::TUPLE_UNFOLD});
|
||||
|
||||
auto seq_add1 = split->input(1)->cast<CNodePtr>();
|
||||
auto seq_add1 = ret->input(1)->cast<CNodePtr>();
|
||||
*seq_add1_ptr = seq_add1;
|
||||
SetKernelBuildInfo(seq_add1, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"},
|
||||
{kNumberTypeFloat32}, {KernelObjectType::TUPLE, KernelObjectType::TUPLE},
|
||||
|
@ -336,9 +323,9 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_unfold_transform) {
|
|||
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) {
|
||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_transform", "before");
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
std::vector<int64_t> shp_1{2, 4};
|
||||
std::vector<int64_t> shp_1{};
|
||||
auto abstract_1 = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_1);
|
||||
std::vector<int64_t> shp_2{2, 4};
|
||||
std::vector<int64_t> shp_2{};
|
||||
auto abstract_2 = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_2);
|
||||
std::vector<int64_t> shp_x{1, 3};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
|
@ -347,8 +334,8 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) {
|
|||
AbstractBasePtrList args_spec_list{abstract_1, abstract_2, x_tuple_abs};
|
||||
auto func_graph = GetFuncGraph(g, args_spec_list);
|
||||
ASSERT_TRUE(func_graph != nullptr);
|
||||
AnfNodePtr make_tuple, split, seq_add1, seq_add2;
|
||||
SetTupleUnfoldToTupleKernelBuildInfo(func_graph, &make_tuple, &split, &seq_add1, &seq_add2);
|
||||
AnfNodePtr make_tuple, seq_add1;
|
||||
SetTupleUnfoldToTupleKernelBuildInfo(func_graph, &make_tuple, &seq_add1);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
|
@ -374,9 +361,9 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) {
|
|||
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tensor_transform) {
|
||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tensor_transform", "before");
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
std::vector<int64_t> shp_x{4};
|
||||
std::vector<int64_t> shp_x{};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
std::vector<int64_t> shp_y{2};
|
||||
std::vector<int64_t> shp_y{};
|
||||
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
|
||||
std::vector<int64_t> shp_z{2, 4};
|
||||
auto z_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_z);
|
||||
|
|
|
@ -18,6 +18,7 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.ops import Primitive
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import _constants as Constants
|
||||
from mindspore.ops.operations import _sequence_ops as seq
|
||||
|
||||
tuple_get_item = Primitive(Constants.kTupleGetItem)
|
||||
|
||||
|
@ -76,25 +77,22 @@ def test_tuple_unfold_to_tuple_transform(tag):
|
|||
"""
|
||||
fns = FnDict()
|
||||
# Need to change AddN to SequenceAdd in later version. This case is just used to cover this pattern.
|
||||
seq_add1 = P.AddN()
|
||||
seq_add2 = P.AddN()
|
||||
seq_add1 = seq.SequenceAdd()
|
||||
tensor_to_scalar = Primitive('TensorToScalar')
|
||||
real_make_tuple = Primitive('RealMakeTuple')
|
||||
|
||||
@fns
|
||||
def before(input_1, input_2, x):
|
||||
res = make_tuple(input_1, input_2)
|
||||
res = seq_add1(res, x)
|
||||
res = split1(res)
|
||||
res = seq_add2(res, x)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(input_1, input_2, x):
|
||||
input_1 = tensor_to_scalar(input_1)
|
||||
input_2 = tensor_to_scalar(input_2)
|
||||
res = real_make_tuple(input_1, input_2)
|
||||
res = seq_add1(res, x)
|
||||
res = split1(res)
|
||||
res = real_make_tuple(tuple_get_item(res, 0), tuple_get_item(res, 1))
|
||||
res = seq_add2(res, x)
|
||||
return res
|
||||
|
||||
return fns[tag]
|
||||
|
@ -110,6 +108,7 @@ def test_tuple_unfold_to_tensor_transform(tag):
|
|||
reshape = P.Reshape()
|
||||
real_make_tuple = Primitive('RealMakeTuple')
|
||||
tuple_to_tensor = Primitive('TupleToTensor')
|
||||
tensor_to_scalar = Primitive('TensorToScalar')
|
||||
|
||||
@fns
|
||||
def before(input_1, input_2, input_3):
|
||||
|
@ -119,6 +118,8 @@ def test_tuple_unfold_to_tensor_transform(tag):
|
|||
|
||||
@fns
|
||||
def after(input_1, input_2, input_3):
|
||||
input_1 = tensor_to_scalar(input_1)
|
||||
input_2 = tensor_to_scalar(input_2)
|
||||
res = real_make_tuple(input_1, input_2)
|
||||
res = tuple_to_tensor(res)
|
||||
res = reshape(input_3, res)
|
||||
|
|
Loading…
Reference in New Issue