diff --git a/model_zoo/official/cv/yolov3_darknet53/train.py b/model_zoo/official/cv/yolov3_darknet53/train.py index 1fa0f29b72e..7d213acc5bc 100644 --- a/model_zoo/official/cv/yolov3_darknet53/train.py +++ b/model_zoo/official/cv/yolov3_darknet53/train.py @@ -135,6 +135,14 @@ def network_init(args): devid = int(os.getenv('DEVICE_ID', '0')) context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, device_target=args.device_target, save_graphs=False, device_id=devid) + + profiler = None + if args.need_profiler: + from mindspore.profiler import Profiler + profiling_dir = os.path.join("profiling", + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + profiler = Profiler(output_path=profiling_dir, is_detail=True, is_show_op_path=True) + # init distributed if args.is_distributed: if args.device_target == "Ascend": @@ -155,6 +163,7 @@ def network_init(args): datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) args.logger = get_logger(args.outputs_dir, args.rank) args.logger.save_args(args) + return profiler def parallel_init(args): @@ -169,10 +178,7 @@ def parallel_init(args): def train(): """Train function.""" args = parse_args() - network_init(args) - if args.need_profiler: - from mindspore.profiler import Profiler - profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True) + profiler = network_init(args) loss_meter = AverageMeter('loss') parallel_init(args)