The profiler cancels the validation of op_summary

This commit is contained in:
臧庆香 2024-05-02 22:48:37 +08:00
parent 915305f3f8
commit 7ad9d8f1e2
6 changed files with 159 additions and 8 deletions

View File

@ -251,10 +251,10 @@ class AscendMsprofExporter:
msprof_json.add(f)
if not op_summary:
raise RuntimeError("The op_summary csv file was not found, perhaps the original data was not collected.")
logger.warning("The op_summary csv file was not found, perhaps the original data was not collected.")
if not op_statistic:
raise RuntimeError("The op_statistics csv file was not found, perhaps the original data was not collected.")
if not msprof_json:
raise RuntimeError("The msprof json file was not found, perhaps the original data was not collected.")
logger.warning("The msprof json file was not found, perhaps the original data was not collected.")
logger.info("Finish checking files.")

View File

@ -88,7 +88,10 @@ class AscendMsprofDataGenerator:
"""read op summary to memory"""
op_summary = []
op_summary_name = fr'{self.mindstudio_profiler_output}/op_summary_*.csv'
op_summary_file = get_newest_file(glob.glob(op_summary_name))[0]
op_summary_files = glob.glob(op_summary_name)
if not op_summary_files:
return
op_summary_file = get_newest_file(op_summary_files)[0]
with open(op_summary_file, newline='') as csvfile:
reader = csv.DictReader(csvfile, delimiter=',', quotechar='"')
for row in reader:

View File

@ -67,7 +67,9 @@ class AscendOPGenerator:
"""
Analyse op summary op statistic generate op data.
"""
if isinstance(self.op_summary, np.ndarray) and self.op_summary.shape[0] == 0 or \
not isinstance(self.op_summary, np.ndarray) and not self.op_summary:
return
self._combine_op_and_kernel(self.op_summary, self.launch_ops)
# aicore intermediation detail
self.op_detail = self._parse_op_detail(self.op_summary)
@ -173,6 +175,9 @@ class AscendOPGenerator:
def _combine_op_and_kernel(self, op_summary, launch_ops):
"""update op name, kernel name etc."""
if isinstance(op_summary, np.ndarray) and op_summary.shape[0] == 0 or not isinstance(op_summary, np.ndarray) \
and not op_summary:
return
self._full_kernel_name = op_summary['Op Name'].copy()
self._op_name = op_summary['Op Name'].copy()
self._kernel_name = np.array(
@ -199,6 +204,9 @@ class AscendOPGenerator:
Args:
op_summary(DataFrame): op summary data.
"""
if isinstance(op_summary, np.ndarray) and op_summary.shape[0] == 0 or \
not isinstance(op_summary, np.ndarray) and not op_summary:
return None
if self.aclnn_status:
op_detail = np.empty((len(op_summary),), dtype=self.op_detail_dt)
op_detail['task_type'] = op_summary['Task Type']
@ -226,7 +234,9 @@ class AscendOPGenerator:
Args:
op_statistic(DataFrame): op statistic data.
"""
if isinstance(op_statistic, np.ndarray) and op_statistic.shape[0] == 0 or \
not isinstance(op_statistic, np.ndarray) and not op_statistic:
return None
groups, _, inverse, _ = np.unique(op_statistic['Op Type'], return_index=True, return_inverse=True,
return_counts=True)
@ -246,7 +256,9 @@ class AscendOPGenerator:
Args:
op_summary(DataFrame): op summary data.
"""
if isinstance(op_summary, np.ndarray) and op_summary.shape[0] == 0 or \
not isinstance(op_summary, np.ndarray) and not op_summary:
return None
op_summary = op_summary[op_summary['Task Type'] == 'AI_CPU']
aicpu_detail = np.empty((len(op_summary),), dtype=self.aicpu_detail_dt)
@ -271,6 +283,8 @@ class AscendOPGenerator:
def op_info_analyse(row):
"""generate op info data"""
if not row['Input Shapes']:
return ""
input_shapes = row['Input Shapes'].replace('"', '').split(';')
input_data_types = row['Input Data Types'].replace('_', '').split(';')
input_formats = row['Input Formats'].replace('_', '').split(';')
@ -295,7 +309,9 @@ class AscendOPGenerator:
'shape': output_shapes[i]
}
return json.dumps(op_info)
if isinstance(op_summary, np.ndarray) and op_summary.shape[0] == 0 or \
not isinstance(op_summary, np.ndarray) and not op_summary:
return None
if self.dynamic_status or self.aclnn_status:
index = list(range(op_summary.shape[0]))
else:

View File

@ -63,7 +63,9 @@ class AscendTimelineGenerator(BaseTimelineGenerator):
"""
logger.info('parse cluster data...')
if isinstance(op_summary, np.ndarray) and op_summary.shape[0] == 0 or \
not isinstance(op_summary, np.ndarray) and not op_summary:
return
timeline_list = op_summary[~np.isin(op_summary['Task Type'], ['AI_CPU', 'HCCL'])][
['Op Name', 'Stream ID', 'Task Start Time', 'Task Duration']]

View File

@ -1469,6 +1469,9 @@ class Profiler:
ProfilerInfo.set_export_flag(flag)
op_summary, op_statistic, steptrace, steptrace_model \
= _ascend_graph_msprof_analyse(mindstudio_profiler_output)
if isinstance(op_statistic, np.ndarray) and op_statistic.shape[0] == 0 or \
not isinstance(op_statistic, np.ndarray) and not op_statistic:
return
kernels = self._ascend_timeline_analyse(op_summary, steptrace, source_path, mindstudio_profiler_output)
launch_ops = self._get_kernel_op_map(op_summary, kernels)
self._ascend_op_analyse(op_summary, op_statistic, self._dynamic_status, launch_ops)

View File

@ -0,0 +1,127 @@
# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test dataset profiling."""
import os
import tempfile
import glob
import pytest
import mindspore.dataset as ds
from mindspore.dataset import DSCallback
from mindspore import dtype as mstype
import mindspore.log as logger
import mindspore.dataset.transforms as transforms
import mindspore as ms
from mindspore.profiler import Profiler
from tests.security_utils import security_off_wrap
MNIST_DIR = "/home/workspace/mindspore_dataset/mnist/"
CIFAR10_DIR = "/home/workspace/mindspore_dataset/cifar-10-batches-bin/"
def create_dict_iterator(datasets):
"""create_dict_iterator"""
count = 0
for _ in datasets.create_dict_iterator(num_epochs=1, output_numpy=True):
count += 1
class PrintInfo(DSCallback):
"""PrintInfo"""
@staticmethod
def ds_begin(ds_run_context):
"""ds_begin"""
logger.info("callback: start dataset pipeline", ds_run_context.cur_epoch_num)
@staticmethod
def ds_epoch_begin(ds_run_context):
"""ds_epoch_begin"""
logger.info("callback: epoch begin, we are in epoch", ds_run_context.cur_epoch_num)
@staticmethod
def ds_epoch_end(ds_run_context):
"""ds_epoch_end"""
logger.info("callback: epoch end, we are in epoch", ds_run_context.cur_epoch_num)
@staticmethod
def ds_step_begin(ds_run_context):
"""ds_step_begin"""
logger.info("callback: step start, we are in epoch", ds_run_context.cur_step_num)
@staticmethod
def ds_step_end(ds_run_context):
"""ds_step_end"""
logger.info("callback: step end, we are in epoch", ds_run_context.cur_step_num)
def add_one_by_epoch(batchinfo):
"""add_one_by_epoch"""
return batchinfo.get_epoch_num() + 1
def other_method_dataset():
"""create other_method dataset"""
path_base = os.path.split(os.path.realpath(__file__))[0]
data = []
for d in range(10):
data.append(d)
dataset = ds.GeneratorDataset(data, "column1")
dataset = dataset.batch(batch_size=add_one_by_epoch)
create_dict_iterator(dataset)
dataset = ds.GeneratorDataset([1, 2], "col1", shuffle=False, num_parallel_workers=1)
dataset = dataset.map(operations=lambda x: x, callbacks=PrintInfo())
create_dict_iterator(dataset)
schema = ds.Schema()
schema.add_column(name='col1', de_type=mstype.int64, shape=[2])
columns1 = [{'name': 'image', 'type': 'int8', 'shape': [3, 3]},
{'name': 'label', 'type': 'int8', 'shape': [1]}]
schema.parse_columns(columns1)
pipeline1 = ds.MnistDataset(MNIST_DIR, num_samples=100)
pipeline2 = ds.Cifar10Dataset(CIFAR10_DIR, num_samples=100)
ds.compare(pipeline1, pipeline2)
dataset = ds.MnistDataset(MNIST_DIR, num_samples=100)
one_hot_encode = transforms.OneHot(10)
dataset = dataset.map(operations=one_hot_encode, input_columns="label")
dataset = dataset.batch(batch_size=10, drop_remainder=True)
ds.serialize(dataset, json_filepath=os.path.join(path_base, "mnist_dataset_pipeline.json"))
ds.show(dataset)
serialized_data = ds.serialize(dataset)
ds.deserialize(input_dict=serialized_data)
return dataset
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_onecard
@security_off_wrap
def test_ascend_dataset_profiler():
"""
Feature: Test the dataset profiling.
Description: Traverse the dataset data, perform data preprocessing, and then verify the collected profiling data.
Expectation: No dataset_iterator_profiling file generated.
"""
ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
with tempfile.TemporaryDirectory() as tmpdir:
profiler = Profiler(output_path=tmpdir)
other_method_dataset()
profiler.analyse()
assert len(glob.glob(f"{tmpdir}/profiler*/dataset_iterator_profiling_*.txt")) == 1