code clean r1.3

This commit is contained in:
zhaosida 2021-08-31 16:55:29 +08:00
parent 65dabb58ef
commit 214fdc2370
3 changed files with 23 additions and 19 deletions

View File

@ -113,6 +113,7 @@ std::map<std::string, uint32_t> AscendKernelRuntime::overflow_tasks_;
AscendKernelRuntime::~AscendKernelRuntime() { AscendKernelRuntime::~AscendKernelRuntime() {
graph_model_map_.clear(); graph_model_map_.clear();
current_graph_ = nullptr; current_graph_ = nullptr;
rt_context_ = nullptr;
} }
void AscendKernelRuntime::SetContext() { void AscendKernelRuntime::SetContext() {
@ -405,7 +406,9 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
vector<std::shared_ptr<TaskInfo>> task_info_list; vector<std::shared_ptr<TaskInfo>> task_info_list;
auto anf_node_list = graph->execution_order(); auto anf_node_list = graph->execution_order();
auto task_generator = TaskGenerator(); auto task_generator = TaskGenerator();
task_generator.GenTasks(anf_node_list, &task_info_list, graph->graph_id()); if (!task_generator.GenTasks(anf_node_list, &task_info_list, graph->graph_id())) {
return false;
}
// Store the task_info_list // Store the task_info_list
auto insert_ret = task_map_.insert(std::make_pair(graph->graph_id(), task_info_list)); auto insert_ret = task_map_.insert(std::make_pair(graph->graph_id(), task_info_list));
if (!insert_ret.second) { if (!insert_ret.second) {
@ -741,7 +744,7 @@ bool AscendKernelRuntime::MemcpyAsync(void *dst, const void *src, uint64_t size,
void AscendKernelRuntime::CreateContext() { void AscendKernelRuntime::CreateContext() {
if (rt_context_ == nullptr) { if (rt_context_ == nullptr) {
auto ret = rtCtxCreate(&rt_context_, 0, device_id_); auto ret = rtCtxCreate(&rt_context_, 0, UintToInt(device_id_));
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]"; MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
} }
@ -756,7 +759,7 @@ bool AscendKernelRuntime::InitDevice() {
MS_EXCEPTION(DeviceProcessError) << "Call rtGetDeviceCount, ret[" << static_cast<int>(ret) << "]"; MS_EXCEPTION(DeviceProcessError) << "Call rtGetDeviceCount, ret[" << static_cast<int>(ret) << "]";
} }
ret = rtSetDevice(device_id_); ret = rtSetDevice(UintToInt(device_id_));
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast<int>(ret) << "]"; MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast<int>(ret) << "]";
} }
@ -811,7 +814,7 @@ bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
communication_stream_ = nullptr; communication_stream_ = nullptr;
} }
ret = rtDeviceReset(device_id); ret = rtDeviceReset(UintToInt(device_id));
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]"; MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]";
} }
@ -924,12 +927,14 @@ uint64_t AscendKernelRuntime::GetAvailableMemMaxSize() const {
return ascend_mem_manager->GetDeviceMemSize(); return ascend_mem_manager->GetDeviceMemSize();
} }
bool AscendKernelRuntime::DeleteDumpDir(std::string path) { bool AscendKernelRuntime::DeleteDumpDir(const std::string &path) {
string real_path = GetRealPath(path); string real_path = GetRealPath(path);
if (DeleteDumpFile(real_path) == -1) { if (DeleteDumpFile(real_path) == -1) {
return false; return false;
} }
rmdir(real_path.c_str()); if (rmdir(real_path.c_str()) == -1) {
MS_LOG(WARNING) << "Delete dir " << real_path << " failed!";
}
return true; return true;
} }
@ -959,12 +964,14 @@ int AscendKernelRuntime::DeleteDumpFile(std::string path) {
rmdir(filepath.c_str()); rmdir(filepath.c_str());
} }
} }
closedir(dir); if (closedir(dir) == -1) {
MS_LOG(WARNING) << "Dump dir " << path << " close failed!";
}
} }
return result; return result;
} }
std::string AscendKernelRuntime::GetRealPath(std::string path) { std::string AscendKernelRuntime::GetRealPath(const std::string &path) {
char real_path_mem[kPathMax] = {0}; char real_path_mem[kPathMax] = {0};
char *real_path_ret = realpath(path.c_str(), real_path_mem); char *real_path_ret = realpath(path.c_str(), real_path_mem);
if (real_path_ret == nullptr) { if (real_path_ret == nullptr) {

View File

@ -89,9 +89,9 @@ class AscendKernelRuntime : public KernelRuntime {
static void DumpTaskExceptionInfo(const session::KernelGraph *graph); static void DumpTaskExceptionInfo(const session::KernelGraph *graph);
static void TaskFailCallback(rtExceptionInfo *task_fail_info); static void TaskFailCallback(rtExceptionInfo *task_fail_info);
void ReportProfilingData(); void ReportProfilingData();
static bool DeleteDumpDir(std::string path); static bool DeleteDumpDir(const std::string &path);
static int DeleteDumpFile(std::string path); static int DeleteDumpFile(std::string path);
static std::string GetRealPath(std::string path); static std::string GetRealPath(const std::string &path);
rtContext_t rt_context_{nullptr}; rtContext_t rt_context_{nullptr};
bool initialized_{false}; bool initialized_{false};

View File

@ -137,7 +137,7 @@ std::vector<size_t> GetReducedFracNZShape(const std::vector<size_t> &ori_shape,
std::vector<size_t> result; std::vector<size_t> result;
std::set<size_t> positive_idx; std::set<size_t> positive_idx;
for (const auto &a : axis) { for (const auto &a : axis) {
positive_idx.insert(a >= 0 ? a : ori_shape.size() + a); positive_idx.insert(a >= 0 ? LongToSize(a) : ori_shape.size() + LongToSize(a));
} }
for (size_t i = 0; i < ori_shape.size(); ++i) { for (size_t i = 0; i < ori_shape.size(); ++i) {
if (positive_idx.count(i) == 0) { if (positive_idx.count(i) == 0) {
@ -263,9 +263,9 @@ void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNo
} }
} }
if (can_convert) { if (can_convert) {
graph_input_format->push_back(default_format); graph_input_format->emplace_back(default_format);
} else { } else {
graph_input_format->push_back(kOpFormat_DEFAULT); graph_input_format->emplace_back(kOpFormat_DEFAULT);
} }
graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i)); graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i));
continue; continue;
@ -302,8 +302,7 @@ void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNo
} }
} }
void UpdateEquivFormat(const std::vector<std::pair<AnfNodePtr, size_t>> &output_index, void UpdateEquivFormat(const std::vector<AnfNodePtr> &node_list, const FuncGraphPtr &func_graph,
const std::vector<AnfNodePtr> &node_list, const FuncGraphPtr &func_graph,
const FuncGraphManagerPtr &mng) { const FuncGraphManagerPtr &mng) {
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
for (size_t i = 0; i < node_list.size(); ++i) { for (size_t i = 0; i < node_list.size(); ++i) {
@ -324,7 +323,6 @@ void UpdateEquivFormat(const std::vector<std::pair<AnfNodePtr, size_t>> &output_
if (out_format == kOpFormat_DEFAULT || !AnfAlgo::HasNodeAttr(kAttrOutputDefault, cnode)) { if (out_format == kOpFormat_DEFAULT || !AnfAlgo::HasNodeAttr(kAttrOutputDefault, cnode)) {
continue; continue;
} }
auto infer_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
// Insert EquivFormat node, then select kernel info again // Insert EquivFormat node, then select kernel info again
std::vector<AnfNodePtr> trans_inputs; std::vector<AnfNodePtr> trans_inputs;
trans_inputs.push_back(NewValueNode(prim::kPrimEquivFormat)); trans_inputs.push_back(NewValueNode(prim::kPrimEquivFormat));
@ -510,8 +508,7 @@ void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func
if (mng == nullptr) { if (mng == nullptr) {
mng = Manage(func_graph, true); mng = Manage(func_graph, true);
} }
auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list); UpdateEquivFormat(node_list, func_graph, mng);
UpdateEquivFormat(output_index, node_list, func_graph, mng);
node_list.clear(); node_list.clear();
input_list.clear(); input_list.clear();
output_list.clear(); output_list.clear();
@ -526,7 +523,7 @@ void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func
// set fix_precision for kernel when the me prim has fix_precision attr // set fix_precision for kernel when the me prim has fix_precision attr
UpdateKernelInfo(node_list); UpdateKernelInfo(node_list);
output_index = kernel::GetOutputIndex(node_list, input_list, output_list); auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list);
SetGraphKernelInfo(kernel_node, output_index, graph_input_format, graph_input_type); SetGraphKernelInfo(kernel_node, output_index, graph_input_format, graph_input_type);
} }
} // namespace ascend } // namespace ascend