From e498c96a203a3b8b680928bfa3f6e1b6339b5985 Mon Sep 17 00:00:00 2001 From: hwjiaorui Date: Mon, 2 Aug 2021 17:25:00 +0800 Subject: [PATCH] tbe json creator UT --- .jenkins/check/config/filter_pylint.txt | 1 + .../tbe/tbe_json/tbe_json_creator.cc | 6 +- .../tbe/tbe_json/tbe_json_creator.h | 2 + tests/ut/cpp/CMakeLists.txt | 1 + .../gtest_input/tbe/tbe_json_creator_test.py | 74 ++++ tests/ut/cpp/runtest.sh | 2 + tests/ut/cpp/tbe/tbe_json_creator_test.cc | 371 ++++++++++++++++++ 7 files changed, 454 insertions(+), 3 deletions(-) create mode 100644 tests/ut/cpp/python_input/gtest_input/tbe/tbe_json_creator_test.py create mode 100644 tests/ut/cpp/tbe/tbe_json_creator_test.cc diff --git a/.jenkins/check/config/filter_pylint.txt b/.jenkins/check/config/filter_pylint.txt index e709efe2b3c..1b8f94c2700 100644 --- a/.jenkins/check/config/filter_pylint.txt +++ b/.jenkins/check/config/filter_pylint.txt @@ -52,6 +52,7 @@ "mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/infer/primitive_test.py" "super-init-not-called" "mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parse_primitive.py" "super-init-not-called" "mindspore/tests/ut/cpp/python_input/gtest_input/pre_activate" "unused-variable" +"mindspore/tests/ut/cpp/python_input/gtest_input/tbe" "unused-variable" "mindspore/tests/ut/python/train/summary/test_summary_abnormal_input.py" "bare-except" "mindspore/tests/ut/python/train/summary/test_graph_summary.py" "protected-access" "mindspore/tests/ut/python/parameter_feature/test_parameter.py" "unused-variable" diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/tbe_json_creator.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/tbe_json_creator.cc index 6d230e078b8..f194b8f2a81 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/tbe_json_creator.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/tbe_json_creator.cc @@ -192,7 +192,7 @@ bool TbeJsonCreator::GenComputeJson(const AnfNodePtr &anf_node, nlohmann::json * void TbeJsonCreator::GenFusionOpName(nlohmann::json *kernel_json, std::string prefix) { json_name_.clear(); - size_t hash_id = GenJsonHash((*kernel_json)); + json_hash_ = GenJsonHash((*kernel_json)); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); json_name_ = std::move(prefix); @@ -203,7 +203,7 @@ void TbeJsonCreator::GenFusionOpName(nlohmann::json *kernel_json, std::string pr json_name_.append("_"); } } - json_name_ = json_name_ + std::to_string(hash_id) + "_" + std::to_string(device_id); + json_name_ = json_name_ + std::to_string(json_hash_) + "_" + std::to_string(device_id); MS_LOG(DEBUG) << "Generate Json name: " << json_name_; (*kernel_json)[kJFusionOpName] = json_name_; } @@ -231,7 +231,7 @@ size_t TbeJsonCreator::GenJsonHash(nlohmann::json tbe_json) { DeleteDescName(&op.at(kJInputDesc)); } } - return std::hash()(tbe_json.dump()); + return std::hash()(op_lists.dump()); } void TbeJsonCreator::AddOpNameForComputeNode(nlohmann::json *kernel_json) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/tbe_json_creator.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/tbe_json_creator.h index 83c3bfdc90f..e71838dfa0e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/tbe_json_creator.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/tbe_json_creator.h @@ -48,6 +48,7 @@ class TbeJsonCreator { virtual bool GenJson(const AnfNodePtr &anf_node, nlohmann::json *kernel_json) { return false; } virtual bool GenJson(const FusionScopeInfo &fusion_scope_info, nlohmann::json *fusion_json) { return false; } std::string GetJsonName() { return json_name_; } + size_t GetJsonHash() { return json_hash_; } protected: bool GenComputeJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json); @@ -72,6 +73,7 @@ class TbeJsonCreator { private: std::string json_name_; + size_t json_hash_; }; } // namespace mindspore::kernel diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 58288960327..86d21eef618 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -68,6 +68,7 @@ if(ENABLE_MINDDATA) ./ps/*.cc ./fl/*.cc ./cxx_api/*.cc + ./tbe/*.cc ) if(NOT ENABLE_PYTHON) diff --git a/tests/ut/cpp/python_input/gtest_input/tbe/tbe_json_creator_test.py b/tests/ut/cpp/python_input/gtest_input/tbe/tbe_json_creator_test.py new file mode 100644 index 00000000000..e860816cfc6 --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/tbe/tbe_json_creator_test.py @@ -0,0 +1,74 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops.operations import _grad_ops as G + +Relu = P.ReLU() +Cast = P.Cast() +Add = P.Add() +Conv2DBackpropFilter = G.Conv2DBackpropFilter(out_channel=4, + kernel_size=1, + pad_mode="valid", + pad=0, + mode=1, + stride=1, + dilation=1, + group=1) +DynamicRNN = P.DynamicRNN(forget_bias=0.0) +LayerNorm = P.LayerNorm() +Conv2D = P.Conv2D(out_channel=32, kernel_size=3) + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_tbe_json_creator(tag): + fns = FnDict() + + @fns + def func_relu_relu_cast(x): + relu1 = Relu(x) + relu2 = Relu(relu1) + res = Cast(relu2, mstype.float16) + return res + + @fns + def func_conv2d_backprop_filter(x, out, shape): + return Conv2DBackpropFilter(x, out, shape) + + @fns + def func_dynamic_rnn(x, w, b, seq_length, init_h, init_c): + return DynamicRNN(x, w, b, seq_length, init_h, init_c) + + @fns + def func_layer_norm(input_x, gamma, beta): + return LayerNorm(input_x, gamma, beta) + + @fns + def fusion_add_conv2d(x, y, z): + add = Add(x, y) + return Conv2D(add, z) + + return fns[tag] diff --git a/tests/ut/cpp/runtest.sh b/tests/ut/cpp/runtest.sh index e4c5f6cdf2f..df1f81e9bd2 100755 --- a/tests/ut/cpp/runtest.sh +++ b/tests/ut/cpp/runtest.sh @@ -32,6 +32,8 @@ ${PROJECT_PATH}/graphengine/third_party/prebuild/aarch64:${LD_LIBRARY_PATH} export PYTHONPATH=${PROJECT_PATH}/tests/ut/cpp/python_input:$PYTHONPATH:${PROJECT_PATH} export GLOG_v=2 export GC_COLLECT_IN_CELL=1 +## set op info config path +export MINDSPORE_OP_INFO_PATH=${PROJECT_PATH}/config/op_info.config ## prepare data for dataset & mindrecord cp -fr $PROJECT_PATH/tests/ut/data ${PROJECT_PATH}/build/mindspore/tests/ut/cpp/ diff --git a/tests/ut/cpp/tbe/tbe_json_creator_test.cc b/tests/ut/cpp/tbe/tbe_json_creator_test.cc new file mode 100644 index 00000000000..7af91ab2c94 --- /dev/null +++ b/tests/ut/cpp/tbe/tbe_json_creator_test.cc @@ -0,0 +1,371 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/ms_context.h" +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" +#include "debug/anf_ir_dump.h" +#include "backend/kernel_compiler/kernel.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h" +#include "backend/kernel_compiler/tbe/tbe_json/single_tbe_json_creator.h" +#include "backend/kernel_compiler/tbe/tbe_json/fusion_tbe_json_creator.h" + +namespace mindspore::kernel { + +using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; +constexpr int64_t kShape4D = 4; +class TestHWTBEJsonCreator : public BackendCommon { + public: + TestHWTBEJsonCreator() : get_py_fun_("gtest_input.tbe.tbe_json_creator_test", true) {} + ~TestHWTBEJsonCreator() override = default; + + UT::PyFuncGraphFetcher get_py_fun_; +}; + +TEST_F(TestHWTBEJsonCreator, test_tbe_single_common) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + context->set_param(MS_CTX_DEVICE_TARGET, kAscendDevice); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_json_creator", "func_relu_relu_cast"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list{x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + + auto ret = kg->get_return(); + EXPECT_NE(ret->input(1), nullptr); + auto tuple = ret->input(1); + EXPECT_NE(tuple, nullptr); + auto cast = tuple->cast()->input(1); + EXPECT_NE(cast, nullptr); + auto relu2 = cast->cast()->input(1); + EXPECT_NE(relu2, nullptr); + auto relu1 = relu2->cast()->input(1); + + KernelBuildInfoBuilder builder; + builder.SetInputsFormat({"NC1HWC0"}); + builder.SetOutputsFormat({"NC1HWC0"}); + builder.SetInputsDeviceType({kFloat32->type_id()}); + builder.SetOutputsDeviceType({kFloat32->type_id()}); + builder.SetKernelType(KernelType::TBE_KERNEL); + builder.SetFusionType(kernel::FusionType::ELEMWISE); + builder.SetProcessor(kernel::Processor::AICORE); + builder.SetKernelType(KernelType::TBE_KERNEL); + + relu1->set_kernel_info(std::make_shared()); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), relu1.get()); + + auto tbe_json_creator_select = std::make_shared(); + auto tbe_json_creator_check = std::make_shared(); + auto tbe_json_creator_build = std::make_shared(); + nlohmann::json kernel_json; + EXPECT_TRUE(tbe_json_creator_select->GenJson(relu1, &kernel_json)); + EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 7996236493612266030U); + EXPECT_TRUE(tbe_json_creator_check->GenJson(relu1, &kernel_json)); + EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 16463039402039306442U); + EXPECT_TRUE(tbe_json_creator_build->GenJson(relu1, &kernel_json)); + EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 8871925407866693227U); +} + +TEST_F(TestHWTBEJsonCreator, test_tbe_single_conv2d_backprop_filter) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + context->set_param(MS_CTX_DEVICE_TARGET, kAscendDevice); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_json_creator", "func_conv2d_backprop_filter"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList abstract_list = {std::make_shared(kShape4D)}; + auto y_abstract = std::make_shared(abstract_list); + AbstractBasePtrList args_spec_list{x_abstract, x_abstract, y_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + + auto ret = kg->get_return(); + EXPECT_NE(ret->input(1), nullptr); + auto tuple = ret->input(1); + EXPECT_NE(tuple, nullptr); + auto conv2d_backprop_filter = tuple->cast()->input(1); + + KernelBuildInfoBuilder builder; + builder.SetInputsFormat({"NC1HWC0", "NC1HWC0"}); + builder.SetOutputsFormat({"NC1HWC0"}); + builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()}); + builder.SetOutputsDeviceType({kFloat32->type_id()}); + builder.SetKernelType(KernelType::TBE_KERNEL); + builder.SetFusionType(kernel::FusionType::ELEMWISE); + builder.SetProcessor(kernel::Processor::AICORE); + builder.SetKernelType(KernelType::TBE_KERNEL); + + conv2d_backprop_filter->set_kernel_info(std::make_shared()); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), conv2d_backprop_filter.get()); + + auto tbe_json_creator_select = std::make_shared(); + auto tbe_json_creator_check = std::make_shared(); + auto tbe_json_creator_build = std::make_shared(); + nlohmann::json kernel_json; + EXPECT_TRUE(tbe_json_creator_select->GenJson(conv2d_backprop_filter, &kernel_json)); + EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 16997569423579290131U); + EXPECT_TRUE(tbe_json_creator_check->GenJson(conv2d_backprop_filter, &kernel_json)); + EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 1051090390656699050U); + EXPECT_TRUE(tbe_json_creator_build->GenJson(conv2d_backprop_filter, &kernel_json)); + EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 3388833908101709327U); +} + +TEST_F(TestHWTBEJsonCreator, test_tbe_single_dynamic_rnn) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + context->set_param(MS_CTX_DEVICE_TARGET, kAscendDevice); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_json_creator", "func_dynamic_rnn"); + std::vector x_shp{2, 16, 64}; + std::vector w_shp{96, 128}; + std::vector b_shp{128}; + std::vector init_h_shp{1, 16, 32}; + std::vector init_c_shp{1, 16, 32}; + auto x_abstract = std::make_shared(kFloat16, x_shp); + auto w_abstract = std::make_shared(kFloat16, w_shp); + auto b_abstract = std::make_shared(kFloat16, b_shp); + auto init_h_abstract = std::make_shared(kFloat16, init_h_shp); + auto init_c_abstract = std::make_shared(kFloat16, init_c_shp); + auto seq_length_abstract = std::make_shared(); + + AbstractBasePtrList args_spec_list{x_abstract, w_abstract, b_abstract, + seq_length_abstract, init_h_abstract, init_c_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + + auto ret = kg->get_return(); + EXPECT_NE(ret->input(1), nullptr); + auto tuple = ret->input(1); + EXPECT_NE(tuple, nullptr); + auto make_tuple = tuple->cast()->input(1); + EXPECT_NE(tuple, nullptr); + auto tuple2 = make_tuple->cast()->input(1); + EXPECT_NE(tuple2, nullptr); + auto dynamic_rnn = tuple2->cast()->input(1); + + KernelBuildInfoBuilder builder; + builder.SetInputsFormat({"NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0"}); + builder.SetOutputsFormat({"NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0"}); + builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id(), + kFloat16->type_id(), kFloat16->type_id()}); + builder.SetOutputsDeviceType({kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id(), + kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id()}); + builder.SetKernelType(KernelType::TBE_KERNEL); + builder.SetFusionType(kernel::FusionType::ELEMWISE); + builder.SetProcessor(kernel::Processor::AICORE); + builder.SetKernelType(KernelType::TBE_KERNEL); + + dynamic_rnn->set_kernel_info(std::make_shared()); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), dynamic_rnn.get()); + + auto tbe_json_creator_select = std::make_shared(); + auto tbe_json_creator_check = std::make_shared(); + auto tbe_json_creator_build = std::make_shared(); + nlohmann::json kernel_json; + EXPECT_TRUE(tbe_json_creator_select->GenJson(dynamic_rnn, &kernel_json)); + EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 599064190761596566U); + EXPECT_TRUE(tbe_json_creator_check->GenJson(dynamic_rnn, &kernel_json)); + EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 13855166034392728379U); + EXPECT_TRUE(tbe_json_creator_build->GenJson(dynamic_rnn, &kernel_json)); + EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 12346685554589275353U); +} + +TEST_F(TestHWTBEJsonCreator, test_tbe_single_layer_norm) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + context->set_param(MS_CTX_DEVICE_TARGET, kAscendDevice); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_json_creator", "func_layer_norm"); + std::vector x_shp{2, 3}; + std::vector gamma_shp{3}; + std::vector beta_shp{3}; + auto x_abstract = std::make_shared(kFloat32, x_shp); + auto gamma_abstract = std::make_shared(kFloat32, gamma_shp); + auto beta_abstract = std::make_shared(kFloat32, beta_shp); + + AbstractBasePtrList args_spec_list{ + x_abstract, + gamma_abstract, + beta_abstract, + }; + auto kg = GetKernelGraph(g, args_spec_list); + + auto ret = kg->get_return(); + EXPECT_NE(ret->input(1), nullptr); + auto tuple = ret->input(1); + EXPECT_NE(tuple, nullptr); + auto make_tuple = tuple->cast()->input(1); + EXPECT_NE(tuple, nullptr); + auto tuple2 = make_tuple->cast()->input(1); + EXPECT_NE(tuple2, nullptr); + auto layer_norm = tuple2->cast()->input(1); + + KernelBuildInfoBuilder builder; + builder.SetInputsFormat({"NC1HWC0", "NC1HWC0", "NC1HWC0"}); + builder.SetOutputsFormat({"NC1HWC0", "NC1HWC0", "NC1HWC0"}); + builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id()}); + builder.SetOutputsDeviceType({kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id()}); + builder.SetKernelType(KernelType::TBE_KERNEL); + builder.SetFusionType(kernel::FusionType::ELEMWISE); + builder.SetProcessor(kernel::Processor::AICORE); + builder.SetKernelType(KernelType::TBE_KERNEL); + + layer_norm->set_kernel_info(std::make_shared()); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), layer_norm.get()); + + auto tbe_json_creator_select = std::make_shared(); + auto tbe_json_creator_check = std::make_shared(); + auto tbe_json_creator_build = std::make_shared(); + nlohmann::json kernel_json; + EXPECT_TRUE(tbe_json_creator_select->GenJson(layer_norm, &kernel_json)); + EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 13056848426482724958U); + EXPECT_TRUE(tbe_json_creator_check->GenJson(layer_norm, &kernel_json)); + EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 3069436317069842619U); + EXPECT_TRUE(tbe_json_creator_build->GenJson(layer_norm, &kernel_json)); + EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 18320131482846743097U); +} + +TEST_F(TestHWTBEJsonCreator, test_tbe_fusion_common) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + context->set_param(MS_CTX_DEVICE_TARGET, kAscendDevice); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_json_creator", "func_relu_relu_cast"); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list{x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + + auto ret = kg->get_return(); + EXPECT_NE(ret->input(1), nullptr); + auto tuple = ret->input(1); + EXPECT_NE(tuple, nullptr); + auto cast = tuple->cast()->input(1); + EXPECT_NE(cast, nullptr); + auto relu2 = cast->cast()->input(1); + EXPECT_NE(relu2, nullptr); + auto relu1 = relu2->cast()->input(1); + + KernelBuildInfoBuilder builder; + builder.SetInputsFormat({"NC1HWC0"}); + builder.SetOutputsFormat({"NC1HWC0"}); + builder.SetInputsDeviceType({kFloat32->type_id()}); + builder.SetOutputsDeviceType({kFloat32->type_id()}); + builder.SetKernelType(KernelType::TBE_KERNEL); + builder.SetFusionType(kernel::FusionType::ELEMWISE); + builder.SetProcessor(kernel::Processor::AICORE); + builder.SetKernelType(KernelType::TBE_KERNEL); + + relu1->set_kernel_info(std::make_shared()); + relu2->set_kernel_info(std::make_shared()); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), relu1.get()); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), relu2.get()); + + KernelBuildInfoBuilder builder1; + builder1.SetInputsFormat({"NC1HWC0"}); + builder1.SetOutputsFormat({"NC1HWC0"}); + builder1.SetInputsDeviceType({kFloat32->type_id()}); + builder1.SetOutputsDeviceType({kFloat16->type_id()}); + builder1.SetKernelType(KernelType::TBE_KERNEL); + builder1.SetFusionType(kernel::FusionType::OPAQUE); + builder1.SetProcessor(kernel::Processor::AICORE); + builder1.SetKernelType(KernelType::TBE_KERNEL); + + cast->set_kernel_info(std::make_shared()); + AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast.get()); + + std::vector input_nodes; + std::vector compute_nodes = {relu1, relu2}; + std::string full_name = "FusionOp_" + AnfAlgo::GetCNodeName(relu1) + "_" + AnfAlgo::GetCNodeName(relu2); + for (auto &node : compute_nodes) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) { + auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0); + if (std::find(compute_nodes.begin(), compute_nodes.end(), real_input.first) == compute_nodes.end()) { + if (auto in = cnode->input(idx); std::find(input_nodes.begin(), input_nodes.end(), in) == input_nodes.end()) { + input_nodes.push_back(in); + } + } + } + } + + FusionScopeInfo fusion_scope_info(0, full_name, input_nodes, compute_nodes, {}); + nlohmann::json fusion_json; + auto tbe_json_creator = std::make_shared(); + EXPECT_TRUE(tbe_json_creator->GenJson(fusion_scope_info, &fusion_json)); + EXPECT_EQ(tbe_json_creator->GetJsonHash(), 4464178465553346953U); +} + +TEST_F(TestHWTBEJsonCreator, test_fusion_add_conv2d) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + context->set_param(MS_CTX_DEVICE_TARGET, kAscendDevice); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_json_creator", "fusion_add_conv2d"); + std::vector x_shp{10, 32, 32, 32}; + std::vector z_shp{32, 32, 3, 3}; + auto x_abstract = std::make_shared(kFloat32, x_shp); + auto z_abstract = std::make_shared(kFloat32, z_shp); + AbstractBasePtrList args_spec_list{x_abstract, x_abstract, z_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + + auto ret = kg->get_return(); + EXPECT_NE(ret->input(1), nullptr); + auto tuple = ret->input(1); + EXPECT_NE(tuple, nullptr); + auto conv2d = tuple->cast()->input(1); + EXPECT_NE(conv2d, nullptr); + auto add = conv2d->cast()->input(1); + EXPECT_NE(add, nullptr); + + KernelBuildInfoBuilder builder; + builder.SetInputsFormat({"NC1HWC0", "NC1HWC0"}); + builder.SetOutputsFormat({"NC1HWC0"}); + builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()}); + builder.SetOutputsDeviceType({kFloat32->type_id()}); + builder.SetKernelType(KernelType::TBE_KERNEL); + builder.SetFusionType(kernel::FusionType::ELEMWISE); + builder.SetProcessor(kernel::Processor::AICORE); + builder.SetKernelType(KernelType::TBE_KERNEL); + + add->set_kernel_info(std::make_shared()); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), add.get()); + conv2d->set_kernel_info(std::make_shared()); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), conv2d.get()); + + std::vector input_nodes; + std::vector compute_nodes = {add, conv2d}; + std::string full_name = "FusionOp_" + AnfAlgo::GetCNodeName(add) + "_" + AnfAlgo::GetCNodeName(conv2d); + for (auto &node : compute_nodes) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) { + auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0); + if (std::find(compute_nodes.begin(), compute_nodes.end(), real_input.first) == compute_nodes.end()) { + if (auto in = cnode->input(idx); std::find(input_nodes.begin(), input_nodes.end(), in) == input_nodes.end()) { + input_nodes.push_back(in); + } + } + } + } + + FusionScopeInfo fusion_scope_info(0, full_name, input_nodes, compute_nodes, {}); + nlohmann::json fusion_json; + auto tbe_json_creator = std::make_shared(); + EXPECT_TRUE(tbe_json_creator->GenJson(fusion_scope_info, &fusion_json)); + EXPECT_EQ(tbe_json_creator->GetJsonHash(), 6707165667078013944U); +} + +} // namespace mindspore::kernel