forked from mindspore-Ecosystem/mindspore
add aicpu kernel info select
This commit is contained in:
parent
af7c54b12a
commit
921ffccdbb
|
@ -1,4 +1,3 @@
|
|||
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
|
@ -95,33 +94,53 @@ enum DataTypeTransMode {
|
|||
FROM_FLOAT_TO_INT32,
|
||||
FROM_FLOAT16_TO_FLOAT,
|
||||
FROM_FLOAT16_TO_INT32,
|
||||
FROM_FLOAT16_TO_UINT8,
|
||||
FROM_INT32_TO_FLOAT,
|
||||
FROM_INT32_TO_FLOAT16,
|
||||
FROM_INT32_TO_UINT8,
|
||||
FROM_INT32_TO_INT8,
|
||||
FROM_INT32_TO_BOOL,
|
||||
FROM_UINT8_TO_FLOAT,
|
||||
FROM_UINT8_TO_INT32,
|
||||
FROM_UINT8_TO_FLOAT16,
|
||||
FROM_INT8_TO_FLOAT,
|
||||
FROM_INT8_TO_FLOAT16,
|
||||
FROM_INT8_TO_INT32,
|
||||
FROM_INT64_TO_INT32,
|
||||
FROM_UINT16_TO_INT32,
|
||||
FROM_BOOL_TO_FLOAT,
|
||||
FROM_BOOL_TO_INT32,
|
||||
FROM_BOOL_TO_UINT8,
|
||||
FROM_BOOL_TO_FLOAT16,
|
||||
FROM_FLOAT64_TO_FLOAT32,
|
||||
FROM_FLOAT32_TO_FLOAT64
|
||||
};
|
||||
|
||||
const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeFloat64, kNumberTypeFloat32), FROM_FLOAT64_TO_FLOAT32},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat64), FROM_FLOAT32_TO_FLOAT64},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat16), FROM_FLOAT_TO_FLOAT16},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeInt32), FROM_FLOAT_TO_INT32},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeFloat32), FROM_FLOAT16_TO_FLOAT},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeInt32), FROM_FLOAT16_TO_INT32},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeUInt8), FROM_FLOAT16_TO_UINT8},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat32), FROM_INT32_TO_FLOAT},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat16), FROM_INT32_TO_FLOAT16},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeUInt8), FROM_INT32_TO_UINT8},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt8), FROM_INT32_TO_INT8},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeBool), FROM_INT32_TO_BOOL},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat32), FROM_UINT8_TO_FLOAT},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeInt32), FROM_UINT8_TO_INT32},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat16), FROM_UINT8_TO_FLOAT16},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat32), FROM_INT8_TO_FLOAT},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat16), FROM_INT8_TO_FLOAT16},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeInt32), FROM_INT8_TO_INT32},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), FROM_INT64_TO_INT32},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32}};
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeInt32), FROM_BOOL_TO_INT32},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat), FROM_BOOL_TO_FLOAT},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeUInt8), FROM_BOOL_TO_UINT8},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), FROM_BOOL_TO_FLOAT16}};
|
||||
|
||||
void CheckMemSize(const TypeIdArgs &args) {
|
||||
auto src_type_size = TypeIdSize(args.host_data_type);
|
||||
|
@ -154,54 +173,46 @@ void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size
|
|||
}
|
||||
|
||||
bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) {
|
||||
switch (mode) {
|
||||
case FROM_FLOAT_TO_FLOAT16:
|
||||
device::FloatToHalf(dst, args.data, data_size);
|
||||
break;
|
||||
case FROM_INT32_TO_FLOAT16:
|
||||
TransDataSrc2Fp16<int32_t>(args, dst, data_size);
|
||||
break;
|
||||
case FROM_FLOAT16_TO_FLOAT:
|
||||
device::HalfToFloat(dst, args.data, data_size);
|
||||
break;
|
||||
case FROM_FLOAT_TO_INT32:
|
||||
TransDataSrc2Dst<float, int32_t>(args, dst, data_size);
|
||||
break;
|
||||
case FROM_FLOAT16_TO_INT32:
|
||||
TransDataSrc2Dst<float16, int32_t>(args, dst, data_size);
|
||||
break;
|
||||
case FROM_INT32_TO_FLOAT:
|
||||
TransDataSrc2Dst<int32_t, float>(args, dst, data_size);
|
||||
break;
|
||||
case FROM_INT32_TO_INT8:
|
||||
TransDataSrc2Dst<int32_t, int8_t>(args, dst, data_size);
|
||||
break;
|
||||
case FROM_INT32_TO_UINT8:
|
||||
TransDataSrc2Dst<int32_t, uint8_t>(args, dst, data_size);
|
||||
break;
|
||||
case FROM_UINT8_TO_INT32:
|
||||
TransDataSrc2Dst<uint8_t, int32_t>(args, dst, data_size);
|
||||
break;
|
||||
case FROM_UINT8_TO_FLOAT:
|
||||
TransDataSrc2Dst<uint8_t, float>(args, dst, data_size);
|
||||
break;
|
||||
case FROM_INT8_TO_FLOAT:
|
||||
TransDataSrc2Dst<int8_t, float>(args, dst, data_size);
|
||||
break;
|
||||
case FROM_INT8_TO_INT32:
|
||||
TransDataSrc2Dst<int8_t, int32_t>(args, dst, data_size);
|
||||
break;
|
||||
case FROM_INT64_TO_INT32:
|
||||
TransDataSrc2Dst<int64_t, int32_t>(args, dst, data_size);
|
||||
break;
|
||||
case FROM_UINT16_TO_INT32:
|
||||
TransDataSrc2Dst<uint16_t, int32_t>(args, dst, data_size);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupported datatype trans";
|
||||
return false;
|
||||
using DtypeKernel = std::function<void(const TypeIdArgs &, void *, const size_t)>;
|
||||
const std::map<DataTypeTransMode, DtypeKernel> cast_kernel_map{
|
||||
{FROM_FLOAT_TO_INT32, TransDataSrc2Dst<float, int32_t>},
|
||||
{FROM_FLOAT64_TO_FLOAT32, TransDataSrc2Dst<double, float>},
|
||||
{FROM_FLOAT32_TO_FLOAT64, TransDataSrc2Dst<float, double>},
|
||||
{FROM_FLOAT16_TO_INT32, TransDataSrc2Dst<float16, int32_t>},
|
||||
{FROM_FLOAT16_TO_UINT8, TransDataSrc2Dst<float16, uint8_t>},
|
||||
{FROM_INT32_TO_FLOAT, TransDataSrc2Dst<int32_t, float>},
|
||||
{FROM_INT32_TO_INT8, TransDataSrc2Dst<int32_t, int8_t>},
|
||||
{FROM_INT32_TO_UINT8, TransDataSrc2Dst<int32_t, uint8_t>},
|
||||
{FROM_INT32_TO_BOOL, TransDataSrc2Dst<int32_t, int8_t>},
|
||||
{FROM_INT32_TO_FLOAT16, TransDataSrc2Fp16<int32_t>},
|
||||
{FROM_UINT8_TO_FLOAT, TransDataSrc2Dst<uint8_t, float>},
|
||||
{FROM_UINT8_TO_INT32, TransDataSrc2Dst<uint8_t, int32_t>},
|
||||
{FROM_UINT8_TO_FLOAT16, TransDataSrc2Fp16<uint8_t>},
|
||||
{FROM_INT8_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
|
||||
{FROM_INT8_TO_FLOAT16, TransDataSrc2Fp16<int8_t>},
|
||||
{FROM_INT8_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
|
||||
{FROM_INT64_TO_INT32, TransDataSrc2Dst<int64_t, int32_t>},
|
||||
{FROM_UINT16_TO_INT32, TransDataSrc2Dst<uint16_t, int32_t>},
|
||||
{FROM_BOOL_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
|
||||
{FROM_BOOL_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
|
||||
{FROM_BOOL_TO_UINT8, TransDataSrc2Dst<int8_t, uint8_t>},
|
||||
{FROM_BOOL_TO_FLOAT16, TransDataSrc2Fp16<int8_t>}};
|
||||
|
||||
if (mode == FROM_FLOAT_TO_FLOAT16) {
|
||||
device::FloatToHalf(dst, args.data, data_size);
|
||||
return true;
|
||||
} else if (mode == FROM_FLOAT16_TO_FLOAT) {
|
||||
device::HalfToFloat(dst, args.data, data_size);
|
||||
return true;
|
||||
}
|
||||
auto iter = cast_kernel_map.find(mode);
|
||||
if (iter != cast_kernel_map.end()) {
|
||||
iter->second(args, dst, data_size);
|
||||
return true;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported datatype trans";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t CubeSizeByType(const TypeId data_type) {
|
||||
|
|
|
@ -464,14 +464,12 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis
|
|||
}
|
||||
} // namespace
|
||||
|
||||
int SelectKernelInfo(const CNodePtr &kernel_node) {
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||
int status = kStatusAllMatched;
|
||||
std::shared_ptr<kernel::KernelBuildInfo> CanHitKernelInfo(
|
||||
int *status, const CNodePtr &kernel_node,
|
||||
const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
bool precision_reduce = false;
|
||||
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
|
||||
kernel::KernelQuery(kernel_node, &kernel_info_list);
|
||||
// filter kernel info matched with me infered type
|
||||
auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list);
|
||||
if (!filtered_kernel_info_list.empty()) {
|
||||
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
|
||||
|
@ -481,15 +479,34 @@ int SelectKernelInfo(const CNodePtr &kernel_node) {
|
|||
FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce);
|
||||
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
|
||||
if (selected_kernel_info == nullptr) {
|
||||
std::ostringstream buffer;
|
||||
PrintInputAndOutputInferType(buffer, kernel_node);
|
||||
MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString()
|
||||
<< "] cannot find valid kernel info, not supported the type" << buffer.str();
|
||||
return nullptr;
|
||||
} else {
|
||||
PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce);
|
||||
status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
|
||||
*status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
|
||||
}
|
||||
}
|
||||
return selected_kernel_info;
|
||||
}
|
||||
|
||||
int SelectKernelInfo(const CNodePtr &kernel_node) {
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||
int status = kStatusAllMatched;
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel::KernelQuery(kernel_node, &kernel_info_list);
|
||||
// filter kernel info matched with me infered type
|
||||
auto selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list);
|
||||
if (selected_kernel_info == nullptr) {
|
||||
MS_LOG(WARNING) << "The node [" << kernel_node->DebugString()
|
||||
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info";
|
||||
kernel::AicpuQuery(kernel_node, &kernel_info_list);
|
||||
selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list);
|
||||
}
|
||||
if (selected_kernel_info == nullptr) {
|
||||
std::ostringstream buffer;
|
||||
PrintInputAndOutputInferType(buffer, kernel_node);
|
||||
MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString()
|
||||
<< "] cannot find valid kernel info, not supported the type " << buffer.str();
|
||||
}
|
||||
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
|
||||
// Set format and data type for input tensor.
|
||||
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
|
||||
|
|
|
@ -67,5 +67,13 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
|
|||
}
|
||||
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
|
||||
}
|
||||
|
||||
void AicpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||
kernel_info_list->clear();
|
||||
AicpuMetadataInfo(kernel_node, kernel_info_list);
|
||||
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
|
||||
void AicpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_
|
||||
|
|
Loading…
Reference in New Issue