forked from mindspore-Ecosystem/mindspore
profiler cleancode
This commit is contained in:
parent
7b55c9858b
commit
310841bd51
|
@ -17,7 +17,6 @@
|
|||
#include <fstream>
|
||||
#include <numeric>
|
||||
#include "sys/stat.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
|
@ -31,6 +30,10 @@ OpDetailInfo::OpDetailInfo(const std::shared_ptr<OpInfo> op_info, float proporti
|
|||
auto op_type_end_iter = op_full_name_.rfind('-');
|
||||
op_type_ = op_full_name_.substr(op_type_begin_iter, op_type_end_iter - op_type_begin_iter);
|
||||
op_name_ = op_full_name_.substr(op_type_begin_iter);
|
||||
if (op_info->op_count == 0) {
|
||||
MS_LOG(ERROR) << "The num of operations can not be 0.";
|
||||
return;
|
||||
}
|
||||
op_avg_time_ = op_info->op_host_cost_time / op_info->op_count;
|
||||
}
|
||||
|
||||
|
@ -39,6 +42,10 @@ void DataSaver::ParseOpInfo(const OpInfoMap &op_info_maps) {
|
|||
float total_time_sum = GetTotalOpTime(op_info_maps);
|
||||
for (auto item : op_info_maps) {
|
||||
op_timestamps_map_[item.first] = item.second.start_duration;
|
||||
if (total_time_sum == 0.0) {
|
||||
MS_LOG(ERROR) << "The total operation times can not be 0.";
|
||||
return;
|
||||
}
|
||||
float proportion = item.second.op_host_cost_time / total_time_sum;
|
||||
auto op_info = std::make_shared<OpInfo>(item.second);
|
||||
if (op_info == nullptr) {
|
||||
|
@ -52,6 +59,10 @@ void DataSaver::ParseOpInfo(const OpInfoMap &op_info_maps) {
|
|||
// update average time of op type
|
||||
for (auto &op_type : op_type_infos_) {
|
||||
// device_infos: <type_name, op_type_info>
|
||||
if (op_type.second.count_ == 0) {
|
||||
MS_LOG(ERROR) << "The num of operation type can not be 0.";
|
||||
return;
|
||||
}
|
||||
op_type.second.avg_time_ = op_type.second.total_time_ / op_type.second.count_;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Get " << op_detail_infos_.size() << " operation items.";
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
#include "profiler/device/profiling.h"
|
||||
#include "utils/log_adapter.h"
|
||||
namespace mindspore {
|
||||
namespace profiler {
|
||||
struct OpDetailInfo {
|
||||
|
@ -73,6 +74,14 @@ struct OpType {
|
|||
std::string GetGpuHeader() const { return "op_type,type_occurrences,total_time(us),total_proportion,avg_time(us)"; }
|
||||
|
||||
void OutputCpuOpTypeInfo(std::ostream &os) const {
|
||||
if (step_ == 0) {
|
||||
MS_LOG(ERROR) << "The run step can not be 0.";
|
||||
return;
|
||||
}
|
||||
if (count_ == 0) {
|
||||
MS_LOG(ERROR) << "The num of operation type can not be 0.";
|
||||
return;
|
||||
}
|
||||
os << op_type_ << ',' << count_ << ',' << count_ / step_ << ',' << total_time_ << ',' << total_time_ / count_ << ','
|
||||
<< proportion_ << std::endl;
|
||||
}
|
||||
|
|
|
@ -68,6 +68,10 @@ void GpuDataSaver::ParseEvent(const std::vector<Event> &events) {
|
|||
for (auto &device_infos : activity_infos_) {
|
||||
// device_infos: <device_id, DeviceActivityInfos>
|
||||
for (auto &activity_info : device_infos.second) {
|
||||
if (activity_info.second.count_ == 0) {
|
||||
MS_LOG(ERROR) << "The num of operations can not be 0.";
|
||||
return;
|
||||
}
|
||||
// activity_info: <kernel_name, Activity>
|
||||
activity_info.second.avg_duration_ = activity_info.second.total_duration_ / activity_info.second.count_;
|
||||
}
|
||||
|
|
|
@ -339,6 +339,10 @@ void GPUProfiler::OpsParser() {
|
|||
std::sort(order_vec.begin(), order_vec.end(), cmp_func);
|
||||
|
||||
for (auto iter = order_vec.begin(); iter != order_vec.end(); iter++) {
|
||||
if (iter->second.op_count == 0) {
|
||||
MS_LOG(ERROR) << "The num of operations can not be 0.";
|
||||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "GPU_profiler"
|
||||
<< "," << iter->first << "," << iter->second.op_count << "," << iter->second.op_kernel_count << ","
|
||||
<< iter->second.op_kernel_api_count << ","
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
"""Profiler error code and messages."""
|
||||
from enum import unique, Enum
|
||||
|
||||
|
||||
_GENERAL_MASK = 0b00001 << 7
|
||||
_PARSER_MASK = 0b00010 << 7
|
||||
_ANALYSER_MASK = 0b00011 << 7
|
||||
|
@ -24,6 +23,7 @@ _ANALYSER_MASK = 0b00011 << 7
|
|||
class ProfilerMgrErrors(Enum):
|
||||
"""Enum definition for profiler errors"""
|
||||
|
||||
|
||||
@unique
|
||||
class ProfilerErrors(ProfilerMgrErrors):
|
||||
"""Profiler error codes."""
|
||||
|
@ -53,8 +53,6 @@ class ProfilerErrors(ProfilerMgrErrors):
|
|||
PIPELINE_OP_NOT_EXIST_ERROR = 8 | _ANALYSER_MASK
|
||||
|
||||
|
||||
|
||||
|
||||
@unique
|
||||
class ProfilerErrorMsg(Enum):
|
||||
"""Profiler error messages."""
|
||||
|
|
|
@ -46,7 +46,6 @@ class ProfilerException(Exception):
|
|||
self.message = message
|
||||
self.http_code = http_code
|
||||
|
||||
|
||||
@property
|
||||
def error_code(self):
|
||||
"""
|
||||
|
|
|
@ -23,6 +23,7 @@ class HWTSContainer:
|
|||
Args:
|
||||
split_list (list): The split list of metadata in HWTS output file.
|
||||
"""
|
||||
|
||||
def __init__(self, split_list):
|
||||
self._op_name = ''
|
||||
self._duration = None
|
||||
|
@ -79,6 +80,7 @@ class TimelineContainer:
|
|||
Args:
|
||||
split_list (list): The split list of metadata in op_compute output file.
|
||||
"""
|
||||
|
||||
def __init__(self, split_list):
|
||||
self._op_name = split_list[0]
|
||||
self._stream_id = str(split_list[1])
|
||||
|
@ -121,6 +123,7 @@ class MemoryGraph:
|
|||
Args:
|
||||
graph_proto (proto): Graph proto, defined in profiler module.
|
||||
"""
|
||||
|
||||
def __init__(self, graph_proto):
|
||||
self._graph_proto = graph_proto
|
||||
self.graph_id = graph_proto.graph_id
|
||||
|
@ -153,6 +156,7 @@ class MemoryNode:
|
|||
Args:
|
||||
node_proto (proto): Node proto.
|
||||
"""
|
||||
|
||||
def __init__(self, node_proto):
|
||||
self._node_proto = node_proto
|
||||
self.node_id = node_proto.node_id
|
||||
|
@ -192,6 +196,7 @@ class MemoryTensor:
|
|||
Args:
|
||||
tensor_proto (proto): Tensor proto.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor_proto):
|
||||
self._tensor_proto = tensor_proto
|
||||
self.tensor_id = tensor_proto.tensor_id
|
||||
|
|
|
@ -83,6 +83,10 @@ class FlopsParser:
|
|||
op_avg_time = op_avg_time_dict[op_name]
|
||||
# Time unit of op_avg_time is ms.
|
||||
# The unit of gflop_per_second is GFLOPS(1e9).
|
||||
if float(op_avg_time) == 0.0:
|
||||
raise ValueError("All operators take 0 ms.")
|
||||
if peak_flops == 0:
|
||||
raise ValueError("The frequency of an operator is 0.")
|
||||
gflop_per_second = task_fops / float(op_avg_time)
|
||||
flops_utilization = (gflop_per_second * 1e9 / peak_flops) * 100
|
||||
self._flops_summary['FLOPs'] += task_fops
|
||||
|
@ -170,9 +174,9 @@ class FlopsParser:
|
|||
# These formula is provided by HISI profiling.
|
||||
# a cube_fp16 instruction has (16**3)*2 float point operation.
|
||||
# a cube_fp16 instruction has 16*16*32*2 float point operation.
|
||||
cube_fops = cube_fp16_exec*(16**3)*2 + cube_int8_exec*16*16*32*2
|
||||
vec_fops = vec_fp32*32 + vec_fp16_128lane_exec*128 + \
|
||||
vec_fp16_64lane_exec*64 + vec_int32_exec*64 + vec_misc_exec*32
|
||||
cube_fops = cube_fp16_exec * (16 ** 3) * 2 + cube_int8_exec * 16 * 16 * 32 * 2
|
||||
vec_fops = vec_fp32 * 32 + vec_fp16_128lane_exec * 128 + \
|
||||
vec_fp16_64lane_exec * 64 + vec_int32_exec * 64 + vec_misc_exec * 32
|
||||
task_fops = cube_fops + vec_fops
|
||||
|
||||
return task_fops
|
||||
|
@ -231,14 +235,14 @@ class FlopsParser:
|
|||
suffix_name = "(recompute_Gradients)"
|
||||
else:
|
||||
suffix_name = f"({top_level_scope})"
|
||||
scope_list = list(map(lambda x: x+suffix_name, scope_list))
|
||||
scope_list = list(map(lambda x: x + suffix_name, scope_list))
|
||||
scope_list[0] = top_level_scope
|
||||
|
||||
# Add root node (refers to total flops).
|
||||
scope_list.insert(0, "Total")
|
||||
scope_depth = len(scope_list)
|
||||
for idx in range(scope_depth - 1):
|
||||
key_name = scope_list[idx] + " " + scope_list[idx+1]
|
||||
key_name = scope_list[idx] + " " + scope_list[idx + 1]
|
||||
self._flops_each_scope.setdefault(key_name, 0)
|
||||
self._flops_each_scope[key_name] += task_fops
|
||||
|
||||
|
|
|
@ -157,7 +157,7 @@ class HcclParser:
|
|||
csv_reader = csv.reader(src_file)
|
||||
# index_0:step_num, index_1:start_point, index_2:end_point
|
||||
# The unit of time stamp is 10ns. To convert it to μs, you need to divide it by 100.
|
||||
step_timestamps_info = [[info[0], float(info[1])/100, float(info[2])/100]
|
||||
step_timestamps_info = [[info[0], float(info[1]) / 100, float(info[2]) / 100]
|
||||
for info in csv_reader if info[0].isdigit()]
|
||||
|
||||
return step_timestamps_info
|
||||
|
@ -219,6 +219,7 @@ class HcclParser:
|
|||
|
||||
def _calculate_communication_operator_iter_cost(self, file_path):
|
||||
"""Calculate the time-consuming of communication operator in one execution round."""
|
||||
|
||||
def _inner_calculate_communication_operator_iter_cost(events):
|
||||
total_notify_wait = self._calculate_notify_wait_time(events)
|
||||
# Divide information by src dst rank_id.
|
||||
|
@ -362,7 +363,7 @@ class HcclParser:
|
|||
rdma_communication_size = 0
|
||||
rdma_communication_wait_time = 0
|
||||
start_index = 0
|
||||
end_index = len(trace_event)-1
|
||||
end_index = len(trace_event) - 1
|
||||
while start_index < end_index:
|
||||
first_task_type = trace_event[start_index].get("args").get("task type")
|
||||
if first_task_type == CommunicationInfo.RDMASEND.value and start_index < end_index - 1:
|
||||
|
@ -386,10 +387,10 @@ class HcclParser:
|
|||
# The unit of rdma_communication_wait_time is ms.
|
||||
# The unit of rdma_bandwidth is KB/s.
|
||||
# The unit of rdma_communication_size is k_byte and The unit of rdma_communication_time is ms.
|
||||
rdma_communication_wait_time = rdma_communication_wait_time/1e3
|
||||
rdma_communication_size = rdma_communication_size/1e3
|
||||
rdma_communication_time = rdma_communication_time/1e3
|
||||
rdma_bandwidth = rdma_communication_size/(rdma_communication_time/1e3) \
|
||||
rdma_communication_wait_time = rdma_communication_wait_time / 1e3
|
||||
rdma_communication_size = rdma_communication_size / 1e3
|
||||
rdma_communication_time = rdma_communication_time / 1e3
|
||||
rdma_bandwidth = rdma_communication_size / (rdma_communication_time / 1e3) \
|
||||
if rdma_communication_size else 0
|
||||
|
||||
return [rdma_communication_time, rdma_communication_size, rdma_bandwidth, rdma_communication_wait_time]
|
||||
|
@ -413,9 +414,9 @@ class HcclParser:
|
|||
|
||||
# The unit of sdma_bandwidth is KB/s.
|
||||
# The unit of sdma_communication_size is k_byte and The unit of sdma_communication_time is ms.
|
||||
sdma_communication_time = sdma_communication_time/1e3
|
||||
sdma_communication_size = sdma_communication_size/1e3
|
||||
sdma_bandwidth = sdma_communication_size/(sdma_communication_time/1e3) \
|
||||
sdma_communication_time = sdma_communication_time / 1e3
|
||||
sdma_communication_size = sdma_communication_size / 1e3
|
||||
sdma_bandwidth = sdma_communication_size / (sdma_communication_time / 1e3) \
|
||||
if sdma_communication_size else 0
|
||||
return [sdma_communication_time, sdma_communication_size, sdma_bandwidth]
|
||||
|
||||
|
@ -427,7 +428,7 @@ class HcclParser:
|
|||
if task_type == CommunicationInfo.NOTIFY_WAIT.value:
|
||||
total_notify_wait_time += item.get("dur", 0)
|
||||
# The unit of total_notify_wait_time is ms.
|
||||
total_notify_wait_time = total_notify_wait_time/1e3
|
||||
total_notify_wait_time = total_notify_wait_time / 1e3
|
||||
return total_notify_wait_time
|
||||
|
||||
def _calculate_communication_average_value(self, communication_info: list):
|
||||
|
@ -436,8 +437,8 @@ class HcclParser:
|
|||
if communication_info_size == 0:
|
||||
return []
|
||||
# index1: communication_cost,index2:wait_cost,index3:link_info
|
||||
communication_cost_average = sum([i[1] for i in communication_info])/communication_info_size
|
||||
wait_cost_average = sum([i[2] for i in communication_info])/communication_info_size
|
||||
communication_cost_average = sum([i[1] for i in communication_info]) / communication_info_size
|
||||
wait_cost_average = sum([i[2] for i in communication_info]) / communication_info_size
|
||||
link_info = [i[3] for i in communication_info]
|
||||
calculate_type = 'average'
|
||||
link_average_info = self._calculate_link_value(link_info, calculate_type)
|
||||
|
|
|
@ -20,6 +20,7 @@ from mindspore import log as logger
|
|||
from mindspore.profiler.common.validator.validate_path import \
|
||||
validate_and_normalize_path
|
||||
|
||||
|
||||
class HWTSLogParser:
|
||||
"""
|
||||
The Parser for hwts log files.
|
||||
|
@ -112,8 +113,8 @@ class HWTSLogParser:
|
|||
|
||||
if int(task_id) < 25000:
|
||||
task_id = str(stream_id) + "_" + str(task_id)
|
||||
result_data += ("%-14s %-4s %-8s %-9s %-8s %-15s %s\n" %(log_type[int(ms_type, 2)], cnt, core_id,
|
||||
blk_id, task_id, syscnt, stream_id))
|
||||
result_data += ("%-14s %-4s %-8s %-9s %-8s %-15s %s\n" % (log_type[int(ms_type, 2)], cnt, core_id,
|
||||
blk_id, task_id, syscnt, stream_id))
|
||||
|
||||
fwrite_format(self._output_filename, data_source=self._dst_file_title, is_start=True)
|
||||
fwrite_format(self._output_filename, data_source=self._dst_file_column_title)
|
||||
|
|
|
@ -113,6 +113,8 @@ class Integrator:
|
|||
op_type_time_cache[op_type][0] += op_time
|
||||
op_type_time_cache[op_type][1] += 1
|
||||
|
||||
if self._total_time == 0:
|
||||
raise ValueError("The total time of operations can not be 0.")
|
||||
op_type_file_name = 'aicore_intermediate_' + self._device_id + '_type.csv'
|
||||
op_type_file_path = os.path.join(self._profiling_dir, op_type_file_name)
|
||||
with open(op_type_file_path, 'w') as type_file:
|
||||
|
@ -1059,6 +1061,7 @@ class AscendTimelineGenerator(BaseTimelineGenerator):
|
|||
framework_info (dict): The framework metadata.
|
||||
aicpu_info (dict): The metadata of AI CPU operator.
|
||||
min_cycle_counter (float): The minimum cycle counter of the timeline.
|
||||
source_path (str): The source of file.
|
||||
"""
|
||||
if min_cycle_counter == float('inf'):
|
||||
min_cycle_counter = 0
|
||||
|
|
|
@ -34,6 +34,7 @@ GIGABYTES = 1024 * 1024 * 1024
|
|||
|
||||
class MemoryUsageParser:
|
||||
"""MemoryUsageParser to parse memory raw data."""
|
||||
|
||||
def __init__(self, profiling_dir, device_id):
|
||||
self._profiling_dir = profiling_dir
|
||||
self._device_id = device_id
|
||||
|
@ -163,6 +164,7 @@ class MemoryUsageParser:
|
|||
|
||||
class GraphMemoryParser:
|
||||
"""Parse memory usage data for each graph."""
|
||||
|
||||
def __init__(self, graph_proto, points, framework):
|
||||
self.graph = None
|
||||
self.nodes = OrderedDict()
|
||||
|
@ -238,7 +240,7 @@ class GraphMemoryParser:
|
|||
if index == 0:
|
||||
node.mem_change = self._mem_change[index] - self.graph.static_mem
|
||||
else:
|
||||
node.mem_change = self._mem_change[index] - self._mem_change[index-1]
|
||||
node.mem_change = self._mem_change[index] - self._mem_change[index - 1]
|
||||
|
||||
self._update_nodes(node)
|
||||
self._update_tensor_source(node)
|
||||
|
@ -308,7 +310,7 @@ class GraphMemoryParser:
|
|||
elif life_long == 'LifeLongGraphStart': # lifetime is from graph start to tensor end
|
||||
if life_end is not None and life_end >= 0:
|
||||
tensor.life_start = 0
|
||||
self._update_mem_change(size, 0, life_end+1, tensor_id)
|
||||
self._update_mem_change(size, 0, life_end + 1, tensor_id)
|
||||
else:
|
||||
logger.info('Cannot locate lifetime end for tensor: %s', tensor_id)
|
||||
elif life_long == 'LifeLongGraphEnd': # lifetime is from tensor start to graph end
|
||||
|
@ -319,7 +321,7 @@ class GraphMemoryParser:
|
|||
logger.info('Cannot locate lifetime start for tensor: %s', tensor_id)
|
||||
elif life_long == 'LifeLongNone': # lifetime is from tensor start to tensor end
|
||||
if life_start is not None and life_end is not None and life_start <= life_end:
|
||||
self._update_mem_change(size, life_start, life_end+1, tensor_id)
|
||||
self._update_mem_change(size, life_start, life_end + 1, tensor_id)
|
||||
else:
|
||||
logger.info('Cannot locate lifetime start or end for tensor: %s', tensor_id)
|
||||
|
||||
|
|
|
@ -304,6 +304,8 @@ class MinddataProfilingAnalyzer:
|
|||
if metrics and metrics['output_queue']:
|
||||
queue_size = metrics['output_queue']['size']
|
||||
queue_length = metrics['output_queue']['length']
|
||||
if queue_length == 0:
|
||||
raise ValueError("The input queue can not be None.")
|
||||
queue_average_size = round(sum(queue_size) / len(queue_size), 2) if queue_size else -1
|
||||
queue_utilization_pct = round(100 * queue_average_size / queue_length, 2)
|
||||
# Compute percentage of time queue is empty
|
||||
|
|
|
@ -20,8 +20,10 @@ from mindspore import log as logger
|
|||
from mindspore.profiler.common.validator.validate_path import \
|
||||
validate_and_normalize_path
|
||||
|
||||
|
||||
class MinddataParser:
|
||||
"""Minddata Aicpu Parser."""
|
||||
|
||||
@staticmethod
|
||||
def parse_minddata_aicpu_data(minddata_aicpu_source_path):
|
||||
"""
|
||||
|
|
|
@ -262,8 +262,12 @@ class MinddataPipelineParser:
|
|||
output_queue = metrics.get('output_queue')
|
||||
if output_queue:
|
||||
queue_size = output_queue.get('size')
|
||||
if queue_size is None:
|
||||
raise ValueError("The queue can not be None.")
|
||||
queue_average_size = sum(queue_size) / len(queue_size)
|
||||
queue_length = output_queue.get('length')
|
||||
if queue_length == 0:
|
||||
raise ValueError("The length of queue can not be 0.")
|
||||
queue_usage_rate = queue_average_size / queue_length
|
||||
|
||||
children_id = op_node.get('children')
|
||||
|
|
|
@ -24,6 +24,7 @@ from mindspore.profiler.parser.container import HWTSContainer
|
|||
|
||||
TIMELINE_FILE_COLUMN_TITLE = 'op_name, stream_id, start_time(ms), duration(ms)'
|
||||
|
||||
|
||||
class OPComputeTimeParser:
|
||||
"""
|
||||
Join hwts info and framework info, get op time info, and output to the result file.
|
||||
|
@ -102,10 +103,12 @@ class OPComputeTimeParser:
|
|||
for op_name, time in op_name_time_dict.items():
|
||||
if op_name in op_name_stream_dict.keys():
|
||||
stream_id = op_name_stream_dict[op_name]
|
||||
if op_name_count_dict[op_name] == 0:
|
||||
raise ValueError("The number of operations can not be 0.")
|
||||
avg_time = time / op_name_count_dict[op_name]
|
||||
total_time += avg_time
|
||||
result_data += ("%s %s %s\n" %(op_name, str(avg_time), stream_id))
|
||||
result_data += ("total op %s 0" %(str(total_time)))
|
||||
result_data += ("%s %s %s\n" % (op_name, str(avg_time), stream_id))
|
||||
result_data += ("total op %s 0" % (str(total_time)))
|
||||
|
||||
timeline_data = []
|
||||
for op_name, time in op_name_time_dict.items():
|
||||
|
@ -146,8 +149,8 @@ class OPComputeTimeParser:
|
|||
Args:
|
||||
timeline_data (list): The metadata to be written into the file.
|
||||
[
|
||||
['op_name_1', 'stream_id_1', 'start_time_1', 'durarion_1'],
|
||||
['op_name_2', 'stream_id_2', 'start_time_2', 'durarion_2'],
|
||||
['op_name_1', 'stream_id_1', 'start_time_1', 'duration_1'],
|
||||
['op_name_2', 'stream_id_2', 'start_time_2', 'duration_2'],
|
||||
[...]
|
||||
]
|
||||
"""
|
||||
|
|
|
@ -348,12 +348,12 @@ class BaseStepTraceParser:
|
|||
csv_writer = csv.writer(file_handle)
|
||||
if not self._is_training_mode:
|
||||
self._header[FP_DURATION] = 'fp'
|
||||
self._header = self._header[:BP_POINT] + self._header[BP_POINT+1:TAIL]
|
||||
self._header = self._header[:BP_POINT] + self._header[BP_POINT + 1:TAIL]
|
||||
csv_writer.writerow(self._header)
|
||||
for row_data in self._result:
|
||||
if not self._is_training_mode:
|
||||
row_data[FP_DURATION] += row_data[TAIL]
|
||||
row_data = row_data[:BP_POINT] + row_data[BP_POINT+1:TAIL]
|
||||
row_data = row_data[:BP_POINT] + row_data[BP_POINT + 1:TAIL]
|
||||
csv_writer.writerow(row_data)
|
||||
os.chmod(self._output_path, stat.S_IREAD | stat.S_IWRITE)
|
||||
except (IOError, OSError) as err:
|
||||
|
|
|
@ -47,12 +47,14 @@ from mindspore.nn.cell import Cell
|
|||
|
||||
INIT_OP_NAME = 'Default/InitDataSetQueue'
|
||||
|
||||
|
||||
class ProfileOption(Enum):
|
||||
"""
|
||||
Profile Option Enum which be used in Profiler.profile.
|
||||
"""
|
||||
trainable_parameters = 0
|
||||
|
||||
|
||||
class Profiler:
|
||||
"""
|
||||
Performance profiling API.
|
||||
|
@ -211,7 +213,7 @@ class Profiler:
|
|||
"aic_metrics": "PipeUtilization",
|
||||
"aicpu": "on",
|
||||
"profile_memory": profile_memory
|
||||
}
|
||||
}
|
||||
|
||||
return profiling_options
|
||||
|
||||
|
@ -539,7 +541,7 @@ class Profiler:
|
|||
for line in f.readlines():
|
||||
if "clock_realtime" in line:
|
||||
# 16 means the first digit of the timestamp, len(line)-3 means the last.
|
||||
job_start_time = line[16:len(line)-3]
|
||||
job_start_time = line[16:len(line) - 3]
|
||||
|
||||
return job_start_time
|
||||
|
||||
|
@ -697,7 +699,7 @@ class Profiler:
|
|||
hccl_parse.parse()
|
||||
|
||||
@staticmethod
|
||||
def profile(network=None, profile_option=None):
|
||||
def profile(network, profile_option):
|
||||
"""
|
||||
Get the number of trainable parameters in the training network.
|
||||
|
||||
|
|
Loading…
Reference in New Issue