repair vgg precision problem

This commit is contained in:
ms_yan 2020-08-17 21:50:08 +08:00
parent 81833943ba
commit a4f2e728ef
4 changed files with 10 additions and 9 deletions

View File

@ -158,7 +158,7 @@ def test(cloud_args=None):
args.models = [args.pre_trained,]
for model in args.models:
dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size)
dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size, mode='eval')
eval_dataloader = dataset.create_tuple_iterator()
network = vgg16(args.num_classes, args, phase="test")

View File

@ -64,7 +64,7 @@ imagenet_cfg = edict({
"image_size": '224,224',
"pad_mode": 'pad',
"padding": 1,
"has_bias": True,
"has_bias": False,
"batch_norm": False,
"keep_checkpoint_max": 10,
"initialize_mode": "KaimingNormal",

View File

@ -31,10 +31,11 @@ def _make_layer(base, args, batch_norm):
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
weight_shape = (v, in_channels, 3, 3)
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor()
if args.initialize_mode == "KaimingNormal":
weight = 'normal'
weight = 'ones'
if args.initialize_mode == "XavierUniform":
weight_shape = (v, in_channels, 3, 3)
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor()
conv2d = nn.Conv2d(in_channels=in_channels,
out_channels=v,
kernel_size=3,

View File

@ -127,7 +127,7 @@ def parse_args(cloud_args=None):
# logging and checkpoint related
parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
parser.add_argument('--ckpt_interval', type=int, default=2, help='ckpt_interval')
parser.add_argument('--ckpt_interval', type=int, default=5, help='ckpt_interval')
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
# distributed related
@ -200,12 +200,12 @@ if __name__ == '__main__':
device_num = args.group_size
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
parameter_broadcast=True, mirror_mean=True)
else:
context.set_context(device_id=args.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
# select for master rank save ckpt or all rank save, compatiable for model parallel
# select for master rank save ckpt or all rank save, compatible for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0: