fix bug of filter kernel info

This commit is contained in:
William Lian 2020-10-19 19:48:56 +08:00
parent b6715eb790
commit 70c2920615
1 changed files with 12 additions and 10 deletions

View File

@ -15,23 +15,24 @@
*/
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h"
#include <memory>
#include <map>
#include <memory>
#include <set>
#include <utility>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/oplib/oplib.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
#include "nlohmann/json.hpp"
#include "backend/optimizer/common/helper.h"
#include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.h"
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_build_client.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "nlohmann/json.hpp"
namespace mindspore {
namespace kernel {
@ -276,12 +277,13 @@ bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const
MS_LOG(INFO) << "Warning: Server not support format with C04 suffix.";
return false;
}
if (format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) {
return true;
}
// not support format:
// 1 NDHWC with shape size != 5
// 2 FRAC_NZ with shape size < 2
// 3 !NDHWC with shape size > 4
if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) ||
(format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) ||
(format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) {
MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size();
return false;