forked from mindspore-Ecosystem/mindspore
!46439 Add tuple_unfold to tuple_unfold ut
Merge pull request !46439 from ZPaC/pass-for-tuple
This commit is contained in:
commit
e8ee9f9b40
|
@ -46,6 +46,11 @@ int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_inpu
|
|||
// using for graph kernel
|
||||
auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
|
||||
MS_EXCEPTION_IF_NULL(dyn_input_node);
|
||||
// Handle tuple nested scenes.
|
||||
if (dyn_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple)) {
|
||||
input_size += SplitTupleInputs(graph, dyn_input_node, plant_inputs);
|
||||
continue;
|
||||
}
|
||||
(void)plant_inputs->emplace_back(dyn_input_node);
|
||||
}
|
||||
return input_size;
|
||||
|
@ -64,6 +69,8 @@ AnfNodePtr CreateNewNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &i
|
|||
|
||||
auto new_cnode = NewCNode(input_list, func_graph, {origin_node});
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
// This pass should not have new node whose abstract differs from the original node. So set the original node's
|
||||
// abstract.
|
||||
new_cnode->set_abstract(origin_node->abstract());
|
||||
new_cnode->set_scope(origin_node->scope());
|
||||
new_cnode->set_primal_attrs(origin_node->primal_attrs());
|
||||
|
@ -87,29 +94,33 @@ void UpdateKernelBuildInfo(const CNodePtr &new_cnode, const CNodePtr &origin_nod
|
|||
KernelBuildInfoPtr origin_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(origin_node);
|
||||
auto new_kernel_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(origin_kernel_build_info);
|
||||
|
||||
// Construct new inputs and outputs info and set to the new kernel build info.
|
||||
// Construct new inputs info and set to the new kernel build info.
|
||||
std::vector<std::string> inputs_device_format;
|
||||
std::vector<std::string> outputs_device_format;
|
||||
std::vector<TypeId> inputs_device_type;
|
||||
std::vector<TypeId> outputs_device_type;
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(new_cnode);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_device_format.push_back(AnfAlgo::GetInputFormat(new_cnode, input_index));
|
||||
inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(new_cnode, input_index));
|
||||
inputs_device_format.push_back(AnfAlgo::GetOutputFormat(new_cnode->input(input_index + kSizeOne), kIndex0));
|
||||
inputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(new_cnode->input(input_index + kSizeOne), kIndex0));
|
||||
}
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(new_cnode);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_device_format.push_back(AnfAlgo::GetOutputFormat(new_cnode, output_index));
|
||||
outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(new_cnode, output_index));
|
||||
}
|
||||
|
||||
new_kernel_builder->SetInputsFormat(inputs_device_format);
|
||||
new_kernel_builder->SetOutputsFormat(outputs_device_format);
|
||||
new_kernel_builder->SetInputsDeviceType(inputs_device_type);
|
||||
new_kernel_builder->SetOutputsDeviceType(outputs_device_type);
|
||||
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
new_cnode->set_kernel_info(kernel_info);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(new_kernel_builder->Build(), new_cnode.get());
|
||||
}
|
||||
|
||||
// A map of kernel object type pairs to processing functions.
|
||||
static std::map<ObjectTypePair, ProcessTypeTransformFunc> kTypePairToProcessFunc;
|
||||
|
||||
InsertTypeTransformOp::InsertTypeTransformOp(bool multigraph)
|
||||
: PatternProcessPass("insert_type_transform_op", multigraph) {
|
||||
kTypePairToProcessFunc[{KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TUPLE_UNFOLD}] =
|
||||
std::bind(&InsertTypeTransformOp::ProcessTupleUnfoldToTupleUnfold, this, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
|
||||
}
|
||||
|
||||
const AnfNodePtr InsertTypeTransformOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -128,7 +139,8 @@ const AnfNodePtr InsertTypeTransformOp::Process(const FuncGraphPtr &func_graph,
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
AnfNodePtrList new_input_list = {common::AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||
for (size_t i = kIndex1; i < cnode->inputs().size(); i++) {
|
||||
MS_LOG(DEBUG) << "Kernel object type of index " << i << " is " << kObjectTypeString[needed_input_type_list[i - 1]];
|
||||
MS_LOG(DEBUG) << "Kernel object type of index " << i << " is "
|
||||
<< kObjectTypeToString[needed_input_type_list[i - 1]];
|
||||
// Get actually needed input kernel object type.
|
||||
KernelObjectType needed_input_type = needed_input_type_list[i - 1];
|
||||
|
||||
|
|
|
@ -28,10 +28,10 @@ namespace opt {
|
|||
using kernel::KernelBuildInfoPtr;
|
||||
using kernel::KernelObjectType;
|
||||
|
||||
std::map<KernelObjectType, std::string> kObjectTypeString = {{KernelObjectType::TENSOR, "tensor"},
|
||||
{KernelObjectType::SCALAR, "scalar"},
|
||||
{KernelObjectType::TUPLE, "tuple"},
|
||||
{KernelObjectType::TUPLE_UNFOLD, "tuple_unfold"}};
|
||||
static std::map<KernelObjectType, std::string> kObjectTypeToString = {{KernelObjectType::TENSOR, "tensor"},
|
||||
{KernelObjectType::SCALAR, "scalar"},
|
||||
{KernelObjectType::TUPLE, "tuple"},
|
||||
{KernelObjectType::TUPLE_UNFOLD, "tuple_unfold"}};
|
||||
|
||||
// Kernel object type pair of:
|
||||
// 1. One node's input kernel object type.
|
||||
|
@ -41,16 +41,18 @@ struct ObjectTypePair {
|
|||
KernelObjectType needed_input_type;
|
||||
|
||||
std::string to_string() const {
|
||||
if (kObjectTypeString.find(current_input_type) == kObjectTypeString.end() ||
|
||||
kObjectTypeString.find(needed_input_type) == kObjectTypeString.end()) {
|
||||
if (kObjectTypeToString.find(current_input_type) == kObjectTypeToString.end() ||
|
||||
kObjectTypeToString.find(needed_input_type) == kObjectTypeToString.end()) {
|
||||
MS_LOG(EXCEPTION) << "The current input object type " << current_input_type << " or needed input object type "
|
||||
<< needed_input_type << " is not valid.";
|
||||
}
|
||||
|
||||
return kObjectTypeString[current_input_type] + "->" + kObjectTypeString[needed_input_type];
|
||||
return kObjectTypeToString[current_input_type] + "->" + kObjectTypeToString[needed_input_type];
|
||||
}
|
||||
|
||||
bool operator<(const ObjectTypePair &t) const { return to_string() < t.to_string(); }
|
||||
|
||||
bool operator==(const ObjectTypePair &t) const { return to_string() == t.to_string(); }
|
||||
};
|
||||
|
||||
// For each unmatched type pair, a processing method is required to correct the types by inserting type transforming
|
||||
|
@ -67,7 +69,6 @@ struct ObjectTypePair {
|
|||
*/
|
||||
using ProcessTypeTransformFunc = std::function<AnfNodePtrList(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
||||
const CNodePtr &node, bool *new_prim)>;
|
||||
std::map<ObjectTypePair, ProcessTypeTransformFunc> kTypePairToProcessFunc;
|
||||
|
||||
// SplitTupleInputs methods refer to the pass ConvertTupleInputToDynamicInput. It unfolds tuple inputs and returns the
|
||||
// unfolded inputs nodes.
|
||||
|
@ -85,11 +86,7 @@ void UpdateKernelBuildInfo(const CNodePtr &new_cnode, const CNodePtr &origin_nod
|
|||
// node's output type). We need this pass to transform these types to valid types.
|
||||
class BACKEND_EXPORT InsertTypeTransformOp : public PatternProcessPass {
|
||||
public:
|
||||
explicit InsertTypeTransformOp(bool multigraph = true) : PatternProcessPass("insert_type_transform_op", multigraph) {
|
||||
kTypePairToProcessFunc[{KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TUPLE_UNFOLD}] =
|
||||
std::bind(&InsertTypeTransformOp::ProcessTupleUnfoldToTupleUnfold, this, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
|
||||
}
|
||||
explicit InsertTypeTransformOp(bool multigraph = true);
|
||||
~InsertTypeTransformOp() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/backend_common_test.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/pass_manager.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
#define private public
|
||||
#define protected public
|
||||
#include "backend/common/pass/insert_type_transform_op.h"
|
||||
#undef private
|
||||
#undef protected
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
using kernel::KernelObjectType;
|
||||
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
||||
|
||||
class TestInsertTypeTransformOp : public BackendCommon {
|
||||
public:
|
||||
TestInsertTypeTransformOp() : getPyFun_("gtest_input.pre_activate.insert_type_transform_op_test", true) {}
|
||||
~TestInsertTypeTransformOp() override = default;
|
||||
|
||||
public:
|
||||
void SetTupleUnfoldToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &func_graph);
|
||||
void SetKernelBuildInfo(const AnfNodePtr &node, const std::vector<std::string> &input_formats,
|
||||
const std::vector<TypeId> &input_types, const std::vector<std::string> &output_formats,
|
||||
const std::vector<TypeId> &output_types, const std::vector<KernelObjectType> &input_obj_types,
|
||||
const std::vector<KernelObjectType> &output_obj_types);
|
||||
UT::PyFuncGraphFetcher getPyFun_;
|
||||
};
|
||||
|
||||
void TestInsertTypeTransformOp::SetTupleUnfoldToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &g) {
|
||||
auto ret = g->get_return();
|
||||
EXPECT_NE(ret->input(1), nullptr);
|
||||
auto addn2 = ret->input(1)->cast<CNodePtr>();
|
||||
MS_LOG(INFO) << "addn2 is " << addn2->fullname_with_scope();
|
||||
SetKernelBuildInfo(addn2, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32},
|
||||
{KernelObjectType::TUPLE_UNFOLD}, {KernelObjectType::TUPLE_UNFOLD});
|
||||
|
||||
auto split2_input_make_tuple = addn2->input(1)->cast<CNodePtr>();
|
||||
MS_LOG(INFO) << "split2_input_make_tuple is " << split2_input_make_tuple->fullname_with_scope();
|
||||
SetKernelBuildInfo(split2_input_make_tuple, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"},
|
||||
{kNumberTypeFloat32}, {KernelObjectType::TENSOR, KernelObjectType::TENSOR},
|
||||
{KernelObjectType::TUPLE_UNFOLD});
|
||||
|
||||
auto split2_get_item1 = split2_input_make_tuple->input(1)->cast<CNodePtr>();
|
||||
MS_LOG(INFO) << "split2_get_item1 is " << split2_get_item1->fullname_with_scope();
|
||||
SetKernelBuildInfo(split2_get_item1, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeInt64}, {"NCHW"},
|
||||
{kNumberTypeFloat32}, {KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TENSOR},
|
||||
{KernelObjectType::TENSOR});
|
||||
|
||||
auto split2_get_item2 = split2_input_make_tuple->input(1)->cast<CNodePtr>();
|
||||
MS_LOG(INFO) << "split2_get_item2 is " << split2_get_item2->fullname_with_scope();
|
||||
SetKernelBuildInfo(split2_get_item2, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeInt64}, {"NCHW"},
|
||||
{kNumberTypeFloat32}, {KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TENSOR},
|
||||
{KernelObjectType::TENSOR});
|
||||
|
||||
auto split2_1 = split2_get_item2->input(1)->cast<CNodePtr>();
|
||||
auto split2_2 = split2_get_item2->input(1)->cast<CNodePtr>();
|
||||
ASSERT_TRUE(split2_1 == split2_2);
|
||||
MS_LOG(INFO) << "split2 is " << split2_1->fullname_with_scope();
|
||||
SetKernelBuildInfo(split2_2, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW", "NCHW"},
|
||||
{kNumberTypeFloat32, kNumberTypeFloat32}, {KernelObjectType::TUPLE_UNFOLD},
|
||||
{KernelObjectType::TENSOR});
|
||||
|
||||
auto addn1 = split2_2->input(1)->cast<CNodePtr>();
|
||||
MS_LOG(INFO) << "addn1 is " << addn1->fullname_with_scope();
|
||||
SetKernelBuildInfo(addn1, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32},
|
||||
{KernelObjectType::TUPLE_UNFOLD}, {KernelObjectType::TENSOR});
|
||||
|
||||
auto split1 = addn1->input(1)->cast<CNodePtr>();
|
||||
MS_LOG(INFO) << "split1 is " << split1->fullname_with_scope();
|
||||
SetKernelBuildInfo(split1, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32},
|
||||
{KernelObjectType::TENSOR}, {KernelObjectType::TUPLE_UNFOLD});
|
||||
|
||||
// The input is a value.
|
||||
auto input_node = split1->input(1);
|
||||
MS_LOG(INFO) << "input_node is " << input_node->fullname_with_scope();
|
||||
SetKernelBuildInfo(input_node, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32},
|
||||
{KernelObjectType::TENSOR}, {KernelObjectType::TENSOR});
|
||||
}
|
||||
|
||||
void TestInsertTypeTransformOp::SetKernelBuildInfo(
|
||||
const AnfNodePtr &node, const std::vector<std::string> &input_formats, const std::vector<TypeId> &input_types,
|
||||
const std::vector<std::string> &output_formats, const std::vector<TypeId> &output_types,
|
||||
const std::vector<KernelObjectType> &input_obj_types, const std::vector<KernelObjectType> &output_obj_types) {
|
||||
KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat(input_formats);
|
||||
builder.SetInputsDeviceType(input_types);
|
||||
builder.SetOutputsFormat(output_formats);
|
||||
builder.SetOutputsDeviceType(output_types);
|
||||
builder.SetInputsKernelObjectType(input_obj_types);
|
||||
builder.SetOutputsKernelObjectType(output_obj_types);
|
||||
node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
|
||||
}
|
||||
|
||||
/// Feature: Dynamic shape.
|
||||
/// Description: Test TupleUnfold to TupleUnfold type transforming pass.
|
||||
/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python.
|
||||
TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_unfold_transform) {
|
||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_unfold_transform", "before");
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
std::vector<int64_t> shp_x{2, 4};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
AbstractBasePtrList args_spec_list{x_abstract};
|
||||
auto func_graph = GetFuncGraph(g, args_spec_list);
|
||||
ASSERT_TRUE(func_graph != nullptr);
|
||||
SetTupleUnfoldToTupleUnfoldKernelBuildInfo(func_graph);
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::InsertTypeTransformOp>());
|
||||
optimizer->AddPassManager(pm);
|
||||
optimizer->Optimize(func_graph);
|
||||
|
||||
FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_unfold_transform", "after");
|
||||
ASSERT_TRUE(g_after != nullptr);
|
||||
EXPECT_TRUE(CheckEqualGraph(func_graph, g_after));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import _constants as Constants
|
||||
tuple_get_item = Primitive(Constants.kTupleGetItem)
|
||||
|
||||
make_tuple = Primitive('MakeTuple')
|
||||
split1 = P.Split(1, 2)
|
||||
split2 = P.Split(0, 2)
|
||||
add1 = P.AddN()
|
||||
add2 = P.AddN()
|
||||
|
||||
input_x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]), mindspore.int32)
|
||||
|
||||
|
||||
class FnDict:
|
||||
def __init__(self):
|
||||
self.fn_dict = {}
|
||||
|
||||
def __call__(self, fn):
|
||||
self.fn_dict[fn.__name__] = fn
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.fn_dict.get(name)
|
||||
|
||||
|
||||
def test_tuple_unfold_to_tuple_unfold_transform(tag):
|
||||
"""
|
||||
Feature: Dynamic shape.
|
||||
Description: Test TupleUnfold to TupleUnfold transforming pass.
|
||||
Expectation: The 'after' graph is identical to the graph after this pass.
|
||||
"""
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(x):
|
||||
res = split1(x)
|
||||
res = add1(res)
|
||||
res = split2(res)
|
||||
res = add2((tuple_get_item(res, 0), tuple_get_item(res, 1)))
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(x):
|
||||
res = split1(x)
|
||||
res = add1(tuple_get_item(res, 0), tuple_get_item(res, 1))
|
||||
res = split2(res)
|
||||
res = add2(tuple_get_item(res, 0), tuple_get_item(res, 1))
|
||||
return res
|
||||
|
||||
return fns[tag]
|
Loading…
Reference in New Issue