profiler pynative python code

This commit is contained in:
臧庆香 2022-03-06 12:02:33 +08:00
parent ad808b3b30
commit 1dd489096e
8 changed files with 531 additions and 254 deletions

View File

@ -848,11 +848,13 @@ bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vector<Addr
auto stream = GetKernelStream(kernel);
#ifndef ENABLE_SECURITY
auto profiler_inst = profiler::ascend::PynativeProfiler::GetInstance();
(void)profiler_inst->OpDataProducerBegin(runtime_instance_, stream, kernel->fullname_with_scope());
MS_EXCEPTION_IF_NULL(profiler_inst);
std::thread::id t_id = std::this_thread::get_id();
(void)profiler_inst->OpDataProducerBegin(runtime_instance_, stream, t_id, kernel->fullname_with_scope());
#endif
ret = kernel_mod->Launch(real_inputs, workspace, outputs, stream);
#ifndef ENABLE_SECURITY
(void)profiler_inst->OpDataProducerEnd();
(void)profiler_inst->OpDataProducerEnd(t_id);
#endif
if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed, kernel full name: " << kernel->fullname_with_scope();

View File

@ -19,6 +19,7 @@
#include <memory>
#include <algorithm>
#include "include/common/utils/utils.h"
#include "profiler/device/profiling.h"
#include "profiler/device/ascend/pynative_profiling.h"
#include "pybind_api/api_register.h"
@ -65,7 +66,7 @@ void PynativeProfiler::WriteStartTime() {
MS_LOG(ERROR) << "Write " << file_path << "failed:" << e.what();
}
ofs.close();
ChangeFileMode(file_path);
ChangeFileMode(file_path, S_IRUSR | S_IWUSR);
MS_LOG(INFO) << "Write profiler start time infos into file: " << file_path;
}
@ -73,50 +74,59 @@ void PynativeProfiler::SaveProfileData() { WriteOpDetail(profile_data_path_); }
void PynativeProfiler::ClearInst() { pynative_op_info_.clear(); }
void PynativeProfiler::OpDataProducerEnd() {}
void PynativeProfiler::OpDataProducerBegin(AscendKernelRuntime *runtime_instance_, void *stream,
const std::string &op_name) {
std::thread::id thread_id, const std::string &op_name) {
if (enable_flag_ == false) {
return;
}
MS_EXCEPTION_IF_NULL(runtime_instance_);
start = runtime_instance_->CreateDeviceTimeEvent();
end = runtime_instance_->CreateDeviceTimeEvent();
MS_EXCEPTION_IF_NULL(stream);
std::shared_ptr<DeviceEvent> start = runtime_instance_->CreateDeviceTimeEvent();
std::shared_ptr<DeviceEvent> end = runtime_instance_->CreateDeviceTimeEvent();
MS_EXCEPTION_IF_NULL(start);
MS_EXCEPTION_IF_NULL(end);
start->set_record_stream(stream);
end->set_record_stream(stream);
start->RecordEvent();
op_name_ = op_name;
stream_ = stream;
PynativeOpInfo op_info;
op_info.start = start;
op_info.end = end;
op_info.op_name = op_name;
op_info.stream = stream;
if (thread_op_info_map_.find(thread_id) == thread_op_info_map_.end()) {
op_info.thread_index = NewThreadIndex();
} else {
op_info.thread_index = thread_op_info_map_[thread_id].thread_index;
}
thread_op_info_map_[thread_id] = op_info;
}
void PynativeProfiler::StepProfilingEnable(const bool enable_flag) { enable_flag_ = enable_flag; }
void PynativeProfiler::OpDataProducerEnd() {
void PynativeProfiler::OpDataProducerEnd(std::thread::id thread_id) {
if (enable_flag_ == false) {
return;
}
if (start == nullptr || end == nullptr) {
MS_LOG(WARNING) << "Pynative profiling, the start or end time of op is null"
if (thread_op_info_map_.find(thread_id) == thread_op_info_map_.end()) {
MS_LOG(WARNING) << "Pynative profiling, the start time of op is null"
<< ", please call the OpDataProducerBegin function first.";
return;
}
PynativeOpInfo op_info = thread_op_info_map_[thread_id];
float cost_time = 0;
// Operator asynchronous execution changed to synchronous
end->RecordEvent();
start->SyncEvent();
end->SyncEvent();
start->ElapsedTime(&cost_time, end.get());
op_info.end->RecordEvent();
op_info.start->SyncEvent();
op_info.end->SyncEvent();
op_info.start->ElapsedTime(&cost_time, op_info.end.get());
PynativeOpInfo op_info;
op_info.op_name = op_name_;
op_info.stream = stream_;
op_info.duration = cost_time;
int64_t milli_second_ratio = 1000;
int64_t end_timestamp = GetRealTimeStamp();
int64_t start_timestamp = end_timestamp - static_cast<int64_t>(cost_time * milli_second_ratio);
@ -144,23 +154,18 @@ void PynativeProfiler::WriteOpDetail(const std::string &out_path_dir) {
std::sort(pynative_op_info_.begin(), pynative_op_info_.end(),
[](const auto &op1, const auto &op2) { return op1.start_timestamp < op2.start_timestamp; });
for (PynativeOpInfo op_info : pynative_op_info_) {
ofs << op_info.op_name << ",0," << std::to_string(op_info.start_timestamp) << "," << op_info.duration
<< std::endl;
ofs << op_info.op_name << "," << op_info.thread_index << "," << std::to_string(op_info.start_timestamp) << ","
<< op_info.duration << std::endl;
}
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Write " << file_path << "failed: " << e.what();
}
ofs.close();
ChangeFileMode(file_path);
ChangeFileMode(file_path, S_IRUSR | S_IWUSR);
MS_LOG(INFO) << "Write " << pynative_op_info_.size() << " op detail infos into file: " << file_path;
}
void PynativeProfiler::ChangeFileMode(const std::string &file_path) const {
if (chmod(common::SafeCStr(file_path), S_IRUSR | S_IWUSR) == -1) {
MS_LOG(WARNING) << "Modify file: " << file_path << " to rw fail.";
return;
}
}
int PynativeProfiler::NewThreadIndex() { return thread_op_info_map_.size() + 1; }
REGISTER_PYBIND_DEFINE(PynativeProfiler_, ([](const py::module *m) {
(void)py::class_<PynativeProfiler, std::shared_ptr<PynativeProfiler>>(*m, "PynativeProfiler")

View File

@ -33,11 +33,14 @@ using mindspore::device::ascend::AscendKernelRuntime;
struct PynativeOpInfo {
std::string op_name;
int thread_index;
// the unit is ms
double_t start_timestamp = 0l;
// the unit is ms
double_t duration = 0l;
void *stream;
void *stream{nullptr};
std::shared_ptr<DeviceEvent> start;
std::shared_ptr<DeviceEvent> end;
};
class MS_CORE_API PynativeProfiler : public Profiler {
@ -47,28 +50,26 @@ class MS_CORE_API PynativeProfiler : public Profiler {
~PynativeProfiler() {}
void Init(const std::string &profileDataPath) override;
void Stop() override;
void OpDataProducerBegin(AscendKernelRuntime *runtime_instance_, void *stream, const std::string &op_name);
void OpDataProducerEnd();
void OpDataProducerBegin(AscendKernelRuntime *runtime_instance_, void *stream, std::thread::id thread_id,
const std::string &op_name);
void OpDataProducerEnd() override;
void OpDataProducerEnd(std::thread::id thread_id);
void StepProfilingEnable(const bool enable_flag) override;
private:
void WriteOpDetail(const std::string &out_path_dir);
void ChangeFileMode(const std::string &file_path) const;
void WriteStartTime();
void SaveProfileData() override;
void ClearInst() override;
int NewThreadIndex();
static std::shared_ptr<PynativeProfiler> profiler_inst_;
std::shared_ptr<DeviceEvent> start;
std::shared_ptr<DeviceEvent> end;
std::int32_t rank_id_;
std::vector<PynativeOpInfo> pynative_op_info_;
std::string op_name_;
bool enable_flag_ = false;
float start_timestamp;
void *stream_;
const uint64_t kUSecondInSecond = 1000000;
const uint64_t milli_second_ratio = 1000;
std::map<std::thread::id, PynativeOpInfo> thread_op_info_map_;
};
} // namespace ascend
} // namespace profiler

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -26,6 +26,7 @@ from mindspore.profiler.common.exceptions.exceptions import ProfilerIOException,
from mindspore.profiler.common.util import query_latest_trace_time_file, to_int, to_millisecond
from mindspore.profiler.common.validator.validate_path import validate_and_normalize_path
from mindspore.profiler.parser.container import TimelineContainer
from mindspore.profiler.parser.op_intermediate_parser import OPIntermediateParser
SIZE_LIMIT_DEFAULT = 20 * 1024 * 1024 # 20MB
@ -72,21 +73,26 @@ class Integrator:
self._parse_aicpu_time()
def get_aicore_data(self):
"""Get ai core data."""
self._aicore_data_load()
return self._aicore_data
def get_aicore_detail_data(self):
"""Get ai core detail data."""
self._aicore_detail_data_load()
return self._aicore_detail_data
def get_aicore_trace_data(self):
"""Get ai core trace data."""
self._aicore_trace_data_load()
return self._aicore_trace_data
def query_for_all_reduce(self):
"""Query all reduce data."""
return self._query_for_all_reduce()
def query_and_sort_by_op_type(self, filter_condition, op_type_order):
"""Query and sort by op type."""
return self._query_and_sort_by_op_type(filter_condition, op_type_order)
def _parse_aicore_type_time(self):
@ -103,18 +109,19 @@ class Integrator:
with open(framework_file, 'r') as src_file:
csv_reader = csv.reader(src_file)
_ = next(csv_reader)
for row in csv_reader:
op_name_type_cache[row[3]] = row[5]
op_type_time_cache = {}
for full_op_name, op_time in self._op_time_cache.items():
op_type = op_name_type_cache.get(full_op_name)
if op_type_time_cache.get(op_type) is None:
op_type_time_cache[op_type] = [op_time, 1]
op_type_time = op_type_time_cache.get(op_type)
if not op_type_time:
op_type_time = [op_time, 1]
op_type_time_cache[op_type] = op_type_time
else:
op_type_time_cache[op_type][0] += op_time
op_type_time_cache[op_type][1] += 1
op_type_time[0] += op_time
op_type_time[1] += 1
if self._total_time == 0:
raise ValueError("The total time of operations can not be 0.")
@ -529,6 +536,9 @@ class BaseTimelineGenerator:
_SCOPE_NAME_TID = 100001
_GPU_OP_TID = 100002
_HOST_CPU_OP_TID = 100003
_SINGLE_TID = 0
_STEPS_SORT_INDEX = -1
_map_tid_name_to_int = {
"Steps": (-4, _STEPS_TID),
@ -547,6 +557,11 @@ class BaseTimelineGenerator:
_max_scope_name_num = 0
_host_cpu_op_label = 'HostCpuOps'
_device_id = 0
_profiling_dir = ""
_timeline_summary_filename = ""
_display_filename = ""
def __init__(self):
self._tid_dict = {
"computation_op": (self._MERGED_COMPUTATION_TID, self._OP_OVERLAP_PID),
@ -584,6 +599,10 @@ class BaseTimelineGenerator:
"args": {"name": "Merged Communication Op"}},
{"name": "thread_name", "ph": "M", "pid": self._OP_OVERLAP_PID, "tid": self._FREE_TIME_TID,
"args": {"name": "Free Time"}},
{"name": "thread_name", "ph": "M", "pid": self._device_id, "tid": self._STEPS_TID,
"args": {"name": "Steps"}},
{"name": "thread_name", "ph": "M", "pid": self._device_id, "tid": self._SINGLE_TID,
"args": {"name": "Ops"}},
{"name": "thread_sort_index", "ph": "M", "pid": self._OP_OVERLAP_PID, "tid": self._MERGED_COMPUTATION_TID,
"args": {"sort_index": self._MERGED_COMPUTATION_TID}},
@ -592,7 +611,9 @@ class BaseTimelineGenerator:
{"name": "thread_sort_index", "ph": "M", "pid": self._OP_OVERLAP_PID, "tid": self._MERGED_COMMUNICATION_TID,
"args": {"sort_index": self._MERGED_COMMUNICATION_TID}},
{"name": "thread_sort_index", "ph": "M", "pid": self._OP_OVERLAP_PID, "tid": self._FREE_TIME_TID,
"args": {"sort_index": self._FREE_TIME_TID}}
"args": {"sort_index": self._FREE_TIME_TID}},
{"name": "thread_sort_index", "ph": "M", "pid": self._device_id, "tid": self._STEPS_TID,
"args": {"sort_index": self._STEPS_SORT_INDEX}},
]
def _get_merged_time_list(self, time_list, get_interval_time=False, display_name="computation_op", factor=1):
@ -845,6 +866,58 @@ class BaseTimelineGenerator:
return step_time_list
@staticmethod
def get_parallel_context():
"""Get parallel context."""
try:
parallel_mode = get_auto_parallel_context("parallel_mode")
stage_num = get_auto_parallel_context("pipeline_stages")
except RuntimeError:
logger.warning("[profiler] the feature of cluster bottleneck analyse "
"is not supported in offline parse mode.")
parallel_mode = "data_parallel"
stage_num = 1
if stage_num > 1:
parallel_mode = "pipeline-parallel"
elif parallel_mode != "data_parallel":
parallel_mode = "model-parallel"
else:
parallel_mode = "data-parallel"
return parallel_mode, stage_num
def _write_cluster_metrices(self, metrices, is_pipeline_parallel, device_target, dev_id):
"""Write cluster metric."""
# Note that the feature of cluster bottleneck analyse is not supported in offline parse mode,
# due to that parallel context is not set.
parallel_mode, stage_num = BaseTimelineGenerator.get_parallel_context()
unit = 1 if device_target == "Ascend" else 1e3
time_decimal_digits = 4
cluster_analyse_file_path = os.path.join(
self._profiling_dir,
self._cluster_analyse_filename.format(parallel_mode, stage_num, self._rank_size, dev_id)
)
cluster_analyse_file_path = validate_and_normalize_path(cluster_analyse_file_path)
try:
with open(cluster_analyse_file_path, 'w') as file_handle:
csv_writer = csv.writer(file_handle)
if is_pipeline_parallel:
header = ['computation_time', 'communication_alone_time', 'stage_time',
'receive_alone_time', 'collective_communication_alone_time']
zip_metrices = zip(metrices[0], metrices[1], metrices[2], metrices[3], metrices[4])
else:
header = ['computation_time', 'communication_alone_time']
zip_metrices = zip(metrices[0], metrices[1])
csv_writer.writerow(header)
for row_data in zip_metrices:
row_data = [round(val / unit, time_decimal_digits) for val in row_data]
csv_writer.writerow(row_data)
os.chmod(cluster_analyse_file_path, stat.S_IREAD | stat.S_IWRITE)
except (IOError, OSError) as err:
logger.warning(f'Failed to save {cluster_analyse_file_path}. {err}')
raise ProfilerIOException
class GpuTimelineGenerator(BaseTimelineGenerator):
"""Generate gpu Timeline data from file."""
@ -1243,7 +1316,7 @@ class GpuTimelineGenerator(BaseTimelineGenerator):
if step_num > 1:
for metric in metrices_per_step_list:
metric.append(sum(metric[1:]) / (step_num - 1))
self._write_cluster_metrices(metrices_per_step_list, is_pipeline_parallel)
self._write_cluster_metrices(metrices_per_step_list, is_pipeline_parallel, "Gpu", self._device_id)
res_timeline = []
res_timeline.extend(comm_not_overlapped_timeline)
@ -1252,49 +1325,6 @@ class GpuTimelineGenerator(BaseTimelineGenerator):
res_timeline.extend(free_timeline)
return res_timeline
def _write_cluster_metrices(self, metrices, is_pipeline_parallel):
"""Write cluster metric."""
# Note that the feature of cluster bottleneck analyse is not supported in offline parse mode,
# due to that parallel context is not set.
try:
parallel_mode = get_auto_parallel_context("parallel_mode")
stage_num = get_auto_parallel_context("pipeline_stages")
except RuntimeError:
logger.warning("[profiler] the feature of cluster bottleneck analyse "
"is not supported in offline parse mode.")
parallel_mode = "data_parallel"
stage_num = 1
if stage_num > 1:
parallel_mode = "pipeline-parallel"
elif parallel_mode != "data_parallel":
parallel_mode = "model-parallel"
else:
parallel_mode = "data-parallel"
cluster_analyse_file_path = os.path.join(
self._profiling_dir,
self._cluster_analyse_filename.format(parallel_mode, stage_num, self._rank_size, self._device_id))
cluster_analyse_file_path = validate_and_normalize_path(cluster_analyse_file_path)
try:
with open(cluster_analyse_file_path, 'w') as file_handle:
csv_writer = csv.writer(file_handle)
if is_pipeline_parallel:
header = ['computation_time', 'communication_alone_time', 'stage_time',
'receive_alone_time', 'collective_communication_alone_time']
zip_metrices = zip(metrices[0], metrices[1], metrices[2], metrices[3], metrices[4])
else:
header = ['computation_time', 'communication_alone_time']
zip_metrices = zip(metrices[0], metrices[1])
csv_writer.writerow(header)
for row_data in zip_metrices:
row_data = [round(val / 1e3, 4) for val in row_data]
csv_writer.writerow(row_data)
os.chmod(cluster_analyse_file_path, stat.S_IREAD | stat.S_IWRITE)
except (IOError, OSError) as err:
logger.warning(f'Failed to save {cluster_analyse_file_path}. {err}')
raise ProfilerIOException
def _compute_time_inside_step(self, metric_timeline, step_time_list):
"""Compute per step time of metric_timeline."""
per_step_time_list = []
@ -1379,33 +1409,6 @@ class AscendTimelineGenerator(BaseTimelineGenerator):
self._display_filename = self._display_filename.format(rank_id)
self._timeline_summary_filename = self._timeline_summary_filename.format(rank_id)
def _load_timeline_data(self, all_reduce_names=None):
"""Load timeline data from file."""
all_reduce_names = all_reduce_names or []
file_path = os.path.join(
self._profiling_dir,
self._output_timeline_data_file_path.format(self._rank_id)
)
file_path = validate_and_normalize_path(file_path)
if not os.path.exists(file_path):
logger.critical("Failed to find parsed timeline file.")
raise ProfilerFileNotFoundException('parsed timeline file')
timeline_list = []
try:
with open(file_path, 'r') as f_obj:
for line in f_obj:
line_list = line.strip('\n').split(',')
if line_list[0] == 'op_name' or line_list[0] in all_reduce_names:
continue
line_list[self._tid_idx] = f"Stream #{line_list[self._tid_idx]}"
timeline_list.append(line_list)
except (IOError, OSError) as err:
logger.critical('Error occurred when read timeline intermediate file: %s', err)
raise ProfilerIOException()
return timeline_list
def _parse_timeline_data(self, timeline, min_cycle_counter):
"""Parse timeline data."""
# factor to convert the time unit from 1ms to 1us for timeline display
@ -1436,6 +1439,34 @@ class AscendTimelineGenerator(BaseTimelineGenerator):
self._update_format_meta_data(timeline_dict)
self._timeline_meta.append(timeline_dict)
@staticmethod
def _get_all_reduce_names(communication_info):
names = []
for info in communication_info:
# all_reduce_name format: stream_stream_id_stream_op_index_opname
all_reduce_name = info[0][info[0].rindex('_') + 1:]
if all_reduce_name not in names:
names.append(all_reduce_name)
return names
def _get_op_timeline(self, communication_info, source_path):
"""get ai_core and cpu timeline."""
all_reduce_names = AscendTimelineGenerator._get_all_reduce_names(communication_info)
timeline_list = OPIntermediateParser(self._profiling_dir, self._rank_id).get_timeline_data(all_reduce_names)
for timeline in timeline_list:
timeline[self._tid_idx] = f"Stream #{timeline[self._tid_idx]}"
cpu_timeline_generator = CpuTimelineGenerator(self._profiling_dir, self._rank_id, self._rank_size)
cpu_timeline_list = cpu_timeline_generator.get_timeline_data()
if cpu_timeline_list:
self._clock_synchronize_to_device(cpu_timeline_list, source_path)
timeline_list.extend(cpu_timeline_list)
timeline_list.sort(key=lambda x: float(x[self._start_time_idx]))
self._max_scope_name_num = self._get_max_scope_name_num(timeline_list)
self._timeline_summary['op_exe_times'] = len(timeline_list)
self._timeline_summary['max_scope_name_num'] = self._max_scope_name_num
return timeline_list
def init_timeline(self, communication_info, framework_info, aicpu_info, min_cycle_counter, source_path):
"""
Init timeline metadata, adding all collected info.
@ -1451,23 +1482,9 @@ class AscendTimelineGenerator(BaseTimelineGenerator):
min_cycle_counter = 0
logger.info('Initiating timeline...')
all_reduce_names = []
for info in communication_info:
# stream_{stream_id}_{stream_op_index}_{opname}
all_reduce_name = info[0][info[0].rindex('_') + 1:]
if all_reduce_name not in all_reduce_names:
all_reduce_names.append(all_reduce_name)
timeline_list = self._load_timeline_data(all_reduce_names)
cpu_timeline_generator = CpuTimelineGenerator(self._profiling_dir, self._rank_id, self._rank_size)
cpu_timeline_list = cpu_timeline_generator.get_timeline_data()
if cpu_timeline_list:
self._clock_synchronize_to_device(cpu_timeline_list, source_path)
timeline_list.extend(cpu_timeline_list)
timeline_list.sort(key=lambda x: float(x[self._start_time_idx]))
self._max_scope_name_num = self._get_max_scope_name_num(timeline_list)
self._timeline_summary['op_exe_times'] = len(timeline_list)
self._timeline_summary['max_scope_name_num'] = self._max_scope_name_num
timeline_list = []
op_timeline_list = self._get_op_timeline(communication_info, source_path)
timeline_list.extend(op_timeline_list)
# Generate step time.
self._set_step_start_and_end_op_name(timeline_list)
@ -1571,14 +1588,15 @@ class AscendTimelineGenerator(BaseTimelineGenerator):
op_type = framework_obj[1]
op_full_name = framework_obj[4]
op_info = framework_obj[5]
framework_info_dict[op_full_name] = {
framework_info = {
'name': op_name,
'args': {
'type': op_type,
'fullname': op_full_name
}
}
framework_info_dict[op_full_name]['args'].update(op_info)
framework_info.get('args').update(op_info)
framework_info_dict[op_full_name] = framework_info
# Insert framework info into timeline.
for timeline_item in self._timeline_meta:
@ -1616,136 +1634,91 @@ class AscendTimelineGenerator(BaseTimelineGenerator):
communication time between stages slow down the training. The value of t3 indicates the degree
that communication inside each stage slow down the training.
"""
step_num = len(step_info)
is_pipeline_parallel = False
comm_merged_timeline, _, comm_display_timeline = self._get_merged_time_list(
comm_info,
display_name="communication"
comm_info, display_name="communication"
)
aicore_timeline_interval, _, aicore_display_timeline = self._get_merged_time_list(
aicore_info,
get_interval_time=True
aicore_info, get_interval_time=True
)
# Consider if the overlap will be 0 or not.
comm_not_overlapped_timeline = self._get_intersection_time(
aicore_timeline_interval,
comm_merged_timeline
aicore_timeline_interval, comm_merged_timeline
)
# Process receive part.
all_timeline = aicore_info + comm_info
all_timeline.sort(key=lambda x: float(x[self._start_time_idx]))
receive_op_timeline, timeline_exclude_receive_op = self._produce_two_separated_timeline(
all_timeline,
"Receive-op"
all_timeline, "Receive-op"
)
if receive_op_timeline:
is_pipeline_parallel = True
receive_op_merged_timeline = self._get_merged_time_list(receive_op_timeline)[0]
timeline_exclude_receive_op_interval = self._get_merged_time_list(
timeline_exclude_receive_op,
get_interval_time=True
timeline_exclude_receive_op, get_interval_time=True
)[0]
receive_op_not_overlapped_timeline = self._get_intersection_time(
timeline_exclude_receive_op_interval,
receive_op_merged_timeline
timeline_exclude_receive_op_interval, receive_op_merged_timeline
)
# Process collective communication part.
collective_comm_timeline = self._produce_two_separated_timeline(
comm_info,
"Receive-op"
comm_info, "Receive-op"
)[-1]
collective_comm_merged_timeline = self._get_merged_time_list(collective_comm_timeline)[0]
collective_comm_not_overlapped_timeline = self._get_intersection_time(
aicore_timeline_interval,
collective_comm_merged_timeline
aicore_timeline_interval, collective_comm_merged_timeline
)
# Generate free time that exclude computation and communication time.
free_timeline = self._get_merged_time_list(
all_timeline,
get_interval_time=True,
display_name="free_time")[1]
all_timeline, get_interval_time=True, display_name="free_time"
)[1]
# Compute these five metrics mentioned above per step.
recieve_alone_time = self._compute_time_inside_step(receive_op_not_overlapped_timeline, step_info)
stage_time, computation_time = [], []
comm_alone_time = self._compute_time_inside_step(comm_not_overlapped_timeline, step_info)
collective_comm_alone_time = self._compute_time_inside_step(
collective_comm_not_overlapped_timeline,
step_info
)
for step in range(step_num):
try:
if is_pipeline_parallel:
stage_time.append(step_info[step][self._duration_idx] - recieve_alone_time[step])
computation_time.append(step_info[step][self._duration_idx] - comm_alone_time[step])
except IndexError as e:
logger.error(e)
metrices_per_step_list = [computation_time, comm_alone_time, stage_time, recieve_alone_time, \
collective_comm_alone_time]
if step_num > 1:
for metric in metrices_per_step_list:
metric.append(sum(metric[1:]) / (step_num - 1))
self._write_cluster_metrices(metrices_per_step_list, is_pipeline_parallel)
self._parse_cluster_metrices(step_info, receive_op_not_overlapped_timeline, comm_not_overlapped_timeline
, collective_comm_not_overlapped_timeline, is_pipeline_parallel)
res_timeline = []
res_timeline.extend(comm_not_overlapped_timeline)
res_timeline.extend(aicore_display_timeline)
res_timeline.extend(comm_display_timeline)
res_timeline.extend(free_timeline)
return res_timeline
def _write_cluster_metrices(self, metrices, is_pipeline_parallel):
"""Write cluster metric."""
# Note that the feature of cluster bottleneck analyse is not supported in offline parse mode,
# due to that parallel context is not set.
try:
parallel_mode = get_auto_parallel_context("parallel_mode")
stage_num = get_auto_parallel_context("pipeline_stages")
except RuntimeError:
logger.warning("[profiler] the feature of cluster bottleneck analyse "
"is not supported in offline parse mode.")
parallel_mode = "data_parallel"
stage_num = 1
if stage_num > 1:
parallel_mode = "pipeline-parallel"
elif parallel_mode != "data_parallel":
parallel_mode = "model-parallel"
else:
parallel_mode = "data-parallel"
cluster_analyse_file_path = os.path.join(
self._profiling_dir,
self._cluster_analyse_filename.format(parallel_mode, stage_num, self._rank_size, self._rank_id))
cluster_analyse_file_path = validate_and_normalize_path(cluster_analyse_file_path)
try:
with open(cluster_analyse_file_path, 'w') as file_handle:
csv_writer = csv.writer(file_handle)
def _parse_cluster_metrices(self, step_info, receive_op_not_overlapped_timeline, comm_not_overlapped_timeline
, collective_comm_not_overlapped_timeline, is_pipeline_parallel):
"""Write the cluster metrices"""
step_num = len(step_info)
# Compute these five metrics mentioned above per step.
recieve_alone_time = self._compute_time_inside_step(receive_op_not_overlapped_timeline, step_info)
stage_time, computation_time = [], []
comm_alone_time = self._compute_time_inside_step(comm_not_overlapped_timeline, step_info)
collective_comm_alone_time = self._compute_time_inside_step(
collective_comm_not_overlapped_timeline, step_info
)
for step in range(step_num):
try:
if is_pipeline_parallel:
header = ['computation_time', 'communication_alone_time', 'stage_time',
'receive_alone_time', 'collective_communication_alone_time']
zip_metrices = zip(metrices[0], metrices[1], metrices[2], metrices[3], metrices[4])
else:
header = ['computation_time', 'communication_alone_time']
zip_metrices = zip(metrices[0], metrices[1])
csv_writer.writerow(header)
for row_data in zip_metrices:
row_data = [round(val, 4) for val in row_data]
csv_writer.writerow(row_data)
os.chmod(cluster_analyse_file_path, stat.S_IREAD | stat.S_IWRITE)
except (IOError, OSError) as err:
logger.warning(f'Failed to save {cluster_analyse_file_path}. {err}')
raise ProfilerIOException
stage_time.append(step_info[step][self._duration_idx] - recieve_alone_time[step])
computation_time.append(step_info[step][self._duration_idx] - comm_alone_time[step])
except IndexError as err:
logger.error(err)
metrices_per_step_list = [computation_time, comm_alone_time, stage_time,
recieve_alone_time, collective_comm_alone_time]
if step_num > 1:
for metric in metrices_per_step_list:
metric.append(sum(metric[1:]) / (step_num - 1))
self._write_cluster_metrices(metrices_per_step_list, is_pipeline_parallel, "Ascend", self._rank_id)
def _compute_time_inside_step(self, metric_timeline, step_time_list):
"""Compute per step time of metric_timeline."""
per_step_time_list = []
step = 0
cur_step_metric_time = 0
step_end_time = step_time_list[step][self._start_time_idx] + step_time_list[step][self._duration_idx]
step_end_time = step_time_list[step][self._start_time_idx] + \
step_time_list[step][self._duration_idx]
for time_item in metric_timeline:
start_time = time_item[self._start_time_idx]
if start_time > step_end_time:
@ -1756,7 +1729,8 @@ class AscendTimelineGenerator(BaseTimelineGenerator):
"find the data length is more than step count, "
"maybe current graph has multi sub graph, skip the last data.")
break
step_end_time = step_time_list[step][self._start_time_idx] + step_time_list[step][self._duration_idx]
step_end_time = step_time_list[step][self._start_time_idx] + \
step_time_list[step][self._duration_idx]
cur_step_metric_time = 0
cur_step_metric_time += time_item[self._duration_idx]
per_step_time_list.append(cur_step_metric_time)
@ -1770,7 +1744,9 @@ class AscendTimelineGenerator(BaseTimelineGenerator):
first_list_len = len(first_time_list)
second_list_len = len(second_time_list)
intersection_segment_display_list = []
while first_list_idx < first_list_len and second_list_idx < second_list_len:
while first_list_idx < first_list_len and \
second_list_idx < second_list_len:
intersection_start = max(
first_time_list[first_list_idx][self._start_time_idx],
second_time_list[second_list_idx][self._start_time_idx]
@ -1780,9 +1756,10 @@ class AscendTimelineGenerator(BaseTimelineGenerator):
second_time_list[second_list_idx][self._duration_idx]
)
if intersection_start < intersection_end:
tid = self._tid_dict.get(display_name, [0, 0])
intersection_segment_display_list.append(
[display_name, self._tid_dict[display_name][0],
intersection_start, intersection_end - intersection_start, self._tid_dict[display_name][1]]
[display_name, tid[0],
intersection_start, intersection_end - intersection_start, tid[1]]
)
if first_time_list[first_list_idx][self._duration_idx] >= \
second_time_list[second_list_idx][self._duration_idx]:
@ -1792,6 +1769,80 @@ class AscendTimelineGenerator(BaseTimelineGenerator):
return intersection_segment_display_list
def init_pynative_timeline(self):
"""Init timeline for pynative model."""
timeline_list = OPIntermediateParser(self._profiling_dir, self._rank_id).get_timeline_data()
cpu_timeline_generator = CpuTimelineGenerator(self._profiling_dir, self._rank_id, self._rank_size)
cpu_timeline_list = cpu_timeline_generator.load_cpu_op_data()
if cpu_timeline_list:
self._pynative_clock_synchronize(cpu_timeline_list)
timeline_list.extend(cpu_timeline_list)
self._timeline_summary['op_exe_times'] = len(timeline_list)
self._max_scope_name_num = self._get_max_scope_name_num(timeline_list)
self._timeline_summary['max_scope_name_num'] = self._max_scope_name_num
timeline_list.sort(key=lambda x: float(x[self._start_time_idx]))
min_cycle_counter = float(timeline_list[0][self._start_time_idx])
step_timeline = self._pynative_get_step_timeline_list(timeline_list)
timeline_list.extend(step_timeline)
stream_count_dict = {}
max_scope_name_num = 0
for timeline in timeline_list:
self._parse_timeline_data(timeline, min_cycle_counter)
self._update_num_of_streams(timeline, stream_count_dict)
cur_scope_name_num = len(timeline[self._op_name_idx].split('/')) - 1
max_scope_name_num = max(cur_scope_name_num, max_scope_name_num)
self._timeline_summary['max_scope_name_num'] = max_scope_name_num
self._timeline_summary['num_of_streams'] = len(stream_count_dict)
def _pynative_get_step_timeline_list(self, timeline_list):
"""Get step timeline list for pynative model."""
step_list = []
if 'GetNext' not in timeline_list[0][self._op_name_idx]:
return step_list
step = [-1, -1]
step_num = 0
tid = "Steps"
for timeline in timeline_list:
if 'GetNext' not in timeline[self._op_name_idx]:
continue
start_time = float(timeline[self._start_time_idx])
if step[0] == -1:
step[0] = start_time
else:
step[1] = start_time - step[0]
step_num = step_num + 1
step_list.append([str(step_num), tid, float(step[0]), step[1]])
step = [start_time, -1]
if step[0] != -1 and step[1] == -1:
step_num = step_num + 1
step_list.append([str(step_num), tid, float(step[0]),
float(timeline_list[-1][self._start_time_idx]) - step[0]])
return step_list
def _pynative_clock_synchronize(self, timeline_list):
"""Synchronize the timestamp from device to host."""
start_time_file_path = os.path.join(self._profiling_dir, f"start_time_{self._rank_id}.txt")
try:
with open(start_time_file_path) as f_obj:
lines = f_obj.readlines()
# lines[0] stores the host monotonic time of start training.
host_monotonic_start_time = int(lines[0].strip().split(':')[-1])
# lines[1] stores the gpu time of start training.
gpu_start_time = int(lines[1].strip().split(':')[-1])
except (IOError, OSError) as err:
logger.critical(f'Error occurred when read {start_time_file_path}: {err}')
raise ProfilerIOException()
time_diff = gpu_start_time * 1000 - host_monotonic_start_time
for idx, time_item in enumerate(timeline_list):
timeline_list[idx][self._start_time_idx] = int(time_item[self._start_time_idx]) + time_diff
timeline_list[idx][self._start_time_idx] = timeline_list[idx][self._start_time_idx] / 1000000
timeline_list[idx][self._duration_idx] = timeline_list[idx][self._duration_idx] / 1000
class CpuTimelineGenerator(GpuTimelineGenerator):
"""Generate cpu Timeline data from file."""

View File

@ -0,0 +1,144 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Op intermediate files parser."""
import csv
import os
from mindspore.profiler.common.exceptions.exceptions import ProfilerFileNotFoundException, \
ProfilerIOException
from mindspore import log as logger
from mindspore.profiler.common.validator.validate_path import validate_and_normalize_path
class OPIntermediateParser:
"""
Op intermediate files parser.
Args:
profiling_dir (str): The directory where the parsed profiling files are
located.
rank_id (str): The rank ID.
"""
_output_timeline_data_file_path = 'output_timeline_data_{}.txt'
_file_name_op_intermediate_type = 'pynative_op_intermediate_{}_type.csv'
_file_name_op_intermediate_detail = 'pynative_op_intermediate_{}_detail.csv'
_op_intermediate_type_header = ['op_type', 'execution_time', 'execution_frequency', 'percent']
_op_intermediate_op_header = ['full_op_name', 'execution_time']
_ms_decimal_digits = 6
_percent_decimal_digits = 2
def __init__(self, profiling_dir, rank_id):
self._profiling_dir = profiling_dir
self._rank_id = rank_id
def get_timeline_data(self, all_reduce_names=None):
"""
Load timeline data from file.
Args:
all_reduce_names (list): The communication operator list.
"""
all_reduce_names = all_reduce_names or []
file_path = os.path.join(
self._profiling_dir,
self._output_timeline_data_file_path.format(self._rank_id)
)
file_path = validate_and_normalize_path(file_path)
if not os.path.exists(file_path):
logger.critical("Failed to find parsed timeline file.")
raise ProfilerFileNotFoundException('parsed timeline file')
timeline_list = []
try:
with open(file_path, 'r') as f_obj:
for line in f_obj:
# line: op_name, stream_id, start_time(ms), duration(ms)
line_list = line.strip('\n').split(',')
# filter out communication operators
if line_list[0] == 'op_name' or line_list[0] in all_reduce_names:
continue
timeline_list.append(line_list)
except (IOError, OSError) as err:
logger.critical('Error occurred when read timeline intermediate file: %s', err)
raise ProfilerIOException()
return timeline_list
def parser_pynative_op_intermediate_detail(self):
"""Parse pynative op intermediate detail."""
timeline_list = self.get_timeline_data(None)
# key:op name, value:[op count, total op execution time]
op_intermediate_detail = {}
for timeline in timeline_list:
op_name = timeline[0].split('/')[-1]
detail = op_intermediate_detail.get(op_name)
if not detail:
detail = [0, 0]
op_intermediate_detail[op_name] = detail
detail[0] = detail[0] + 1
detail[1] = detail[1] + float(timeline[3])
op_op_file_path = os.path.join(self._profiling_dir,
self._file_name_op_intermediate_detail.format(self._rank_id))
with os.fdopen(os.open(op_op_file_path, os.O_WRONLY | os.O_CREAT, 0o660), 'w') as op_file:
csv_writer = csv.writer(op_file)
csv_writer.writerow(self._op_intermediate_op_header)
for op_name, op_name_time_info in op_intermediate_detail.items():
op_info = [
op_name, round(op_name_time_info[1] / op_name_time_info[0], self._ms_decimal_digits)
]
csv_writer.writerow(op_info)
def parser_pynative_op_type(self):
"""Parse pynative op intermediate type."""
timeline_list = self.get_timeline_data(None)
# key:op type, value:[op count, total op execution time, op execution time percent]
op_type_list = {}
for timeline in timeline_list:
type_name = timeline[0].split('/')[-1].split('-')[0]
op_type = op_type_list.get(type_name)
if not op_type:
op_type = [0, 0, 0]
op_type_list[type_name] = op_type
op_type[0] = op_type[0] + 1
op_type[1] = op_type[1] + float(timeline[3])
sum_avg_time = 0
for _, op_type in op_type_list.items():
op_type[1] = op_type[1] / op_type[0]
sum_avg_time = sum_avg_time + op_type[1]
if sum_avg_time <= 0:
logger.error("Operator time must be greater than 0.")
return
for _, op_type in op_type_list.items():
op_type[2] = op_type[1] / sum_avg_time
op_type_file_path = os.path.join(self._profiling_dir,
self._file_name_op_intermediate_type.format(self._rank_id))
with os.fdopen(os.open(op_type_file_path, os.O_WRONLY | os.O_CREAT, 0o660), 'w') as type_file:
csv_writer = csv.writer(type_file)
csv_writer.writerow(self._op_intermediate_type_header)
for op_type, op_type_time_info in op_type_list.items():
type_info = [
op_type, op_type_time_info[1], op_type_time_info[0],
round((op_type_time_info[1] / sum_avg_time) * 100, self._percent_decimal_digits)
]
csv_writer.writerow(type_info)

View File

@ -19,6 +19,7 @@ import time
import json
from enum import Enum
from mindspore.nn.cell import Cell
from mindspore import log as logger, context
from mindspore.communication.management import GlobalComm, get_rank, get_group_size
import mindspore._c_expression as c_expression
@ -34,7 +35,7 @@ from mindspore.profiler.parser.aicpu_data_parser import DataPreProcessParser
from mindspore.profiler.parser.framework_parser import FrameworkParser
from mindspore.profiler.parser.hwts_log_parser import HWTSLogParser
from mindspore.profiler.parser.integrator import Integrator
from mindspore.profiler.parser.integrator import GpuTimelineGenerator, AscendTimelineGenerator, CpuTimelineGenerator
from mindspore.profiler.parser.integrator import GpuTimelineGenerator, CpuTimelineGenerator, AscendTimelineGenerator
from mindspore.profiler.parser.memory_usage_parser import MemoryUsageParser
from mindspore.profiler.parser.minddata_parser import MinddataParser
from mindspore.profiler.parser.minddata_analyzer import MinddataProfilingAnalyzer
@ -44,7 +45,7 @@ from mindspore.profiler.parser.minddata_pipeline_parser import \
from mindspore.profiler.parser.optime_parser import OPComputeTimeParser
from mindspore.profiler.parser.step_trace_parser import GpuStepTraceParser, AscendStepTraceParser
from mindspore.profiler.parser.hccl_parser import HcclParser
from mindspore.nn.cell import Cell
from mindspore.profiler.parser.op_intermediate_parser import OPIntermediateParser
INIT_OP_NAME = 'Default/InitDataSetQueue'
@ -57,9 +58,6 @@ def deprecated(name, version):
def _environment_check():
if c_expression.security.enable_security():
raise RuntimeError("Profiler is not supported if compiled with \'-s on\'")
if context.get_context("mode") == context.PYNATIVE_MODE:
raise RuntimeError("Profiler is not supported in pynative mode currently, "
"and it is only supported in graph mode.")
class ProfileOption(Enum):
@ -140,6 +138,7 @@ class Profiler:
_aicpu_op_output_filename_target = "output_data_preprocess_aicpu_"
_has_analysed = False
_has_initialized = False
_ascend_profiling_options = {}
def __init__(self, **kwargs):
if Profiler._has_initialized:
@ -189,25 +188,13 @@ class Profiler:
self._parse_parameter_for_ascend(**kwargs)
os.environ['DEVICE_ID'] = self._dev_id
profiling_options = json.dumps(self._construct_profiling_options())
self._ascend_profiling_options = json.dumps(self._construct_profiling_options())
# Characters longer than 2048 are ignored, resulting in profiling option resolution errors
if len(profiling_options) > 2048:
if len(self._ascend_profiling_options) > 2048:
msg = f"For '{self.__class__.__name__}', the environment parameter length exceeds " \
f"the limit (2048), please input valid parameters."
logger.critical(msg)
raise ValueError(msg)
# use context interface to open profiling, for the new mindspore version(after 2020.5.21)
self._ascend_profiler = c_expression.AscendProfiler.get_instance()
self._ascend_profiler.init(self._output_path, int(self._dev_id), profiling_options)
base_profiling_container_path = os.path.join(self._output_path, "container")
container_path = os.path.join(base_profiling_container_path, self._dev_id)
data_path = os.path.join(container_path, "data")
data_path = validate_and_normalize_path(data_path)
if not os.path.exists(data_path):
os.makedirs(data_path, exist_ok=True)
# add job id env through user input later
self._job_id_env = 0
if self.start_profile:
self.start()
@ -316,6 +303,19 @@ class Profiler:
self._ascend_analyse()
logger.info("Profiling: all the data have been analyzed.")
def _ascend_pynative_analyse(self):
"""Collect and analyse ascend pynative model performance data."""
op_intermediate_parser = OPIntermediateParser(self._output_path, self._rank_id)
op_intermediate_parser.parser_pynative_op_type()
op_intermediate_parser.parser_pynative_op_intermediate_detail()
timeline_analyser = AscendTimelineGenerator(self._output_path, self._dev_id, self._rank_id,
self._rank_size)
timeline_analyser.init_pynative_timeline()
size_limit = 100 * 1024 * 1024 # 100MB
timeline_analyser.write_timeline(size_limit)
timeline_analyser.write_timeline_summary()
def _ascend_analyse(self):
"""Collect and analyse ascend performance data."""
self._rank_size = 1
@ -330,14 +330,18 @@ class Profiler:
else:
logger.info("No need to stop profiler because profiler has been stopped.")
self._ascend_profiler.finalize()
if context.get_context("mode") == context.PYNATIVE_MODE:
self._ascend_pynative_analyse()
else:
self._ascend_graph_analyse()
job_id = self._get_profiling_job_id()
logger.info("Profiling: job id is %s ", job_id)
def _ascend_graph_op_analyse(self, source_path):
"""
Ascend graph model hwts analyse.
self._check_output_path(output_path=self._output_path)
source_path = os.path.join(self._output_path, job_id)
Returns:
list[obj]: The list is: framework_parser, aicpu_data_parser, optime_parser, op_task_dict
"""
# parse hwts.log.data.45.dev file, and get task profiling data
hwts_output_filename = self._hwts_output_filename_target + self._rank_id + ".txt"
hwts_output_filename = os.path.join(self._output_path, hwts_output_filename)
@ -353,8 +357,7 @@ class Profiler:
framework_parser.parse()
op_task_dict = framework_parser.to_task_id_full_op_name_dict()
if not op_task_dict:
logger.error("Profiling: fail to parse framework files.")
return
raise RuntimeError('Profiling: fail to parse framework files.')
# get op compute time from hwts data and framework data, write output_op_compute_time.txt
opcompute_output_filename = self._opcompute_output_filename_target + self._rank_id + ".txt"
@ -375,6 +378,10 @@ class Profiler:
logger.info("Profiling: analyzing the data preprocess data.")
aicpu_data_parser.execute()
return [framework_parser, aicpu_data_parser, optime_parser, op_task_dict]
def _ascend_graph_minddata_analyse(self, source_path):
"""Analyse mindadata for ascend graph model."""
# Parsing minddata AICPU profiling
logger.info("Profiling: analyzing the minddata AICPU data.")
MinddataParser.execute(source_path, self._output_path, self._rank_id)
@ -395,6 +402,23 @@ class Profiler:
except ProfilerException as err:
logger.warning(err.message)
def _ascend_graph_analyse(self):
"""Ascend graph mode analyse."""
self._ascend_profiler.finalize()
job_id = self._get_profiling_job_id()
logger.info("Profiling: job id is %s ", job_id)
self._check_output_path(output_path=self._output_path)
source_path = os.path.join(self._output_path, job_id)
op_parser_obj = self._ascend_graph_op_analyse(source_path)
framework_parser = op_parser_obj[0]
aicpu_data_parser = op_parser_obj[1]
optime_parser = op_parser_obj[2]
op_task_dict = op_parser_obj[3]
self._ascend_graph_minddata_analyse(source_path)
# analyse op compute time info
try:
logger.info("Profiling: analyzing the operation compute time.")
@ -497,7 +521,7 @@ class Profiler:
raise RuntimeError("The profiler has already started. Use profiler.start() only when start_profile value "
"is set to False.")
#No need to start anything if parse profiling data offline
# No need to start anything if parse profiling data offline
if self._is_offline_parser():
return
@ -507,7 +531,33 @@ class Profiler:
if self._device_target and self._device_target == "GPU":
self._gpu_profiler.step_profiling_enable(True)
elif self._device_target and self._device_target == "Ascend":
self._ascend_profiler.start()
if context.get_context("mode") == context.PYNATIVE_MODE:
self._ascend_pynative_start()
else:
self._ascend_graph_start()
def _ascend_pynative_start(self):
"""Ascend pynative mode start profiling."""
pynative_profiler = c_expression.PynativeProfiler
self._pynative_profiler = pynative_profiler.get_instance()
self._pynative_profiler.init(self._output_path)
def _ascend_graph_start(self):
"""Ascend graph mode start profiling."""
# use context interface to open profiling, for the new mindspore version(after 2020.5.21)
self._ascend_profiler = c_expression.AscendProfiler.get_instance()
self._ascend_profiler.init(self._output_path, int(self._dev_id), self._ascend_profiling_options)
base_profiling_container_path = os.path.join(self._output_path, "container")
container_path = os.path.join(base_profiling_container_path, self._dev_id)
data_path = os.path.join(container_path, "data")
data_path = validate_and_normalize_path(data_path)
if not os.path.exists(data_path):
os.makedirs(data_path, exist_ok=True)
# add job id env through user input later
self._job_id_env = 0
self._ascend_profiler.start()
def stop(self):
"""
@ -544,7 +594,7 @@ class Profiler:
else:
raise RuntimeError("The profiler has not started, so can not stop.")
#No need to stop anything if parse profiling data offline
# No need to stop anything if parse profiling data offline
if self._is_offline_parser():
return
@ -554,7 +604,11 @@ class Profiler:
if self._device_target and self._device_target == "GPU":
self._gpu_profiler.stop()
elif self._device_target and self._device_target == "Ascend":
self._ascend_profiler.stop()
if context.get_context("mode") == context.PYNATIVE_MODE:
self._pynative_profiler.stop()
else:
self._ascend_profiler.stop()
self._stop_time = int(time.time() * 10000000)
logger.info("Profiling: stop time: %d", self._stop_time)

View File

@ -52,3 +52,23 @@ def test_ascend_profiling():
add(Tensor(x), Tensor(y))
profiler.analyse()
assert len(glob.glob(f"{tmpdir}/profiler*/*PROF*/device_*/data/Framework*")) == 4
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@security_off_wrap
def test_ascend_pynative_profiling():
"""
Feature: Test the ascend pynative model profiling
Description: Generate the Net op timeline
Expectation: Timeline generated successfully
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
with tempfile.TemporaryDirectory() as tmpdir:
profiler = Profiler(output_path=tmpdir)
add = Net()
add(Tensor(x), Tensor(y))
profiler.analyse()
assert len(glob.glob(f"{tmpdir}/profiler*/output_timeline_data_*.txt")) == 1

View File

@ -89,7 +89,7 @@ class LeNet5(nn.Cell):
def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1):
"""create dataset for train"""
# define dataset
mnist_ds = ds.MnistDataset(data_path, num_samples=batch_size * 100)
mnist_ds = ds.MnistDataset(data_path, num_samples=batch_size * 10)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0