forked from mindspore-Ecosystem/mindspore
code clean r1.3
This commit is contained in:
parent
65dabb58ef
commit
214fdc2370
|
@ -113,6 +113,7 @@ std::map<std::string, uint32_t> AscendKernelRuntime::overflow_tasks_;
|
||||||
AscendKernelRuntime::~AscendKernelRuntime() {
|
AscendKernelRuntime::~AscendKernelRuntime() {
|
||||||
graph_model_map_.clear();
|
graph_model_map_.clear();
|
||||||
current_graph_ = nullptr;
|
current_graph_ = nullptr;
|
||||||
|
rt_context_ = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendKernelRuntime::SetContext() {
|
void AscendKernelRuntime::SetContext() {
|
||||||
|
@ -405,7 +406,9 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
||||||
vector<std::shared_ptr<TaskInfo>> task_info_list;
|
vector<std::shared_ptr<TaskInfo>> task_info_list;
|
||||||
auto anf_node_list = graph->execution_order();
|
auto anf_node_list = graph->execution_order();
|
||||||
auto task_generator = TaskGenerator();
|
auto task_generator = TaskGenerator();
|
||||||
task_generator.GenTasks(anf_node_list, &task_info_list, graph->graph_id());
|
if (!task_generator.GenTasks(anf_node_list, &task_info_list, graph->graph_id())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
// Store the task_info_list
|
// Store the task_info_list
|
||||||
auto insert_ret = task_map_.insert(std::make_pair(graph->graph_id(), task_info_list));
|
auto insert_ret = task_map_.insert(std::make_pair(graph->graph_id(), task_info_list));
|
||||||
if (!insert_ret.second) {
|
if (!insert_ret.second) {
|
||||||
|
@ -741,7 +744,7 @@ bool AscendKernelRuntime::MemcpyAsync(void *dst, const void *src, uint64_t size,
|
||||||
|
|
||||||
void AscendKernelRuntime::CreateContext() {
|
void AscendKernelRuntime::CreateContext() {
|
||||||
if (rt_context_ == nullptr) {
|
if (rt_context_ == nullptr) {
|
||||||
auto ret = rtCtxCreate(&rt_context_, 0, device_id_);
|
auto ret = rtCtxCreate(&rt_context_, 0, UintToInt(device_id_));
|
||||||
if (ret != RT_ERROR_NONE) {
|
if (ret != RT_ERROR_NONE) {
|
||||||
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
|
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
|
||||||
}
|
}
|
||||||
|
@ -756,7 +759,7 @@ bool AscendKernelRuntime::InitDevice() {
|
||||||
MS_EXCEPTION(DeviceProcessError) << "Call rtGetDeviceCount, ret[" << static_cast<int>(ret) << "]";
|
MS_EXCEPTION(DeviceProcessError) << "Call rtGetDeviceCount, ret[" << static_cast<int>(ret) << "]";
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = rtSetDevice(device_id_);
|
ret = rtSetDevice(UintToInt(device_id_));
|
||||||
if (ret != RT_ERROR_NONE) {
|
if (ret != RT_ERROR_NONE) {
|
||||||
MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast<int>(ret) << "]";
|
MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast<int>(ret) << "]";
|
||||||
}
|
}
|
||||||
|
@ -811,7 +814,7 @@ bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
|
||||||
communication_stream_ = nullptr;
|
communication_stream_ = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = rtDeviceReset(device_id);
|
ret = rtDeviceReset(UintToInt(device_id));
|
||||||
if (ret != RT_ERROR_NONE) {
|
if (ret != RT_ERROR_NONE) {
|
||||||
MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]";
|
MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]";
|
||||||
}
|
}
|
||||||
|
@ -924,12 +927,14 @@ uint64_t AscendKernelRuntime::GetAvailableMemMaxSize() const {
|
||||||
return ascend_mem_manager->GetDeviceMemSize();
|
return ascend_mem_manager->GetDeviceMemSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AscendKernelRuntime::DeleteDumpDir(std::string path) {
|
bool AscendKernelRuntime::DeleteDumpDir(const std::string &path) {
|
||||||
string real_path = GetRealPath(path);
|
string real_path = GetRealPath(path);
|
||||||
if (DeleteDumpFile(real_path) == -1) {
|
if (DeleteDumpFile(real_path) == -1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
rmdir(real_path.c_str());
|
if (rmdir(real_path.c_str()) == -1) {
|
||||||
|
MS_LOG(WARNING) << "Delete dir " << real_path << " failed!";
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -959,12 +964,14 @@ int AscendKernelRuntime::DeleteDumpFile(std::string path) {
|
||||||
rmdir(filepath.c_str());
|
rmdir(filepath.c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
closedir(dir);
|
if (closedir(dir) == -1) {
|
||||||
|
MS_LOG(WARNING) << "Dump dir " << path << " close failed!";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string AscendKernelRuntime::GetRealPath(std::string path) {
|
std::string AscendKernelRuntime::GetRealPath(const std::string &path) {
|
||||||
char real_path_mem[kPathMax] = {0};
|
char real_path_mem[kPathMax] = {0};
|
||||||
char *real_path_ret = realpath(path.c_str(), real_path_mem);
|
char *real_path_ret = realpath(path.c_str(), real_path_mem);
|
||||||
if (real_path_ret == nullptr) {
|
if (real_path_ret == nullptr) {
|
||||||
|
|
|
@ -89,9 +89,9 @@ class AscendKernelRuntime : public KernelRuntime {
|
||||||
static void DumpTaskExceptionInfo(const session::KernelGraph *graph);
|
static void DumpTaskExceptionInfo(const session::KernelGraph *graph);
|
||||||
static void TaskFailCallback(rtExceptionInfo *task_fail_info);
|
static void TaskFailCallback(rtExceptionInfo *task_fail_info);
|
||||||
void ReportProfilingData();
|
void ReportProfilingData();
|
||||||
static bool DeleteDumpDir(std::string path);
|
static bool DeleteDumpDir(const std::string &path);
|
||||||
static int DeleteDumpFile(std::string path);
|
static int DeleteDumpFile(std::string path);
|
||||||
static std::string GetRealPath(std::string path);
|
static std::string GetRealPath(const std::string &path);
|
||||||
|
|
||||||
rtContext_t rt_context_{nullptr};
|
rtContext_t rt_context_{nullptr};
|
||||||
bool initialized_{false};
|
bool initialized_{false};
|
||||||
|
|
|
@ -137,7 +137,7 @@ std::vector<size_t> GetReducedFracNZShape(const std::vector<size_t> &ori_shape,
|
||||||
std::vector<size_t> result;
|
std::vector<size_t> result;
|
||||||
std::set<size_t> positive_idx;
|
std::set<size_t> positive_idx;
|
||||||
for (const auto &a : axis) {
|
for (const auto &a : axis) {
|
||||||
positive_idx.insert(a >= 0 ? a : ori_shape.size() + a);
|
positive_idx.insert(a >= 0 ? LongToSize(a) : ori_shape.size() + LongToSize(a));
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < ori_shape.size(); ++i) {
|
for (size_t i = 0; i < ori_shape.size(); ++i) {
|
||||||
if (positive_idx.count(i) == 0) {
|
if (positive_idx.count(i) == 0) {
|
||||||
|
@ -263,9 +263,9 @@ void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (can_convert) {
|
if (can_convert) {
|
||||||
graph_input_format->push_back(default_format);
|
graph_input_format->emplace_back(default_format);
|
||||||
} else {
|
} else {
|
||||||
graph_input_format->push_back(kOpFormat_DEFAULT);
|
graph_input_format->emplace_back(kOpFormat_DEFAULT);
|
||||||
}
|
}
|
||||||
graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i));
|
graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i));
|
||||||
continue;
|
continue;
|
||||||
|
@ -302,8 +302,7 @@ void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void UpdateEquivFormat(const std::vector<std::pair<AnfNodePtr, size_t>> &output_index,
|
void UpdateEquivFormat(const std::vector<AnfNodePtr> &node_list, const FuncGraphPtr &func_graph,
|
||||||
const std::vector<AnfNodePtr> &node_list, const FuncGraphPtr &func_graph,
|
|
||||||
const FuncGraphManagerPtr &mng) {
|
const FuncGraphManagerPtr &mng) {
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
for (size_t i = 0; i < node_list.size(); ++i) {
|
for (size_t i = 0; i < node_list.size(); ++i) {
|
||||||
|
@ -324,7 +323,6 @@ void UpdateEquivFormat(const std::vector<std::pair<AnfNodePtr, size_t>> &output_
|
||||||
if (out_format == kOpFormat_DEFAULT || !AnfAlgo::HasNodeAttr(kAttrOutputDefault, cnode)) {
|
if (out_format == kOpFormat_DEFAULT || !AnfAlgo::HasNodeAttr(kAttrOutputDefault, cnode)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto infer_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
|
|
||||||
// Insert EquivFormat node, then select kernel info again
|
// Insert EquivFormat node, then select kernel info again
|
||||||
std::vector<AnfNodePtr> trans_inputs;
|
std::vector<AnfNodePtr> trans_inputs;
|
||||||
trans_inputs.push_back(NewValueNode(prim::kPrimEquivFormat));
|
trans_inputs.push_back(NewValueNode(prim::kPrimEquivFormat));
|
||||||
|
@ -510,8 +508,7 @@ void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func
|
||||||
if (mng == nullptr) {
|
if (mng == nullptr) {
|
||||||
mng = Manage(func_graph, true);
|
mng = Manage(func_graph, true);
|
||||||
}
|
}
|
||||||
auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list);
|
UpdateEquivFormat(node_list, func_graph, mng);
|
||||||
UpdateEquivFormat(output_index, node_list, func_graph, mng);
|
|
||||||
node_list.clear();
|
node_list.clear();
|
||||||
input_list.clear();
|
input_list.clear();
|
||||||
output_list.clear();
|
output_list.clear();
|
||||||
|
@ -526,7 +523,7 @@ void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func
|
||||||
// set fix_precision for kernel when the me prim has fix_precision attr
|
// set fix_precision for kernel when the me prim has fix_precision attr
|
||||||
UpdateKernelInfo(node_list);
|
UpdateKernelInfo(node_list);
|
||||||
|
|
||||||
output_index = kernel::GetOutputIndex(node_list, input_list, output_list);
|
auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list);
|
||||||
SetGraphKernelInfo(kernel_node, output_index, graph_input_format, graph_input_type);
|
SetGraphKernelInfo(kernel_node, output_index, graph_input_format, graph_input_type);
|
||||||
}
|
}
|
||||||
} // namespace ascend
|
} // namespace ascend
|
||||||
|
|
Loading…
Reference in New Issue