forked from mindspore-Ecosystem/mindspore
repair vgg precision problem
This commit is contained in:
parent
81833943ba
commit
a4f2e728ef
|
@ -158,7 +158,7 @@ def test(cloud_args=None):
|
|||
args.models = [args.pre_trained,]
|
||||
|
||||
for model in args.models:
|
||||
dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size)
|
||||
dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size, mode='eval')
|
||||
eval_dataloader = dataset.create_tuple_iterator()
|
||||
network = vgg16(args.num_classes, args, phase="test")
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ imagenet_cfg = edict({
|
|||
"image_size": '224,224',
|
||||
"pad_mode": 'pad',
|
||||
"padding": 1,
|
||||
"has_bias": True,
|
||||
"has_bias": False,
|
||||
"batch_norm": False,
|
||||
"keep_checkpoint_max": 10,
|
||||
"initialize_mode": "KaimingNormal",
|
||||
|
|
|
@ -31,10 +31,11 @@ def _make_layer(base, args, batch_norm):
|
|||
if v == 'M':
|
||||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
else:
|
||||
weight_shape = (v, in_channels, 3, 3)
|
||||
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor()
|
||||
if args.initialize_mode == "KaimingNormal":
|
||||
weight = 'normal'
|
||||
weight = 'ones'
|
||||
if args.initialize_mode == "XavierUniform":
|
||||
weight_shape = (v, in_channels, 3, 3)
|
||||
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor()
|
||||
|
||||
conv2d = nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=v,
|
||||
kernel_size=3,
|
||||
|
|
|
@ -127,7 +127,7 @@ def parse_args(cloud_args=None):
|
|||
# logging and checkpoint related
|
||||
parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
|
||||
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
|
||||
parser.add_argument('--ckpt_interval', type=int, default=2, help='ckpt_interval')
|
||||
parser.add_argument('--ckpt_interval', type=int, default=5, help='ckpt_interval')
|
||||
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
|
||||
|
||||
# distributed related
|
||||
|
@ -200,12 +200,12 @@ if __name__ == '__main__':
|
|||
device_num = args.group_size
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
parameter_broadcast=True, mirror_mean=True)
|
||||
else:
|
||||
context.set_context(device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
|
||||
# select for master rank save ckpt or all rank save, compatiable for model parallel
|
||||
# select for master rank save ckpt or all rank save, compatible for model parallel
|
||||
args.rank_save_ckpt_flag = 0
|
||||
if args.is_save_on_master:
|
||||
if args.rank == 0:
|
||||
|
|
Loading…
Reference in New Issue