From e7b3ac0a015d90a91a32cce211ad0f206ec7ee37 Mon Sep 17 00:00:00 2001 From: wandongdong Date: Sat, 9 May 2020 20:25:45 +0800 Subject: [PATCH] fix eval fp16 bug and codex --- example/mobilenetv2_imagenet2012/eval.py | 3 +++ example/mobilenetv2_imagenet2012/launch.py | 31 ++++++++++++++-------- example/mobilenetv2_imagenet2012/train.py | 4 +-- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/example/mobilenetv2_imagenet2012/eval.py b/example/mobilenetv2_imagenet2012/eval.py index 0060862a4e5..f7085859b5d 100644 --- a/example/mobilenetv2_imagenet2012/eval.py +++ b/example/mobilenetv2_imagenet2012/eval.py @@ -42,6 +42,9 @@ if __name__ == '__main__': loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') net = mobilenet_v2(num_classes=config.num_classes) net.to_float(mstype.float16) + for _, cell in net.cells_and_names(): + if isinstance(cell, nn.Dense): + cell.add_flags_recursive(fp32=True) dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size) step_size = dataset.get_dataset_size() diff --git a/example/mobilenetv2_imagenet2012/launch.py b/example/mobilenetv2_imagenet2012/launch.py index 02021a7249a..22c4af0c315 100644 --- a/example/mobilenetv2_imagenet2012/launch.py +++ b/example/mobilenetv2_imagenet2012/launch.py @@ -17,6 +17,7 @@ import os import sys import json import subprocess +import shutil from argparse import ArgumentParser def parse_args(): @@ -126,25 +127,33 @@ def main(): # spawn the processes processes = [] cmds = [] + log_files = [] + env = os.environ.copy() + env['RANK_SIZE'] = str(args.nproc_per_node) for rank_id in range(0, args.nproc_per_node): device_id = visible_devices[rank_id] device_dir = os.path.join(os.getcwd(), 'device{}'.format(rank_id)) - rank_process = 'export RANK_SIZE={} && export RANK_ID={} && export DEVICE_ID={} && '.format(args.nproc_per_node, - rank_id, device_id) + env['RANK_ID'] = str(rank_id) + env['DEVICE_ID'] = str(device_id) if args.nproc_per_node > 1: - rank_process += 'export MINDSPORE_HCCL_CONFIG_PATH={} && '.format(table_fn) - rank_process += 'export RANK_TABLE_FILE={} && '.format(table_fn) - rank_process += 'rm -rf {dir} && mkdir {dir} && cd {dir} && python {script} '.format(dir=device_dir, - script=args.training_script - ) - rank_process += ' '.join(args.training_script_args) + ' > log{}.log 2>&1 &'.format(rank_id) - process = subprocess.Popen(rank_process, shell=True) + env['MINDSPORE_HCCL_CONFIG_PATH'] = table_fn + env['RANK_TABLE_FILE'] = table_fn + if os.path.exists(device_dir): + shutil.rmtree(device_dir) + os.mkdir(device_dir) + cmd = [sys.executable, '-u'] + cmd.append(args.training_script) + cmd.extend(args.training_script_args) + log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') + process = subprocess.Popen(cmd, stdout=log_file, env=env) processes.append(process) - cmds.append(rank_process) - for process, cmd in zip(processes, cmds): + cmds.append(cmd) + log_files.append(log_file) + for process, cmd, log_file in zip(processes, cmds, log_files): process.wait() if process.returncode != 0: raise subprocess.CalledProcessError(returncode=process, cmd=cmd) + log_file.close() if __name__ == "__main__": diff --git a/example/mobilenetv2_imagenet2012/train.py b/example/mobilenetv2_imagenet2012/train.py index d36737821c6..d22e97a290b 100644 --- a/example/mobilenetv2_imagenet2012/train.py +++ b/example/mobilenetv2_imagenet2012/train.py @@ -119,7 +119,7 @@ class Monitor(Callback): print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds, per_step_mseconds, np.mean(self.losses) - ), flush=True) + )) def step_begin(self, run_context): self.step_time = time.time() @@ -139,7 +139,7 @@ class Monitor(Callback): print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format( cb_params.cur_epoch_num - 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss, - np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]), flush=True) + np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) if __name__ == '__main__':