!23742 fix tbe check bug

Merge pull request !23742 from hwjiaorui/tbe-check
This commit is contained in:
i-robot 2021-09-18 06:27:07 +00:00 committed by Gitee
commit 80afc4b41c
2 changed files with 12 additions and 20 deletions

View File

@ -354,11 +354,11 @@ void SelectTbeJsonCreator::GenInputDescJson(const AnfNodePtr &anf_node, size_t r
nlohmann::json *input_desc) {
MS_EXCEPTION_IF_NULL(anf_node);
GenDesJsonCommon(input_desc);
auto shape = TbeJsonUtils::GetInputOriShapeForTbeBuild(anf_node, real_input_index);
if (shape.empty()) {
shape.emplace_back(1);
auto ori_shape = TbeJsonUtils::GetInputOriShapeForTbeBuild(anf_node, real_input_index);
if (ori_shape.empty()) {
ori_shape.emplace_back(1);
}
auto ori_shape = shape;
auto shape = ori_shape;
auto def_format = TbeJsonUtils::IsNeedChangeDefaultFormat(anf_node) ? kOpFormat_NCDHW : kOpFormat_NCHW;
auto format = def_format;
@ -390,13 +390,9 @@ void CheckTbeJsonCreator::GenDescJson(const AnfNodePtr &anf_node, size_t node_ou
if (ori_shape.empty()) {
ori_shape.emplace_back(1);
}
shape = TbeJsonUtils::GetOutputDeviceShapeForTbeBuild(anf_node, node_out_idx);
if (shape.empty()) {
shape.emplace_back(1);
}
shape = ori_shape;
auto def_format = TbeJsonUtils::IsNeedChangeDefaultFormat(anf_node) ? kOpFormat_NCDHW : kOpFormat_NCHW;
auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx);
format = TbeAdapter::FormatPass(format, ori_shape.size());
auto format = def_format;
(*output_desc)[kJDataType] = tbe::TypeIdToString(AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx));
(*output_desc)[kJDtype] = GetJsonValue<std::string>(*output_desc, kJDataType);
@ -415,14 +411,10 @@ void CheckTbeJsonCreator::GenInputDescJson(const AnfNodePtr &anf_node, size_t re
if (ori_shape.empty()) {
ori_shape.emplace_back(1);
}
auto shape = TbeJsonUtils::GetInputDeviceShapeForTbeBuild(anf_node, real_input_index);
if (shape.empty()) {
shape.emplace_back(1);
}
auto shape = ori_shape;
auto def_format = TbeJsonUtils::IsNeedChangeDefaultFormat(anf_node) ? kOpFormat_NCDHW : kOpFormat_NCHW;
auto format = AnfAlgo::GetInputFormat(anf_node, real_input_index);
format = TbeAdapter::FormatPass(format, ori_shape.size());
auto format = def_format;
(*input_desc)[kJDtype] = tbe::TypeIdToString(AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index));
(*input_desc)[kJDataType] = GetJsonValue<std::string>(*input_desc, kJDtype);
(*input_desc)[kJOriShape] = ori_shape;

View File

@ -78,7 +78,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_common) {
EXPECT_TRUE(tbe_json_creator_select->GenJson(relu1, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 4297213426602035622U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(relu1, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 6011570351795510237U);
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 4297213426602035622U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(relu1, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 6011570351795510237U);
}
@ -121,7 +121,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_conv2d_backprop_filter) {
EXPECT_TRUE(tbe_json_creator_select->GenJson(conv2d_backprop_filter, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 3804649253898608226U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(conv2d_backprop_filter, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 4580870880229487185U);
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 3804649253898608226U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(conv2d_backprop_filter, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 4580870880229487185U);
}
@ -179,7 +179,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_dynamic_rnn) {
EXPECT_TRUE(tbe_json_creator_select->GenJson(dynamic_rnn, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 13058640182660031121U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(dynamic_rnn, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 4729701784171992376U);
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 13058640182660031121U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(dynamic_rnn, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 4729701784171992376U);
}
@ -233,7 +233,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_layer_norm) {
EXPECT_TRUE(tbe_json_creator_select->GenJson(layer_norm, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 1114128635775386802U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(layer_norm, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 9247341733773157591U);
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 2175017419293202360U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(layer_norm, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 9247341733773157591U);
}