perator CUDA operator performance data.

This commit is contained in:
liuchuting 2022-06-07 15:01:21 +08:00
parent 7eaddb9f49
commit 38797fd535
12 changed files with 175 additions and 39 deletions

View File

@ -122,11 +122,11 @@ void ProfilingReporter::ReportStepPoint(const std::vector<std::shared_ptr<StepPo
}
void ProfilingReporter::DynamicNodeReport(const CNodePtr &node, uint32_t stream_id, uint32_t task_id,
KernelType kernel_type) {
const KernelType kernel_type) {
ReportTask(node, stream_id, task_id, kernel_type);
ReportNode(node, stream_id, task_id, MSPROF_GE_TENSOR_TYPE_INPUT);
ReportNode(node, stream_id, task_id, MSPROF_GE_TENSOR_TYPE_OUTPUT);
MS_LOG(INFO) << "Profiling report one dynamic node data finish.";
MS_LOG(INFO) << "Profiling report one dynamic node <" << node->fullname_with_scope() << "> data finish.";
}
const CNodePtr ProfilingReporter::GetCNode(const std::string &name) const {

View File

@ -64,7 +64,7 @@ class ProfilingReporter {
~ProfilingReporter() = default;
void ReportTasks();
void DynamicNodeReport(const CNodePtr &node, uint32_t stream_id, uint32_t task_id, KernelType kernel_type);
void DynamicNodeReport(const CNodePtr &node, uint32_t stream_id, uint32_t task_id, const KernelType kernel_type);
void ReportStepPoint(const vector<std::shared_ptr<StepPointDesc>> &points);
private:

View File

@ -187,7 +187,7 @@ void AscendProfiler::Finalize() const {
}
void AscendProfiler::GetNodeTaskIdStreamId(const CNodePtr &kernel, uint32_t graph_id, int device_id,
KernelType kernel_type) {
const KernelType kernel_type) {
uint32_t stream_id;
uint32_t task_id;
uint32_t rt_model_id = 0;

View File

@ -46,7 +46,7 @@ class AscendProfiler : public Profiler {
void Finalize() const;
bool IsInitialized() const { return init_flag_; }
void ReportErrorMessage() const;
void GetNodeTaskIdStreamId(const CNodePtr &kernel, uint32_t graph_id, int device_id, KernelType kernel_type);
void GetNodeTaskIdStreamId(const CNodePtr &kernel, uint32_t graph_id, int device_id, const KernelType kernel_type);
bool GetNetDynamicShapeStatus() const { return is_dynamic_shape_net_; }
void SetNetDynamicShapeStatus() { is_dynamic_shape_net_ = true; }
std::map<std::thread::id, uint32_t> last_tid;

View File

@ -194,7 +194,15 @@ void CPUProfiler::SaveProfileData() {
}
}
void CPUProfiler::ClearInst() { op_info_map_.clear(); }
void CPUProfiler::ClearInst() {
op_info_map_.clear();
all_step_start_end_info_.clear();
step_start_end_info_vector_.clear();
all_kernel_info_.clear();
init_flag_ = false;
enable_flag_ = false;
has_find = false;
}
REGISTER_PYBIND_DEFINE(CPUProfiler_, ([](const py::module *m) {
(void)py::class_<CPUProfiler, std::shared_ptr<CPUProfiler>>(*m, "CPUProfiler")

View File

@ -95,7 +95,7 @@ float DataSaver::GetTotalOpTime(const OpInfoMap &op_info_maps) const {
return sum;
}
void DataSaver::WriteOpType(const std::string &saver_base_dir) const {
void DataSaver::WriteOpType(const std::string &saver_base_dir) {
std::string file_path = saver_base_dir + "/" + op_side_ + "_op_type_info_" + device_id_ + ".csv";
std::ofstream ofs(file_path);
// check if the file is writable
@ -123,9 +123,10 @@ void DataSaver::WriteOpType(const std::string &saver_base_dir) const {
ofs.close();
ChangeFileMode(file_path);
MS_LOG(INFO) << "Write " << op_type_infos_.size() << " op type infos into file: " << file_path;
op_type_infos_.clear();
}
void DataSaver::WriteOpDetail(const std::string &saver_base_dir) const {
void DataSaver::WriteOpDetail(const std::string &saver_base_dir) {
std::string file_path = saver_base_dir + "/" + op_side_ + "_op_detail_info_" + device_id_ + ".csv";
std::ofstream ofs(file_path);
if (!ofs.is_open()) {
@ -152,9 +153,10 @@ void DataSaver::WriteOpDetail(const std::string &saver_base_dir) const {
ofs.close();
ChangeFileMode(file_path);
MS_LOG(INFO) << "Write " << op_detail_infos_.size() << " op detail infos into file: " << file_path;
op_detail_infos_.clear();
}
void DataSaver::WriteOpTimestamp(const std::string &saver_base_dir) const {
void DataSaver::WriteOpTimestamp(const std::string &saver_base_dir) {
std::string file_path = saver_base_dir + "/" + op_side_ + "_op_execute_timestamp_" + device_id_ + ".txt";
std::ofstream ofs(file_path);
// check if the file is writable
@ -183,6 +185,9 @@ void DataSaver::WriteOpTimestamp(const std::string &saver_base_dir) const {
}
ofs.close();
ChangeFileMode(file_path);
if (op_side_ == "cpu") {
op_timestamps_map_.clear();
}
}
void DataSaver::ChangeFileMode(const std::string &file_path) const {

View File

@ -120,11 +120,11 @@ class DataSaver {
float GetTotalOpTime(const OpInfoMap &op_info_maps) const;
void WriteOpType(const std::string &saver_base_dir) const;
void WriteOpType(const std::string &saver_base_dir);
void WriteOpDetail(const std::string &saver_base_dir) const;
void WriteOpDetail(const std::string &saver_base_dir);
void WriteOpTimestamp(const std::string &saver_base_dir) const;
void WriteOpTimestamp(const std::string &saver_base_dir);
void ChangeFileMode(const std::string &file_path) const;

View File

@ -290,6 +290,7 @@ void GpuDataSaver::WriteStepTraceAsyncLaunchKernel(const std::string &saver_base
ofs.close();
ChangeFileMode(file_path);
MS_LOG(INFO) << "Write step trace infos into file: " << file_path;
op_timestamps_map_.clear();
}
void GpuDataSaver::WriteStepTrace(const std::string &saver_base_dir) {
@ -336,6 +337,7 @@ void GpuDataSaver::WriteStepTrace(const std::string &saver_base_dir) {
ofs.close();
ChangeFileMode(file_path);
MS_LOG(INFO) << "Write step trace infos into file: " << file_path;
op_timestamps_map_.clear();
}
void GpuDataSaver::WriteStartTime(const std::string &saver_base_dir, const BaseTime &start_time) {

View File

@ -548,8 +548,16 @@ void GPUProfiler::ClearInst() {
op_name_map_.clear();
events_.clear();
activities_enable_.clear();
all_step_start_end_info_.clear();
step_start_end_info_vector_.clear();
all_kernel_info_.clear();
is_init_ = false;
is_dynamic_shape_net_ = false;
enable_flag_ = false;
sync_enable_flag_ = true;
init_flag_ = false;
enable_flag_ = false;
has_find = false;
cupti_callback_events_count_ = 0l;
cupti_callback_events_drop_count_ = 0l;
cupti_activity_events_count_ = 0l;

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Thr parser for parsing framework files."""
"""The parser for parsing framework files."""
import csv
import os
import re
@ -41,14 +41,13 @@ FILE_DATA_STRUCT_DICT = {
FileDataType.TASK_DESC_INFO.value: TASK_DESC_STRUCT
}
COL_NAMES = ['task_id', 'stream_id', 'block_dim', 'full_op_name', 'op_name', 'op_type', 'subgraph', 'op_info']
OpData = namedtuple('OpData', field_names=COL_NAMES)
class FrameworkParser:
"""
Thr parser for parsing framework files.
The parser for parsing framework files.
Args:
profiling_path (str): The profiling path which should contain CANN profiling data.
@ -408,14 +407,15 @@ class FrameworkParser:
class GpuFrameWorkParser:
"""
Thr parser for parsing framework files.
The parser for parsing framework files.
Args:
output_path (str): The profiling path which should contain GPU profiling data.
dev_id (str): The device ID.
"""
def __init__(self, output_path, dev_id, op_names):
"""Thr parser for parsing framework files."""
"""The parser for parsing framework files."""
self._dev_id = dev_id
self._output_path = output_path
self.op_names = op_names
@ -423,14 +423,18 @@ class GpuFrameWorkParser:
self.framework_list = []
self.op_detail = {}
self.operation_info = {}
self.detail_info_dir = []
self.activity_info_dir = []
self.framework_info_dir = []
self.cpu_detail_info_dir = []
self.gpu_detail_info_dir = []
self.op_execute_times = {}
def parse(self):
"""Parse op performance data."""
self.get_device_target_filename()
self.get_framework_summary()
self.get_op_detail_info()
self.get_cpu_op_detail_info()
self.get_activity_op_info()
if isinstance(self.op_names, str):
self.combine_performance_data(self.op_names)
elif isinstance(self.op_names, list):
@ -455,9 +459,9 @@ class GpuFrameWorkParser:
if item not in self.framework_list:
self.framework_list.append(item)
def get_op_detail_info(self):
"""Get op detail data."""
for filename in self.detail_info_dir:
def get_cpu_op_detail_info(self):
"""Get cpu operators detail data."""
for filename in self.cpu_detail_info_dir:
op_side = filename.split('_')[0]
op_detail_file_path = os.path.join(self._output_path, filename)
op_detail_file_path = validate_and_normalize_path(op_detail_file_path)
@ -469,6 +473,48 @@ class GpuFrameWorkParser:
# line_info[4]: op_occurrences, line_info[5]: op_detail_time(us), line_info[6]: op_avg_time(us);
self.op_detail[line_info[2]] = [line_info[4], line_info[5], line_info[6], op_side]
def get_execute_times(self):
"""Get gpu operators execute times."""
if self.gpu_detail_info_dir:
gpu_op_detail_file_path = os.path.join(self._output_path, self.gpu_detail_info_dir[0])
gpu_op_detail_file_path = validate_and_normalize_path(gpu_op_detail_file_path)
with open(gpu_op_detail_file_path, 'r') as fp:
op_detail_info = fp.readlines()
for line_info in op_detail_info[1:]:
line_info = line_info.strip(' ').strip('\n').split(',')
self.op_execute_times[line_info[2]] = line_info[4]
def get_activity_op_info(self):
"""Get op detail data."""
all_file = os.listdir(self._output_path)
for file_name in all_file:
if file_name.startswith('gpu_op_detail') and file_name.endswith(f'{self._dev_id}.csv'):
self.gpu_detail_info_dir.append(file_name)
if not self.gpu_detail_info_dir and self.activity_info_dir:
raise RuntimeError(f'The output file <%s> is not found.' % self.gpu_detail_info_dir)
self.get_execute_times()
for filename in self.activity_info_dir:
op_side = filename.split('_')[0]
activity_file_path = os.path.join(self._output_path, filename)
activity_file_path = validate_and_normalize_path(activity_file_path)
with open(activity_file_path, 'r') as file:
activity_info = file.readlines()
for line_info in activity_info[1:]:
line_info = line_info.strip(' ').strip('\n').replace(', ', ';').split(',')
op_name = line_info[2].split('/')[-1]
op_occurrences = int(self.op_execute_times.get(op_name))
op_total_time = float(line_info[-4])
if not self.op_detail.get(op_name):
# line_info[4]: op_occurrences, line_info[5]: op_detail_time(us), line_info[6]: op_avg_time(us);
self.op_detail[op_name] = [op_occurrences, op_total_time,
round(op_total_time/op_occurrences, 4), op_side]
else:
self.op_detail.get(op_name)[1] += op_total_time
self.op_detail.get(op_name)[2] = self.op_detail.get(op_name)[1] / self.op_detail.get(op_name)[0]
self.op_detail[op_name] = [self.op_detail.get(op_name)[0],
round(self.op_detail.get(op_name)[1], 4),
round(self.op_detail.get(op_name)[2], 4), op_side]
def combine_performance_data(self, op_name):
"""Combine operator detail info with framework info."""
unique_op_info = []
@ -477,6 +523,8 @@ class GpuFrameWorkParser:
factor = 1000 # convert time unit from ms to us.
for line_info in self.framework_list:
op_detail = self.op_detail.get(line_info[1])
if not op_detail:
continue
if op_name in line_info and line_info[3] == op_detail[3]:
op_side = line_info[3]
op_shape = '{}:{}'.format(op_side, ','.join(line_info[2]))
@ -487,7 +535,7 @@ class GpuFrameWorkParser:
# Classify according to the operator information of the same shape.
op_shape_dict.get(op_shape)[0] += op_occurrences
op_shape_dict.get(op_shape)[1] += op_total_time
op_shape_dict.get(op_shape)[2] += op_shape_dict.get(op_shape)[1] / op_shape_dict.get(op_shape)[0]
op_shape_dict.get(op_shape)[2] = op_shape_dict.get(op_shape)[1] / op_shape_dict.get(op_shape)[0]
op_shape_dict[op_shape] = [op_shape_dict.get(op_shape)[0], round(op_shape_dict.get(op_shape)[1], 4),
round(op_shape_dict.get(op_shape)[2], 4), op_side]
else:
@ -510,31 +558,37 @@ class GpuFrameWorkParser:
if unique_op_info:
self.operation_info[op_name] = unique_op_info
else:
logger.warning(f'The information of {op_name} is not found. Please verify that the operator name is correct'
f' or the operator is used in the network.')
raise RuntimeError(f'The information of <{op_name}> is not found. Please verify that the operator name is'
f' correct or the operator is used in the network.')
def get_device_target_filename(self):
"""Get device target filename."""
gpu_framework_file = f'gpu_framework_{self._dev_id}.txt'
cpu_framework_file = f'cpu_framework_{self._dev_id}.txt'
gpu_op_detail_file = f'gpu_op_detail_info_{self._dev_id}.csv'
gpu_activity_file = f'gpu_activity_data_{self._dev_id}.csv'
cpu_op_detail_file = f'cpu_op_detail_info_{self._dev_id}.csv'
all_file = os.listdir(self._output_path)
if not all_file:
raise RuntimeError(f'No profiler file is found in the path <%s>. '
f'Check whether the profiler path is correct.' % self._output_path)
if gpu_op_detail_file in all_file and gpu_framework_file not in all_file:
if gpu_activity_file in all_file and gpu_framework_file not in all_file:
raise RuntimeError(f'The output file <%s> is not found.' % gpu_framework_file)
if cpu_op_detail_file in all_file and cpu_framework_file not in all_file:
raise RuntimeError(f'The output file <%s> is not found.' % cpu_framework_file)
if gpu_op_detail_file not in all_file and cpu_op_detail_file not in all_file:
if gpu_framework_file in all_file and gpu_activity_file not in all_file:
raise RuntimeError(f'The output file <%s> is not found.' % gpu_activity_file)
if cpu_framework_file in all_file and cpu_op_detail_file not in all_file:
raise RuntimeError(f'The output file <%s> is not found.' % cpu_op_detail_file)
if gpu_activity_file not in all_file and cpu_op_detail_file not in all_file:
raise RuntimeError(f'The profiling data of this card which device_id is equal to {self._dev_id} does not'
f' exist. Check whether device_id is correct.')
for file_name in all_file:
if file_name.endswith(f'detail_info_{self._dev_id}.csv'):
self.detail_info_dir.append(file_name)
if file_name.endswith(f'activity_data_{self._dev_id}.csv'):
self.activity_info_dir.append(file_name)
if file_name.endswith(f'framework_{self._dev_id}.txt'):
self.framework_info_dir.append(file_name)
if file_name.startswith('cpu_op_detail') and file_name.endswith(f'{self._dev_id}.csv'):
self.cpu_detail_info_dir.append(file_name)
class DynamicFrameWorkParser:
@ -545,6 +599,7 @@ class DynamicFrameWorkParser:
output_path (str): The profiling path which should contain Ascend profiling data.
rank_id (int): The rank ID.
"""
def __init__(self, output_path, rank_id):
"""Initialization of parsing framework data."""
self._output_path = output_path
@ -590,11 +645,13 @@ class DynamicFrameWorkParser:
timeline_origin_file_path = os.path.join(self._output_path, timeline_origin_file_name)
timeline_origin_file_path = validate_and_normalize_path(timeline_origin_file_path)
aicpu_file_path = os.path.join(self._output_path, aicpu_file_name)
def read_file(file_path):
"""Read file data."""
with open(file_path, 'r') as fp:
file_info = fp.readlines()[1:]
return file_info
timeline_info = read_file(timeline_origin_file_path)
for line_info in timeline_info:
line_info = line_info.strip('\n').split(',')

View File

@ -139,6 +139,7 @@ class Profiler:
self._dev_id = None
self._cpu_profiler = None
self._gpu_profiler = None
self._md_profiler = None
self._init_time = None
self._ascend_job_id = ''
self._job_id_env = None
@ -158,10 +159,6 @@ class Profiler:
self._ascend_dynamic_status = False
self._cpu_dynamic_status = False
self._gpu_dynamic_status = False
# Setup and start MindData Profiling
self._md_profiler = cde.GlobalContext.profiling_manager()
self._md_profiler.init()
self._decide_device_target(kwargs)
if self.start_profile:
self.start()
@ -211,14 +208,24 @@ class Profiler:
if device_id and not isinstance(device_id, int):
raise TypeError(f"For 'Profiler.op_analyse()', the parameter device_id must be int, "
f"but got type {type(device_id)}")
online_device_id = int(self._dev_id)
self._dev_id = self._dev_id if device_id is None else device_id
if self._dev_id is None:
self._dev_id = 0
if not isinstance(op_name, str) and not isinstance(op_name, list):
raise TypeError(f"For 'Profiler.op_analyse()', the parameter op_name must be str or list, "
f"but got type {type(op_name)}")
if not op_name:
raise TypeError(f"For 'Profiler.op_analyse()', the parameter op_name cannot be "", '' or [].")
parser = GpuFrameWorkParser(self._output_path, self._dev_id, op_name)
op_info = parser.parse()
if self._rank_size > 1:
if online_device_id == int(self._dev_id):
return op_info
if online_device_id != int(self._dev_id):
message = f"For 'Profiler.op_analyse()', the parameter device_id is equal to {self._dev_id}, but the " \
f"current device id is {online_device_id}, so no operator performance information is queried."
return message
return op_info
def _decide_device_target(self, kwargs):
@ -249,6 +256,9 @@ class Profiler:
def _gpu_profiler_init(self, kwargs):
"""Gpu profiler init."""
# Setup and start MindData Profiling
self._md_profiler = cde.GlobalContext.profiling_manager()
self._md_profiler.init()
if context.get_context("mode") == context.PYNATIVE_MODE:
raise RuntimeError("Pynative model is not supported on GPU currently.")
self._parse_parameter_for_gpu(kwargs)
@ -262,6 +272,9 @@ class Profiler:
def _ascend_profiler_init(self, kwargs):
"""Ascend profiler init."""
# Setup and start MindData Profiling
self._md_profiler = cde.GlobalContext.profiling_manager()
self._md_profiler.init()
self._init_time = int(time.time() * 10000000)
logger.info("Profiling: profiling init time: %d", self._init_time)
self._parse_parameter_for_ascend(kwargs)
@ -383,10 +396,7 @@ class Profiler:
"""
Collect and analyze training performance data, support calls during and after training. The example shows above.
"""
if Profiler._has_analysed:
msg = "Do not analyze twice in the profiler."
raise RuntimeError(msg)
Profiler._has_analysed = True
Profiler._has_initialized = False
self._cpu_dynamic_status = self._cpu_profiler.dynamic_status()
_environment_check()
@ -674,12 +684,13 @@ class Profiler:
if self._is_offline_parser():
return
self._md_profiler.start()
self._cpu_profiler.step_profiling_enable(True)
if self._device_target and self._device_target == DeviceTarget.GPU.value:
self._md_profiler.start()
self._gpu_profiler.step_profiling_enable(True)
elif self._device_target and self._device_target == DeviceTarget.ASCEND.value:
self._md_profiler.start()
if context.get_context("mode") == context.PYNATIVE_MODE:
self._ascend_pynative_start()
else:

View File

@ -21,7 +21,9 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.dataset as ds
from mindspore import Profiler
from mindspore import Model
from tests.security_utils import security_off_wrap
@ -38,6 +40,26 @@ x = np.random.randn(1, 3, 3, 4).astype(np.float32)
y = np.random.randn(1, 3, 3, 4).astype(np.float32)
class NetWork(nn.Cell):
def __init__(self):
super(NetWork, self).__init__()
self.unique = P.Unique()
self.shape = P.TensorShape()
self.reshape = P.Reshape()
self.add = P.Add()
def construct(self, a, b):
val = self.add(a, b)
size = self.shape(val)
res = self.reshape(val, size)
return res
def dataset_generator():
for i in range(1, 10):
yield (np.ones((32, 2 * i), dtype=np.float32), np.ones((32, 2 * i), dtype=np.float32))
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -72,3 +94,26 @@ def test_ascend_pynative_profiling():
add(Tensor(x), Tensor(y))
profiler.analyse()
assert len(glob.glob(f"{tmpdir}/profiler*/output_timeline_data_*.txt")) == 1
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@security_off_wrap
def test_shape():
"""
Feature: Test the ascend dynamic shape model profiling
Description: Generate the Net dynamic shape data.
Expectation: Dynamic shape data generated successfully
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
with tempfile.TemporaryDirectory() as tmpdir:
network = NetWork()
profiler = Profiler(output_path=tmpdir)
dataset = ds.GeneratorDataset(dataset_generator, ["data1", "data2"])
dataset.set_dynamic_columns(columns={"data1": [32, None], "data2": [32, None]})
model = Model(network)
model.train(1, dataset, dataset_sink_mode=True)
profiler.analyse()
assert len(glob.glob(f"{tmpdir}/profiler*/dynamic_shape_*.json")) == 1