forked from mindspore-Ecosystem/mindspore
fix shufflenetv2 script
This commit is contained in:
parent
40222f59a7
commit
64e01c2348
|
@ -14,7 +14,6 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""evaluate imagenet"""
|
"""evaluate imagenet"""
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
@ -33,9 +32,8 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
if args_opt.platform == 'Ascend':
|
if args_opt.platform != 'GPU':
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
raise ValueError("Only supported GPU training.")
|
||||||
context.set_context(device_id=device_id)
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
|
||||||
|
|
||||||
|
|
|
@ -92,7 +92,7 @@ parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
devid, rank_id, rank_size = 0, 0, 1
|
rank_id, rank_size = 0, 1
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
@ -101,10 +101,7 @@ def main():
|
||||||
init("nccl")
|
init("nccl")
|
||||||
context.set_context(device_target='GPU')
|
context.set_context(device_target='GPU')
|
||||||
else:
|
else:
|
||||||
init()
|
raise ValueError("Only supported GPU training.")
|
||||||
devid = int(os.getenv('DEVICE_ID'))
|
|
||||||
context.set_context(
|
|
||||||
device_target='Ascend', device_id=devid, reserve_class_name_in_scope=False)
|
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
rank_id = get_rank()
|
rank_id = get_rank()
|
||||||
rank_size = get_group_size()
|
rank_size = get_group_size()
|
||||||
|
@ -113,6 +110,8 @@ def main():
|
||||||
else:
|
else:
|
||||||
if args.GPU:
|
if args.GPU:
|
||||||
context.set_context(device_target='GPU')
|
context.set_context(device_target='GPU')
|
||||||
|
else:
|
||||||
|
raise ValueError("Only supported GPU training.")
|
||||||
|
|
||||||
net = efficientnet_b0(num_classes=cfg.num_classes,
|
net = efficientnet_b0(num_classes=cfg.num_classes,
|
||||||
drop_rate=cfg.drop,
|
drop_rate=cfg.drop,
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""evaluate imagenet"""
|
"""evaluate imagenet"""
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
@ -34,9 +33,8 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
if args_opt.platform == 'Ascend':
|
if args_opt.platform != 'GPU':
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
raise ValueError("Only supported GPU training.")
|
||||||
context.set_context(device_id=device_id)
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
|
||||||
net = NASNetAMobile(num_classes=cfg.num_classes, is_training=False)
|
net = NASNetAMobile(num_classes=cfg.num_classes, is_training=False)
|
||||||
|
|
|
@ -45,15 +45,15 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
if args_opt.platform != "GPU":
|
||||||
|
raise ValueError("Only supported GPU training.")
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||||
|
|
||||||
# init distributed
|
# init distributed
|
||||||
if args_opt.is_distributed:
|
if args_opt.is_distributed:
|
||||||
if args_opt.platform == "Ascend":
|
|
||||||
init()
|
|
||||||
else:
|
|
||||||
init("nccl")
|
init("nccl")
|
||||||
cfg.rank = get_rank()
|
cfg.rank = get_rank()
|
||||||
cfg.group_size = get_group_size()
|
cfg.group_size = get_group_size()
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""evaluate_imagenet"""
|
"""evaluate_imagenet"""
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
@ -33,9 +32,8 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
if args_opt.platform == 'Ascend':
|
if args_opt.platform != 'GPU':
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
raise ValueError("Only supported GPU training.")
|
||||||
context.set_context(device_id=device_id)
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, device_id=0)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, device_id=0)
|
||||||
net = ShuffleNetV2(n_class=cfg.num_classes)
|
net = ShuffleNetV2(n_class=cfg.num_classes)
|
||||||
|
|
|
@ -34,7 +34,6 @@ config_gpu = edict({
|
||||||
|
|
||||||
### Loss Config
|
### Loss Config
|
||||||
'label_smooth_factor': 0.1,
|
'label_smooth_factor': 0.1,
|
||||||
'aux_factor': 0.4,
|
|
||||||
|
|
||||||
### Learning Rate Config
|
### Learning Rate Config
|
||||||
'lr_init': 0.5,
|
'lr_init': 0.5,
|
||||||
|
@ -42,8 +41,6 @@ config_gpu = edict({
|
||||||
### Optimization Config
|
### Optimization Config
|
||||||
'weight_decay': 0.00004,
|
'weight_decay': 0.00004,
|
||||||
'momentum': 0.9,
|
'momentum': 0.9,
|
||||||
'opt_eps': 1.0,
|
|
||||||
'rmsprop_decay': 0.9,
|
|
||||||
"loss_scale": 1,
|
"loss_scale": 1,
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
|
@ -47,15 +47,15 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--model_size', type=str, default='1.0x', help='ShuffleNetV2 model size parameter')
|
parser.add_argument('--model_size', type=str, default='1.0x', help='ShuffleNetV2 model size parameter')
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
if args_opt.platform != "GPU":
|
||||||
|
raise ValueError("Only supported GPU training.")
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||||
|
|
||||||
# init distributed
|
# init distributed
|
||||||
if args_opt.is_distributed:
|
if args_opt.is_distributed:
|
||||||
if args_opt.platform == "Ascend":
|
|
||||||
init()
|
|
||||||
else:
|
|
||||||
init("nccl")
|
init("nccl")
|
||||||
cfg.rank = get_rank()
|
cfg.rank = get_rank()
|
||||||
cfg.group_size = get_group_size()
|
cfg.group_size = get_group_size()
|
||||||
|
|
Loading…
Reference in New Issue