!32522 parallel strategy pb to json

Merge pull request !32522 from zangqx/zangqx_10
This commit is contained in:
i-robot 2022-04-04 04:05:32 +00:00 committed by Gitee
commit e84f16ac7d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 30 additions and 8 deletions

View File

@ -38,7 +38,7 @@ namespace profiler {
namespace ascend {
bool has_save_parallel_strategy = false;
bool has_got_parallel_strategy_data = false;
bool profiling_parallel_strategy_enabled = false;
bool profiling_parallel_strategy_enabled = true;
irpb::ProfilingParallel cache_profiling_parallel_pb;
bool IsProfilingParallelStrategyEnabled() {
@ -168,7 +168,7 @@ void SaveParallelStrategyToFile() {
if (rank_id.empty()) {
rank_id = "0";
}
std::string file_path = dir_path + std::string("/parallel_strategy_") + std::string(rank_id) + std::string(".json");
std::string file_path = dir_path + std::string("/parallel_strategy_pb_") + std::string(rank_id) + std::string(".bin");
MS_LOG(INFO) << "Start to write parallel strategy string, file path is " << file_path;
std::ofstream ofs(file_path);
@ -178,9 +178,7 @@ void SaveParallelStrategyToFile() {
return;
}
std::string profiling_parallel_str;
google::protobuf::util::MessageToJsonString(cache_profiling_parallel_pb, &profiling_parallel_str);
ofs << profiling_parallel_str;
ofs << cache_profiling_parallel_pb.SerializeAsString();
ofs.close();
ChangeFileMode(file_path, S_IRUSR | S_IWUSR);

View File

@ -1845,7 +1845,9 @@ class AscendTimelineGenerator(BaseTimelineGenerator):
def _pynative_get_step_timeline_list(self, timeline_list):
"""Get step timeline list for pynative model."""
step_list = []
if 'GetNext' not in timeline_list[0][self._op_name_idx]:
# The timeline starts with the GetNext op
if len(timeline_list) < 2 or 'GetNext' not in timeline_list[0][self._op_name_idx] and \
'GetNext' not in timeline_list[1][self._op_name_idx]:
return step_list
step = [-1, -1]
step_num = 0

View File

@ -17,11 +17,13 @@ import os
import stat
import time
import json
from google.protobuf.json_format import MessageToJson
from mindspore import log as logger, context
from mindspore.communication.management import GlobalComm, get_rank, get_group_size
import mindspore._c_expression as c_expression
import mindspore._c_dataengine as cde
from mindspore.train.profiling_parallel_pb2 import ProfilingParallel
from mindspore.profiler.common.exceptions.exceptions import ProfilerFileNotFoundException, \
ProfilerIOException, ProfilerException, ProfilerRawFileException
from mindspore.profiler.common.exceptions.exceptions import ProfilerPathErrorException
@ -139,6 +141,7 @@ class Profiler:
self._job_id_env = None
self._filt_optype_names = ''
self._output_path = ''
self._rank_size = 0
_environment_check()
# get device_id and device_target
self._get_devid_rankid_and_devtarget()
@ -156,8 +159,6 @@ class Profiler:
self._decide_device_target(kwargs)
if self.start_profile:
self.start()
elif context.get_context("mode") == context.PYNATIVE_MODE:
raise RuntimeError("Pynative model does not support conditional collection of performance data.")
def _decide_device_target(self, kwargs):
"""Complete Profiler initialization according to device_target"""
@ -532,6 +533,8 @@ class Profiler:
self._dev_id, self._rank_id, is_training_mode_flag)
logger.info("Profiling: analyzing the operation FLOPs.")
flops_parser.execute()
logger.info("Profiling: analyzing the parallel strategy.")
self._analyse_parallel_strategy()
@staticmethod
def _check_output_path(output_path):
@ -578,6 +581,8 @@ class Profiler:
>>> def end(self, run_context):
... self.profiler.analyse()
"""
if not self.start_profile and context.get_context("mode") == context.PYNATIVE_MODE:
raise RuntimeError("Pynative model does not support conditional collection of performance data.")
self._start_time = int(time.time() * 10000000)
logger.info("Profiling: start time: %d", self._start_time)
@ -1104,3 +1109,20 @@ class Profiler:
hccl_parse = HcclParser(hccl_path, self._dev_id, self._rank_id, self._output_path)
hccl_parse.parse()
logger.info("Analyse hccl info successfully.")
def _analyse_parallel_strategy(self):
"""Analyse parallel strategy from proto binary to json."""
binary_file = os.path.join(self._output_path, 'parallel_strategy_pb_{}.bin'.format(self._rank_id))
binary_file = validate_and_normalize_path(binary_file)
if not os.path.isfile(binary_file):
return
with open(binary_file, 'rb') as f:
data = f.read()
parallel = ProfilingParallel()
parallel.ParseFromString(data)
parallel_json = MessageToJson(parallel)
json_file = os.path.join(self._output_path, 'parallel_strategy_{}.json'.format(self._rank_id))
with os.fdopen(os.open(json_file, os.O_WRONLY | os.O_CREAT, 0o660), 'w') as f:
f.write(parallel_json)
os.remove(binary_file)