forked from mindspore-Ecosystem/mindspore
!23742 fix tbe check bug
Merge pull request !23742 from hwjiaorui/tbe-check
This commit is contained in:
commit
80afc4b41c
|
@ -354,11 +354,11 @@ void SelectTbeJsonCreator::GenInputDescJson(const AnfNodePtr &anf_node, size_t r
|
|||
nlohmann::json *input_desc) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
GenDesJsonCommon(input_desc);
|
||||
auto shape = TbeJsonUtils::GetInputOriShapeForTbeBuild(anf_node, real_input_index);
|
||||
if (shape.empty()) {
|
||||
shape.emplace_back(1);
|
||||
auto ori_shape = TbeJsonUtils::GetInputOriShapeForTbeBuild(anf_node, real_input_index);
|
||||
if (ori_shape.empty()) {
|
||||
ori_shape.emplace_back(1);
|
||||
}
|
||||
auto ori_shape = shape;
|
||||
auto shape = ori_shape;
|
||||
|
||||
auto def_format = TbeJsonUtils::IsNeedChangeDefaultFormat(anf_node) ? kOpFormat_NCDHW : kOpFormat_NCHW;
|
||||
auto format = def_format;
|
||||
|
@ -390,13 +390,9 @@ void CheckTbeJsonCreator::GenDescJson(const AnfNodePtr &anf_node, size_t node_ou
|
|||
if (ori_shape.empty()) {
|
||||
ori_shape.emplace_back(1);
|
||||
}
|
||||
shape = TbeJsonUtils::GetOutputDeviceShapeForTbeBuild(anf_node, node_out_idx);
|
||||
if (shape.empty()) {
|
||||
shape.emplace_back(1);
|
||||
}
|
||||
shape = ori_shape;
|
||||
auto def_format = TbeJsonUtils::IsNeedChangeDefaultFormat(anf_node) ? kOpFormat_NCDHW : kOpFormat_NCHW;
|
||||
auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx);
|
||||
format = TbeAdapter::FormatPass(format, ori_shape.size());
|
||||
auto format = def_format;
|
||||
|
||||
(*output_desc)[kJDataType] = tbe::TypeIdToString(AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx));
|
||||
(*output_desc)[kJDtype] = GetJsonValue<std::string>(*output_desc, kJDataType);
|
||||
|
@ -415,14 +411,10 @@ void CheckTbeJsonCreator::GenInputDescJson(const AnfNodePtr &anf_node, size_t re
|
|||
if (ori_shape.empty()) {
|
||||
ori_shape.emplace_back(1);
|
||||
}
|
||||
auto shape = TbeJsonUtils::GetInputDeviceShapeForTbeBuild(anf_node, real_input_index);
|
||||
if (shape.empty()) {
|
||||
shape.emplace_back(1);
|
||||
}
|
||||
auto shape = ori_shape;
|
||||
|
||||
auto def_format = TbeJsonUtils::IsNeedChangeDefaultFormat(anf_node) ? kOpFormat_NCDHW : kOpFormat_NCHW;
|
||||
auto format = AnfAlgo::GetInputFormat(anf_node, real_input_index);
|
||||
format = TbeAdapter::FormatPass(format, ori_shape.size());
|
||||
auto format = def_format;
|
||||
(*input_desc)[kJDtype] = tbe::TypeIdToString(AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index));
|
||||
(*input_desc)[kJDataType] = GetJsonValue<std::string>(*input_desc, kJDtype);
|
||||
(*input_desc)[kJOriShape] = ori_shape;
|
||||
|
|
|
@ -78,7 +78,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_common) {
|
|||
EXPECT_TRUE(tbe_json_creator_select->GenJson(relu1, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 4297213426602035622U);
|
||||
EXPECT_TRUE(tbe_json_creator_check->GenJson(relu1, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 6011570351795510237U);
|
||||
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 4297213426602035622U);
|
||||
EXPECT_TRUE(tbe_json_creator_build->GenJson(relu1, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 6011570351795510237U);
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_conv2d_backprop_filter) {
|
|||
EXPECT_TRUE(tbe_json_creator_select->GenJson(conv2d_backprop_filter, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 3804649253898608226U);
|
||||
EXPECT_TRUE(tbe_json_creator_check->GenJson(conv2d_backprop_filter, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 4580870880229487185U);
|
||||
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 3804649253898608226U);
|
||||
EXPECT_TRUE(tbe_json_creator_build->GenJson(conv2d_backprop_filter, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 4580870880229487185U);
|
||||
}
|
||||
|
@ -179,7 +179,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_dynamic_rnn) {
|
|||
EXPECT_TRUE(tbe_json_creator_select->GenJson(dynamic_rnn, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 13058640182660031121U);
|
||||
EXPECT_TRUE(tbe_json_creator_check->GenJson(dynamic_rnn, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 4729701784171992376U);
|
||||
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 13058640182660031121U);
|
||||
EXPECT_TRUE(tbe_json_creator_build->GenJson(dynamic_rnn, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 4729701784171992376U);
|
||||
}
|
||||
|
@ -233,7 +233,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_layer_norm) {
|
|||
EXPECT_TRUE(tbe_json_creator_select->GenJson(layer_norm, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 1114128635775386802U);
|
||||
EXPECT_TRUE(tbe_json_creator_check->GenJson(layer_norm, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 9247341733773157591U);
|
||||
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 2175017419293202360U);
|
||||
EXPECT_TRUE(tbe_json_creator_build->GenJson(layer_norm, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 9247341733773157591U);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue