forked from mindspore-Ecosystem/mindspore
commit
ef44a0a981
|
@ -36,41 +36,6 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
constexpr char kAxis[] = "axis";
|
||||
constexpr char kTypeInt32[] = "Int32";
|
||||
const std::unordered_map<std::string, TypeId> type_id_maps = {{"float", TypeId::kNumberTypeFloat32},
|
||||
{"float16", TypeId::kNumberTypeFloat16},
|
||||
{"float32", TypeId::kNumberTypeFloat32},
|
||||
{"float64", TypeId::kNumberTypeFloat64},
|
||||
{"int", TypeId::kNumberTypeInt},
|
||||
{"int8", TypeId::kNumberTypeInt8},
|
||||
{"int16", TypeId::kNumberTypeInt16},
|
||||
{"int32", TypeId::kNumberTypeInt32},
|
||||
{"int64", TypeId::kNumberTypeInt64},
|
||||
{"uint", TypeId::kNumberTypeUInt},
|
||||
{"uint8", TypeId::kNumberTypeUInt8},
|
||||
{"uint16", TypeId::kNumberTypeUInt16},
|
||||
{"uint32", TypeId::kNumberTypeUInt32},
|
||||
{"uint64", TypeId::kNumberTypeUInt64},
|
||||
{"bool", TypeId::kNumberTypeBool},
|
||||
{"complex64", TypeId::kNumberTypeComplex64},
|
||||
{"complex128", TypeId::kNumberTypeComplex128}};
|
||||
|
||||
const std::map<TypeId, std::string> type_id_str_map = {{TypeId::kNumberTypeFloat32, "float32"},
|
||||
{TypeId::kNumberTypeFloat16, "float16"},
|
||||
{TypeId::kNumberTypeFloat, "float"},
|
||||
{TypeId::kNumberTypeFloat64, "float64"},
|
||||
{TypeId::kNumberTypeInt, "int"},
|
||||
{TypeId::kNumberTypeInt8, "int8"},
|
||||
{TypeId::kNumberTypeInt16, "int16"},
|
||||
{TypeId::kNumberTypeInt32, "int32"},
|
||||
{TypeId::kNumberTypeInt64, "int64"},
|
||||
{TypeId::kNumberTypeUInt, "uint"},
|
||||
{TypeId::kNumberTypeUInt8, "uint8"},
|
||||
{TypeId::kNumberTypeUInt16, "uint16"},
|
||||
{TypeId::kNumberTypeUInt32, "uint32"},
|
||||
{TypeId::kNumberTypeUInt64, "uint64"},
|
||||
{TypeId::kNumberTypeBool, "bool"},
|
||||
{TypeId::kNumberTypeComplex64, "complex64"},
|
||||
{TypeId::kNumberTypeComplex128, "complex128"}};
|
||||
|
||||
const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
|
||||
{"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
|
||||
|
|
|
@ -47,6 +47,48 @@ constexpr unsigned int AUTODIFF_COMPILE_OVERTIME = 600;
|
|||
constexpr auto kArgDataformat = "data_format";
|
||||
|
||||
const std::vector<std::string> support_devices = {"aicore", "aicpu", "cuda"};
|
||||
const std::unordered_map<std::string, TypeId> type_id_maps = {
|
||||
{"float", TypeId::kNumberTypeFloat32},
|
||||
{"float16", TypeId::kNumberTypeFloat16},
|
||||
{"float32", TypeId::kNumberTypeFloat32},
|
||||
{"float64", TypeId::kNumberTypeFloat64},
|
||||
{"int", TypeId::kNumberTypeInt},
|
||||
{"int8", TypeId::kNumberTypeInt8},
|
||||
{"int16", TypeId::kNumberTypeInt16},
|
||||
{"int32", TypeId::kNumberTypeInt32},
|
||||
{"int64", TypeId::kNumberTypeInt64},
|
||||
{"uint", TypeId::kNumberTypeUInt},
|
||||
{"uint8", TypeId::kNumberTypeUInt8},
|
||||
{"uint16", TypeId::kNumberTypeUInt16},
|
||||
{"uint32", TypeId::kNumberTypeUInt32},
|
||||
{"uint64", TypeId::kNumberTypeUInt64},
|
||||
{"bool", TypeId::kNumberTypeBool},
|
||||
{"int4", TypeId::kNumberTypeInt4},
|
||||
{"complex64", TypeId::kNumberTypeComplex64},
|
||||
{"complex128", TypeId::kNumberTypeComplex128},
|
||||
{"", TypeId::kMetaTypeNone},
|
||||
};
|
||||
const std::map<TypeId, std::string> type_id_str_map = {
|
||||
{TypeId::kNumberTypeFloat32, "float32"},
|
||||
{TypeId::kNumberTypeFloat16, "float16"},
|
||||
{TypeId::kNumberTypeFloat, "float32"},
|
||||
{TypeId::kNumberTypeFloat64, "float64"},
|
||||
{TypeId::kNumberTypeInt, "int"},
|
||||
{TypeId::kNumberTypeInt8, "int8"},
|
||||
{TypeId::kNumberTypeInt16, "int16"},
|
||||
{TypeId::kNumberTypeInt32, "int32"},
|
||||
{TypeId::kNumberTypeInt64, "int64"},
|
||||
{TypeId::kNumberTypeUInt, "uint"},
|
||||
{TypeId::kNumberTypeUInt8, "uint8"},
|
||||
{TypeId::kNumberTypeUInt16, "uint16"},
|
||||
{TypeId::kNumberTypeUInt32, "uint32"},
|
||||
{TypeId::kNumberTypeUInt64, "uint64"},
|
||||
{TypeId::kNumberTypeBool, "int8"},
|
||||
{TypeId::kNumberTypeInt4, "int4"},
|
||||
{TypeId::kNumberTypeComplex64, "complex64"},
|
||||
{TypeId::kNumberTypeComplex128, "complex128"},
|
||||
{TypeId::kMetaTypeNone, ""},
|
||||
};
|
||||
|
||||
struct KernelMetaInfo {
|
||||
uintptr_t func_stub_;
|
||||
|
|
|
@ -62,7 +62,6 @@ int TypeStrToDstType(const std::string &type_str) {
|
|||
}
|
||||
return kInvalid;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
std::unordered_set<std::string> TbeAdapter::input_order_adjusted_ops_ = {kConv2DBackpropInputOpName,
|
||||
kConv2DBackpropFilterOpName,
|
||||
|
|
|
@ -22,53 +22,11 @@
|
|||
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace tbe {
|
||||
const std::unordered_map<std::string, TypeId> type_str_id_maps = {
|
||||
{"float", TypeId::kNumberTypeFloat32},
|
||||
{"float16", TypeId::kNumberTypeFloat16},
|
||||
{"float32", TypeId::kNumberTypeFloat32},
|
||||
{"float64", TypeId::kNumberTypeFloat64},
|
||||
{"int", TypeId::kNumberTypeInt},
|
||||
{"int8", TypeId::kNumberTypeInt8},
|
||||
{"int16", TypeId::kNumberTypeInt16},
|
||||
{"int32", TypeId::kNumberTypeInt32},
|
||||
{"int64", TypeId::kNumberTypeInt64},
|
||||
{"uint", TypeId::kNumberTypeUInt},
|
||||
{"uint8", TypeId::kNumberTypeUInt8},
|
||||
{"uint16", TypeId::kNumberTypeUInt16},
|
||||
{"uint32", TypeId::kNumberTypeUInt32},
|
||||
{"uint64", TypeId::kNumberTypeUInt64},
|
||||
{"bool", TypeId::kNumberTypeBool},
|
||||
{"int4", TypeId::kNumberTypeInt4},
|
||||
{"complex64", TypeId::kNumberTypeComplex64},
|
||||
{"complex128", TypeId::kNumberTypeComplex128},
|
||||
{"", TypeId::kMetaTypeNone},
|
||||
};
|
||||
|
||||
const std::map<TypeId, std::string> type_id_str_maps = {
|
||||
{TypeId::kNumberTypeFloat32, "float32"},
|
||||
{TypeId::kNumberTypeFloat16, "float16"},
|
||||
{TypeId::kNumberTypeFloat, "float32"},
|
||||
{TypeId::kNumberTypeFloat64, "float64"},
|
||||
{TypeId::kNumberTypeInt, "int"},
|
||||
{TypeId::kNumberTypeInt8, "int8"},
|
||||
{TypeId::kNumberTypeInt16, "int16"},
|
||||
{TypeId::kNumberTypeInt32, "int32"},
|
||||
{TypeId::kNumberTypeInt64, "int64"},
|
||||
{TypeId::kNumberTypeUInt, "uint"},
|
||||
{TypeId::kNumberTypeUInt8, "uint8"},
|
||||
{TypeId::kNumberTypeUInt16, "uint16"},
|
||||
{TypeId::kNumberTypeUInt32, "uint32"},
|
||||
{TypeId::kNumberTypeUInt64, "uint64"},
|
||||
{TypeId::kNumberTypeBool, "int8"},
|
||||
{TypeId::kNumberTypeInt4, "int4"},
|
||||
{TypeId::kNumberTypeComplex64, "complex64"},
|
||||
{TypeId::kNumberTypeComplex128, "complex128"},
|
||||
{TypeId::kMetaTypeNone, ""},
|
||||
};
|
||||
|
||||
const std::unordered_map<std::string, size_t> type_nbyte_maps = {
|
||||
{"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
|
||||
|
@ -79,16 +37,16 @@ const std::unordered_map<std::string, size_t> type_nbyte_maps = {
|
|||
};
|
||||
|
||||
TypeId DtypeToTypeId(const std::string &dtypes) {
|
||||
auto iter = type_str_id_maps.find(dtypes);
|
||||
if (iter == type_str_id_maps.end()) {
|
||||
auto iter = type_id_maps.find(dtypes);
|
||||
if (iter == type_id_maps.end()) {
|
||||
MS_LOG(EXCEPTION) << "Illegal input device dtype: " << dtypes;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
std::string TypeIdToString(TypeId type_id) {
|
||||
auto iter = type_id_str_maps.find(type_id);
|
||||
if (iter == type_id_str_maps.end()) {
|
||||
auto iter = type_id_str_map.find(type_id);
|
||||
if (iter == type_id_str_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Illegal input dtype: " << TypeIdLabel(type_id);
|
||||
}
|
||||
return iter->second;
|
||||
|
|
|
@ -55,8 +55,7 @@ void GetRealInputSize(const nlohmann::json &input_json, std::vector<size_t> *inp
|
|||
input_size_list->push_back((*size_i));
|
||||
}
|
||||
|
||||
void GetInputSizeList(const nlohmann::json &input_json, std::vector<size_t> *input_size_list,
|
||||
const AnfNodePtr &anf_node) {
|
||||
void GetInputSizeList(const nlohmann::json &input_json, std::vector<size_t> *input_size_list) {
|
||||
for (size_t i = 0; i < input_json.size(); i++) {
|
||||
if (input_json[i].is_array()) {
|
||||
for (size_t m = 0; m < input_json[i].size(); m++) {
|
||||
|
@ -103,8 +102,7 @@ void GetRealOutputSize(const nlohmann::json &output_json, std::vector<size_t> *o
|
|||
output_size_list->push_back((*size_i));
|
||||
}
|
||||
|
||||
void GetOutputSizeList(const nlohmann::json &output_json, std::vector<size_t> *output_size_list,
|
||||
const AnfNodePtr &anf_node) {
|
||||
void GetOutputSizeList(const nlohmann::json &output_json, std::vector<size_t> *output_size_list) {
|
||||
for (size_t i = 0; i < output_json.size(); i++) {
|
||||
if (output_json[i].is_array()) {
|
||||
for (size_t m = 0; m < output_json[i].size(); m++) {
|
||||
|
@ -140,8 +138,8 @@ bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<si
|
|||
for (size_t i = 0; i < op_list.size(); i++) {
|
||||
auto op_info = op_list[i];
|
||||
if (op_info["type"] != "Data") {
|
||||
GetInputSizeList(op_info["input_desc"], input_size_list, anf_node);
|
||||
GetOutputSizeList(op_info["output_desc"], output_size_list, anf_node);
|
||||
GetInputSizeList(op_info["input_desc"], input_size_list);
|
||||
GetOutputSizeList(op_info["output_desc"], output_size_list);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -1448,56 +1448,59 @@ void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPt
|
|||
MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.cur_event_num();
|
||||
}
|
||||
|
||||
std::vector<std::pair<uint32_t, vector<size_t>>> AscendStreamAssign::GetStreamIDHcomMap(
|
||||
std::vector<CNodePtr> cnode_ptr_list, std::string group, size_t graph_id) {
|
||||
std::vector<std::pair<uint32_t, vector<size_t>>> stream_indices;
|
||||
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
|
||||
auto cur_cnode = cnode_ptr_list[i];
|
||||
if (!IsHcom(cur_cnode)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
|
||||
auto group_name = GetHcomGroup(cur_cnode);
|
||||
auto cur_graph_id = AnfAlgo::GetGraphId(cur_cnode.get());
|
||||
MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name
|
||||
<< "; stream id:" << cur_stream_id;
|
||||
if (group_name != group || cur_graph_id != graph_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool exit = false;
|
||||
for (auto &item : stream_indices) {
|
||||
if (item.first == cur_stream_id) {
|
||||
item.second.emplace_back(i);
|
||||
exit = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!exit) {
|
||||
stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
|
||||
}
|
||||
}
|
||||
return stream_indices;
|
||||
}
|
||||
|
||||
void AscendStreamAssign::InsertEventHcomDependHcomAtSameGroup(
|
||||
const NotNull<KernelGraphPtr> &graph_ptr, std::pair<std::string, std::map<uint32_t, std::set<uint32_t>>> group_item) {
|
||||
for (const auto &graph_item : group_item.second) {
|
||||
auto stream_indices = GetStreamIDHcomMap(graph_ptr->execution_order(), group_item.first, graph_item.first);
|
||||
constexpr size_t kStreamMax = 2;
|
||||
if (stream_indices.size() < kStreamMax) {
|
||||
MS_LOG(INFO) << "Group:" << group_item.first << ", Graph: " << graph_item.first
|
||||
<< " different stream hcom size is less than 2, no need insert event between them";
|
||||
continue;
|
||||
}
|
||||
InsertEventBetweenHcom(graph_ptr, stream_indices);
|
||||
}
|
||||
}
|
||||
|
||||
void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
if (group_hcom_graph_map_.empty()) {
|
||||
return;
|
||||
}
|
||||
for (const auto &group_item : group_hcom_graph_map_) {
|
||||
auto group = group_item.first;
|
||||
for (const auto &graph_item : group_item.second) {
|
||||
auto graph_id = graph_item.first;
|
||||
auto cnode_ptr_list = graph_ptr->execution_order();
|
||||
std::vector<std::pair<uint32_t, vector<size_t>>> stream_indices;
|
||||
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
|
||||
auto cur_cnode = cnode_ptr_list[i];
|
||||
if (!IsHcom(cur_cnode)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
|
||||
auto group_name = GetHcomGroup(cur_cnode);
|
||||
auto cur_graph_id = AnfAlgo::GetGraphId(cur_cnode.get());
|
||||
MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name
|
||||
<< "; stream id:" << cur_stream_id;
|
||||
if (group_name != group || cur_graph_id != graph_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (stream_indices.empty()) {
|
||||
stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
|
||||
} else {
|
||||
bool exit = false;
|
||||
for (auto &item : stream_indices) {
|
||||
if (item.first == cur_stream_id) {
|
||||
item.second.emplace_back(i);
|
||||
exit = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!exit) {
|
||||
stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constexpr size_t kStreamMax = 2;
|
||||
if (stream_indices.size() < kStreamMax) {
|
||||
MS_LOG(INFO) << "Group:" << group
|
||||
<< "; different stream hcom size is less than 2, no need insert event between them";
|
||||
continue;
|
||||
}
|
||||
InsertEventBetweenHcom(graph_ptr, stream_indices);
|
||||
}
|
||||
InsertEventHcomDependHcomAtSameGroup(graph_ptr, group_item);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -98,6 +98,10 @@ class AscendStreamAssign {
|
|||
void InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||
void InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr,
|
||||
const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index);
|
||||
void InsertEventHcomDependHcomAtSameGroup(const NotNull<KernelGraphPtr> &graph_ptr,
|
||||
std::pair<std::string, std::map<uint32_t, std::set<uint32_t>>> group_item);
|
||||
std::vector<std::pair<uint32_t, vector<size_t>>> GetStreamIDHcomMap(std::vector<CNodePtr> cnode_ptr_list,
|
||||
std::string group, size_t graph_id);
|
||||
|
||||
void AdjustAtomicAddrCleanOrder(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||
vector<CNodePtr> GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &cur_cnode_ptr);
|
||||
|
|
Loading…
Reference in New Issue