forked from mindspore-Ecosystem/mindspore
fallback ops before ai cpu kernel select
This commit is contained in:
parent
6b91d6a2a1
commit
22a80730d9
|
@ -415,7 +415,7 @@
|
|||
"TransData ": "support boll",
|
||||
"ScatterNdD ": "Accuracy issues",
|
||||
"Trace": "Hadn't adapted tbe implementation",
|
||||
"AssignAdd" : "Frac_nz in pangu not support"
|
||||
"AssignAdd": "Frac_nz in pangu not support"
|
||||
},
|
||||
"SkipNodes": [
|
||||
"BroadcastTo",
|
||||
|
@ -445,5 +445,11 @@
|
|||
"TransData",
|
||||
"ScatterNdD",
|
||||
"AssignAdd"
|
||||
]
|
||||
],
|
||||
"FallbackOps": {
|
||||
"DeformableOffsets": [
|
||||
1,
|
||||
2
|
||||
]
|
||||
}
|
||||
}
|
|
@ -26,6 +26,7 @@ constexpr auto kAttrDefaultValue = "AttrDefaultValue";
|
|||
constexpr auto kNodeName = "NodeName";
|
||||
constexpr auto kInputOrders = "InputOrders";
|
||||
constexpr auto kSkipNodes = "SkipNodes";
|
||||
constexpr auto kFallbackOps = "FallbackOps";
|
||||
constexpr auto kSkipDynamicCompileStatic = "SkipDynamicCompileStatic";
|
||||
bool SuperBar::LoadSBConfig(const nlohmann::json &js) {
|
||||
if (!LoadSBNodeAttr(js)) {
|
||||
|
@ -43,6 +44,9 @@ bool SuperBar::LoadSBConfig(const nlohmann::json &js) {
|
|||
if (!LoadSBSkipNodes(js)) {
|
||||
return false;
|
||||
}
|
||||
if (!LoadSBFallbackOps(js)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -102,6 +106,14 @@ bool SuperBar::IsSkipNode(const std::string &op_name) {
|
|||
return (std::find(skip_nodes_.begin(), skip_nodes_.end(), op_name) != skip_nodes_.end());
|
||||
}
|
||||
|
||||
std::vector<size_t> SuperBar::GetSBFallbackOpIndex(const std::string &op_name) {
|
||||
auto iter = fallback_ops_.find(op_name);
|
||||
if (iter == fallback_ops_.end()) {
|
||||
return {};
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
bool SuperBar::IsSkipDynamicCompileStaticNode(const std::string &op_name) {
|
||||
return (std::find(skip_dynamic_compile_static_nodes_.begin(), skip_dynamic_compile_static_nodes_.end(), op_name) !=
|
||||
skip_dynamic_compile_static_nodes_.end());
|
||||
|
@ -166,6 +178,21 @@ bool SuperBar::LoadSBNodeInput(const nlohmann::json &js) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool SuperBar::LoadSBFallbackOps(const nlohmann::json &js) {
|
||||
// some ops like "DeformableOffsets", need delete assist input before AI_CPU kernel select
|
||||
auto js_iter = js.find(kFallbackOps);
|
||||
if (js_iter == js.end()) {
|
||||
MS_LOG(ERROR) << "Find fallback node failed, json: " << js.dump();
|
||||
return false;
|
||||
}
|
||||
const auto &fallback_nodes = js_iter->get<nlohmann::json>();
|
||||
for (auto iter = fallback_nodes.begin(); iter != fallback_nodes.end(); ++iter) {
|
||||
const auto &node_name = iter.key();
|
||||
fallback_ops_[node_name] = iter->get<std::vector<size_t>>();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SuperBar::LoadSBSkipNodes(const nlohmann::json &js) {
|
||||
if (js.find(kSkipNodes) == js.end()) {
|
||||
MS_LOG(ERROR) << "Find skip node failed, json: " << js.dump();
|
||||
|
|
|
@ -36,12 +36,14 @@ class BACKEND_EXPORT SuperBar {
|
|||
static std::optional<std::map<size_t, size_t>> GetGraphIdxToKernelIdx(const std::string &op_name);
|
||||
static bool IsSkipNode(const std::string &op_name);
|
||||
static bool IsSkipDynamicCompileStaticNode(const std::string &op_name);
|
||||
static std::vector<size_t> GetSBFallbackOpIndex(const std::string &op_name);
|
||||
|
||||
private:
|
||||
static bool LoadSBNodeAttr(const nlohmann::json &js);
|
||||
static bool LoadSBNodeAttrDefaultValue(const nlohmann::json &js);
|
||||
static bool LoadSBNodeInput(const nlohmann::json &js);
|
||||
static bool LoadSBSkipNodes(const nlohmann::json &js);
|
||||
static bool LoadSBFallbackOps(const nlohmann::json &js);
|
||||
static bool LoadSBSkipDynamicCompileStaticNode(const nlohmann::json &js);
|
||||
inline static std::map<std::string, std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>>> node_input_order_ =
|
||||
{};
|
||||
|
@ -49,6 +51,7 @@ class BACKEND_EXPORT SuperBar {
|
|||
inline static std::map<std::string, std::map<std::string, std::string>> node_attr_ms_to_kernel_;
|
||||
inline static std::map<std::string, std::map<std::string, std::string>> node_attr_default_value_map_ = {};
|
||||
inline static std::vector<std::string> skip_nodes_;
|
||||
inline static std::map<std::string, std::vector<size_t>> fallback_ops_;
|
||||
inline static std::vector<std::string> skip_dynamic_compile_static_nodes_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include <algorithm>
|
||||
#include "plugin/device/ascend/kernel/kernel_query.h"
|
||||
#include "kernel/oplib/oplib.h"
|
||||
#include "kernel/oplib/super_bar.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
|
||||
#include "plugin/device/ascend/kernel/aicpu/aicpu_attr_to_input_registry.h"
|
||||
#include "plugin/device/ascend/kernel/aicpu/aicpu_input_to_attr_registry.h"
|
||||
|
@ -81,6 +82,25 @@ mindspore::HashSet<std::string> kHighPrecisionOp = {kConv2DOpName,
|
|||
kBiasAddGradOpName,
|
||||
kSigmoidCrossEntropyWithLogitsV2OpName};
|
||||
|
||||
void FallbackOps(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto op_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
auto inputs = kernel_node->inputs();
|
||||
const auto &fallback_idx = kernel::SuperBar::GetSBFallbackOpIndex(op_name);
|
||||
if (fallback_idx.empty() || inputs.empty()) {
|
||||
return;
|
||||
}
|
||||
AnfNodePtrList new_inputs = {inputs[0]};
|
||||
for (const auto &idx : fallback_idx) {
|
||||
if (idx >= inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid idx: " << idx << ", node: " << kernel_node->fullname_with_scope()
|
||||
<< ", total input size: " << inputs.size();
|
||||
}
|
||||
(void)new_inputs.emplace_back(inputs[idx]);
|
||||
}
|
||||
kernel_node->set_inputs(new_inputs);
|
||||
}
|
||||
|
||||
bool MatchUnfoldInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfoPtr &kernel_build_info) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Check input data type
|
||||
|
@ -1221,6 +1241,7 @@ std::tuple<KernelSelectStatus, std::string, ExceptionType> SelectKernelInfoWithM
|
|||
ConvertConstInputToAttr(kernel_node, input_to_attr_info);
|
||||
}
|
||||
|
||||
FallbackOps(kernel_node);
|
||||
kernel::AICPUQuery(kernel_node, &aicpu_kernel_info_list);
|
||||
select_status = SetMatchedKernelInfo(kernel_node, aicpu_kernel_info_list);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrIsAiCpuKernel, MakeValue(true), kernel_node);
|
||||
|
|
|
@ -893,7 +893,7 @@ bool FormatTransfer::TransDataForwardCore(const FormatArgs &args, void *result,
|
|||
return false;
|
||||
}
|
||||
if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
|
||||
return NCHW_TO_FRAC_Z_WITH_GROPUS(args, result, true, groups);
|
||||
return NCHW_TO_FRAC_Z_WITH_GROUPS(args, result, true, groups);
|
||||
}
|
||||
auto iter = format_trans_fp_map.find(args.device_format);
|
||||
if (iter == format_trans_fp_map.end()) {
|
||||
|
@ -1392,7 +1392,7 @@ bool FormatTransfer::NCDHW_TO_FRAC_Z3D(const FormatArgs &args, void *result) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool FormatTransfer::NCHW_TO_FRAC_Z_WITH_GROPUS(const FormatArgs &args, void *result, bool to_device, int64_t groups) {
|
||||
bool FormatTransfer::NCHW_TO_FRAC_Z_WITH_GROUPS(const FormatArgs &args, void *result, bool to_device, int64_t groups) {
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
auto size = Common4DCheck(args);
|
||||
auto n_dim = args.host_shape[kN];
|
||||
|
@ -1765,7 +1765,7 @@ bool FormatTransfer::NDC1HWC0_TO_NCDHW(const FormatArgs &args, void *result) {
|
|||
|
||||
bool FormatTransfer::FRAC_Z_TO_NCHW_WITH_GROUPS(const FormatArgs &args, void *result, int64_t groups) {
|
||||
MS_LOG(DEBUG) << "Trans format from frac_z to nchw with groups=" << groups;
|
||||
return NCHW_TO_FRAC_Z_WITH_GROPUS(args, result, false, groups);
|
||||
return NCHW_TO_FRAC_Z_WITH_GROUPS(args, result, false, groups);
|
||||
}
|
||||
|
||||
int64_t FormatTransfer::Common4DCheck(const FormatArgs &args) {
|
||||
|
|
|
@ -241,7 +241,7 @@ class FormatTransfer {
|
|||
static bool NCHW_TO_C1HWNCOC0(const FormatArgs &args, void *result);
|
||||
static bool NCDHW_TO_NDC1HWC0(const FormatArgs &args, void *result);
|
||||
static bool NCDHW_TO_FRAC_Z3D(const FormatArgs &args, void *result);
|
||||
static bool NCHW_TO_FRAC_Z_WITH_GROPUS(const FormatArgs &args, void *result, bool to_device, int64_t groups);
|
||||
static bool NCHW_TO_FRAC_Z_WITH_GROUPS(const FormatArgs &args, void *result, bool to_device, int64_t groups);
|
||||
|
||||
// DEVICE TO HOST
|
||||
static bool TO_NCHW(const FormatArgs &args, void *result);
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""DeformableOffsets op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
deformable_offsets_op_info = AiCPURegOp("DeformableOffsets") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "offsets", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC, DataType.F16_NHWC) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(deformable_offsets_op_info)
|
||||
def _deformable_offsets_aicpu():
|
||||
"""DeformableOffsets AiCPU register"""
|
||||
return
|
Loading…
Reference in New Issue