fallback ops before ai cpu kernel select

This commit is contained in:
lby 2023-01-28 17:43:38 +08:00
parent 6b91d6a2a1
commit 22a80730d9
7 changed files with 95 additions and 6 deletions

View File

@ -415,7 +415,7 @@
"TransData ": "support boll",
"ScatterNdD ": "Accuracy issues",
"Trace": "Hadn't adapted tbe implementation",
"AssignAdd" : "Frac_nz in pangu not support"
"AssignAdd": "Frac_nz in pangu not support"
},
"SkipNodes": [
"BroadcastTo",
@ -445,5 +445,11 @@
"TransData",
"ScatterNdD",
"AssignAdd"
]
],
"FallbackOps": {
"DeformableOffsets": [
1,
2
]
}
}

View File

@ -26,6 +26,7 @@ constexpr auto kAttrDefaultValue = "AttrDefaultValue";
constexpr auto kNodeName = "NodeName";
constexpr auto kInputOrders = "InputOrders";
constexpr auto kSkipNodes = "SkipNodes";
constexpr auto kFallbackOps = "FallbackOps";
constexpr auto kSkipDynamicCompileStatic = "SkipDynamicCompileStatic";
bool SuperBar::LoadSBConfig(const nlohmann::json &js) {
if (!LoadSBNodeAttr(js)) {
@ -43,6 +44,9 @@ bool SuperBar::LoadSBConfig(const nlohmann::json &js) {
if (!LoadSBSkipNodes(js)) {
return false;
}
if (!LoadSBFallbackOps(js)) {
return false;
}
return true;
}
@ -102,6 +106,14 @@ bool SuperBar::IsSkipNode(const std::string &op_name) {
return (std::find(skip_nodes_.begin(), skip_nodes_.end(), op_name) != skip_nodes_.end());
}
std::vector<size_t> SuperBar::GetSBFallbackOpIndex(const std::string &op_name) {
auto iter = fallback_ops_.find(op_name);
if (iter == fallback_ops_.end()) {
return {};
}
return iter->second;
}
bool SuperBar::IsSkipDynamicCompileStaticNode(const std::string &op_name) {
return (std::find(skip_dynamic_compile_static_nodes_.begin(), skip_dynamic_compile_static_nodes_.end(), op_name) !=
skip_dynamic_compile_static_nodes_.end());
@ -166,6 +178,21 @@ bool SuperBar::LoadSBNodeInput(const nlohmann::json &js) {
return true;
}
bool SuperBar::LoadSBFallbackOps(const nlohmann::json &js) {
// some ops like "DeformableOffsets", need delete assist input before AI_CPU kernel select
auto js_iter = js.find(kFallbackOps);
if (js_iter == js.end()) {
MS_LOG(ERROR) << "Find fallback node failed, json: " << js.dump();
return false;
}
const auto &fallback_nodes = js_iter->get<nlohmann::json>();
for (auto iter = fallback_nodes.begin(); iter != fallback_nodes.end(); ++iter) {
const auto &node_name = iter.key();
fallback_ops_[node_name] = iter->get<std::vector<size_t>>();
}
return true;
}
bool SuperBar::LoadSBSkipNodes(const nlohmann::json &js) {
if (js.find(kSkipNodes) == js.end()) {
MS_LOG(ERROR) << "Find skip node failed, json: " << js.dump();

View File

@ -36,12 +36,14 @@ class BACKEND_EXPORT SuperBar {
static std::optional<std::map<size_t, size_t>> GetGraphIdxToKernelIdx(const std::string &op_name);
static bool IsSkipNode(const std::string &op_name);
static bool IsSkipDynamicCompileStaticNode(const std::string &op_name);
static std::vector<size_t> GetSBFallbackOpIndex(const std::string &op_name);
private:
static bool LoadSBNodeAttr(const nlohmann::json &js);
static bool LoadSBNodeAttrDefaultValue(const nlohmann::json &js);
static bool LoadSBNodeInput(const nlohmann::json &js);
static bool LoadSBSkipNodes(const nlohmann::json &js);
static bool LoadSBFallbackOps(const nlohmann::json &js);
static bool LoadSBSkipDynamicCompileStaticNode(const nlohmann::json &js);
inline static std::map<std::string, std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>>> node_input_order_ =
{};
@ -49,6 +51,7 @@ class BACKEND_EXPORT SuperBar {
inline static std::map<std::string, std::map<std::string, std::string>> node_attr_ms_to_kernel_;
inline static std::map<std::string, std::map<std::string, std::string>> node_attr_default_value_map_ = {};
inline static std::vector<std::string> skip_nodes_;
inline static std::map<std::string, std::vector<size_t>> fallback_ops_;
inline static std::vector<std::string> skip_dynamic_compile_static_nodes_;
};
} // namespace mindspore::kernel

View File

@ -27,6 +27,7 @@
#include <algorithm>
#include "plugin/device/ascend/kernel/kernel_query.h"
#include "kernel/oplib/oplib.h"
#include "kernel/oplib/super_bar.h"
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
#include "plugin/device/ascend/kernel/aicpu/aicpu_attr_to_input_registry.h"
#include "plugin/device/ascend/kernel/aicpu/aicpu_input_to_attr_registry.h"
@ -81,6 +82,25 @@ mindspore::HashSet<std::string> kHighPrecisionOp = {kConv2DOpName,
kBiasAddGradOpName,
kSigmoidCrossEntropyWithLogitsV2OpName};
void FallbackOps(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto op_name = common::AnfAlgo::GetCNodeName(kernel_node);
auto inputs = kernel_node->inputs();
const auto &fallback_idx = kernel::SuperBar::GetSBFallbackOpIndex(op_name);
if (fallback_idx.empty() || inputs.empty()) {
return;
}
AnfNodePtrList new_inputs = {inputs[0]};
for (const auto &idx : fallback_idx) {
if (idx >= inputs.size()) {
MS_LOG(EXCEPTION) << "Invalid idx: " << idx << ", node: " << kernel_node->fullname_with_scope()
<< ", total input size: " << inputs.size();
}
(void)new_inputs.emplace_back(inputs[idx]);
}
kernel_node->set_inputs(new_inputs);
}
bool MatchUnfoldInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfoPtr &kernel_build_info) {
MS_EXCEPTION_IF_NULL(cnode);
// Check input data type
@ -1221,6 +1241,7 @@ std::tuple<KernelSelectStatus, std::string, ExceptionType> SelectKernelInfoWithM
ConvertConstInputToAttr(kernel_node, input_to_attr_info);
}
FallbackOps(kernel_node);
kernel::AICPUQuery(kernel_node, &aicpu_kernel_info_list);
select_status = SetMatchedKernelInfo(kernel_node, aicpu_kernel_info_list);
common::AnfAlgo::SetNodeAttr(kAttrIsAiCpuKernel, MakeValue(true), kernel_node);

View File

@ -893,7 +893,7 @@ bool FormatTransfer::TransDataForwardCore(const FormatArgs &args, void *result,
return false;
}
if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
return NCHW_TO_FRAC_Z_WITH_GROPUS(args, result, true, groups);
return NCHW_TO_FRAC_Z_WITH_GROUPS(args, result, true, groups);
}
auto iter = format_trans_fp_map.find(args.device_format);
if (iter == format_trans_fp_map.end()) {
@ -1392,7 +1392,7 @@ bool FormatTransfer::NCDHW_TO_FRAC_Z3D(const FormatArgs &args, void *result) {
return true;
}
bool FormatTransfer::NCHW_TO_FRAC_Z_WITH_GROPUS(const FormatArgs &args, void *result, bool to_device, int64_t groups) {
bool FormatTransfer::NCHW_TO_FRAC_Z_WITH_GROUPS(const FormatArgs &args, void *result, bool to_device, int64_t groups) {
MS_EXCEPTION_IF_NULL(result);
auto size = Common4DCheck(args);
auto n_dim = args.host_shape[kN];
@ -1765,7 +1765,7 @@ bool FormatTransfer::NDC1HWC0_TO_NCDHW(const FormatArgs &args, void *result) {
bool FormatTransfer::FRAC_Z_TO_NCHW_WITH_GROUPS(const FormatArgs &args, void *result, int64_t groups) {
MS_LOG(DEBUG) << "Trans format from frac_z to nchw with groups=" << groups;
return NCHW_TO_FRAC_Z_WITH_GROPUS(args, result, false, groups);
return NCHW_TO_FRAC_Z_WITH_GROUPS(args, result, false, groups);
}
int64_t FormatTransfer::Common4DCheck(const FormatArgs &args) {

View File

@ -241,7 +241,7 @@ class FormatTransfer {
static bool NCHW_TO_C1HWNCOC0(const FormatArgs &args, void *result);
static bool NCDHW_TO_NDC1HWC0(const FormatArgs &args, void *result);
static bool NCDHW_TO_FRAC_Z3D(const FormatArgs &args, void *result);
static bool NCHW_TO_FRAC_Z_WITH_GROPUS(const FormatArgs &args, void *result, bool to_device, int64_t groups);
static bool NCHW_TO_FRAC_Z_WITH_GROUPS(const FormatArgs &args, void *result, bool to_device, int64_t groups);
// DEVICE TO HOST
static bool TO_NCHW(const FormatArgs &args, void *result);

View File

@ -0,0 +1,32 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DeformableOffsets op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
deformable_offsets_op_info = AiCPURegOp("DeformableOffsets") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.input(1, "offsets", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC, DataType.F16_NHWC) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \
.get_op_info()
@op_info_register(deformable_offsets_op_info)
def _deformable_offsets_aicpu():
"""DeformableOffsets AiCPU register"""
return