diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/single_tbe_json_creator.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/single_tbe_json_creator.cc index 8455a8c5166..504d0e2c311 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/single_tbe_json_creator.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/single_tbe_json_creator.cc @@ -15,6 +15,8 @@ */ #include "backend/kernel_compiler/tbe/tbe_json/single_tbe_json_creator.h" +#include +#include #include "frontend/parallel/ops_info/ops_utils.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/tbe/tbe_adapter.h" @@ -375,50 +377,4 @@ bool SelectTbeJsonCreator::AttrsJsonPostProcessing(const AnfNodePtr &anf_node, c tbe::TbeAdapter::LayerNormAttrJsonPost(anf_node, attrs_json); return true; } - -void CheckTbeJsonCreator::GenDescJson(const AnfNodePtr &anf_node, size_t node_out_idx, size_t desc_output_idx, - nlohmann::json *output_desc) { - MS_EXCEPTION_IF_NULL(anf_node); - GenDesJsonCommon(output_desc); - std::vector shape; - std::vector ori_shape; - ori_shape = TbeJsonUtils::GetOutputOriShapeForTbeBuild(anf_node, node_out_idx); - if (ori_shape.empty()) { - ori_shape.emplace_back(1); - } - shape = ori_shape; - auto def_format = TbeJsonUtils::IsNeedChangeDefaultFormat(anf_node) ? kOpFormat_NCDHW : kOpFormat_NCHW; - auto format = def_format; - - (*output_desc)[kJDataType] = tbe::TypeIdToString(AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx)); - (*output_desc)[kJDtype] = GetJsonValue(*output_desc, kJDataType); - (*output_desc)[kJFormat] = format; - (*output_desc)[kJOriFormat] = def_format; - (*output_desc)[kJOriShape] = ori_shape; - (*output_desc)[kJShape] = shape; - (*output_desc)[kJOutputIndex] = desc_output_idx; -} - -void CheckTbeJsonCreator::GenInputDescJson(const AnfNodePtr &anf_node, size_t real_input_index, - nlohmann::json *input_desc) { - MS_EXCEPTION_IF_NULL(anf_node); - GenDesJsonCommon(input_desc); - auto ori_shape = TbeJsonUtils::GetInputOriShapeForTbeBuild(anf_node, real_input_index); - if (ori_shape.empty()) { - ori_shape.emplace_back(1); - } - auto shape = ori_shape; - - auto def_format = TbeJsonUtils::IsNeedChangeDefaultFormat(anf_node) ? kOpFormat_NCDHW : kOpFormat_NCHW; - auto format = def_format; - (*input_desc)[kJDtype] = tbe::TypeIdToString(AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index)); - (*input_desc)[kJDataType] = GetJsonValue(*input_desc, kJDtype); - (*input_desc)[kJOriShape] = ori_shape; - (*input_desc)[kJOriFormat] = def_format; - (*input_desc)[kJShape] = shape; - (*input_desc)[kJFormat] = format; - (*input_desc)[kJValid] = true; - (*input_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index, format); - GenInputConstValue(anf_node, real_input_index, input_desc); -} } // namespace mindspore::kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/single_tbe_json_creator.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/single_tbe_json_creator.h index 89190981a9f..aea3df6bd4d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/single_tbe_json_creator.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_json/single_tbe_json_creator.h @@ -46,11 +46,6 @@ class CheckTbeJsonCreator : public SingleTbeJsonCreator { public: CheckTbeJsonCreator() = default; ~CheckTbeJsonCreator() override = default; - - protected: - void GenDescJson(const AnfNodePtr &anf_node, size_t node_out_idx, size_t desc_output_idx, - nlohmann::json *output_desc) override; - void GenInputDescJson(const AnfNodePtr &anf_node, size_t real_input_index, nlohmann::json *input_desc) override; }; class SelectTbeJsonCreator : public SingleTbeJsonCreator { public: diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.cc new file mode 100644 index 00000000000..47b9f27c28f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" +#include +#include +#include "base/base.h" +#include "common/trans.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) { + size_t real_input_num = AnfAlgo::GetInputTensorNum(node); + for (size_t i = 0; i < real_input_num; i++) { + session::KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, i); + auto format = AnfAlgo::GetInputFormat(node, i); + if (!CheckValidOutputDeviceShape(kernel_with_index.first, kernel_with_index.second, format)) { + MS_LOG(WARNING) << "TBE Host check input device shape failed, node:" << node->fullname_with_scope() + << ", input node: " << kernel_with_index.first->DebugString() << ", format:" << format; + return false; + } + } + + size_t real_output_num = AnfAlgo::GetOutputTensorNum(node); + for (size_t i = 0; i < real_output_num; i++) { + auto format = AnfAlgo::GetOutputFormat(node, i); + if (!CheckValidOutputDeviceShape(node, i, format)) { + MS_LOG(WARNING) << "TBE Host check output device shape failed, node:" << node->fullname_with_scope() + << ", format:" << format; + return false; + } + } + return true; +} + +std::vector HostCheck::GetFinalInferShape(const AnfNodePtr &node, const size_t output_idx, + const std::string &format) { + auto output_shape = AnfAlgo::GetOutputDetailShape(node, output_idx); + std::vector infer_shape; + if (output_shape->isa()) { + auto shape_ptr = output_shape->cast(); + MS_EXCEPTION_IF_NULL(shape_ptr); + infer_shape = shape_ptr->shape(); + } + if (infer_shape.empty()) { + return infer_shape; + } + + if (trans::IsNeedPadding(format, infer_shape.size())) { + infer_shape = trans::PaddingShape(infer_shape, format, AnfAlgo::GetOutputReshapeType(node, output_idx)); + } + + auto temp_shape = infer_shape; + if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM && + infer_shape.size() < kShape4dDims && k3DFormatSet.find(format) == k3DFormatSet.end()) { + MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; + temp_shape = trans::PaddingShapeTo4dDefault(infer_shape); + } + if (infer_shape.size() != trans::kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) { + temp_shape = trans::PaddingShapeTo5dDefault(infer_shape); + } + return temp_shape; +} + +bool HostCheck::CheckValidOutputDeviceShape(const AnfNodePtr &node, const size_t output_idx, + const std::string &format) { + auto infer_shape = GetFinalInferShape(node, output_idx, format); + + if (infer_shape.empty()) { + return true; + } + + std::set check_4D_format = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_FRAC_Z, + kOpFormat_NC1HWC0, kOpFormat_C1HWNCoC0, kOpFormat_FRACTAL_Z_C04, + kOpFormat_NC1HWC0_C04}; + std::set check_5D_format = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; + if (check_4D_format.find(format) != check_4D_format.end()) { + return infer_shape.size() == kShape4dDims; + } + if (check_5D_format.find(format) != check_5D_format.end()) { + return infer_shape.size() == kShape5dDims; + } + + if (format == kOpFormat_FRAC_NZ) { + return infer_shape.size() >= kShape2dDims || + (infer_shape.size() == 1 && (infer_shape[0] == 1 || (infer_shape[0] % SizeToLong(kCubeSize) == 0))); + } + + if (format == kOpFormat_FRACTAL_ZN_RNN) { + return infer_shape.size() >= kShape2dDims; + } + + if (format == kOpFormat_ND_RNN_BIAS) { + return infer_shape.size() > 0; + } + return true; +} + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h index b0002a70333..e3ba23d0405 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_SELECT_COMMON_UTILS_H_ #include #include +#include "backend/session/anf_runtime_algorithm.h" + namespace mindspore { namespace kernel { struct SupportFormat { @@ -24,6 +26,19 @@ struct SupportFormat { std::vector> output_format; }; using SupportFormatItem = std::vector; + +class HostCheck { + public: + HostCheck() = default; + ~HostCheck() = default; + static bool CheckValidDeviceShape(const AnfNodePtr &node); + + private: + static bool CheckValidOutputDeviceShape(const AnfNodePtr &node, const size_t output_idx, const std::string &format); + static std::vector GetFinalInferShape(const AnfNodePtr &node, const size_t output_idx, + const std::string &format); +}; + } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc index 1e2b043c253..de4acba52cb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc @@ -269,11 +269,8 @@ bool TbeKernelSelect::TbeCheckSupported(const KernelBuildInfoIter &kernel_build_ // replace kernel_info with current kernel info auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_); AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get()); - bool ret = true; auto &build_manager = kernel::ascend::TbeKernelCompileManager::GetInstance(); - if (!build_manager.TbeOpCheckSupported(cnode_ptr_)) { - ret = false; - } + auto ret = HostCheck::CheckValidDeviceShape(cnode_ptr_) && build_manager.TbeOpCheckSupported(cnode_ptr_); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get()); return ret; } diff --git a/tests/ut/cpp/tbe/tbe_json_creator_test.cc b/tests/ut/cpp/tbe/tbe_json_creator_test.cc index 5c697aff80f..f821b8287c6 100644 --- a/tests/ut/cpp/tbe/tbe_json_creator_test.cc +++ b/tests/ut/cpp/tbe/tbe_json_creator_test.cc @@ -78,7 +78,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_common) { EXPECT_TRUE(tbe_json_creator_select->GenJson(relu1, &kernel_json)); EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 13146561810461380838U); EXPECT_TRUE(tbe_json_creator_check->GenJson(relu1, &kernel_json)); - EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 13146561810461380838U); + EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 17413190217831512531U); EXPECT_TRUE(tbe_json_creator_build->GenJson(relu1, &kernel_json)); EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 17413190217831512531U); } @@ -121,7 +121,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_conv2d_backprop_filter) { EXPECT_TRUE(tbe_json_creator_select->GenJson(conv2d_backprop_filter, &kernel_json)); EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 751738472046426254U); EXPECT_TRUE(tbe_json_creator_check->GenJson(conv2d_backprop_filter, &kernel_json)); - EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 751738472046426254U); + EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 8516089404045447470U); EXPECT_TRUE(tbe_json_creator_build->GenJson(conv2d_backprop_filter, &kernel_json)); EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 8516089404045447470U); } @@ -179,7 +179,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_dynamic_rnn) { EXPECT_TRUE(tbe_json_creator_select->GenJson(dynamic_rnn, &kernel_json)); EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 13288675099420394285U); EXPECT_TRUE(tbe_json_creator_check->GenJson(dynamic_rnn, &kernel_json)); - EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 13288675099420394285U); + EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 17084598473306810717U); EXPECT_TRUE(tbe_json_creator_build->GenJson(dynamic_rnn, &kernel_json)); EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 17084598473306810717U); } @@ -233,7 +233,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_layer_norm) { EXPECT_TRUE(tbe_json_creator_select->GenJson(layer_norm, &kernel_json)); EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 6545088373747371515U); EXPECT_TRUE(tbe_json_creator_check->GenJson(layer_norm, &kernel_json)); - EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 6866520754867840453U); + EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 10583210293426000299U); EXPECT_TRUE(tbe_json_creator_build->GenJson(layer_norm, &kernel_json)); EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 10583210293426000299U); }