diff --git a/mindspore/profiler/parser/flops_parser.py b/mindspore/profiler/parser/flops_parser.py index e0caf9b9e59..a44525b82aa 100644 --- a/mindspore/profiler/parser/flops_parser.py +++ b/mindspore/profiler/parser/flops_parser.py @@ -101,6 +101,8 @@ class FlopsParser: op_name_set.add(op_name) self._add_flops_to_each_scope(op_name, task_fops) + if not op_name_set: + raise ProfilerRawFileException("No aicore operator found.") self._flops_summary['FLOPS'] /= len(op_name_set) self._flops_summary['FLOPS_Utilization'] /= len(op_name_set) self._format_scope_flops() diff --git a/mindspore/profiler/parser/integrator.py b/mindspore/profiler/parser/integrator.py index 7f57196a9d2..1fc64906bc2 100644 --- a/mindspore/profiler/parser/integrator.py +++ b/mindspore/profiler/parser/integrator.py @@ -20,6 +20,8 @@ import stat from decimal import Decimal from mindspore import log as logger +from mindspore.context import get_auto_parallel_context +from mindspore.communication.management import get_group_size from mindspore.profiler.common.exceptions.exceptions import ProfilerIOException, \ ProfilerFileNotFoundException, ProfilerRawFileException, ProfilerParamValueErrorException from mindspore.profiler.common.util import query_latest_trace_time_file, to_int, to_millisecond @@ -898,7 +900,6 @@ class GpuTimelineGenerator(BaseTimelineGenerator): return activity_timeline_list - def init_timeline(self): """Init timeline metadata, adding all collected info.""" timeline_list = self._load_timeline_data() @@ -988,10 +989,17 @@ class AscendTimelineGenerator(BaseTimelineGenerator): """Generate ascend Timeline data from file.""" _display_filename = 'ascend_timeline_display_{}.json' _timeline_summary_filename = 'ascend_timeline_summary_{}.json' + _cluster_analyse_filename = 'ascend_cluster_analyse_{}_{}_{}_{}.csv' def __init__(self, profiling_dir, device_id): self._profiling_dir = profiling_dir self._device_id = device_id + self._tid_dict = { + "aicore": 7999, + "communication_not_overlapped": 8000, + "communication": 8001, + "free_time": 8002 + } def _load_timeline_data(self): """Load timeline data from file.""" @@ -1046,12 +1054,12 @@ class AscendTimelineGenerator(BaseTimelineGenerator): self._update_format_meta_data(timeline_dict) self._timeline_meta.append(timeline_dict) - def init_timeline(self, all_reduce_info, framework_info, aicpu_info, min_cycle_counter, source_path): + def init_timeline(self, communication_info, framework_info, aicpu_info, min_cycle_counter, source_path): """ Init timeline metadata, adding all collected info. Args: - all_reduce_info (list[list]): The metadata of AllReduce operator. + communication_info (list[list]): The metadata of communication operator. framework_info (dict): The framework metadata. aicpu_info (dict): The metadata of AI CPU operator. min_cycle_counter (float): The minimum cycle counter of the timeline. @@ -1080,29 +1088,34 @@ class AscendTimelineGenerator(BaseTimelineGenerator): default_scope_name_time_list = self._get_scope_name_time_list(timeline_list, "Default") gradient_scope_name_time_list = self._get_scope_name_time_list(timeline_list, "Gradients") recompute_scope_name_time_list = self._get_scope_name_time_list(timeline_list, "recompute_Default") - timeline_list.extend(step_time_list) - timeline_list.extend(default_scope_name_time_list) - timeline_list.extend(recompute_scope_name_time_list) - timeline_list.extend(gradient_scope_name_time_list) timeline_list.sort(key=lambda x: (float(x[self._start_time_idx]), x[self._tid_idx])) # Add AllReduce info to timeline temp list and sort by start time. - if all_reduce_info: + if communication_info: logger.debug('AllReduce info found. Start adding info into timeline...') - timeline_list.extend(all_reduce_info) + cluster_related_timeline = self._analyse_and_write_cluster_profiling_data( + timeline_list, communication_info, step_time_list) + timeline_list.extend(cluster_related_timeline) + timeline_list.extend(communication_info) timeline_list.sort(key=lambda x: float(x[self._start_time_idx])) # Add AI CPU data into timeline temp list and sort by start time. aicpu_data = aicpu_info.get('info') if aicpu_data: timeline_list.extend(aicpu_data) - timeline_list.sort(key=lambda x: float(x[2])) self._timeline_summary['op_exe_times'] += aicpu_info.get('op_exe_times', 0) self._timeline_summary['num_of_streams'] += aicpu_info.get('num_of_streams', 0) self._timeline_summary['num_of_ops'] += aicpu_info.get('num_of_ops', 0) self._timeline_summary['total_time'] += aicpu_info.get('total_time', 0) + # Add step time and scope name info. + timeline_list.extend(step_time_list) + timeline_list.extend(default_scope_name_time_list) + timeline_list.extend(recompute_scope_name_time_list) + timeline_list.extend(gradient_scope_name_time_list) + timeline_list.sort(key=lambda x: float(x[self._start_time_idx])) + # Init a dict for counting the num of streams. stream_count_dict = {} for timeline in timeline_list: @@ -1187,6 +1200,246 @@ class AscendTimelineGenerator(BaseTimelineGenerator): timeline_item['args'] = framework_item.get('args') logger.debug('Finished adding framework info into timeline...') + def _produce_two_separated_timeline(self, timeline, op_name): + """Produce two separated timeline based on op_name""" + timeline_include_op_name = [] + timeline_exclude_op_name = [] + for time_item in timeline: + if op_name in time_item[self._op_name_idx]: + timeline_include_op_name.append(time_item) + else: + timeline_exclude_op_name.append(time_item) + return timeline_include_op_name, timeline_exclude_op_name + + def _analyse_and_write_cluster_profiling_data(self, aicore_timeline, communication_timeline, step_time_list): + """ + Analyse the cluster communication and computation data, and write it to file. + + To analyse the cluster performance bottleneck based on timeline, define the time of a training + step as "t_total", propose five metrics as follows: + 1) The time that "receive" operators not overlapped by others(t1) + 2) The time that is consumed inside the stage(t_total - t1) + 3) The time that "communication" operators not overlapped by others(t2) + 4) The time that consumed by computation(t_total - t2) + 5) The time that "collective communication" operators not overlapped by others(t3) + In pipeline parallel mode, we can locate slow stage based on "t_total-t1". Inside each stage, + we can locate slow card based on "t_total-t2". The value of t1 indicates the degree that + communication time between stages slow down the training. The value of t3 indicates the degree + that communication inside each stage slow down the training. + """ + is_pipeline_parallel = False + comm_merged_timeline, _, comm_display_timeline = self._get_merged_time_list( + communication_timeline, + display_name="communication" + ) + aicore_timeline_interval, _, aicore_display_timeline = self._get_merged_time_list( + aicore_timeline, + get_interval_time=True + ) + # Consider if the overlap will be 0 or not. + comm_not_overlaped_timeline = self._get_intersection_time( + aicore_timeline_interval, + comm_merged_timeline + ) + + # Process receive part. + all_timeline = aicore_timeline + communication_timeline + all_timeline.sort(key=lambda x: float(x[self._start_time_idx])) + receive_op_timeline, timeline_exclude_receive_op = self._produce_two_separated_timeline( + all_timeline, + "Receive-op" + ) + if receive_op_timeline: + is_pipeline_parallel = True + receive_op_merged_timeline = self._get_merged_time_list(receive_op_timeline)[0] + timeline_exclude_receive_op_interval = self._get_merged_time_list( + timeline_exclude_receive_op, + get_interval_time=True + )[0] + receive_op_not_overlaped_timeline = self._get_intersection_time( + timeline_exclude_receive_op_interval, + receive_op_merged_timeline + ) + + # Process collective communication part. + collective_comm_timeline = self._produce_two_separated_timeline( + communication_timeline, + "Receive-op" + )[-1] + collective_comm_merged_timeline = self._get_merged_time_list(collective_comm_timeline)[0] + collective_comm_not_overlaped_timeline = self._get_intersection_time( + aicore_timeline_interval, + collective_comm_merged_timeline + ) + + # Generate free time that exclude computation and communication time. + free_timeline = self._get_merged_time_list( + all_timeline, + get_interval_time=True, + display_name="free_time" + )[1] + + # Compute these five metrics mentioned above per step. + recieve_alone_time = self._compute_time_inside_step(receive_op_not_overlaped_timeline, step_time_list) + stage_time, computation_time = [], [] + comm_alone_time = self._compute_time_inside_step(comm_not_overlaped_timeline, step_time_list) + collective_comm_alone_time = self._compute_time_inside_step( + collective_comm_not_overlaped_timeline, + step_time_list + ) + for step in range(len(step_time_list)): + if is_pipeline_parallel: + stage_time.append(step_time_list[step][self._duration_idx] - recieve_alone_time[step]) + computation_time.append(step_time_list[step][self._duration_idx] - comm_alone_time[step]) + + metrices_per_step_list = [computation_time, comm_alone_time, stage_time, + recieve_alone_time, collective_comm_alone_time] + self._write_cluster_metrices(metrices_per_step_list, is_pipeline_parallel) + res_timeline = [] + res_timeline.extend(comm_not_overlaped_timeline) + res_timeline.extend(aicore_display_timeline) + res_timeline.extend(comm_display_timeline) + res_timeline.extend(free_timeline) + + return res_timeline + + def _write_cluster_metrices(self, metrices, is_pipeline_parallel): + """Write cluster metric""" + # Note that cluster bottleneck analyse do not support offline parse, + # due to that parallel context is not set. + try: + parallel_mode = get_auto_parallel_context("parallel_mode") + stage_num = get_auto_parallel_context("pipeline_stages") + rank_size = get_group_size() + except RuntimeError: + logger.error("[profiler] cluster bottleneck analyse do not support offline parse.") + parallel_mode = "data-parallel" + stage_num, rank_size = 1, 1 + + cluster_analyse_file_path = os.path.join( + self._profiling_dir, + self._cluster_analyse_filename.format(parallel_mode, stage_num, rank_size, self._device_id) + ) + cluster_analyse_file_path = validate_and_normalize_path(cluster_analyse_file_path) + + try: + with open(cluster_analyse_file_path, 'w') as file_handle: + csv_writer = csv.writer(file_handle) + if is_pipeline_parallel: + header = ['computation_time', 'communication_alone_time', 'stage_time', + 'receive_alone_time', 'collective_communication_alone_time'] + zip_metrices = zip(metrices[0], metrices[1], metrices[2], metrices[3], metrices[4]) + else: + header = ['computation_time', 'communication_alone_time'] + zip_metrices = zip(metrices[0], metrices[1]) + csv_writer.writerow(header) + for row_data in zip_metrices: + row_data = [round(val, 4) for val in row_data] + csv_writer.writerow(row_data) + os.chmod(cluster_analyse_file_path, stat.S_IREAD | stat.S_IWRITE) + except (IOError, OSError) as err: + logger.warning(f'Failed to save {cluster_analyse_file_path}. {err}') + raise ProfilerIOException + + def _compute_time_inside_step(self, metric_timeline, step_time_list): + """Compute per step time of metric_timeline""" + per_step_time_list = [] + step = 0 + cur_step_metric_time = 0 + step_end_time = step_time_list[step][self._start_time_idx] + \ + step_time_list[step][self._duration_idx] + for time_item in metric_timeline: + start_time = time_item[self._start_time_idx] + if start_time > step_end_time: + per_step_time_list.append(cur_step_metric_time) + step += 1 + step_end_time = step_time_list[step][self._start_time_idx] + \ + step_time_list[step][self._duration_idx] + cur_step_metric_time = 0 + cur_step_metric_time += time_item[self._duration_idx] + per_step_time_list.append(cur_step_metric_time) + + return per_step_time_list + + def _get_merged_time_list(self, time_list, get_interval_time=False, display_name="aicore"): + """ + Get merged time segment list. + + The process of merge is, for example, there is a list [[1,5], [2,6], [7,8]], + each items in this list contains a start_time and end_time, + the merged result is [[1,6], [7,8]]. + """ + time_merged_segment_list = [] + tid = self._tid_dict[display_name] + for time_item in time_list: + time_segment = list(map(float, time_item[self._start_time_idx:self._duration_idx+1])) + time_segment[1] += time_segment[0] + if not time_merged_segment_list or \ + time_segment[0] > time_merged_segment_list[-1]: + time_merged_segment_list.extend(time_segment) + else: + time_merged_segment_list[-1] = max( + time_merged_segment_list[-1], + time_segment[1] + ) + + # merged_display_list data used for ui page. + merged_display_list = [ + [display_name, tid, time_merged_segment_list[i * 2], + time_merged_segment_list[i * 2 + 1] - time_merged_segment_list[i * 2]] + for i in range(len(time_merged_segment_list) // 2) + ] + + if get_interval_time: + time_merged_segment_list = time_merged_segment_list[1:-1] + + # merged_res_list data used to compute overlap with other time_list. + merged_res_list = [ + [display_name, tid, time_merged_segment_list[i * 2], time_merged_segment_list[i * 2 + 1]] + for i in range(len(time_merged_segment_list) // 2) + ] + + # interval_display_list is interval time used for ui page. + interval_display_list = [ + [display_name, tid, time_merged_segment_list[i * 2], + time_merged_segment_list[i * 2 + 1] - time_merged_segment_list[i * 2]] + for i in range(len(time_merged_segment_list) // 2) + ] + + return merged_res_list, interval_display_list, merged_display_list + + def _get_intersection_time(self, first_time_list, second_time_list, + display_name="communication_not_overlapped"): + """Get intersection time of two time list.""" + first_list_idx, second_list_idx = 0, 0 + first_list_len = len(first_time_list) + second_list_len = len(second_time_list) + intersection_segment_display_list = [] + + while first_list_idx < first_list_len and \ + second_list_idx < second_list_len: + intersection_start = max( + first_time_list[first_list_idx][self._start_time_idx], + second_time_list[second_list_idx][self._start_time_idx] + ) + intersection_end = min( + first_time_list[first_list_idx][self._duration_idx], + second_time_list[second_list_idx][self._duration_idx] + ) + if intersection_start < intersection_end: + intersection_segment_display_list.append( + [display_name, self._tid_dict[display_name], + intersection_start, intersection_end - intersection_start] + ) + if first_time_list[first_list_idx][self._duration_idx] >= \ + second_time_list[second_list_idx][self._duration_idx]: + second_list_idx += 1 + else: + first_list_idx += 1 + + return intersection_segment_display_list + + class CpuTimelineGenerator(GpuTimelineGenerator): """Generate cpu Timeline data from file.""" _output_op_execute_time_file_path = "cpu_op_execute_timestamp_{}.txt" diff --git a/mindspore/profiler/profiling.py b/mindspore/profiler/profiling.py index ca9d3d1d6d2..61a1aa0b79a 100644 --- a/mindspore/profiler/profiling.py +++ b/mindspore/profiler/profiling.py @@ -186,7 +186,7 @@ class Profiler: "bp_point": bp_point, "training_trace": "on", "task_trace": "on", - "aic_metrics": "PipeUtilization", + "aic_metrics": "ArithmeticUtilization", "aicpu": "on", "profile_memory": profile_memory }