forked from mindspore-Ecosystem/mindspore
Update tuple add name
This commit is contained in:
parent
05d1a1df96
commit
23c0d7de3c
|
@ -441,14 +441,16 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTensor(const FuncGraph
|
|||
MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString();
|
||||
tuple_to_tensor->set_abstract(abs);
|
||||
|
||||
// Need to check if input is MakeTuple node, the format and device type could be abtained or not.
|
||||
SetKernelInfoForNewCNode(tuple_to_tensor);
|
||||
// Data type of the tensor should be set as an attr of TupleToTensor op.
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
|
||||
// Attr name is to be confirmed.
|
||||
common::AnfAlgo::SetNodeAttr("tensor_type", MakeValue(static_cast<int64_t>(data_type)), tuple_to_tensor);
|
||||
|
||||
SetKernelInfoForNewCNode(tuple_to_tensor);
|
||||
// Set object type to TUPLE for TupleUnfoldToTuple pattern to be matched.
|
||||
KernelBuildInfoPtr tuple_to_tensor_build_info = AnfAlgo::GetSelectKernelBuildInfo(tuple_to_tensor);
|
||||
MS_EXCEPTION_IF_NULL(tuple_to_tensor_build_info);
|
||||
tuple_to_tensor_build_info->SetInputsKernelObjectType({KernelObjectType::TUPLE});
|
||||
|
||||
return {tuple_to_tensor};
|
||||
}
|
||||
|
||||
|
@ -470,6 +472,12 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTensor(const FuncGraphPtr &f
|
|||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString();
|
||||
tuple_to_tensor->set_abstract(abs);
|
||||
|
||||
// Data type of the tensor should be set as an attr of TupleToTensor op.
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index);
|
||||
// Attr name is to be confirmed.
|
||||
common::AnfAlgo::SetNodeAttr("tensor_type", MakeValue(static_cast<int64_t>(data_type)), tuple_to_tensor);
|
||||
|
||||
SetKernelInfoForNewCNode(tuple_to_tensor);
|
||||
return {tuple_to_tensor};
|
||||
}
|
||||
|
|
|
@ -461,7 +461,6 @@ constexpr auto kRealTupleGetItem = "RealTupleGetItem";
|
|||
constexpr auto kRealMakeList = "RealMakeList";
|
||||
constexpr auto kRealListGetItem = "RealListGetItem";
|
||||
constexpr auto kTupleSetItem = "TupleSetItem";
|
||||
constexpr auto kTupleAdd = "TupleAdd";
|
||||
|
||||
GVAR_DEF(PrimitivePtr, kPrimExtractGlimpse, std::make_shared<Primitive>(kExtractGlimpse));
|
||||
//
|
||||
|
@ -1738,7 +1737,6 @@ GVAR_DEF(PrimitivePtr, kPrimOCRDetectionPreHandle, std::make_shared<Primitive>("
|
|||
GVAR_DEF(PrimitivePtr, kPrimTupleToTensor, std::make_shared<Primitive>(kTupleToTensor));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorToTuple, std::make_shared<Primitive>(kTensorToTuple));
|
||||
GVAR_DEF(PrimitivePtr, kPrimRealMakeTuple, std::make_shared<Primitive>(kRealMakeTuple));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTupleAdd, std::make_shared<Primitive>(kTupleAdd));
|
||||
|
||||
// AdamApplyOne
|
||||
GVAR_DEF(PrimitivePtr, kPrimAdamApplyOne, std::make_shared<Primitive>("AdamApplyOne"));
|
||||
|
|
|
@ -44,10 +44,10 @@ class TestInsertTypeTransformOp : public BackendCommon {
|
|||
void SetTupleUnfoldToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *split1_ptr, AnfNodePtr *addn1_ptr,
|
||||
AnfNodePtr *split2_ptr, AnfNodePtr *addn2_ptr);
|
||||
void SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr, AnfNodePtr *split_ptr,
|
||||
AnfNodePtr *tuple_add1_ptr, AnfNodePtr *tuple_add2_ptr);
|
||||
AnfNodePtr *seq_add1_ptr, AnfNodePtr *seq_add2_ptr);
|
||||
void SetTupleUnfoldToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple, AnfNodePtr *reshape);
|
||||
void SetTupleToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *reshape_ptr);
|
||||
void SetTensorToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *tuple_add);
|
||||
void SetTensorToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *seq_add);
|
||||
|
||||
void SetKernelBuildInfo(const AnfNodePtr &node, const std::vector<std::string> &input_formats,
|
||||
const std::vector<TypeId> &input_types, const std::vector<std::string> &output_formats,
|
||||
|
@ -121,27 +121,33 @@ void TestInsertTypeTransformOp::SetTupleUnfoldToTupleUnfoldKernelBuildInfo(
|
|||
}
|
||||
|
||||
void TestInsertTypeTransformOp::SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr,
|
||||
AnfNodePtr *split_ptr, AnfNodePtr *tuple_add1_ptr,
|
||||
AnfNodePtr *tuple_add2_ptr) {
|
||||
AnfNodePtr *split_ptr, AnfNodePtr *seq_add1_ptr,
|
||||
AnfNodePtr *seq_add2_ptr) {
|
||||
auto ret = g->get_return();
|
||||
EXPECT_NE(ret->input(1), nullptr);
|
||||
auto tuple_add2 = ret->input(1)->cast<CNodePtr>();
|
||||
*tuple_add2_ptr = tuple_add2;
|
||||
SetKernelBuildInfo(tuple_add2, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32},
|
||||
{KernelObjectType::TUPLE}, {KernelObjectType::TENSOR});
|
||||
auto seq_add2 = ret->input(1)->cast<CNodePtr>();
|
||||
*seq_add2_ptr = seq_add2;
|
||||
SetKernelBuildInfo(seq_add2, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"},
|
||||
{kNumberTypeFloat32}, {KernelObjectType::TUPLE, KernelObjectType::TUPLE},
|
||||
{KernelObjectType::TENSOR});
|
||||
|
||||
auto split = tuple_add2->input(1)->cast<CNodePtr>();
|
||||
auto split = seq_add2->input(1)->cast<CNodePtr>();
|
||||
*split_ptr = split;
|
||||
MS_LOG(INFO) << "split is " << split->fullname_with_scope();
|
||||
SetKernelBuildInfo(split, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32},
|
||||
{KernelObjectType::TENSOR}, {KernelObjectType::TUPLE_UNFOLD});
|
||||
|
||||
auto tuple_add1 = split->input(1)->cast<CNodePtr>();
|
||||
*tuple_add1_ptr = tuple_add1;
|
||||
SetKernelBuildInfo(tuple_add1, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32},
|
||||
{KernelObjectType::TUPLE}, {KernelObjectType::TENSOR});
|
||||
auto seq_add1 = split->input(1)->cast<CNodePtr>();
|
||||
*seq_add1_ptr = seq_add1;
|
||||
SetKernelBuildInfo(seq_add1, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"},
|
||||
{kNumberTypeFloat32}, {KernelObjectType::TUPLE, KernelObjectType::TUPLE},
|
||||
{KernelObjectType::TENSOR});
|
||||
|
||||
auto make_tuple = tuple_add1->input(1)->cast<CNodePtr>();
|
||||
auto input_x = seq_add1->input(2);
|
||||
SetKernelBuildInfo(input_x, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32}, {KernelObjectType::TUPLE},
|
||||
{KernelObjectType::TUPLE});
|
||||
|
||||
auto make_tuple = seq_add1->input(1)->cast<CNodePtr>();
|
||||
*make_tuple_ptr = make_tuple;
|
||||
MS_LOG(INFO) << "make_tuple is " << make_tuple->fullname_with_scope();
|
||||
SetKernelBuildInfo(make_tuple, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"},
|
||||
|
@ -203,19 +209,19 @@ void TestInsertTypeTransformOp::SetTupleToTensorKernelBuildInfo(const FuncGraphP
|
|||
{KernelObjectType::TENSOR});
|
||||
}
|
||||
|
||||
void TestInsertTypeTransformOp::SetTensorToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *tuple_add_ptr) {
|
||||
void TestInsertTypeTransformOp::SetTensorToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *seq_add_ptr) {
|
||||
auto ret = g->get_return();
|
||||
EXPECT_NE(ret->input(1), nullptr);
|
||||
auto tuple_add = ret->input(1)->cast<CNodePtr>();
|
||||
*tuple_add_ptr = tuple_add;
|
||||
SetKernelBuildInfo(tuple_add, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"},
|
||||
auto seq_add = ret->input(1)->cast<CNodePtr>();
|
||||
*seq_add_ptr = seq_add;
|
||||
SetKernelBuildInfo(seq_add, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"},
|
||||
{kNumberTypeFloat32}, {KernelObjectType::TUPLE, KernelObjectType::TUPLE},
|
||||
{KernelObjectType::TUPLE});
|
||||
auto input2 = tuple_add->input(2);
|
||||
auto input2 = seq_add->input(2);
|
||||
SetKernelBuildInfo(input2, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32}, {KernelObjectType::TENSOR},
|
||||
{KernelObjectType::TENSOR});
|
||||
|
||||
auto input1 = tuple_add->input(1);
|
||||
auto input1 = seq_add->input(1);
|
||||
SetKernelBuildInfo(input1, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32}, {KernelObjectType::TENSOR},
|
||||
{KernelObjectType::TENSOR});
|
||||
}
|
||||
|
@ -292,15 +298,19 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_unfold_transform) {
|
|||
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) {
|
||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_transform", "before");
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
std::vector<int64_t> shp_x{2, 4};
|
||||
std::vector<int64_t> shp_1{2, 4};
|
||||
auto abstract_1 = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_1);
|
||||
std::vector<int64_t> shp_2{2, 4};
|
||||
auto abstract_2 = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_2);
|
||||
std::vector<int64_t> shp_x{1, 3};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
std::vector<int64_t> shp_y{2, 4};
|
||||
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, y_abstract};
|
||||
AbstractBasePtrList abstract_list = {x_abstract};
|
||||
auto x_tuple_abs = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
AbstractBasePtrList args_spec_list{abstract_1, abstract_2, x_tuple_abs};
|
||||
auto func_graph = GetFuncGraph(g, args_spec_list);
|
||||
ASSERT_TRUE(func_graph != nullptr);
|
||||
AnfNodePtr make_tuple, split, tuple_add1, tuple_add2;
|
||||
SetTupleUnfoldToTupleKernelBuildInfo(func_graph, &make_tuple, &split, &tuple_add1, &tuple_add2);
|
||||
AnfNodePtr make_tuple, split, seq_add1, seq_add2;
|
||||
SetTupleUnfoldToTupleKernelBuildInfo(func_graph, &make_tuple, &split, &seq_add1, &seq_add2);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
|
@ -383,8 +393,8 @@ TEST_F(TestInsertTypeTransformOp, test_tensor_to_tuple_transform) {
|
|||
AbstractBasePtrList args_spec_list{x_abstract, y_abstract};
|
||||
auto func_graph = GetFuncGraph(g, args_spec_list);
|
||||
ASSERT_TRUE(func_graph != nullptr);
|
||||
AnfNodePtr tuple_add;
|
||||
SetTensorToTupleKernelBuildInfo(func_graph, &tuple_add);
|
||||
AnfNodePtr seq_add;
|
||||
SetTensorToTupleKernelBuildInfo(func_graph, &seq_add);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
|
|
|
@ -18,6 +18,7 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.ops import Primitive
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import _constants as Constants
|
||||
|
||||
tuple_get_item = Primitive(Constants.kTupleGetItem)
|
||||
|
||||
make_tuple = Primitive('MakeTuple')
|
||||
|
@ -74,26 +75,26 @@ def test_tuple_unfold_to_tuple_transform(tag):
|
|||
Expectation: The 'after' graph is identical to the graph after this pass.
|
||||
"""
|
||||
fns = FnDict()
|
||||
# Need to change AddN to TupleAdd in later version.
|
||||
tuple_add1 = P.AddN()
|
||||
tuple_add2 = P.AddN()
|
||||
# Need to change AddN to SequenceAdd in later version. This case is just used to cover this pattern.
|
||||
seq_add1 = P.AddN()
|
||||
seq_add2 = P.AddN()
|
||||
real_make_tuple = Primitive('RealMakeTuple')
|
||||
|
||||
@fns
|
||||
def before(input_1, input_2):
|
||||
def before(input_1, input_2, x):
|
||||
res = make_tuple(input_1, input_2)
|
||||
res = tuple_add1(res)
|
||||
res = seq_add1(res, x)
|
||||
res = split1(res)
|
||||
res = tuple_add2(res)
|
||||
res = seq_add2(res, x)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(input_1, input_2):
|
||||
def after(input_1, input_2, x):
|
||||
res = real_make_tuple(input_1, input_2)
|
||||
res = tuple_add1(res)
|
||||
res = seq_add1(res, x)
|
||||
res = split1(res)
|
||||
res = real_make_tuple(tuple_get_item(res, 0), tuple_get_item(res, 1))
|
||||
res = tuple_add2(res)
|
||||
res = seq_add2(res, x)
|
||||
return res
|
||||
|
||||
return fns[tag]
|
||||
|
@ -157,19 +158,20 @@ def test_tensor_to_tuple_transform(tag):
|
|||
Expectation: The 'after' graph is identical to the graph after this pass.
|
||||
"""
|
||||
fns = FnDict()
|
||||
tuple_add = P.Add()
|
||||
# Need to change Add to SequenceAdd in later version. This case is just used to cover this pattern.
|
||||
seq_add = P.Add()
|
||||
tensor_to_tuple = Primitive('TensorToTuple')
|
||||
|
||||
@fns
|
||||
def before(x, y):
|
||||
res = tuple_add(x, y)
|
||||
res = seq_add(x, y)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(x, y):
|
||||
input1 = tensor_to_tuple(x)
|
||||
input2 = tensor_to_tuple(y)
|
||||
res = tuple_add(input1, input2)
|
||||
res = seq_add(input1, input2)
|
||||
return res
|
||||
|
||||
return fns[tag]
|
||||
|
|
Loading…
Reference in New Issue