!40327 add check for FRAC_ZN_RNN in CheckValidDeviceShape

Merge pull request !40327 from yuchaojie/op_select
This commit is contained in:
i-robot 2022-08-15 12:09:05 +00:00 committed by Gitee
commit 30dd25c647
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 33 additions and 18 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,16 +26,27 @@ namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kNcdhwShapeSize = 5;
bool CheckValidInputAndHiddenSize(const AnfNodePtr &node) {
if (node->isa<Parameter>()) {
auto param = node->cast<ParameterPtr>();
return param->input_size() > 0 && param->hidden_size() > 0;
}
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
return common::AnfAlgo::HasNodeAttr(kAttrInputSize, cnode) && common::AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode);
}
return false;
}
} // namespace
bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) {
size_t real_input_num = common::AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < real_input_num; i++) {
session::KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i);
auto format = AnfAlgo::GetInputFormat(node, i);
if (!CheckValidOutputDeviceShape(kernel_with_index.first, kernel_with_index.second, format)) {
if (!CheckValidInOutDeviceShape(node, i, false, format)) {
MS_LOG(WARNING) << "TBE Host check input device shape failed, node:" << node->fullname_with_scope()
<< ", input node: " << kernel_with_index.first->DebugString() << ", format:" << format;
<< ", format:" << format;
return false;
}
}
@ -43,7 +54,7 @@ bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) {
size_t real_output_num = common::AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < real_output_num; i++) {
auto format = AnfAlgo::GetOutputFormat(node, i);
if (!CheckValidOutputDeviceShape(node, i, format)) {
if (!CheckValidInOutDeviceShape(node, i, true, format)) {
MS_LOG(WARNING) << "TBE Host check output device shape failed, node:" << node->fullname_with_scope()
<< ", format:" << format;
return false;
@ -52,12 +63,13 @@ bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) {
return true;
}
std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, const size_t output_idx,
std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output,
const std::string &format) {
auto output_shape = common::AnfAlgo::GetOutputDetailShape(node, output_idx);
auto shape = is_output ? common::AnfAlgo::GetOutputDetailShape(node, index)
: common::AnfAlgo::GetPrevNodeOutputDetailShape(node, index);
std::vector<int64_t> infer_shape;
if (output_shape->isa<abstract::Shape>()) {
auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
if (shape->isa<abstract::Shape>()) {
auto shape_ptr = shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_ptr);
infer_shape = shape_ptr->shape();
}
@ -66,7 +78,9 @@ std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, const
}
if (trans::IsNeedPadding(format, infer_shape.size())) {
infer_shape = trans::PaddingShape(infer_shape, format, AnfAlgo::GetOutputReshapeType(node, output_idx), node);
auto reshape_type =
is_output ? AnfAlgo::GetOutputReshapeType(node, index) : AnfAlgo::GetInputReshapeType(node, index);
infer_shape = trans::PaddingShape(infer_shape, format, reshape_type, node);
}
auto temp_shape = infer_shape;
@ -81,9 +95,9 @@ std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, const
return temp_shape;
}
bool HostCheck::CheckValidOutputDeviceShape(const AnfNodePtr &node, const size_t output_idx,
const std::string &format) {
auto infer_shape = GetFinalInferShape(node, output_idx, format);
bool HostCheck::CheckValidInOutDeviceShape(const AnfNodePtr &node, size_t index, bool is_output,
const std::string &format) {
auto infer_shape = GetFinalInferShape(node, index, is_output, format);
if (infer_shape.empty()) {
return true;
}
@ -105,11 +119,11 @@ bool HostCheck::CheckValidOutputDeviceShape(const AnfNodePtr &node, const size_t
}
if (format == kOpFormat_FRACTAL_ZN_RNN) {
return infer_shape.size() >= kShape2dDims;
return infer_shape.size() >= kShape2dDims && CheckValidInputAndHiddenSize(node);
}
if (format == kOpFormat_ND_RNN_BIAS) {
return infer_shape.size() > 0;
return infer_shape.size() > 0 && CheckValidInputAndHiddenSize(node);
}
return true;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -35,8 +35,9 @@ class HostCheck {
static bool CheckValidDeviceShape(const AnfNodePtr &node);
private:
static bool CheckValidOutputDeviceShape(const AnfNodePtr &node, const size_t output_idx, const std::string &format);
static std::vector<int64_t> GetFinalInferShape(const AnfNodePtr &node, const size_t output_idx,
static bool CheckValidInOutDeviceShape(const AnfNodePtr &node, size_t index, bool is_output,
const std::string &format);
static std::vector<int64_t> GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output,
const std::string &format);
};