Merge pull request !24692 from hwjiaorui/master
This commit is contained in:
i-robot 2021-10-12 07:33:51 +00:00 committed by Gitee
commit ef44a0a981
7 changed files with 103 additions and 134 deletions

View File

@ -36,41 +36,6 @@ namespace mindspore {
namespace kernel {
constexpr char kAxis[] = "axis";
constexpr char kTypeInt32[] = "Int32";
const std::unordered_map<std::string, TypeId> type_id_maps = {{"float", TypeId::kNumberTypeFloat32},
{"float16", TypeId::kNumberTypeFloat16},
{"float32", TypeId::kNumberTypeFloat32},
{"float64", TypeId::kNumberTypeFloat64},
{"int", TypeId::kNumberTypeInt},
{"int8", TypeId::kNumberTypeInt8},
{"int16", TypeId::kNumberTypeInt16},
{"int32", TypeId::kNumberTypeInt32},
{"int64", TypeId::kNumberTypeInt64},
{"uint", TypeId::kNumberTypeUInt},
{"uint8", TypeId::kNumberTypeUInt8},
{"uint16", TypeId::kNumberTypeUInt16},
{"uint32", TypeId::kNumberTypeUInt32},
{"uint64", TypeId::kNumberTypeUInt64},
{"bool", TypeId::kNumberTypeBool},
{"complex64", TypeId::kNumberTypeComplex64},
{"complex128", TypeId::kNumberTypeComplex128}};
const std::map<TypeId, std::string> type_id_str_map = {{TypeId::kNumberTypeFloat32, "float32"},
{TypeId::kNumberTypeFloat16, "float16"},
{TypeId::kNumberTypeFloat, "float"},
{TypeId::kNumberTypeFloat64, "float64"},
{TypeId::kNumberTypeInt, "int"},
{TypeId::kNumberTypeInt8, "int8"},
{TypeId::kNumberTypeInt16, "int16"},
{TypeId::kNumberTypeInt32, "int32"},
{TypeId::kNumberTypeInt64, "int64"},
{TypeId::kNumberTypeUInt, "uint"},
{TypeId::kNumberTypeUInt8, "uint8"},
{TypeId::kNumberTypeUInt16, "uint16"},
{TypeId::kNumberTypeUInt32, "uint32"},
{TypeId::kNumberTypeUInt64, "uint64"},
{TypeId::kNumberTypeBool, "bool"},
{TypeId::kNumberTypeComplex64, "complex64"},
{TypeId::kNumberTypeComplex128, "complex128"}};
const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
{"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},

View File

@ -47,6 +47,48 @@ constexpr unsigned int AUTODIFF_COMPILE_OVERTIME = 600;
constexpr auto kArgDataformat = "data_format";
const std::vector<std::string> support_devices = {"aicore", "aicpu", "cuda"};
const std::unordered_map<std::string, TypeId> type_id_maps = {
{"float", TypeId::kNumberTypeFloat32},
{"float16", TypeId::kNumberTypeFloat16},
{"float32", TypeId::kNumberTypeFloat32},
{"float64", TypeId::kNumberTypeFloat64},
{"int", TypeId::kNumberTypeInt},
{"int8", TypeId::kNumberTypeInt8},
{"int16", TypeId::kNumberTypeInt16},
{"int32", TypeId::kNumberTypeInt32},
{"int64", TypeId::kNumberTypeInt64},
{"uint", TypeId::kNumberTypeUInt},
{"uint8", TypeId::kNumberTypeUInt8},
{"uint16", TypeId::kNumberTypeUInt16},
{"uint32", TypeId::kNumberTypeUInt32},
{"uint64", TypeId::kNumberTypeUInt64},
{"bool", TypeId::kNumberTypeBool},
{"int4", TypeId::kNumberTypeInt4},
{"complex64", TypeId::kNumberTypeComplex64},
{"complex128", TypeId::kNumberTypeComplex128},
{"", TypeId::kMetaTypeNone},
};
const std::map<TypeId, std::string> type_id_str_map = {
{TypeId::kNumberTypeFloat32, "float32"},
{TypeId::kNumberTypeFloat16, "float16"},
{TypeId::kNumberTypeFloat, "float32"},
{TypeId::kNumberTypeFloat64, "float64"},
{TypeId::kNumberTypeInt, "int"},
{TypeId::kNumberTypeInt8, "int8"},
{TypeId::kNumberTypeInt16, "int16"},
{TypeId::kNumberTypeInt32, "int32"},
{TypeId::kNumberTypeInt64, "int64"},
{TypeId::kNumberTypeUInt, "uint"},
{TypeId::kNumberTypeUInt8, "uint8"},
{TypeId::kNumberTypeUInt16, "uint16"},
{TypeId::kNumberTypeUInt32, "uint32"},
{TypeId::kNumberTypeUInt64, "uint64"},
{TypeId::kNumberTypeBool, "int8"},
{TypeId::kNumberTypeInt4, "int4"},
{TypeId::kNumberTypeComplex64, "complex64"},
{TypeId::kNumberTypeComplex128, "complex128"},
{TypeId::kMetaTypeNone, ""},
};
struct KernelMetaInfo {
uintptr_t func_stub_;

View File

@ -62,7 +62,6 @@ int TypeStrToDstType(const std::string &type_str) {
}
return kInvalid;
}
} // namespace
std::unordered_set<std::string> TbeAdapter::input_order_adjusted_ops_ = {kConv2DBackpropInputOpName,
kConv2DBackpropFilterOpName,

View File

@ -22,53 +22,11 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/ms_utils.h"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace kernel {
namespace tbe {
const std::unordered_map<std::string, TypeId> type_str_id_maps = {
{"float", TypeId::kNumberTypeFloat32},
{"float16", TypeId::kNumberTypeFloat16},
{"float32", TypeId::kNumberTypeFloat32},
{"float64", TypeId::kNumberTypeFloat64},
{"int", TypeId::kNumberTypeInt},
{"int8", TypeId::kNumberTypeInt8},
{"int16", TypeId::kNumberTypeInt16},
{"int32", TypeId::kNumberTypeInt32},
{"int64", TypeId::kNumberTypeInt64},
{"uint", TypeId::kNumberTypeUInt},
{"uint8", TypeId::kNumberTypeUInt8},
{"uint16", TypeId::kNumberTypeUInt16},
{"uint32", TypeId::kNumberTypeUInt32},
{"uint64", TypeId::kNumberTypeUInt64},
{"bool", TypeId::kNumberTypeBool},
{"int4", TypeId::kNumberTypeInt4},
{"complex64", TypeId::kNumberTypeComplex64},
{"complex128", TypeId::kNumberTypeComplex128},
{"", TypeId::kMetaTypeNone},
};
const std::map<TypeId, std::string> type_id_str_maps = {
{TypeId::kNumberTypeFloat32, "float32"},
{TypeId::kNumberTypeFloat16, "float16"},
{TypeId::kNumberTypeFloat, "float32"},
{TypeId::kNumberTypeFloat64, "float64"},
{TypeId::kNumberTypeInt, "int"},
{TypeId::kNumberTypeInt8, "int8"},
{TypeId::kNumberTypeInt16, "int16"},
{TypeId::kNumberTypeInt32, "int32"},
{TypeId::kNumberTypeInt64, "int64"},
{TypeId::kNumberTypeUInt, "uint"},
{TypeId::kNumberTypeUInt8, "uint8"},
{TypeId::kNumberTypeUInt16, "uint16"},
{TypeId::kNumberTypeUInt32, "uint32"},
{TypeId::kNumberTypeUInt64, "uint64"},
{TypeId::kNumberTypeBool, "int8"},
{TypeId::kNumberTypeInt4, "int4"},
{TypeId::kNumberTypeComplex64, "complex64"},
{TypeId::kNumberTypeComplex128, "complex128"},
{TypeId::kMetaTypeNone, ""},
};
const std::unordered_map<std::string, size_t> type_nbyte_maps = {
{"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
@ -79,16 +37,16 @@ const std::unordered_map<std::string, size_t> type_nbyte_maps = {
};
TypeId DtypeToTypeId(const std::string &dtypes) {
auto iter = type_str_id_maps.find(dtypes);
if (iter == type_str_id_maps.end()) {
auto iter = type_id_maps.find(dtypes);
if (iter == type_id_maps.end()) {
MS_LOG(EXCEPTION) << "Illegal input device dtype: " << dtypes;
}
return iter->second;
}
std::string TypeIdToString(TypeId type_id) {
auto iter = type_id_str_maps.find(type_id);
if (iter == type_id_str_maps.end()) {
auto iter = type_id_str_map.find(type_id);
if (iter == type_id_str_map.end()) {
MS_LOG(EXCEPTION) << "Illegal input dtype: " << TypeIdLabel(type_id);
}
return iter->second;

View File

@ -55,8 +55,7 @@ void GetRealInputSize(const nlohmann::json &input_json, std::vector<size_t> *inp
input_size_list->push_back((*size_i));
}
void GetInputSizeList(const nlohmann::json &input_json, std::vector<size_t> *input_size_list,
const AnfNodePtr &anf_node) {
void GetInputSizeList(const nlohmann::json &input_json, std::vector<size_t> *input_size_list) {
for (size_t i = 0; i < input_json.size(); i++) {
if (input_json[i].is_array()) {
for (size_t m = 0; m < input_json[i].size(); m++) {
@ -103,8 +102,7 @@ void GetRealOutputSize(const nlohmann::json &output_json, std::vector<size_t> *o
output_size_list->push_back((*size_i));
}
void GetOutputSizeList(const nlohmann::json &output_json, std::vector<size_t> *output_size_list,
const AnfNodePtr &anf_node) {
void GetOutputSizeList(const nlohmann::json &output_json, std::vector<size_t> *output_size_list) {
for (size_t i = 0; i < output_json.size(); i++) {
if (output_json[i].is_array()) {
for (size_t m = 0; m < output_json[i].size(); m++) {
@ -140,8 +138,8 @@ bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<si
for (size_t i = 0; i < op_list.size(); i++) {
auto op_info = op_list[i];
if (op_info["type"] != "Data") {
GetInputSizeList(op_info["input_desc"], input_size_list, anf_node);
GetOutputSizeList(op_info["output_desc"], output_size_list, anf_node);
GetInputSizeList(op_info["input_desc"], input_size_list);
GetOutputSizeList(op_info["output_desc"], output_size_list);
}
}
return true;

View File

@ -1448,56 +1448,59 @@ void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPt
MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.cur_event_num();
}
std::vector<std::pair<uint32_t, vector<size_t>>> AscendStreamAssign::GetStreamIDHcomMap(
std::vector<CNodePtr> cnode_ptr_list, std::string group, size_t graph_id) {
std::vector<std::pair<uint32_t, vector<size_t>>> stream_indices;
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
auto cur_cnode = cnode_ptr_list[i];
if (!IsHcom(cur_cnode)) {
continue;
}
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
auto group_name = GetHcomGroup(cur_cnode);
auto cur_graph_id = AnfAlgo::GetGraphId(cur_cnode.get());
MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name
<< "; stream id:" << cur_stream_id;
if (group_name != group || cur_graph_id != graph_id) {
continue;
}
bool exit = false;
for (auto &item : stream_indices) {
if (item.first == cur_stream_id) {
item.second.emplace_back(i);
exit = true;
break;
}
}
if (!exit) {
stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
}
}
return stream_indices;
}
void AscendStreamAssign::InsertEventHcomDependHcomAtSameGroup(
const NotNull<KernelGraphPtr> &graph_ptr, std::pair<std::string, std::map<uint32_t, std::set<uint32_t>>> group_item) {
for (const auto &graph_item : group_item.second) {
auto stream_indices = GetStreamIDHcomMap(graph_ptr->execution_order(), group_item.first, graph_item.first);
constexpr size_t kStreamMax = 2;
if (stream_indices.size() < kStreamMax) {
MS_LOG(INFO) << "Group:" << group_item.first << ", Graph: " << graph_item.first
<< " different stream hcom size is less than 2, no need insert event between them";
continue;
}
InsertEventBetweenHcom(graph_ptr, stream_indices);
}
}
void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
if (group_hcom_graph_map_.empty()) {
return;
}
for (const auto &group_item : group_hcom_graph_map_) {
auto group = group_item.first;
for (const auto &graph_item : group_item.second) {
auto graph_id = graph_item.first;
auto cnode_ptr_list = graph_ptr->execution_order();
std::vector<std::pair<uint32_t, vector<size_t>>> stream_indices;
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
auto cur_cnode = cnode_ptr_list[i];
if (!IsHcom(cur_cnode)) {
continue;
}
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
auto group_name = GetHcomGroup(cur_cnode);
auto cur_graph_id = AnfAlgo::GetGraphId(cur_cnode.get());
MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name
<< "; stream id:" << cur_stream_id;
if (group_name != group || cur_graph_id != graph_id) {
continue;
}
if (stream_indices.empty()) {
stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
} else {
bool exit = false;
for (auto &item : stream_indices) {
if (item.first == cur_stream_id) {
item.second.emplace_back(i);
exit = true;
break;
}
}
if (!exit) {
stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
}
}
}
constexpr size_t kStreamMax = 2;
if (stream_indices.size() < kStreamMax) {
MS_LOG(INFO) << "Group:" << group
<< "; different stream hcom size is less than 2, no need insert event between them";
continue;
}
InsertEventBetweenHcom(graph_ptr, stream_indices);
}
InsertEventHcomDependHcomAtSameGroup(graph_ptr, group_item);
}
}

View File

@ -98,6 +98,10 @@ class AscendStreamAssign {
void InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr,
const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index);
void InsertEventHcomDependHcomAtSameGroup(const NotNull<KernelGraphPtr> &graph_ptr,
std::pair<std::string, std::map<uint32_t, std::set<uint32_t>>> group_item);
std::vector<std::pair<uint32_t, vector<size_t>>> GetStreamIDHcomMap(std::vector<CNodePtr> cnode_ptr_list,
std::string group, size_t graph_id);
void AdjustAtomicAddrCleanOrder(const NotNull<KernelGraphPtr> &graph_ptr);
vector<CNodePtr> GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &cur_cnode_ptr);