forked from mindspore-Ecosystem/mindspore
make xception can get device id from environment in standalone mode, and
modify requirements.txt
This commit is contained in:
parent
837d6e71de
commit
7ca9108982
|
@ -38,7 +38,7 @@ set_seed(1)
|
|||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='image classification training')
|
||||
parser.add_argument('--is_distributed', action='store_true', default=False, help='distributed training')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='run platform, (Default: Ascend)')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='dataset path')
|
||||
parser.add_argument("--is_fp32", action='store_true', default=False, help='fp32 training, add --is_fp32')
|
||||
|
@ -64,9 +64,8 @@ if __name__ == '__main__':
|
|||
else:
|
||||
rank = 0
|
||||
group_size = 1
|
||||
context.set_context(device_id=0)
|
||||
# if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||
# context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False)
|
||||
# define network
|
||||
|
|
|
@ -15,7 +15,7 @@ sklearn >= 0.0 # for st test
|
|||
pandas >= 1.0.2 # for ut test
|
||||
astunparse >= 1.6.3
|
||||
packaging >= 20.0
|
||||
pycocotools >= 2.0.0 # for st test
|
||||
pycocotools >= 2.0.2 # for st test
|
||||
tables >= 3.6.1 # for st test
|
||||
psutil >= 5.6.1
|
||||
subword-nmt>=0.3.7 # for st test
|
||||
|
|
Loading…
Reference in New Issue