!46439 Add tuple_unfold to tuple_unfold ut

Merge pull request !46439 from ZPaC/pass-for-tuple
This commit is contained in:
i-robot 2022-12-07 10:23:22 +00:00 committed by Gitee
commit e8ee9f9b40
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 243 additions and 27 deletions

View File

@ -46,6 +46,11 @@ int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_inpu
// using for graph kernel
auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
MS_EXCEPTION_IF_NULL(dyn_input_node);
// Handle tuple nested scenes.
if (dyn_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple)) {
input_size += SplitTupleInputs(graph, dyn_input_node, plant_inputs);
continue;
}
(void)plant_inputs->emplace_back(dyn_input_node);
}
return input_size;
@ -64,6 +69,8 @@ AnfNodePtr CreateNewNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &i
auto new_cnode = NewCNode(input_list, func_graph, {origin_node});
MS_EXCEPTION_IF_NULL(new_cnode);
// This pass should not have new node whose abstract differs from the original node. So set the original node's
// abstract.
new_cnode->set_abstract(origin_node->abstract());
new_cnode->set_scope(origin_node->scope());
new_cnode->set_primal_attrs(origin_node->primal_attrs());
@ -87,29 +94,33 @@ void UpdateKernelBuildInfo(const CNodePtr &new_cnode, const CNodePtr &origin_nod
KernelBuildInfoPtr origin_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(origin_node);
auto new_kernel_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(origin_kernel_build_info);
// Construct new inputs and outputs info and set to the new kernel build info.
// Construct new inputs info and set to the new kernel build info.
std::vector<std::string> inputs_device_format;
std::vector<std::string> outputs_device_format;
std::vector<TypeId> inputs_device_type;
std::vector<TypeId> outputs_device_type;
size_t input_num = common::AnfAlgo::GetInputTensorNum(new_cnode);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_device_format.push_back(AnfAlgo::GetInputFormat(new_cnode, input_index));
inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(new_cnode, input_index));
inputs_device_format.push_back(AnfAlgo::GetOutputFormat(new_cnode->input(input_index + kSizeOne), kIndex0));
inputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(new_cnode->input(input_index + kSizeOne), kIndex0));
}
size_t output_num = common::AnfAlgo::GetOutputTensorNum(new_cnode);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_device_format.push_back(AnfAlgo::GetOutputFormat(new_cnode, output_index));
outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(new_cnode, output_index));
}
new_kernel_builder->SetInputsFormat(inputs_device_format);
new_kernel_builder->SetOutputsFormat(outputs_device_format);
new_kernel_builder->SetInputsDeviceType(inputs_device_type);
new_kernel_builder->SetOutputsDeviceType(outputs_device_type);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
new_cnode->set_kernel_info(kernel_info);
AnfAlgo::SetSelectKernelBuildInfo(new_kernel_builder->Build(), new_cnode.get());
}
// A map of kernel object type pairs to processing functions.
static std::map<ObjectTypePair, ProcessTypeTransformFunc> kTypePairToProcessFunc;
InsertTypeTransformOp::InsertTypeTransformOp(bool multigraph)
: PatternProcessPass("insert_type_transform_op", multigraph) {
kTypePairToProcessFunc[{KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TUPLE_UNFOLD}] =
std::bind(&InsertTypeTransformOp::ProcessTupleUnfoldToTupleUnfold, this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
}
const AnfNodePtr InsertTypeTransformOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
@ -128,7 +139,8 @@ const AnfNodePtr InsertTypeTransformOp::Process(const FuncGraphPtr &func_graph,
MS_EXCEPTION_IF_NULL(cnode);
AnfNodePtrList new_input_list = {common::AnfAlgo::GetCNodePrimitiveNode(cnode)};
for (size_t i = kIndex1; i < cnode->inputs().size(); i++) {
MS_LOG(DEBUG) << "Kernel object type of index " << i << " is " << kObjectTypeString[needed_input_type_list[i - 1]];
MS_LOG(DEBUG) << "Kernel object type of index " << i << " is "
<< kObjectTypeToString[needed_input_type_list[i - 1]];
// Get actually needed input kernel object type.
KernelObjectType needed_input_type = needed_input_type_list[i - 1];

View File

@ -28,10 +28,10 @@ namespace opt {
using kernel::KernelBuildInfoPtr;
using kernel::KernelObjectType;
std::map<KernelObjectType, std::string> kObjectTypeString = {{KernelObjectType::TENSOR, "tensor"},
{KernelObjectType::SCALAR, "scalar"},
{KernelObjectType::TUPLE, "tuple"},
{KernelObjectType::TUPLE_UNFOLD, "tuple_unfold"}};
static std::map<KernelObjectType, std::string> kObjectTypeToString = {{KernelObjectType::TENSOR, "tensor"},
{KernelObjectType::SCALAR, "scalar"},
{KernelObjectType::TUPLE, "tuple"},
{KernelObjectType::TUPLE_UNFOLD, "tuple_unfold"}};
// Kernel object type pair of:
// 1. One node's input kernel object type.
@ -41,16 +41,18 @@ struct ObjectTypePair {
KernelObjectType needed_input_type;
std::string to_string() const {
if (kObjectTypeString.find(current_input_type) == kObjectTypeString.end() ||
kObjectTypeString.find(needed_input_type) == kObjectTypeString.end()) {
if (kObjectTypeToString.find(current_input_type) == kObjectTypeToString.end() ||
kObjectTypeToString.find(needed_input_type) == kObjectTypeToString.end()) {
MS_LOG(EXCEPTION) << "The current input object type " << current_input_type << " or needed input object type "
<< needed_input_type << " is not valid.";
}
return kObjectTypeString[current_input_type] + "->" + kObjectTypeString[needed_input_type];
return kObjectTypeToString[current_input_type] + "->" + kObjectTypeToString[needed_input_type];
}
bool operator<(const ObjectTypePair &t) const { return to_string() < t.to_string(); }
bool operator==(const ObjectTypePair &t) const { return to_string() == t.to_string(); }
};
// For each unmatched type pair, a processing method is required to correct the types by inserting type transforming
@ -67,7 +69,6 @@ struct ObjectTypePair {
*/
using ProcessTypeTransformFunc = std::function<AnfNodePtrList(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
const CNodePtr &node, bool *new_prim)>;
std::map<ObjectTypePair, ProcessTypeTransformFunc> kTypePairToProcessFunc;
// SplitTupleInputs methods refer to the pass ConvertTupleInputToDynamicInput. It unfolds tuple inputs and returns the
// unfolded inputs nodes.
@ -85,11 +86,7 @@ void UpdateKernelBuildInfo(const CNodePtr &new_cnode, const CNodePtr &origin_nod
// node's output type). We need this pass to transform these types to valid types.
class BACKEND_EXPORT InsertTypeTransformOp : public PatternProcessPass {
public:
explicit InsertTypeTransformOp(bool multigraph = true) : PatternProcessPass("insert_type_transform_op", multigraph) {
kTypePairToProcessFunc[{KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TUPLE_UNFOLD}] =
std::bind(&InsertTypeTransformOp::ProcessTupleUnfoldToTupleUnfold, this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
}
explicit InsertTypeTransformOp(bool multigraph = true);
~InsertTypeTransformOp() override = default;
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;

View File

@ -0,0 +1,140 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#include "ir/anf.h"
#include "ir/tensor.h"
#include "include/common/debug/anf_ir_dump.h"
#include "common/py_func_graph_fetcher.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "backend/common/optimizer/optimizer.h"
#include "backend/common/optimizer/pass_manager.h"
#include "include/common/utils/utils.h"
#define private public
#define protected public
#include "backend/common/pass/insert_type_transform_op.h"
#undef private
#undef protected
namespace mindspore {
namespace opt {
using kernel::KernelObjectType;
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
class TestInsertTypeTransformOp : public BackendCommon {
public:
TestInsertTypeTransformOp() : getPyFun_("gtest_input.pre_activate.insert_type_transform_op_test", true) {}
~TestInsertTypeTransformOp() override = default;
public:
void SetTupleUnfoldToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &func_graph);
void SetKernelBuildInfo(const AnfNodePtr &node, const std::vector<std::string> &input_formats,
const std::vector<TypeId> &input_types, const std::vector<std::string> &output_formats,
const std::vector<TypeId> &output_types, const std::vector<KernelObjectType> &input_obj_types,
const std::vector<KernelObjectType> &output_obj_types);
UT::PyFuncGraphFetcher getPyFun_;
};
void TestInsertTypeTransformOp::SetTupleUnfoldToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &g) {
auto ret = g->get_return();
EXPECT_NE(ret->input(1), nullptr);
auto addn2 = ret->input(1)->cast<CNodePtr>();
MS_LOG(INFO) << "addn2 is " << addn2->fullname_with_scope();
SetKernelBuildInfo(addn2, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32},
{KernelObjectType::TUPLE_UNFOLD}, {KernelObjectType::TUPLE_UNFOLD});
auto split2_input_make_tuple = addn2->input(1)->cast<CNodePtr>();
MS_LOG(INFO) << "split2_input_make_tuple is " << split2_input_make_tuple->fullname_with_scope();
SetKernelBuildInfo(split2_input_make_tuple, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"},
{kNumberTypeFloat32}, {KernelObjectType::TENSOR, KernelObjectType::TENSOR},
{KernelObjectType::TUPLE_UNFOLD});
auto split2_get_item1 = split2_input_make_tuple->input(1)->cast<CNodePtr>();
MS_LOG(INFO) << "split2_get_item1 is " << split2_get_item1->fullname_with_scope();
SetKernelBuildInfo(split2_get_item1, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeInt64}, {"NCHW"},
{kNumberTypeFloat32}, {KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TENSOR},
{KernelObjectType::TENSOR});
auto split2_get_item2 = split2_input_make_tuple->input(1)->cast<CNodePtr>();
MS_LOG(INFO) << "split2_get_item2 is " << split2_get_item2->fullname_with_scope();
SetKernelBuildInfo(split2_get_item2, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeInt64}, {"NCHW"},
{kNumberTypeFloat32}, {KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TENSOR},
{KernelObjectType::TENSOR});
auto split2_1 = split2_get_item2->input(1)->cast<CNodePtr>();
auto split2_2 = split2_get_item2->input(1)->cast<CNodePtr>();
ASSERT_TRUE(split2_1 == split2_2);
MS_LOG(INFO) << "split2 is " << split2_1->fullname_with_scope();
SetKernelBuildInfo(split2_2, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW", "NCHW"},
{kNumberTypeFloat32, kNumberTypeFloat32}, {KernelObjectType::TUPLE_UNFOLD},
{KernelObjectType::TENSOR});
auto addn1 = split2_2->input(1)->cast<CNodePtr>();
MS_LOG(INFO) << "addn1 is " << addn1->fullname_with_scope();
SetKernelBuildInfo(addn1, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32},
{KernelObjectType::TUPLE_UNFOLD}, {KernelObjectType::TENSOR});
auto split1 = addn1->input(1)->cast<CNodePtr>();
MS_LOG(INFO) << "split1 is " << split1->fullname_with_scope();
SetKernelBuildInfo(split1, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32},
{KernelObjectType::TENSOR}, {KernelObjectType::TUPLE_UNFOLD});
// The input is a value.
auto input_node = split1->input(1);
MS_LOG(INFO) << "input_node is " << input_node->fullname_with_scope();
SetKernelBuildInfo(input_node, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32},
{KernelObjectType::TENSOR}, {KernelObjectType::TENSOR});
}
void TestInsertTypeTransformOp::SetKernelBuildInfo(
const AnfNodePtr &node, const std::vector<std::string> &input_formats, const std::vector<TypeId> &input_types,
const std::vector<std::string> &output_formats, const std::vector<TypeId> &output_types,
const std::vector<KernelObjectType> &input_obj_types, const std::vector<KernelObjectType> &output_obj_types) {
KernelBuildInfoBuilder builder;
builder.SetInputsFormat(input_formats);
builder.SetInputsDeviceType(input_types);
builder.SetOutputsFormat(output_formats);
builder.SetOutputsDeviceType(output_types);
builder.SetInputsKernelObjectType(input_obj_types);
builder.SetOutputsKernelObjectType(output_obj_types);
node->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
}
/// Feature: Dynamic shape.
/// Description: Test TupleUnfold to TupleUnfold type transforming pass.
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python.
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_unfold_transform) {
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_unfold_transform", "before");
ASSERT_TRUE(g != nullptr);
std::vector<int64_t> shp_x{2, 4};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list{x_abstract};
auto func_graph = GetFuncGraph(g, args_spec_list);
ASSERT_TRUE(func_graph != nullptr);
SetTupleUnfoldToTupleUnfoldKernelBuildInfo(func_graph);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::InsertTypeTransformOp>());
optimizer->AddPassManager(pm);
optimizer->Optimize(func_graph);
FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_unfold_transform", "after");
ASSERT_TRUE(g_after != nullptr);
EXPECT_TRUE(CheckEqualGraph(func_graph, g_after));
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,67 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore
from mindspore.common.tensor import Tensor
from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops import _constants as Constants
tuple_get_item = Primitive(Constants.kTupleGetItem)
make_tuple = Primitive('MakeTuple')
split1 = P.Split(1, 2)
split2 = P.Split(0, 2)
add1 = P.AddN()
add2 = P.AddN()
input_x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]), mindspore.int32)
class FnDict:
def __init__(self):
self.fn_dict = {}
def __call__(self, fn):
self.fn_dict[fn.__name__] = fn
def __getitem__(self, name):
return self.fn_dict.get(name)
def test_tuple_unfold_to_tuple_unfold_transform(tag):
"""
Feature: Dynamic shape.
Description: Test TupleUnfold to TupleUnfold transforming pass.
Expectation: The 'after' graph is identical to the graph after this pass.
"""
fns = FnDict()
@fns
def before(x):
res = split1(x)
res = add1(res)
res = split2(res)
res = add2((tuple_get_item(res, 0), tuple_get_item(res, 1)))
return res
@fns
def after(x):
res = split1(x)
res = add1(tuple_get_item(res, 0), tuple_get_item(res, 1))
res = split2(res)
res = add2(tuple_get_item(res, 0), tuple_get_item(res, 1))
return res
return fns[tag]