diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.cc index 64d9262a133..04bc63d95ed 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,16 +26,27 @@ namespace mindspore { namespace kernel { namespace { constexpr size_t kNcdhwShapeSize = 5; + +bool CheckValidInputAndHiddenSize(const AnfNodePtr &node) { + if (node->isa()) { + auto param = node->cast(); + return param->input_size() > 0 && param->hidden_size() > 0; + } + if (node->isa()) { + auto cnode = node->cast(); + return common::AnfAlgo::HasNodeAttr(kAttrInputSize, cnode) && common::AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode); + } + return false; +} } // namespace bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) { size_t real_input_num = common::AnfAlgo::GetInputTensorNum(node); for (size_t i = 0; i < real_input_num; i++) { - session::KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i); auto format = AnfAlgo::GetInputFormat(node, i); - if (!CheckValidOutputDeviceShape(kernel_with_index.first, kernel_with_index.second, format)) { + if (!CheckValidInOutDeviceShape(node, i, false, format)) { MS_LOG(WARNING) << "TBE Host check input device shape failed, node:" << node->fullname_with_scope() - << ", input node: " << kernel_with_index.first->DebugString() << ", format:" << format; + << ", format:" << format; return false; } } @@ -43,7 +54,7 @@ bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) { size_t real_output_num = common::AnfAlgo::GetOutputTensorNum(node); for (size_t i = 0; i < real_output_num; i++) { auto format = AnfAlgo::GetOutputFormat(node, i); - if (!CheckValidOutputDeviceShape(node, i, format)) { + if (!CheckValidInOutDeviceShape(node, i, true, format)) { MS_LOG(WARNING) << "TBE Host check output device shape failed, node:" << node->fullname_with_scope() << ", format:" << format; return false; @@ -52,12 +63,13 @@ bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) { return true; } -std::vector HostCheck::GetFinalInferShape(const AnfNodePtr &node, const size_t output_idx, +std::vector HostCheck::GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output, const std::string &format) { - auto output_shape = common::AnfAlgo::GetOutputDetailShape(node, output_idx); + auto shape = is_output ? common::AnfAlgo::GetOutputDetailShape(node, index) + : common::AnfAlgo::GetPrevNodeOutputDetailShape(node, index); std::vector infer_shape; - if (output_shape->isa()) { - auto shape_ptr = output_shape->cast(); + if (shape->isa()) { + auto shape_ptr = shape->cast(); MS_EXCEPTION_IF_NULL(shape_ptr); infer_shape = shape_ptr->shape(); } @@ -66,7 +78,9 @@ std::vector HostCheck::GetFinalInferShape(const AnfNodePtr &node, const } if (trans::IsNeedPadding(format, infer_shape.size())) { - infer_shape = trans::PaddingShape(infer_shape, format, AnfAlgo::GetOutputReshapeType(node, output_idx), node); + auto reshape_type = + is_output ? AnfAlgo::GetOutputReshapeType(node, index) : AnfAlgo::GetInputReshapeType(node, index); + infer_shape = trans::PaddingShape(infer_shape, format, reshape_type, node); } auto temp_shape = infer_shape; @@ -81,9 +95,9 @@ std::vector HostCheck::GetFinalInferShape(const AnfNodePtr &node, const return temp_shape; } -bool HostCheck::CheckValidOutputDeviceShape(const AnfNodePtr &node, const size_t output_idx, - const std::string &format) { - auto infer_shape = GetFinalInferShape(node, output_idx, format); +bool HostCheck::CheckValidInOutDeviceShape(const AnfNodePtr &node, size_t index, bool is_output, + const std::string &format) { + auto infer_shape = GetFinalInferShape(node, index, is_output, format); if (infer_shape.empty()) { return true; } @@ -105,11 +119,11 @@ bool HostCheck::CheckValidOutputDeviceShape(const AnfNodePtr &node, const size_t } if (format == kOpFormat_FRACTAL_ZN_RNN) { - return infer_shape.size() >= kShape2dDims; + return infer_shape.size() >= kShape2dDims && CheckValidInputAndHiddenSize(node); } if (format == kOpFormat_ND_RNN_BIAS) { - return infer_shape.size() > 0; + return infer_shape.size() > 0 && CheckValidInputAndHiddenSize(node); } return true; } diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.h b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.h index 305f67de083..21588bf9470 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.h @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,8 +35,9 @@ class HostCheck { static bool CheckValidDeviceShape(const AnfNodePtr &node); private: - static bool CheckValidOutputDeviceShape(const AnfNodePtr &node, const size_t output_idx, const std::string &format); - static std::vector GetFinalInferShape(const AnfNodePtr &node, const size_t output_idx, + static bool CheckValidInOutDeviceShape(const AnfNodePtr &node, size_t index, bool is_output, + const std::string &format); + static std::vector GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output, const std::string &format); };