forked from mindspore-Ecosystem/mindspore
!25447 Check Device Shape
Merge pull request !25447 from hwjiaorui/check-support-json
This commit is contained in:
commit
036bc4329f
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "backend/kernel_compiler/tbe/tbe_json/single_tbe_json_creator.h"
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include "frontend/parallel/ops_info/ops_utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/tbe/tbe_adapter.h"
|
||||
|
@ -375,50 +377,4 @@ bool SelectTbeJsonCreator::AttrsJsonPostProcessing(const AnfNodePtr &anf_node, c
|
|||
tbe::TbeAdapter::LayerNormAttrJsonPost(anf_node, attrs_json);
|
||||
return true;
|
||||
}
|
||||
|
||||
void CheckTbeJsonCreator::GenDescJson(const AnfNodePtr &anf_node, size_t node_out_idx, size_t desc_output_idx,
|
||||
nlohmann::json *output_desc) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
GenDesJsonCommon(output_desc);
|
||||
std::vector<int64_t> shape;
|
||||
std::vector<int64_t> ori_shape;
|
||||
ori_shape = TbeJsonUtils::GetOutputOriShapeForTbeBuild(anf_node, node_out_idx);
|
||||
if (ori_shape.empty()) {
|
||||
ori_shape.emplace_back(1);
|
||||
}
|
||||
shape = ori_shape;
|
||||
auto def_format = TbeJsonUtils::IsNeedChangeDefaultFormat(anf_node) ? kOpFormat_NCDHW : kOpFormat_NCHW;
|
||||
auto format = def_format;
|
||||
|
||||
(*output_desc)[kJDataType] = tbe::TypeIdToString(AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx));
|
||||
(*output_desc)[kJDtype] = GetJsonValue<std::string>(*output_desc, kJDataType);
|
||||
(*output_desc)[kJFormat] = format;
|
||||
(*output_desc)[kJOriFormat] = def_format;
|
||||
(*output_desc)[kJOriShape] = ori_shape;
|
||||
(*output_desc)[kJShape] = shape;
|
||||
(*output_desc)[kJOutputIndex] = desc_output_idx;
|
||||
}
|
||||
|
||||
void CheckTbeJsonCreator::GenInputDescJson(const AnfNodePtr &anf_node, size_t real_input_index,
|
||||
nlohmann::json *input_desc) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
GenDesJsonCommon(input_desc);
|
||||
auto ori_shape = TbeJsonUtils::GetInputOriShapeForTbeBuild(anf_node, real_input_index);
|
||||
if (ori_shape.empty()) {
|
||||
ori_shape.emplace_back(1);
|
||||
}
|
||||
auto shape = ori_shape;
|
||||
|
||||
auto def_format = TbeJsonUtils::IsNeedChangeDefaultFormat(anf_node) ? kOpFormat_NCDHW : kOpFormat_NCHW;
|
||||
auto format = def_format;
|
||||
(*input_desc)[kJDtype] = tbe::TypeIdToString(AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index));
|
||||
(*input_desc)[kJDataType] = GetJsonValue<std::string>(*input_desc, kJDtype);
|
||||
(*input_desc)[kJOriShape] = ori_shape;
|
||||
(*input_desc)[kJOriFormat] = def_format;
|
||||
(*input_desc)[kJShape] = shape;
|
||||
(*input_desc)[kJFormat] = format;
|
||||
(*input_desc)[kJValid] = true;
|
||||
(*input_desc)[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index, format);
|
||||
GenInputConstValue(anf_node, real_input_index, input_desc);
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -46,11 +46,6 @@ class CheckTbeJsonCreator : public SingleTbeJsonCreator {
|
|||
public:
|
||||
CheckTbeJsonCreator() = default;
|
||||
~CheckTbeJsonCreator() override = default;
|
||||
|
||||
protected:
|
||||
void GenDescJson(const AnfNodePtr &anf_node, size_t node_out_idx, size_t desc_output_idx,
|
||||
nlohmann::json *output_desc) override;
|
||||
void GenInputDescJson(const AnfNodePtr &anf_node, size_t real_input_index, nlohmann::json *input_desc) override;
|
||||
};
|
||||
class SelectTbeJsonCreator : public SingleTbeJsonCreator {
|
||||
public:
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h"
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "base/base.h"
|
||||
#include "common/trans.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) {
|
||||
size_t real_input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t i = 0; i < real_input_num; i++) {
|
||||
session::KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
|
||||
auto format = AnfAlgo::GetInputFormat(node, i);
|
||||
if (!CheckValidOutputDeviceShape(kernel_with_index.first, kernel_with_index.second, format)) {
|
||||
MS_LOG(WARNING) << "TBE Host check input device shape failed, node:" << node->fullname_with_scope()
|
||||
<< ", input node: " << kernel_with_index.first->DebugString() << ", format:" << format;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
size_t real_output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t i = 0; i < real_output_num; i++) {
|
||||
auto format = AnfAlgo::GetOutputFormat(node, i);
|
||||
if (!CheckValidOutputDeviceShape(node, i, format)) {
|
||||
MS_LOG(WARNING) << "TBE Host check output device shape failed, node:" << node->fullname_with_scope()
|
||||
<< ", format:" << format;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, const size_t output_idx,
|
||||
const std::string &format) {
|
||||
auto output_shape = AnfAlgo::GetOutputDetailShape(node, output_idx);
|
||||
std::vector<int64_t> infer_shape;
|
||||
if (output_shape->isa<abstract::Shape>()) {
|
||||
auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
infer_shape = shape_ptr->shape();
|
||||
}
|
||||
if (infer_shape.empty()) {
|
||||
return infer_shape;
|
||||
}
|
||||
|
||||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
||||
infer_shape = trans::PaddingShape(infer_shape, format, AnfAlgo::GetOutputReshapeType(node, output_idx));
|
||||
}
|
||||
|
||||
auto temp_shape = infer_shape;
|
||||
if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
|
||||
infer_shape.size() < kShape4dDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
|
||||
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
|
||||
temp_shape = trans::PaddingShapeTo4dDefault(infer_shape);
|
||||
}
|
||||
if (infer_shape.size() != trans::kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) {
|
||||
temp_shape = trans::PaddingShapeTo5dDefault(infer_shape);
|
||||
}
|
||||
return temp_shape;
|
||||
}
|
||||
|
||||
bool HostCheck::CheckValidOutputDeviceShape(const AnfNodePtr &node, const size_t output_idx,
|
||||
const std::string &format) {
|
||||
auto infer_shape = GetFinalInferShape(node, output_idx, format);
|
||||
|
||||
if (infer_shape.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::set<std::string> check_4D_format = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_FRAC_Z,
|
||||
kOpFormat_NC1HWC0, kOpFormat_C1HWNCoC0, kOpFormat_FRACTAL_Z_C04,
|
||||
kOpFormat_NC1HWC0_C04};
|
||||
std::set<std::string> check_5D_format = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
|
||||
if (check_4D_format.find(format) != check_4D_format.end()) {
|
||||
return infer_shape.size() == kShape4dDims;
|
||||
}
|
||||
if (check_5D_format.find(format) != check_5D_format.end()) {
|
||||
return infer_shape.size() == kShape5dDims;
|
||||
}
|
||||
|
||||
if (format == kOpFormat_FRAC_NZ) {
|
||||
return infer_shape.size() >= kShape2dDims ||
|
||||
(infer_shape.size() == 1 && (infer_shape[0] == 1 || (infer_shape[0] % SizeToLong(kCubeSize) == 0)));
|
||||
}
|
||||
|
||||
if (format == kOpFormat_FRACTAL_ZN_RNN) {
|
||||
return infer_shape.size() >= kShape2dDims;
|
||||
}
|
||||
|
||||
if (format == kOpFormat_ND_RNN_BIAS) {
|
||||
return infer_shape.size() > 0;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,6 +17,8 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_SELECT_COMMON_UTILS_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
struct SupportFormat {
|
||||
|
@ -24,6 +26,19 @@ struct SupportFormat {
|
|||
std::vector<std::vector<std::string>> output_format;
|
||||
};
|
||||
using SupportFormatItem = std::vector<std::string>;
|
||||
|
||||
class HostCheck {
|
||||
public:
|
||||
HostCheck() = default;
|
||||
~HostCheck() = default;
|
||||
static bool CheckValidDeviceShape(const AnfNodePtr &node);
|
||||
|
||||
private:
|
||||
static bool CheckValidOutputDeviceShape(const AnfNodePtr &node, const size_t output_idx, const std::string &format);
|
||||
static std::vector<int64_t> GetFinalInferShape(const AnfNodePtr &node, const size_t output_idx,
|
||||
const std::string &format);
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -269,11 +269,8 @@ bool TbeKernelSelect::TbeCheckSupported(const KernelBuildInfoIter &kernel_build_
|
|||
// replace kernel_info with current kernel info
|
||||
auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get());
|
||||
bool ret = true;
|
||||
auto &build_manager = kernel::ascend::TbeKernelCompileManager::GetInstance();
|
||||
if (!build_manager.TbeOpCheckSupported(cnode_ptr_)) {
|
||||
ret = false;
|
||||
}
|
||||
auto ret = HostCheck::CheckValidDeviceShape(cnode_ptr_) && build_manager.TbeOpCheckSupported(cnode_ptr_);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get());
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -78,7 +78,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_common) {
|
|||
EXPECT_TRUE(tbe_json_creator_select->GenJson(relu1, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 13146561810461380838U);
|
||||
EXPECT_TRUE(tbe_json_creator_check->GenJson(relu1, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 13146561810461380838U);
|
||||
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 17413190217831512531U);
|
||||
EXPECT_TRUE(tbe_json_creator_build->GenJson(relu1, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 17413190217831512531U);
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_conv2d_backprop_filter) {
|
|||
EXPECT_TRUE(tbe_json_creator_select->GenJson(conv2d_backprop_filter, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 751738472046426254U);
|
||||
EXPECT_TRUE(tbe_json_creator_check->GenJson(conv2d_backprop_filter, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 751738472046426254U);
|
||||
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 8516089404045447470U);
|
||||
EXPECT_TRUE(tbe_json_creator_build->GenJson(conv2d_backprop_filter, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 8516089404045447470U);
|
||||
}
|
||||
|
@ -179,7 +179,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_dynamic_rnn) {
|
|||
EXPECT_TRUE(tbe_json_creator_select->GenJson(dynamic_rnn, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 13288675099420394285U);
|
||||
EXPECT_TRUE(tbe_json_creator_check->GenJson(dynamic_rnn, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 13288675099420394285U);
|
||||
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 17084598473306810717U);
|
||||
EXPECT_TRUE(tbe_json_creator_build->GenJson(dynamic_rnn, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 17084598473306810717U);
|
||||
}
|
||||
|
@ -233,7 +233,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_layer_norm) {
|
|||
EXPECT_TRUE(tbe_json_creator_select->GenJson(layer_norm, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_select->GetJsonHash(), 6545088373747371515U);
|
||||
EXPECT_TRUE(tbe_json_creator_check->GenJson(layer_norm, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 6866520754867840453U);
|
||||
EXPECT_EQ(tbe_json_creator_check->GetJsonHash(), 10583210293426000299U);
|
||||
EXPECT_TRUE(tbe_json_creator_build->GenJson(layer_norm, &kernel_json));
|
||||
EXPECT_EQ(tbe_json_creator_build->GetJsonHash(), 10583210293426000299U);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue