From e03fbc0b98b92848813e48c74c7ed97eb253d788 Mon Sep 17 00:00:00 2001 From: zhaoting Date: Tue, 9 Jun 2020 20:58:25 +0800 Subject: [PATCH] fix resnet50 distribute bug --- example/resnet50_cifar10/train.py | 2 ++ example/resnet50_imagenet2012/train.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/example/resnet50_cifar10/train.py b/example/resnet50_cifar10/train.py index 275f7188a7c..323695ae291 100755 --- a/example/resnet50_cifar10/train.py +++ b/example/resnet50_cifar10/train.py @@ -15,6 +15,7 @@ """train_imagenet.""" import os import argparse +import numpy as np from dataset import create_dataset from lr_generator import get_lr from config import config @@ -45,6 +46,7 @@ if __name__ == '__main__': target = args_opt.device_target ckpt_save_dir = config.save_checkpoint_path context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + np.random.seed(1) if not args_opt.do_eval and args_opt.run_distribute: if target == "Ascend": device_id = int(os.getenv('DEVICE_ID')) diff --git a/example/resnet50_imagenet2012/train.py b/example/resnet50_imagenet2012/train.py index a76de78f6d5..abb55731dce 100755 --- a/example/resnet50_imagenet2012/train.py +++ b/example/resnet50_imagenet2012/train.py @@ -15,6 +15,7 @@ """train_imagenet.""" import os import argparse +import numpy as np from dataset import create_dataset from lr_generator import get_lr from config import config @@ -48,6 +49,7 @@ if __name__ == '__main__': target = args_opt.device_target ckpt_save_dir = config.save_checkpoint_path context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + np.random.seed(1) if not args_opt.do_eval and args_opt.run_distribute: if target == "Ascend": device_id = int(os.getenv('DEVICE_ID'))