!40327 add check for FRAC_ZN_RNN in CheckValidDeviceShape
Merge pull request !40327 from yuchaojie/op_select
This commit is contained in:
commit
30dd25c647
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,16 +26,27 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kNcdhwShapeSize = 5;
|
||||
|
||||
bool CheckValidInputAndHiddenSize(const AnfNodePtr &node) {
|
||||
if (node->isa<Parameter>()) {
|
||||
auto param = node->cast<ParameterPtr>();
|
||||
return param->input_size() > 0 && param->hidden_size() > 0;
|
||||
}
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
return common::AnfAlgo::HasNodeAttr(kAttrInputSize, cnode) && common::AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) {
|
||||
size_t real_input_num = common::AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t i = 0; i < real_input_num; i++) {
|
||||
session::KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i);
|
||||
auto format = AnfAlgo::GetInputFormat(node, i);
|
||||
if (!CheckValidOutputDeviceShape(kernel_with_index.first, kernel_with_index.second, format)) {
|
||||
if (!CheckValidInOutDeviceShape(node, i, false, format)) {
|
||||
MS_LOG(WARNING) << "TBE Host check input device shape failed, node:" << node->fullname_with_scope()
|
||||
<< ", input node: " << kernel_with_index.first->DebugString() << ", format:" << format;
|
||||
<< ", format:" << format;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -43,7 +54,7 @@ bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) {
|
|||
size_t real_output_num = common::AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t i = 0; i < real_output_num; i++) {
|
||||
auto format = AnfAlgo::GetOutputFormat(node, i);
|
||||
if (!CheckValidOutputDeviceShape(node, i, format)) {
|
||||
if (!CheckValidInOutDeviceShape(node, i, true, format)) {
|
||||
MS_LOG(WARNING) << "TBE Host check output device shape failed, node:" << node->fullname_with_scope()
|
||||
<< ", format:" << format;
|
||||
return false;
|
||||
|
@ -52,12 +63,13 @@ bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) {
|
|||
return true;
|
||||
}
|
||||
|
||||
std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, const size_t output_idx,
|
||||
std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output,
|
||||
const std::string &format) {
|
||||
auto output_shape = common::AnfAlgo::GetOutputDetailShape(node, output_idx);
|
||||
auto shape = is_output ? common::AnfAlgo::GetOutputDetailShape(node, index)
|
||||
: common::AnfAlgo::GetPrevNodeOutputDetailShape(node, index);
|
||||
std::vector<int64_t> infer_shape;
|
||||
if (output_shape->isa<abstract::Shape>()) {
|
||||
auto shape_ptr = output_shape->cast<abstract::ShapePtr>();
|
||||
if (shape->isa<abstract::Shape>()) {
|
||||
auto shape_ptr = shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
infer_shape = shape_ptr->shape();
|
||||
}
|
||||
|
@ -66,7 +78,9 @@ std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, const
|
|||
}
|
||||
|
||||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
||||
infer_shape = trans::PaddingShape(infer_shape, format, AnfAlgo::GetOutputReshapeType(node, output_idx), node);
|
||||
auto reshape_type =
|
||||
is_output ? AnfAlgo::GetOutputReshapeType(node, index) : AnfAlgo::GetInputReshapeType(node, index);
|
||||
infer_shape = trans::PaddingShape(infer_shape, format, reshape_type, node);
|
||||
}
|
||||
|
||||
auto temp_shape = infer_shape;
|
||||
|
@ -81,9 +95,9 @@ std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, const
|
|||
return temp_shape;
|
||||
}
|
||||
|
||||
bool HostCheck::CheckValidOutputDeviceShape(const AnfNodePtr &node, const size_t output_idx,
|
||||
bool HostCheck::CheckValidInOutDeviceShape(const AnfNodePtr &node, size_t index, bool is_output,
|
||||
const std::string &format) {
|
||||
auto infer_shape = GetFinalInferShape(node, output_idx, format);
|
||||
auto infer_shape = GetFinalInferShape(node, index, is_output, format);
|
||||
if (infer_shape.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
@ -105,11 +119,11 @@ bool HostCheck::CheckValidOutputDeviceShape(const AnfNodePtr &node, const size_t
|
|||
}
|
||||
|
||||
if (format == kOpFormat_FRACTAL_ZN_RNN) {
|
||||
return infer_shape.size() >= kShape2dDims;
|
||||
return infer_shape.size() >= kShape2dDims && CheckValidInputAndHiddenSize(node);
|
||||
}
|
||||
|
||||
if (format == kOpFormat_ND_RNN_BIAS) {
|
||||
return infer_shape.size() > 0;
|
||||
return infer_shape.size() > 0 && CheckValidInputAndHiddenSize(node);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -35,8 +35,9 @@ class HostCheck {
|
|||
static bool CheckValidDeviceShape(const AnfNodePtr &node);
|
||||
|
||||
private:
|
||||
static bool CheckValidOutputDeviceShape(const AnfNodePtr &node, const size_t output_idx, const std::string &format);
|
||||
static std::vector<int64_t> GetFinalInferShape(const AnfNodePtr &node, const size_t output_idx,
|
||||
static bool CheckValidInOutDeviceShape(const AnfNodePtr &node, size_t index, bool is_output,
|
||||
const std::string &format);
|
||||
static std::vector<int64_t> GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output,
|
||||
const std::string &format);
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue