!13719 modified train.py in network
From: @shuzigood Reviewed-by: Signed-off-by:
This commit is contained in:
commit
7324ee14c6
|
@ -135,6 +135,14 @@ def network_init(args):
|
||||||
devid = int(os.getenv('DEVICE_ID', '0'))
|
devid = int(os.getenv('DEVICE_ID', '0'))
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||||
device_target=args.device_target, save_graphs=False, device_id=devid)
|
device_target=args.device_target, save_graphs=False, device_id=devid)
|
||||||
|
|
||||||
|
profiler = None
|
||||||
|
if args.need_profiler:
|
||||||
|
from mindspore.profiler import Profiler
|
||||||
|
profiling_dir = os.path.join("profiling",
|
||||||
|
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||||
|
profiler = Profiler(output_path=profiling_dir, is_detail=True, is_show_op_path=True)
|
||||||
|
|
||||||
# init distributed
|
# init distributed
|
||||||
if args.is_distributed:
|
if args.is_distributed:
|
||||||
if args.device_target == "Ascend":
|
if args.device_target == "Ascend":
|
||||||
|
@ -155,6 +163,7 @@ def network_init(args):
|
||||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||||
args.logger.save_args(args)
|
args.logger.save_args(args)
|
||||||
|
return profiler
|
||||||
|
|
||||||
|
|
||||||
def parallel_init(args):
|
def parallel_init(args):
|
||||||
|
@ -169,10 +178,7 @@ def parallel_init(args):
|
||||||
def train():
|
def train():
|
||||||
"""Train function."""
|
"""Train function."""
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
network_init(args)
|
profiler = network_init(args)
|
||||||
if args.need_profiler:
|
|
||||||
from mindspore.profiler import Profiler
|
|
||||||
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)
|
|
||||||
|
|
||||||
loss_meter = AverageMeter('loss')
|
loss_meter = AverageMeter('loss')
|
||||||
parallel_init(args)
|
parallel_init(args)
|
||||||
|
|
Loading…
Reference in New Issue