forked from mindspore-Ecosystem/mindspore
fix resnet50 distribute bug
This commit is contained in:
parent
b16a552d41
commit
e03fbc0b98
|
@ -15,6 +15,7 @@
|
||||||
"""train_imagenet."""
|
"""train_imagenet."""
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
|
import numpy as np
|
||||||
from dataset import create_dataset
|
from dataset import create_dataset
|
||||||
from lr_generator import get_lr
|
from lr_generator import get_lr
|
||||||
from config import config
|
from config import config
|
||||||
|
@ -45,6 +46,7 @@ if __name__ == '__main__':
|
||||||
target = args_opt.device_target
|
target = args_opt.device_target
|
||||||
ckpt_save_dir = config.save_checkpoint_path
|
ckpt_save_dir = config.save_checkpoint_path
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||||
|
np.random.seed(1)
|
||||||
if not args_opt.do_eval and args_opt.run_distribute:
|
if not args_opt.do_eval and args_opt.run_distribute:
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
"""train_imagenet."""
|
"""train_imagenet."""
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
|
import numpy as np
|
||||||
from dataset import create_dataset
|
from dataset import create_dataset
|
||||||
from lr_generator import get_lr
|
from lr_generator import get_lr
|
||||||
from config import config
|
from config import config
|
||||||
|
@ -48,6 +49,7 @@ if __name__ == '__main__':
|
||||||
target = args_opt.device_target
|
target = args_opt.device_target
|
||||||
ckpt_save_dir = config.save_checkpoint_path
|
ckpt_save_dir = config.save_checkpoint_path
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||||
|
np.random.seed(1)
|
||||||
if not args_opt.do_eval and args_opt.run_distribute:
|
if not args_opt.do_eval and args_opt.run_distribute:
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
|
Loading…
Reference in New Issue