forked from mindspore-Ecosystem/mindspore
!1079 Convert AiCpu kernel when aicore not supported
Merge pull request !1079 from lianliguang/convert-to-AICPU-when-AiCore-not-supported-kernel
This commit is contained in:
commit
86d4797399
|
@ -85,7 +85,7 @@ const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberType
|
|||
} while (0)
|
||||
|
||||
template <typename T>
|
||||
T Ceil(T n1, T n2) {
|
||||
T DivCeil(T n1, T n2) {
|
||||
return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0;
|
||||
}
|
||||
|
||||
|
@ -371,15 +371,48 @@ std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
|
|||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
size_t c0 = 4;
|
||||
size_t first_dim = DivCeil(c0 * shape[2] * shape[3], kCubeSize);
|
||||
size_t no = DivCeil(DivCeil(shape[0], kCubeSize) * kCubeSize, kCubeSize);
|
||||
device_shape.push_back(first_dim);
|
||||
device_shape.push_back(no);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
size_t C1 = 1;
|
||||
size_t C0 = 4;
|
||||
device_shape.push_back(shape[0]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(C0);
|
||||
return device_shape;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
|
||||
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
|
||||
const std::map<std::string, DeviceShapeTransfer> device_shape_map{
|
||||
{kOpFormat_NCHW, NchwDeviceShape}, {kOpFormat_NHWC, NhwcDeviceShape},
|
||||
{kOpFormat_HWCN, HwchDeviceShape}, {kOpFormat_FRAC_Z, FracZDeviceShape},
|
||||
{kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
|
||||
};
|
||||
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
|
||||
{kOpFormat_NHWC, NhwcDeviceShape},
|
||||
{kOpFormat_HWCN, HwchDeviceShape},
|
||||
{kOpFormat_FRAC_Z, FracZDeviceShape},
|
||||
{kOpFormat_NC1HWC0, Nc1hwc0DeviceShape},
|
||||
{kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
|
||||
{kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
|
||||
{kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}};
|
||||
|
||||
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
|
||||
return shape;
|
||||
|
@ -506,13 +539,13 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
size_t c1 = Ceil(c, c0);
|
||||
size_t c1 = DivCeil(c, c0);
|
||||
size_t hw = h * w;
|
||||
size_t chw = c * hw;
|
||||
size_t hwc0 = hw * c0;
|
||||
size_t nchw = n * chw;
|
||||
|
||||
size_t hf_cnt = Ceil(n, kCubeSize);
|
||||
size_t hf_cnt = DivCeil(n, kCubeSize);
|
||||
size_t vf_cnt = c1 * hw;
|
||||
size_t fractal_ele_cnt = c0 * kCubeSize;
|
||||
size_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
|
||||
|
@ -775,7 +808,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
size_t c1 = Ceil(c, c0);
|
||||
size_t c1 = DivCeil(c, c0);
|
||||
size_t hw = h * w;
|
||||
size_t chw = c * hw;
|
||||
size_t c1hwc0 = c1 * hw * c0;
|
||||
|
|
|
@ -34,6 +34,7 @@ namespace ascend {
|
|||
namespace {
|
||||
const float kWegihtBaseScore = 1;
|
||||
const float kFeatureMapBaseScore = 10;
|
||||
constexpr auto kPriChoosenFormat = "pri_format";
|
||||
enum MatchCountPriority : int {
|
||||
MATCH_COUNT_PRIORITY_BEGIN = 0,
|
||||
MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
|
||||
|
@ -85,6 +86,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
|
|||
if (need_change_nd) {
|
||||
priority_matched_format = kOpFormat_DEFAULT;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode);
|
||||
return priority_matched_format;
|
||||
}
|
||||
/**
|
||||
|
@ -394,9 +396,9 @@ void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode,
|
|||
std::ostringstream buffer;
|
||||
buffer << cnode->DebugString();
|
||||
if (precision_reduce) {
|
||||
buffer << " reduce precision, node datatype: ";
|
||||
buffer << " reduce precision, node datatype: \n";
|
||||
} else {
|
||||
buffer << " raise precision, node datatype: ";
|
||||
buffer << " raise precision, node datatype: \n";
|
||||
}
|
||||
PrintInputAndOutputInferType(buffer, cnode);
|
||||
buffer << ", select kernel:" << selected_kernel_build_info->ToString();
|
||||
|
@ -464,66 +466,57 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis
|
|||
}
|
||||
} // namespace
|
||||
|
||||
std::shared_ptr<kernel::KernelBuildInfo> CanHitKernelInfo(
|
||||
int *status, const CNodePtr &kernel_node,
|
||||
KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
|
||||
const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
KernelSelectStatus select_status = kNoMatched;
|
||||
bool precision_reduce = false;
|
||||
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
|
||||
// Matched kernel info
|
||||
// Filter kernel info matched with me infered type
|
||||
auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list);
|
||||
if (!filtered_kernel_info_list.empty()) {
|
||||
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
|
||||
select_status = kStatusAllMatched;
|
||||
} else {
|
||||
// selected kernel info using raised precision or reduce precision
|
||||
filtered_kernel_info_list =
|
||||
FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce);
|
||||
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
|
||||
if (selected_kernel_info == nullptr) {
|
||||
return nullptr;
|
||||
return select_status;
|
||||
} else {
|
||||
PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce);
|
||||
*status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
|
||||
select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
|
||||
}
|
||||
}
|
||||
return selected_kernel_info;
|
||||
// Set kernel info to the anfnode
|
||||
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
|
||||
// Set format and data type for input tensor.
|
||||
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
|
||||
return select_status;
|
||||
}
|
||||
|
||||
int SelectKernelInfo(const CNodePtr &kernel_node) {
|
||||
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||
int status = kStatusAllMatched;
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel::KernelQuery(kernel_node, &kernel_info_list);
|
||||
// filter kernel info matched with me infered type
|
||||
auto selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list);
|
||||
if (selected_kernel_info == nullptr) {
|
||||
auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
|
||||
// If aicore not find valid kernel info reloading aicpu kernel info list to find it
|
||||
if (select_status == kNoMatched) {
|
||||
MS_LOG(WARNING) << "The node [" << kernel_node->DebugString()
|
||||
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info";
|
||||
kernel::AicpuQuery(kernel_node, &kernel_info_list);
|
||||
selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list);
|
||||
kernel::AICpuQuery(kernel_node, &kernel_info_list);
|
||||
select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
|
||||
}
|
||||
if (selected_kernel_info == nullptr) {
|
||||
// The kernel info not finded both in the aicpu kernel list & aicore kernel list
|
||||
if (select_status == kNoMatched) {
|
||||
std::ostringstream buffer;
|
||||
PrintInputAndOutputInferType(buffer, kernel_node);
|
||||
MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString()
|
||||
<< "] cannot find valid kernel info, not supported the type " << buffer.str();
|
||||
}
|
||||
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
|
||||
// Set format and data type for input tensor.
|
||||
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
|
||||
return status;
|
||||
}
|
||||
|
||||
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
|
||||
const kernel::KernelBuildInfoPtr &new_kernel_build_info) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||
kernel::KernelQuery(kernel_node, &kernel_info_list);
|
||||
auto result = std::find_if(kernel_info_list.begin(), kernel_info_list.end(),
|
||||
[&new_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
return *item == *new_kernel_build_info;
|
||||
});
|
||||
return result != kernel_info_list.end();
|
||||
return select_status;
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -21,8 +21,13 @@
|
|||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
int SelectKernelInfo(const CNodePtr &kernel_node);
|
||||
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, const kernel::KernelBuildInfoPtr &new_kernel_build_info);
|
||||
enum KernelSelectStatus {
|
||||
kNoMatched = -1,
|
||||
kStatusAllMatched = 0,
|
||||
kStatusReducePrecision = 1,
|
||||
kStatusRaisePrecision = 2,
|
||||
};
|
||||
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node);
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,7 +35,7 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
|
|||
std::vector<std::string> input_format, output_format;
|
||||
std::vector<TypeId> input_type, output_type;
|
||||
for (const auto &data_type : data_type_list) {
|
||||
for (const auto &format : k4DSupportFormat) {
|
||||
for (const auto &format : kOpFormatList) {
|
||||
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
input_format.clear();
|
||||
input_format.push_back(format);
|
||||
|
|
|
@ -35,14 +35,18 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
|
|||
return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() &&
|
||||
AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum();
|
||||
});
|
||||
kernel_info_list->clear();
|
||||
if (!filtered_list.empty()) {
|
||||
kernel_info_list->clear();
|
||||
(void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "node" << kernel_node->DebugString() << "'s output size : ["
|
||||
MS_LOG(WARNING) << "All kernel Info list does not match any kernel info ";
|
||||
for (size_t index; index < kernel_info_list->size(); ++index) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list->at(index));
|
||||
MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString();
|
||||
}
|
||||
MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : ["
|
||||
<< AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
|
||||
<< "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node)
|
||||
<< "] cannot match any kernelInfo !";
|
||||
<< "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !";
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
@ -50,7 +54,6 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
|
|||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||
TbeMetadataInfo(kernel_node, kernel_info_list);
|
||||
|
||||
if (kernel_info_list->empty()) {
|
||||
AicpuMetadataInfo(kernel_node, kernel_info_list);
|
||||
}
|
||||
|
@ -68,12 +71,41 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
|
|||
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
|
||||
}
|
||||
|
||||
void AicpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
|
||||
void AICpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||
kernel_info_list->clear();
|
||||
AicpuMetadataInfo(kernel_node, kernel_info_list);
|
||||
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
|
||||
}
|
||||
bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(select_kernel_build_info);
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||
auto cnode = kernel_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
AicpuMetadataInfo(cnode, &kernel_info_list);
|
||||
FilterInvalidKernelInfo(cnode, &kernel_info_list);
|
||||
return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
|
||||
[&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
return *item == *select_kernel_build_info;
|
||||
});
|
||||
}
|
||||
|
||||
bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(select_kernel_build_info);
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||
auto cnode = kernel_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
TbeMetadataInfo(cnode, &kernel_info_list);
|
||||
FilterInvalidKernelInfo(cnode, &kernel_info_list);
|
||||
return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
|
||||
[&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
return *item == *select_kernel_build_info;
|
||||
});
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,7 +26,9 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
|
||||
void AicpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
|
||||
void AICpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
|
||||
bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
|
||||
bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_
|
||||
|
|
|
@ -551,11 +551,6 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn
|
|||
}
|
||||
|
||||
bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) {
|
||||
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND,
|
||||
kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
|
||||
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0,
|
||||
kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
|
||||
|
||||
// if format is default, it remarkes support all format
|
||||
if (kOpFormatList.find(format) == kOpFormatList.end()) {
|
||||
MS_LOG(EXCEPTION) << "Got the unknown format " << format;
|
||||
|
|
|
@ -54,6 +54,7 @@
|
|||
#include "pre_activate/pass/optimize_dependence.h"
|
||||
#include "pre_activate/pass/erase_visit_attr.h"
|
||||
#include "pre_activate/ascend/format_type/insert_cast.h"
|
||||
#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h"
|
||||
#include "pre_activate/pass/eliminate_redundant_op.h"
|
||||
#include "pre_activate/pass/common_subexpression_elimination.h"
|
||||
#include "pre_activate/ascend/format_type/merge_cast_to_op.h"
|
||||
|
@ -172,6 +173,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
|
|||
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>());
|
||||
optimizer->AddPassManager(mixed_precision_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
|
|
|
@ -268,6 +268,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
}
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
|
||||
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get());
|
||||
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
|
||||
return cast;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,10 +30,6 @@ class KernelSelect {
|
|||
KernelSelect() = default;
|
||||
virtual ~KernelSelect() = default;
|
||||
virtual void SelectKernel(const CNodePtr &cnode) { device::ascend::SelectKernelInfo(cnode); }
|
||||
virtual bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
|
||||
const kernel::KernelBuildInfoPtr &new_kernel_build_info) {
|
||||
return device::ascend::CheckKernelAccuracySupported(kernel_node, new_kernel_build_info);
|
||||
}
|
||||
};
|
||||
using KernelSelectPtr = std::shared_ptr<KernelSelect>;
|
||||
|
||||
|
@ -41,8 +37,13 @@ class SupportedChecker {
|
|||
public:
|
||||
SupportedChecker() = default;
|
||||
virtual ~SupportedChecker() = default;
|
||||
virtual bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
|
||||
return kernel::CheckSupported(anf_node, select_kernel_build_info);
|
||||
virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node,
|
||||
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
|
||||
return kernel::IsSupportedByAiCore(anf_node, select_kernel_build_info);
|
||||
}
|
||||
virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node,
|
||||
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
|
||||
return kernel::IsSupportedByAiCpu(anf_node, select_kernel_build_info);
|
||||
}
|
||||
};
|
||||
using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>;
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h"
|
||||
#include <memory>
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
#include "kernel/kernel_query.h"
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef ConvertUnSupportNodeToAICPU::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({X, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraphPtr &,
|
||||
const mindspore::AnfNodePtr &node,
|
||||
const mindspore::EquivPtr &) const {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto node_name = AnfAlgo::GetCNodeName(node);
|
||||
if (node_name != prim::KPrimTransData->name() || node_name != prim::kPrimCast->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
|
||||
if (supported_checker_->CheckAiCoreSupported(node, kernel_builder_info)) {
|
||||
return node;
|
||||
} else if (supported_checker_->CheckAiCpuSupported(node, kernel_builder_info)) {
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info);
|
||||
builder->SetKernelType(AICPU_KERNEL);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node ["
|
||||
<< node->DebugString() << "]";
|
||||
}
|
||||
return node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "pre_activate/ascend/ascend_helper.h"
|
||||
#ifndef MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H
|
||||
#define MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConvertUnSupportNodeToAICPU : public PatternProcessPass {
|
||||
public:
|
||||
explicit ConvertUnSupportNodeToAICPU(bool multigraph = true)
|
||||
: PatternProcessPass("convert_unsupported_node_to_aicpu", multigraph),
|
||||
supported_checker_(std::make_shared<SupportedChecker>()) {}
|
||||
~ConvertUnSupportNodeToAICPU() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
SupportedCheckerPtr supported_checker_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_
|
||||
#define MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_
|
||||
#include <string>
|
||||
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
|
@ -32,4 +32,4 @@ class RunOpInsertCast : public PatternProcessPass {
|
|||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_
|
||||
#define MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
@ -41,4 +41,4 @@ class RunOpInsertTransData : public PatternProcessPass {
|
|||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_
|
||||
|
|
|
@ -128,7 +128,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
|
|||
auto indices_const = CreateValueNode(new_cnode);
|
||||
new_cnode->add_input(indices_const);
|
||||
MS_EXCEPTION_IF_NULL(supported_checker_);
|
||||
if (!supported_checker_->CheckSupported(new_cnode, CreateKernelBuildInfo())) {
|
||||
if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap
|
|||
new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor());
|
||||
|
||||
auto new_fusion_transdata = std::make_shared<Primitive>(kTransDataOpName);
|
||||
if (kernel_select_->CheckKernelAccuracySupported(transdata_cnode, new_transdata_builder->Build())) {
|
||||
if (supported_checker_->CheckAiCoreSupported(transdata_cnode, new_transdata_builder->Build())) {
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata),
|
||||
utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
|
||||
auto new_node = func_graph->NewCNode(inputs);
|
||||
|
|
|
@ -34,7 +34,7 @@ class TransposeTransDataFusion : public PatternProcessPass {
|
|||
explicit TransposeTransDataFusion(bool multigraph = true)
|
||||
: PatternProcessPass("transpose_transdata_fusion", multigraph) {
|
||||
input_varptr_ = std::make_shared<Var>();
|
||||
kernel_select_ = std::make_shared<KernelSelect>();
|
||||
supported_checker_ = std::make_shared<SupportedChecker>();
|
||||
}
|
||||
~TransposeTransDataFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
|
@ -42,7 +42,9 @@ class TransposeTransDataFusion : public PatternProcessPass {
|
|||
|
||||
private:
|
||||
VarPtr input_varptr_;
|
||||
KernelSelectPtr kernel_select_;
|
||||
|
||||
private:
|
||||
SupportedCheckerPtr supported_checker_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -329,9 +329,9 @@ void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
|
|||
size_t reduce_precision_count = 0;
|
||||
for (const auto &cnode : kernel_graph.execution_order()) {
|
||||
auto status = device::ascend::SelectKernelInfo(cnode);
|
||||
if (status == kStatusRaisePrecision) {
|
||||
if (status == device::ascend::kStatusRaisePrecision) {
|
||||
raise_precision_count++;
|
||||
} else if (status == kStatusReducePrecision) {
|
||||
} else if (status == device::ascend::kStatusReducePrecision) {
|
||||
reduce_precision_count++;
|
||||
}
|
||||
MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString();
|
||||
|
|
|
@ -27,6 +27,8 @@
|
|||
namespace mindspore {
|
||||
namespace session {
|
||||
namespace {
|
||||
constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
|
||||
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
|
||||
void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
|
||||
std::unordered_set<AnfNodePtr> *visited_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(que);
|
||||
|
@ -180,11 +182,24 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
|||
cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
|
||||
// create kernel_info from new parameter
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
std::vector<size_t> feature_map_input_indexs;
|
||||
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
|
||||
// then the node's output is a feature map output
|
||||
if (inputs.size() == 1 || std::any_of(inputs.begin() + 1, inputs.end(),
|
||||
[&](const AnfNodePtr &node) { return AnfAlgo::IsFeatureMapOutput(node); })) {
|
||||
for (size_t index = 1; index < inputs.size(); ++index) {
|
||||
auto node = inputs[index];
|
||||
if (AnfAlgo::IsFeatureMapOutput(node)) {
|
||||
feature_map_input_indexs.push_back(index);
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
|
||||
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
|
||||
}
|
||||
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
|
||||
kernel_info->SetFeatureMapFlag(true);
|
||||
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(true), cnode);
|
||||
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
|
||||
} else {
|
||||
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(false), cnode);
|
||||
}
|
||||
cnode->set_kernel_info(kernel_info);
|
||||
AnfAlgo::SetGraphId(graph_id_, cnode.get());
|
||||
|
|
|
@ -142,6 +142,7 @@ constexpr auto kLabelGotoOpName = "LabelGoto";
|
|||
|
||||
// attr key name
|
||||
constexpr auto kAttrInputNames = "input_names";
|
||||
constexpr auto kIsBackendCast = "is_backed_cast";
|
||||
constexpr auto kAttrOutputNames = "output_names";
|
||||
constexpr auto kAttrVisited = "visited";
|
||||
constexpr auto kAttrShape = "shape";
|
||||
|
@ -201,10 +202,6 @@ constexpr auto kControlDependBehindIndex = 2;
|
|||
// index define of depend
|
||||
constexpr auto kRealInputIndexInDepend = 1;
|
||||
constexpr auto kDependAttachNodeIndex = 2;
|
||||
// status of kernel select result
|
||||
const int kStatusReducePrecision = -1;
|
||||
const int kStatusRaisePrecision = 1;
|
||||
const int kStatusAllMatched = 0;
|
||||
// format
|
||||
constexpr auto kOpFormat_DEFAULT = "DefaultFormat";
|
||||
constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0";
|
||||
|
@ -218,18 +215,11 @@ constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ";
|
|||
constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0";
|
||||
constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04";
|
||||
constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04";
|
||||
const std::set<std::string> k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC,
|
||||
kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
|
||||
kOpFormat_C1HWNCoC0};
|
||||
|
||||
const std::set<std::string> k2DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_Z,
|
||||
kOpFormat_NC1KHKWHWC0};
|
||||
const std::set<std::string> k3DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0};
|
||||
const std::set<std::string> k4DSupportFormat = k1DSupportFormat;
|
||||
const std::vector<std::set<std::string>> kShapeSupportFormatMap = {k1DSupportFormat, k2DSupportFormat, k3DSupportFormat,
|
||||
k4DSupportFormat};
|
||||
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND,
|
||||
kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
|
||||
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0,
|
||||
kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
|
||||
const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN};
|
||||
|
||||
const std::set<std::string> kOptOperatorSet = {
|
||||
kMomentumOpName, kApplyMomentumOpName, kApplyAdadeltaOpName,
|
||||
kApplyAdagradOpName, kApplyAdagradDAName, kApplyAdamOpName,
|
||||
|
|
|
@ -1,345 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "mindspore/ccsrc/device/ascend/kernel_select_ascend.h"
|
||||
#include "common/common_test.h"
|
||||
#include "session/kernel_graph.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "utils/utils.h"
|
||||
#include "operator/ops.h"
|
||||
#include "mindspore/ccsrc/device/kernel_info.h"
|
||||
#include "mindspore/ccsrc/kernel/kernel_build_info.h"
|
||||
#include <vector>
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
namespace {
|
||||
using KernelInfo = device::KernelInfo;
|
||||
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
||||
using KernelBuildInfo = kernel::KernelBuildInfo;
|
||||
using KernelGraph = session::KernelGraph;
|
||||
using KernelBuildInfoPtr = std::shared_ptr<KernelBuildInfo>;
|
||||
using KernelBuilderPtr = std::shared_ptr<KernelBuildInfoBuilder>;
|
||||
using Shape = std::vector<size_t>;
|
||||
using ShapeList = std::vector<Shape>;
|
||||
enum MatchCountPriority {
|
||||
MATCH_COUNT_PRIORITY_BEGIN = 0,
|
||||
MATCH_FORMAT_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
|
||||
MATCH_DTYPE_COUNT,
|
||||
MATCH_NZ_FORMAT_COUNT,
|
||||
MATCH_5D_FORMAT_COUNT,
|
||||
MATCH_OUTPUT_DTYPE_COUNT,
|
||||
MATCH_COUNT_PRIORITY_END
|
||||
};
|
||||
|
||||
const std::set<std::string> kOpFormatList = {
|
||||
kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC,
|
||||
kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ};
|
||||
|
||||
bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) {
|
||||
// if format is default,it remarkes support all format
|
||||
if (kOpFormatList.find(format) == kOpFormatList.end()) {
|
||||
MS_EXCEPTION(ArgumentError) << "got the unknow format " << format;
|
||||
}
|
||||
if (format == kOpFormat_DEFAULT) {
|
||||
return true;
|
||||
}
|
||||
// if shape size is 0,the shape will be a scalar
|
||||
if (shape.empty()) {
|
||||
return true;
|
||||
}
|
||||
if (shape.size() > kShapeSupportFormatMap.size()) {
|
||||
return false;
|
||||
}
|
||||
if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) {
|
||||
return shape[shape.size() - 1] % 16 != 0 && shape[shape.size() - 2] % 16 != 0;
|
||||
}
|
||||
return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end());
|
||||
}
|
||||
|
||||
bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto check_function = [](const std::vector<size_t> &shape, const std::string &format) -> bool {
|
||||
if (!IsShapeMatchFormat(shape, format)) {
|
||||
return false;
|
||||
}
|
||||
for (auto shape_value : shape) {
|
||||
if (shape_value == 0) {
|
||||
MS_EXCEPTION(ArgumentError) << "dimension size of the tensor shape should be a positive integer, but got ["
|
||||
<< shape_value << "]";
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) {
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index);
|
||||
if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) {
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
|
||||
if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Check input data type
|
||||
for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
|
||||
AnfNodePtr cur_input = cnode->input(input_index + 1);
|
||||
MS_EXCEPTION_IF_NULL(cur_input);
|
||||
TypeId input_origin_type;
|
||||
if (cur_input->isa<Parameter>() && AnfAlgo::IsParameterWeight(cur_input->cast<ParameterPtr>())) {
|
||||
// weight
|
||||
input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0);
|
||||
} else {
|
||||
// feature map
|
||||
input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
|
||||
}
|
||||
if (input_origin_type == kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Check output data type
|
||||
for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) {
|
||||
if (kernel_build_info.GetOutputDeviceType(output_index) != AnfAlgo::GetOutputInferDataType(cnode, output_index)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* compare too vector by priority,select a better vector,like compare too num,first compare highest num location,if
|
||||
* equal then next num location
|
||||
* example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3]
|
||||
*/
|
||||
bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best_item) {
|
||||
MS_EXCEPTION_IF_NULL(best_item);
|
||||
if (cur_item.size() != best_item->size()) {
|
||||
MS_LOG(ERROR) << "item size should be same!";
|
||||
return false;
|
||||
}
|
||||
// Update the best_item by comparing the cur_item and best_item
|
||||
for (size_t i = 0; i < cur_item.size(); i++) {
|
||||
if (cur_item[i] > best_item->at(i)) {
|
||||
*best_item = cur_item;
|
||||
return true;
|
||||
} else if (cur_item[i] == best_item->at(i)) {
|
||||
continue;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node,
|
||||
std::vector<int> *const cur_kernelinfo_match_counts) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(cur_kernelinfo_match_counts);
|
||||
if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) {
|
||||
MS_EXCEPTION(ArgumentError) << "Out of range cur_kernelinfo_match_counts " << MATCH_COUNT_PRIORITY_END;
|
||||
}
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||
AnfNodePtr input_anf_node = kernel_node->input(input_index + 1);
|
||||
MS_EXCEPTION_IF_NULL(input_anf_node);
|
||||
// if a input parameter is a weight with default format, the input shouldn't participate the judge
|
||||
if (input_anf_node->isa<Parameter>()) {
|
||||
auto para = input_anf_node->cast<ParameterPtr>();
|
||||
if (AnfAlgo::IsParameterWeight(para) && AnfAlgo::GetOutputDeviceDataType(para, 0) == kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
|
||||
(*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++;
|
||||
}
|
||||
if (kernel_build_info.GetInputDeviceType(input_index) ==
|
||||
AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) {
|
||||
(*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT]++;
|
||||
}
|
||||
if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_FRAC_NZ) {
|
||||
(*cur_kernelinfo_match_counts)[MATCH_NZ_FORMAT_COUNT]++;
|
||||
}
|
||||
if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_NC1HWC0) {
|
||||
(*cur_kernelinfo_match_counts)[MATCH_5D_FORMAT_COUNT]++;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
|
||||
// cal count of same output dtype between abstract and kernel info
|
||||
if (kernel_build_info.GetOutputDeviceType(output_index) ==
|
||||
AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) {
|
||||
(*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SetKernelBuildInfo(KernelBuilderPtr builder) {
|
||||
builder->SetFusionType(kernel::OPAQUE);
|
||||
builder->SetKernelType(AUTO_DIFF_KERNEL);
|
||||
builder->SetProcessor(kernel::AICORE);
|
||||
}
|
||||
|
||||
void test_select(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list) {
|
||||
std::vector<int> most_match_counts = {-1, -1, -1, -1, -1};
|
||||
int selected_index = -1;
|
||||
for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
|
||||
std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0, 0};
|
||||
if (!IsValidKernelInfo(kernel_node, *(kernel_info_list[info_index]))) {
|
||||
continue;
|
||||
}
|
||||
if (!MatchInferOutputDataType(kernel_node, *(kernel_info_list[info_index]))) {
|
||||
continue;
|
||||
}
|
||||
std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index];
|
||||
UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts);
|
||||
// Currently the selection policy is the match format count first, and then is datatype counts.
|
||||
if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) {
|
||||
selected_index = SizeToInt(info_index);
|
||||
}
|
||||
}
|
||||
if (selected_index == -1) {
|
||||
MS_EXCEPTION(NotExistsError) << "" << kernel_node->DebugString() << " Cannot find valid kernel Info !";
|
||||
}
|
||||
auto index = IntToSize(selected_index);
|
||||
if (index >= kernel_info_list.size()) {
|
||||
MS_EXCEPTION(ArgumentError) << "index outof range";
|
||||
}
|
||||
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info_ptr = kernel_info_list[index];
|
||||
MS_EXCEPTION_IF_NULL(selected_kernel_info_ptr);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, kernel_node.get());
|
||||
}
|
||||
|
||||
void SetParentAbstract(std::vector<AnfNodePtr> parent_list, std::vector<std::vector<size_t>> shapes,
|
||||
std::vector<TypeId> types) {
|
||||
for (const auto &node : parent_list) {
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, node.get());
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
class AscendKernelSelctTest : public UT::Common {
|
||||
public:
|
||||
AscendKernelSelctTest() = default;
|
||||
void SetUp() override {}
|
||||
void TearDown() override {}
|
||||
};
|
||||
|
||||
TEST_F(AscendKernelSelctTest, TestSelect) {
|
||||
std::vector<KernelBuilderPtr> build_list;
|
||||
std::vector<TypeId> type_list = {kNumberTypeFloat32};
|
||||
for (size_t i = 0; i <= 4; ++i) {
|
||||
build_list.push_back(std::make_shared<KernelBuildInfoBuilder>());
|
||||
SetKernelBuildInfo(build_list[i]);
|
||||
build_list[i]->SetInputsDeviceType(type_list);
|
||||
build_list[i]->SetOutputsDeviceType(type_list);
|
||||
}
|
||||
|
||||
std::vector<std::string> nd_fmt = {kOpFormat_DEFAULT};
|
||||
std::vector<std::string> nz_fmt = {kOpFormat_FRAC_NZ};
|
||||
auto anf_graph = std::make_shared<KernelGraph>();
|
||||
|
||||
// 16's multiple should not chose format NZ
|
||||
Shape nd_shapes = {2, 32, 224, 224};
|
||||
|
||||
Shape nz_shapes = {3, 3, 5, 5};
|
||||
auto add_value = NewValueNode(prim::kPrimTensorAdd);
|
||||
auto a_node = anf_graph->NewCNode(std::vector<AnfNodePtr>{add_value});
|
||||
auto b_node = anf_graph->NewCNode(std::vector<AnfNodePtr>{add_value});
|
||||
std::vector<AnfNodePtr> parent_list = {add_value, a_node, b_node};
|
||||
|
||||
auto c_node = anf_graph->NewCNode(parent_list);
|
||||
|
||||
// a b
|
||||
// \ /
|
||||
// c
|
||||
// a & b: kernel_info:{output_format:{nz},dtype:{kNumberTypeFloat32}}
|
||||
// infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}}
|
||||
// c: infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3,224, 224}}
|
||||
|
||||
// set a & b's info
|
||||
SetParentAbstract(parent_list, ShapeList{nz_shapes}, type_list);
|
||||
// set abstract c
|
||||
AnfAlgo::SetOutputInferTypeAndShape(type_list, ShapeList{nd_shapes}, c_node.get());
|
||||
// set format of kernel info
|
||||
build_list[0]->SetOutputsFormat(nz_fmt);
|
||||
build_list[1]->SetOutputsFormat(nz_fmt);
|
||||
|
||||
build_list[2]->SetInputsFormat(std::vector<std::string>{nd_fmt[0], nd_fmt[0]});
|
||||
build_list[3]->SetInputsFormat(std::vector<std::string>{nz_fmt[0], nz_fmt[0]});
|
||||
build_list[2]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32});
|
||||
build_list[3]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32});
|
||||
build_list[2]->SetOutputsFormat(nd_fmt);
|
||||
build_list[3]->SetOutputsFormat(nz_fmt);
|
||||
std::vector<KernelBuildInfoPtr> select_info_list;
|
||||
// set select info list
|
||||
select_info_list.emplace_back(build_list[2]->Build());
|
||||
select_info_list.emplace_back(build_list[3]->Build());
|
||||
|
||||
// set device info for a & b
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_list[0]->Build(), a_node.get());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_list[1]->Build(), b_node.get());
|
||||
|
||||
test_select(c_node, select_info_list);
|
||||
EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 0), kOpFormat_DEFAULT);
|
||||
EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 1), kOpFormat_DEFAULT);
|
||||
|
||||
// set a & b's info
|
||||
// a b
|
||||
// \ /
|
||||
// c
|
||||
// a: kernel_info:{output_format:{5d},dtype:{kNumberTypeFloat32}}
|
||||
// infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}}
|
||||
// b: kernel_info:{output_format:{nz},dtype:{kNumberTypeFloat32}}
|
||||
// infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}}
|
||||
// c: infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}}
|
||||
|
||||
// set a & b's info
|
||||
SetParentAbstract(parent_list, ShapeList{nz_shapes}, type_list);
|
||||
// set abstract c
|
||||
AnfAlgo::SetOutputInferTypeAndShape(type_list, ShapeList{nz_shapes}, c_node.get());
|
||||
// set format of kernel info
|
||||
build_list[0]->SetOutputsFormat(std::vector<std::string>{kOpFormat_NC1HWC0});
|
||||
build_list[1]->SetOutputsFormat(nz_fmt);
|
||||
|
||||
build_list[2]->SetInputsFormat(std::vector<std::string>{kOpFormat_NC1HWC0, nd_fmt[0]});
|
||||
build_list[3]->SetInputsFormat(std::vector<std::string>{nd_fmt[0], nz_fmt[0]});
|
||||
build_list[2]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32});
|
||||
build_list[3]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32});
|
||||
build_list[2]->SetOutputsFormat(nd_fmt);
|
||||
build_list[3]->SetOutputsFormat(nz_fmt);
|
||||
// set select info list
|
||||
select_info_list.emplace_back(build_list[2]->Build());
|
||||
select_info_list.emplace_back(build_list[3]->Build());
|
||||
|
||||
// set device info for a & b
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_list[0]->Build(), a_node.get());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_list[1]->Build(), b_node.get());
|
||||
|
||||
test_select(c_node, select_info_list);
|
||||
EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 0), kOpFormat_DEFAULT);
|
||||
EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 1), kOpFormat_FRAC_NZ);
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -39,7 +39,7 @@ class MockSupportedChecker : public SupportedChecker {
|
|||
public:
|
||||
MockSupportedChecker() = default;
|
||||
~MockSupportedChecker() override = default;
|
||||
bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
|
||||
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
|
||||
return true;
|
||||
}
|
||||
}; // namespace opt
|
||||
|
|
|
@ -37,6 +37,15 @@ class TestHWTransposeTransdataFusion : public BackendCommon {
|
|||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
class MockSupportedChecker : public SupportedChecker {
|
||||
public:
|
||||
MockSupportedChecker() = default;
|
||||
~MockSupportedChecker() override = default;
|
||||
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
|
||||
public:
|
||||
MockInsertTransOpKernelSelectTrans4Dto5D() = default;
|
||||
|
@ -60,37 +69,6 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
|
|||
}
|
||||
};
|
||||
|
||||
class MockTransposeTransdataFusionKernelSelect : public KernelSelect {
|
||||
public:
|
||||
MockTransposeTransdataFusionKernelSelect() = default;
|
||||
~MockTransposeTransdataFusionKernelSelect() override = default;
|
||||
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
|
||||
const kernel::KernelBuildInfoPtr &new_kernel_build_info) override {
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat({kOpFormat_NCHW});
|
||||
builder.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
builder.SetInputsDeviceType({kNumberTypeFloat16});
|
||||
builder.SetOutputsDeviceType({kNumberTypeFloat16});
|
||||
builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
kernel_info_list.push_back(builder.Build());
|
||||
MS_LOG(INFO) << "transpose transdata fusion success";
|
||||
MS_LOG(INFO) << "new transdata build info input format:" << new_kernel_build_info->GetInputFormat(0)
|
||||
<< ",outputformat:" << new_kernel_build_info->GetOutputFormat(0)
|
||||
<< ",kerneltype:" << new_kernel_build_info->kernel_type()
|
||||
<< ",fusiontype:" << new_kernel_build_info->fusion_type()
|
||||
<< ",process:" << new_kernel_build_info->processor();
|
||||
auto result = std::find_if(kernel_info_list.begin(), kernel_info_list.end(),
|
||||
[&new_kernel_build_info](kernel::KernelBuildInfoPtr item) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
return *item == *new_kernel_build_info;
|
||||
});
|
||||
return result != kernel_info_list.end();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
|
||||
/*
|
||||
* def before(input0, input1):
|
||||
|
@ -128,7 +106,7 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
|
|||
insert_trans_op_pass->kernel_select_ = std::make_shared<MockInsertTransOpKernelSelectTrans4Dto5D>();
|
||||
pm->AddPass(insert_trans_op_pass);
|
||||
auto transpose_transdata_pass = std::make_shared<opt::TransposeTransDataFusion>();
|
||||
transpose_transdata_pass->kernel_select_ = std::make_shared<MockTransposeTransdataFusionKernelSelect>();
|
||||
transpose_transdata_pass->supported_checker_ = std::make_shared<MockSupportedChecker>();
|
||||
pm->AddPass(transpose_transdata_pass);
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
|
|
Loading…
Reference in New Issue