tbe json creator UT

This commit is contained in:
hwjiaorui 2021-08-02 17:25:00 +08:00
parent 076aa59dab
commit e498c96a20
7 changed files with 454 additions and 3 deletions

View File

@ -52,6 +52,7 @@
"mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/infer/primitive_test.py" "super-init-not-called"
"mindspore/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parse_primitive.py" "super-init-not-called"
"mindspore/tests/ut/cpp/python_input/gtest_input/pre_activate" "unused-variable"
"mindspore/tests/ut/cpp/python_input/gtest_input/tbe" "unused-variable"
"mindspore/tests/ut/python/train/summary/test_summary_abnormal_input.py" "bare-except"
"mindspore/tests/ut/python/train/summary/test_graph_summary.py" "protected-access"
"mindspore/tests/ut/python/parameter_feature/test_parameter.py" "unused-variable"

View File

@ -192,7 +192,7 @@ bool TbeJsonCreator::GenComputeJson(const AnfNodePtr &anf_node, nlohmann::json *
void TbeJsonCreator::GenFusionOpName(nlohmann::json *kernel_json, std::string prefix) {
json_name_.clear();
size_t hash_id = GenJsonHash((*kernel_json));
json_hash_ = GenJsonHash((*kernel_json));
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
json_name_ = std::move(prefix);
@ -203,7 +203,7 @@ void TbeJsonCreator::GenFusionOpName(nlohmann::json *kernel_json, std::string pr
json_name_.append("_");
}
}
json_name_ = json_name_ + std::to_string(hash_id) + "_" + std::to_string(device_id);
json_name_ = json_name_ + std::to_string(json_hash_) + "_" + std::to_string(device_id);
MS_LOG(DEBUG) << "Generate Json name: " << json_name_;
(*kernel_json)[kJFusionOpName] = json_name_;
}
@ -231,7 +231,7 @@ size_t TbeJsonCreator::GenJsonHash(nlohmann::json tbe_json) {
DeleteDescName(&op.at(kJInputDesc));
}
}
return std::hash<std::string>()(tbe_json.dump());
return std::hash<std::string>()(op_lists.dump());
}
void TbeJsonCreator::AddOpNameForComputeNode(nlohmann::json *kernel_json) {

View File

@ -48,6 +48,7 @@ class TbeJsonCreator {
virtual bool GenJson(const AnfNodePtr &anf_node, nlohmann::json *kernel_json) { return false; }
virtual bool GenJson(const FusionScopeInfo &fusion_scope_info, nlohmann::json *fusion_json) { return false; }
std::string GetJsonName() { return json_name_; }
size_t GetJsonHash() { return json_hash_; }
protected:
bool GenComputeJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json);
@ -72,6 +73,7 @@ class TbeJsonCreator {
private:
std::string json_name_;
size_t json_hash_;
};
} // namespace mindspore::kernel

View File

@ -68,6 +68,7 @@ if(ENABLE_MINDDATA)
./ps/*.cc
./fl/*.cc
./cxx_api/*.cc
./tbe/*.cc
)
if(NOT ENABLE_PYTHON)

View File

@ -0,0 +1,74 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
Relu = P.ReLU()
Cast = P.Cast()
Add = P.Add()
Conv2DBackpropFilter = G.Conv2DBackpropFilter(out_channel=4,
kernel_size=1,
pad_mode="valid",
pad=0,
mode=1,
stride=1,
dilation=1,
group=1)
DynamicRNN = P.DynamicRNN(forget_bias=0.0)
LayerNorm = P.LayerNorm()
Conv2D = P.Conv2D(out_channel=32, kernel_size=3)
class FnDict:
def __init__(self):
self.fnDict = {}
def __call__(self, fn):
self.fnDict[fn.__name__] = fn
def __getitem__(self, name):
return self.fnDict[name]
def test_tbe_json_creator(tag):
fns = FnDict()
@fns
def func_relu_relu_cast(x):
relu1 = Relu(x)
relu2 = Relu(relu1)
res = Cast(relu2, mstype.float16)
return res
@fns
def func_conv2d_backprop_filter(x, out, shape):
return Conv2DBackpropFilter(x, out, shape)
@fns
def func_dynamic_rnn(x, w, b, seq_length, init_h, init_c):
return DynamicRNN(x, w, b, seq_length, init_h, init_c)
@fns
def func_layer_norm(input_x, gamma, beta):
return LayerNorm(input_x, gamma, beta)
@fns
def fusion_add_conv2d(x, y, z):
add = Add(x, y)
return Conv2D(add, z)
return fns[tag]

View File

@ -32,6 +32,8 @@ ${PROJECT_PATH}/graphengine/third_party/prebuild/aarch64:${LD_LIBRARY_PATH}
export PYTHONPATH=${PROJECT_PATH}/tests/ut/cpp/python_input:$PYTHONPATH:${PROJECT_PATH}
export GLOG_v=2
export GC_COLLECT_IN_CELL=1
## set op info config path
export MINDSPORE_OP_INFO_PATH=${PROJECT_PATH}/config/op_info.config
## prepare data for dataset & mindrecord
cp -fr $PROJECT_PATH/tests/ut/data ${PROJECT_PATH}/build/mindspore/tests/ut/cpp/

View File

@ -0,0 +1,371 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/ms_context.h"
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "debug/anf_ir_dump.h"
#include "backend/kernel_compiler/kernel.h"
#include "runtime/device/kernel_info.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h"
#include "backend/kernel_compiler/tbe/tbe_json/single_tbe_json_creator.h"
#include "backend/kernel_compiler/tbe/tbe_json/fusion_tbe_json_creator.h"
namespace mindspore::kernel {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
constexpr int64_t kShape4D = 4;
class TestHWTBEJsonCreator : public BackendCommon {
public:
TestHWTBEJsonCreator() : get_py_fun_("gtest_input.tbe.tbe_json_creator_test", true) {}
~TestHWTBEJsonCreator() override = default;
UT::PyFuncGraphFetcher get_py_fun_;
};
TEST_F(TestHWTBEJsonCreator, test_tbe_single_common) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_json_creator", "func_relu_relu_cast");
std::vector<int64_t> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list{x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
auto ret = kg->get_return();
EXPECT_NE(ret->input(1), nullptr);
auto tuple = ret->input(1);
EXPECT_NE(tuple, nullptr);
auto cast = tuple->cast<CNodePtr>()->input(1);
EXPECT_NE(cast, nullptr);
auto relu2 = cast->cast<CNodePtr>()->input(1);
EXPECT_NE(relu2, nullptr);
auto relu1 = relu2->cast<CNodePtr>()->input(1);
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({"NC1HWC0"});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);
relu1->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), relu1.get());
auto tbe_json_creator_select = std::make_shared<SelectTbeJsonCreator>();
auto tbe_json_creator_check = std::make_shared<CheckTbeJsonCreator>();
auto tbe_json_creator_build = std::make_shared<BuildTbeJsonCreator>();
nlohmann::json kernel_json;
EXPECT_TRUE(tbe_json_creator_select->GenJson(relu1, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 7996236493612266030U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(relu1, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 16463039402039306442U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(relu1, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 8871925407866693227U);
}
TEST_F(TestHWTBEJsonCreator, test_tbe_single_conv2d_backprop_filter) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_json_creator", "func_conv2d_backprop_filter");
std::vector<int64_t> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList abstract_list = {std::make_shared<abstract::AbstractScalar>(kShape4D)};
auto y_abstract = std::make_shared<abstract::AbstractTuple>(abstract_list);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, y_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
auto ret = kg->get_return();
EXPECT_NE(ret->input(1), nullptr);
auto tuple = ret->input(1);
EXPECT_NE(tuple, nullptr);
auto conv2d_backprop_filter = tuple->cast<CNodePtr>()->input(1);
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({"NC1HWC0", "NC1HWC0"});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);
conv2d_backprop_filter->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), conv2d_backprop_filter.get());
auto tbe_json_creator_select = std::make_shared<SelectTbeJsonCreator>();
auto tbe_json_creator_check = std::make_shared<CheckTbeJsonCreator>();
auto tbe_json_creator_build = std::make_shared<BuildTbeJsonCreator>();
nlohmann::json kernel_json;
EXPECT_TRUE(tbe_json_creator_select->GenJson(conv2d_backprop_filter, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 16997569423579290131U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(conv2d_backprop_filter, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 1051090390656699050U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(conv2d_backprop_filter, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 3388833908101709327U);
}
TEST_F(TestHWTBEJsonCreator, test_tbe_single_dynamic_rnn) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_json_creator", "func_dynamic_rnn");
std::vector<int64_t> x_shp{2, 16, 64};
std::vector<int64_t> w_shp{96, 128};
std::vector<int64_t> b_shp{128};
std::vector<int64_t> init_h_shp{1, 16, 32};
std::vector<int64_t> init_c_shp{1, 16, 32};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, x_shp);
auto w_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, w_shp);
auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, b_shp);
auto init_h_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, init_h_shp);
auto init_c_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, init_c_shp);
auto seq_length_abstract = std::make_shared<abstract::AbstractNone>();
AbstractBasePtrList args_spec_list{x_abstract, w_abstract, b_abstract,
seq_length_abstract, init_h_abstract, init_c_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
auto ret = kg->get_return();
EXPECT_NE(ret->input(1), nullptr);
auto tuple = ret->input(1);
EXPECT_NE(tuple, nullptr);
auto make_tuple = tuple->cast<CNodePtr>()->input(1);
EXPECT_NE(tuple, nullptr);
auto tuple2 = make_tuple->cast<CNodePtr>()->input(1);
EXPECT_NE(tuple2, nullptr);
auto dynamic_rnn = tuple2->cast<CNodePtr>()->input(1);
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({"NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0"});
builder.SetOutputsFormat({"NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0"});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id(),
kFloat16->type_id(), kFloat16->type_id()});
builder.SetOutputsDeviceType({kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id(),
kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);
dynamic_rnn->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), dynamic_rnn.get());
auto tbe_json_creator_select = std::make_shared<SelectTbeJsonCreator>();
auto tbe_json_creator_check = std::make_shared<CheckTbeJsonCreator>();
auto tbe_json_creator_build = std::make_shared<BuildTbeJsonCreator>();
nlohmann::json kernel_json;
EXPECT_TRUE(tbe_json_creator_select->GenJson(dynamic_rnn, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 599064190761596566U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(dynamic_rnn, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 13855166034392728379U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(dynamic_rnn, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 12346685554589275353U);
}
TEST_F(TestHWTBEJsonCreator, test_tbe_single_layer_norm) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_json_creator", "func_layer_norm");
std::vector<int64_t> x_shp{2, 3};
std::vector<int64_t> gamma_shp{3};
std::vector<int64_t> beta_shp{3};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, x_shp);
auto gamma_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, gamma_shp);
auto beta_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, beta_shp);
AbstractBasePtrList args_spec_list{
x_abstract,
gamma_abstract,
beta_abstract,
};
auto kg = GetKernelGraph(g, args_spec_list);
auto ret = kg->get_return();
EXPECT_NE(ret->input(1), nullptr);
auto tuple = ret->input(1);
EXPECT_NE(tuple, nullptr);
auto make_tuple = tuple->cast<CNodePtr>()->input(1);
EXPECT_NE(tuple, nullptr);
auto tuple2 = make_tuple->cast<CNodePtr>()->input(1);
EXPECT_NE(tuple2, nullptr);
auto layer_norm = tuple2->cast<CNodePtr>()->input(1);
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({"NC1HWC0", "NC1HWC0", "NC1HWC0"});
builder.SetOutputsFormat({"NC1HWC0", "NC1HWC0", "NC1HWC0"});
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);
layer_norm->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), layer_norm.get());
auto tbe_json_creator_select = std::make_shared<SelectTbeJsonCreator>();
auto tbe_json_creator_check = std::make_shared<CheckTbeJsonCreator>();
auto tbe_json_creator_build = std::make_shared<BuildTbeJsonCreator>();
nlohmann::json kernel_json;
EXPECT_TRUE(tbe_json_creator_select->GenJson(layer_norm, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 13056848426482724958U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(layer_norm, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 3069436317069842619U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(layer_norm, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 18320131482846743097U);
}
TEST_F(TestHWTBEJsonCreator, test_tbe_fusion_common) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_json_creator", "func_relu_relu_cast");
std::vector<int64_t> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list{x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
auto ret = kg->get_return();
EXPECT_NE(ret->input(1), nullptr);
auto tuple = ret->input(1);
EXPECT_NE(tuple, nullptr);
auto cast = tuple->cast<CNodePtr>()->input(1);
EXPECT_NE(cast, nullptr);
auto relu2 = cast->cast<CNodePtr>()->input(1);
EXPECT_NE(relu2, nullptr);
auto relu1 = relu2->cast<CNodePtr>()->input(1);
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({"NC1HWC0"});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);
relu1->set_kernel_info(std::make_shared<device::KernelInfo>());
relu2->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), relu1.get());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), relu2.get());
KernelBuildInfoBuilder builder1;
builder1.SetInputsFormat({"NC1HWC0"});
builder1.SetOutputsFormat({"NC1HWC0"});
builder1.SetInputsDeviceType({kFloat32->type_id()});
builder1.SetOutputsDeviceType({kFloat16->type_id()});
builder1.SetKernelType(KernelType::TBE_KERNEL);
builder1.SetFusionType(kernel::FusionType::OPAQUE);
builder1.SetProcessor(kernel::Processor::AICORE);
builder1.SetKernelType(KernelType::TBE_KERNEL);
cast->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast.get());
std::vector<AnfNodePtr> input_nodes;
std::vector<AnfNodePtr> compute_nodes = {relu1, relu2};
std::string full_name = "FusionOp_" + AnfAlgo::GetCNodeName(relu1) + "_" + AnfAlgo::GetCNodeName(relu2);
for (auto &node : compute_nodes) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) {
auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0);
if (std::find(compute_nodes.begin(), compute_nodes.end(), real_input.first) == compute_nodes.end()) {
if (auto in = cnode->input(idx); std::find(input_nodes.begin(), input_nodes.end(), in) == input_nodes.end()) {
input_nodes.push_back(in);
}
}
}
}
FusionScopeInfo fusion_scope_info(0, full_name, input_nodes, compute_nodes, {});
nlohmann::json fusion_json;
auto tbe_json_creator = std::make_shared<FusionBuildTbeJsonCreator>();
EXPECT_TRUE(tbe_json_creator->GenJson(fusion_scope_info, &fusion_json));
EXPECT_EQ(tbe_json_creator->GetJsonHash(), 4464178465553346953U);
}
TEST_F(TestHWTBEJsonCreator, test_fusion_add_conv2d) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_json_creator", "fusion_add_conv2d");
std::vector<int64_t> x_shp{10, 32, 32, 32};
std::vector<int64_t> z_shp{32, 32, 3, 3};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, x_shp);
auto z_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, z_shp);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, z_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
auto ret = kg->get_return();
EXPECT_NE(ret->input(1), nullptr);
auto tuple = ret->input(1);
EXPECT_NE(tuple, nullptr);
auto conv2d = tuple->cast<CNodePtr>()->input(1);
EXPECT_NE(conv2d, nullptr);
auto add = conv2d->cast<CNodePtr>()->input(1);
EXPECT_NE(add, nullptr);
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({"NC1HWC0", "NC1HWC0"});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);
add->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), add.get());
conv2d->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), conv2d.get());
std::vector<AnfNodePtr> input_nodes;
std::vector<AnfNodePtr> compute_nodes = {add, conv2d};
std::string full_name = "FusionOp_" + AnfAlgo::GetCNodeName(add) + "_" + AnfAlgo::GetCNodeName(conv2d);
for (auto &node : compute_nodes) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) {
auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0);
if (std::find(compute_nodes.begin(), compute_nodes.end(), real_input.first) == compute_nodes.end()) {
if (auto in = cnode->input(idx); std::find(input_nodes.begin(), input_nodes.end(), in) == input_nodes.end()) {
input_nodes.push_back(in);
}
}
}
}
FusionScopeInfo fusion_scope_info(0, full_name, input_nodes, compute_nodes, {});
nlohmann::json fusion_json;
auto tbe_json_creator = std::make_shared<FusionBuildTbeJsonCreator>();
EXPECT_TRUE(tbe_json_creator->GenJson(fusion_scope_info, &fusion_json));
EXPECT_EQ(tbe_json_creator->GetJsonHash(), 6707165667078013944U);
}
} // namespace mindspore::kernel