!25447 Check Device Shape

Merge pull request !25447 from hwjiaorui/check-support-json
This commit is contained in:
i-robot 2021-10-28 12:38:06 +00:00 committed by Gitee
commit 036bc4329f
6 changed files with 137 additions and 60 deletions

View File

@ -15,6 +15,8 @@
*/
#include "backend/kernel_compiler/tbe/tbe_json/single_tbe_json_creator.h"
#include <algorithm>
#include <string>
#include "frontend/parallel/ops_info/ops_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/tbe/tbe_adapter.h"
@ -375,50 +377,4 @@ bool SelectTbeJsonCreator::AttrsJsonPostProcessing(const AnfNodePtr &anf_node, c
tbe::TbeAdapter::LayerNormAttrJsonPost(anf_node, attrs_json);
return true;
}
void CheckTbeJsonCreator::GenDescJson(const AnfNodePtr &anf_node, size_t node_out_idx, size_t desc_output_idx,
nlohmann::json *output_desc) {
MS_EXCEPTION_IF_NULL(anf_node);
GenDesJsonCommon(output_desc);
std::vector<int64_t> shape;
std::vector<int64_t> ori_shape;
ori_shape = TbeJsonUtils::GetOutputOriShapeForTbeBuild(anf_node, node_out_idx);
if (ori_shape.empty()) {
ori_shape.emplace_back(1);
}
shape = ori_shape;
auto def_format = TbeJsonUtils::IsNeedChangeDefaultFormat(anf_node) ? kOpFormat_NCDHW : kOpFormat_NCHW;
auto format = def_format;
(*output_desc)[kJDataType] = tbe::TypeIdToString(AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx));
(*output_desc)[kJDtype] = GetJsonValue<std::string>(*output_desc, kJDataType);
(*output_desc)[kJFormat] = format;
(*output_desc)[kJOriFormat] = def_format;
(*output_desc)[kJOriShape] = ori_shape;
(*output_desc)[kJShape] = shape;
(*output_desc)[kJOutputIndex] = desc_output_idx;
}
void CheckTbeJsonCreator::GenInputDescJson(const AnfNodePtr &anf_node, size_t real_input_index,
nlohmann::json *input_desc) {
MS_EXCEPTION_IF_NULL(anf_node);
GenDesJsonCommon(input_desc);
auto ori_shape = TbeJsonUtils::GetInputOriShapeForTbeBuild(anf_node, real_input_index);
if (ori_shape.empty()) {
ori_shape.emplace_back(1);
}
auto shape = ori_shape;
auto def_format = TbeJsonUtils::IsNeedChangeDefaultFormat(anf_node) ? kOpFormat_NCDHW : kOpFormat_NCHW;
auto format = def_format;
(*input_desc)[kJDtype] = tbe::TypeIdToString(AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index));
(*input_desc)[kJDataType] = GetJsonValue<std::string>(*input_desc, kJDtype);
(*input_desc)[kJOriShape] = ori_shape;
(*input_desc)[kJOriFormat] = def_format;
(*input_desc)[kJShape] = shape;
(*input_desc)[kJFormat] = format;
(*input_desc)[kJValid] = true;
(*input_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index, format);
GenInputConstValue(anf_node, real_input_index, input_desc);
}
} // namespace mindspore::kernel

View File

@ -46,11 +46,6 @@ class CheckTbeJsonCreator : public SingleTbeJsonCreator {
public:
CheckTbeJsonCreator() = default;
~CheckTbeJsonCreator() override = default;
protected:
void GenDescJson(const AnfNodePtr &anf_node, size_t node_out_idx, size_t desc_output_idx,
nlohmann::json *output_desc) override;
void GenInputDescJson(const AnfNodePtr &anf_node, size_t real_input_index, nlohmann::json *input_desc) override;
};
class SelectTbeJsonCreator : public SingleTbeJsonCreator {
public:

View File

@ -0,0 +1,114 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h"
#include <set>
#include <string>
#include "base/base.h"
#include "common/trans.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace kernel {
bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) {
size_t real_input_num = AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < real_input_num; i++) {
session::KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
auto format = AnfAlgo::GetInputFormat(node, i);
if (!CheckValidOutputDeviceShape(kernel_with_index.first, kernel_with_index.second, format)) {
MS_LOG(WARNING) << "TBE Host check input device shape failed, node:" << node->fullname_with_scope()
<< ", input node: " << kernel_with_index.first->DebugString() << ", format:" << format;
return false;
}
}
size_t real_output_num = AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < real_output_num; i++) {
auto format = AnfAlgo::GetOutputFormat(node, i);
if (!CheckValidOutputDeviceShape(node, i, format)) {
MS_LOG(WARNING) << "TBE Host check output device shape failed, node:" << node->fullname_with_scope()
<< ", format:" << format;
return false;
}
}
return true;
}
std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, const size_t output_idx,
const std::string &format) {
auto output_shape = AnfAlgo::GetOutputDetailShape(node, output_idx);
std::vector<int64_t> infer_shape;
if (output_shape->isa<abstract::Shape>()) {
auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_ptr);
infer_shape = shape_ptr->shape();
}
if (infer_shape.empty()) {
return infer_shape;
}
if (trans::IsNeedPadding(format, infer_shape.size())) {
infer_shape = trans::PaddingShape(infer_shape, format, AnfAlgo::GetOutputReshapeType(node, output_idx));
}
auto temp_shape = infer_shape;
if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
infer_shape.size() < kShape4dDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
temp_shape = trans::PaddingShapeTo4dDefault(infer_shape);
}
if (infer_shape.size() != trans::kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) {
temp_shape = trans::PaddingShapeTo5dDefault(infer_shape);
}
return temp_shape;
}
bool HostCheck::CheckValidOutputDeviceShape(const AnfNodePtr &node, const size_t output_idx,
const std::string &format) {
auto infer_shape = GetFinalInferShape(node, output_idx, format);
if (infer_shape.empty()) {
return true;
}
std::set<std::string> check_4D_format = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_FRAC_Z,
kOpFormat_NC1HWC0, kOpFormat_C1HWNCoC0, kOpFormat_FRACTAL_Z_C04,
kOpFormat_NC1HWC0_C04};
std::set<std::string> check_5D_format = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
if (check_4D_format.find(format) != check_4D_format.end()) {
return infer_shape.size() == kShape4dDims;
}
if (check_5D_format.find(format) != check_5D_format.end()) {
return infer_shape.size() == kShape5dDims;
}
if (format == kOpFormat_FRAC_NZ) {
return infer_shape.size() >= kShape2dDims ||
(infer_shape.size() == 1 && (infer_shape[0] == 1 || (infer_shape[0] % SizeToLong(kCubeSize) == 0)));
}
if (format == kOpFormat_FRACTAL_ZN_RNN) {
return infer_shape.size() >= kShape2dDims;
}
if (format == kOpFormat_ND_RNN_BIAS) {
return infer_shape.size() > 0;
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,6 +17,8 @@
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_SELECT_COMMON_UTILS_H_
#include <string>
#include <vector>
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace kernel {
struct SupportFormat {
@ -24,6 +26,19 @@ struct SupportFormat {
std::vector<std::vector<std::string>> output_format;
};
using SupportFormatItem = std::vector<std::string>;
class HostCheck {
public:
HostCheck() = default;
~HostCheck() = default;
static bool CheckValidDeviceShape(const AnfNodePtr &node);
private:
static bool CheckValidOutputDeviceShape(const AnfNodePtr &node, const size_t output_idx, const std::string &format);
static std::vector<int64_t> GetFinalInferShape(const AnfNodePtr &node, const size_t output_idx,
const std::string &format);
};
} // namespace kernel
} // namespace mindspore

View File

@ -269,11 +269,8 @@ bool TbeKernelSelect::TbeCheckSupported(const KernelBuildInfoIter &kernel_build_
// replace kernel_info with current kernel info
auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_);
AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get());
bool ret = true;
auto &build_manager = kernel::ascend::TbeKernelCompileManager::GetInstance();
if (!build_manager.TbeOpCheckSupported(cnode_ptr_)) {
ret = false;
}
auto ret = HostCheck::CheckValidDeviceShape(cnode_ptr_) && build_manager.TbeOpCheckSupported(cnode_ptr_);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get());
return ret;
}

View File

@ -78,7 +78,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_common) {
EXPECT_TRUE(tbe_json_creator_select->GenJson(relu1, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 13146561810461380838U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(relu1, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 13146561810461380838U);
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 17413190217831512531U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(relu1, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 17413190217831512531U);
}
@ -121,7 +121,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_conv2d_backprop_filter) {
EXPECT_TRUE(tbe_json_creator_select->GenJson(conv2d_backprop_filter, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 751738472046426254U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(conv2d_backprop_filter, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 751738472046426254U);
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 8516089404045447470U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(conv2d_backprop_filter, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 8516089404045447470U);
}
@ -179,7 +179,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_dynamic_rnn) {
EXPECT_TRUE(tbe_json_creator_select->GenJson(dynamic_rnn, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 13288675099420394285U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(dynamic_rnn, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 13288675099420394285U);
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 17084598473306810717U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(dynamic_rnn, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 17084598473306810717U);
}
@ -233,7 +233,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_layer_norm) {
EXPECT_TRUE(tbe_json_creator_select->GenJson(layer_norm, &kernel_json));
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 6545088373747371515U);
EXPECT_TRUE(tbe_json_creator_check->GenJson(layer_norm, &kernel_json));
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 6866520754867840453U);
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 10583210293426000299U);
EXPECT_TRUE(tbe_json_creator_build->GenJson(layer_norm, &kernel_json));
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 10583210293426000299U);
}