From 0c1761322f10ec169901d0ad438f37089688814e Mon Sep 17 00:00:00 2001 From: dessyang Date: Tue, 21 Jul 2020 16:44:40 -0400 Subject: [PATCH] Add the necessary rescale op and the option of not shuffling validation data --- model_zoo/googlenet/src/dataset.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/model_zoo/googlenet/src/dataset.py b/model_zoo/googlenet/src/dataset.py index a3f74a06178..cc33d2e5944 100644 --- a/model_zoo/googlenet/src/dataset.py +++ b/model_zoo/googlenet/src/dataset.py @@ -32,7 +32,10 @@ def create_dataset(data_home, repeat_num=1, training=True): data_dir = os.path.join(data_home, "cifar-10-verify-bin") rank_size, rank_id = _get_rank_info() - data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) + if training: + data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=True) + else: + data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=False) resize_height = cfg.image_height resize_width = cfg.image_width @@ -41,6 +44,7 @@ def create_dataset(data_home, repeat_num=1, training=True): random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT random_horizontal_op = vision.RandomHorizontalFlip() resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR + rescale_op = vision.Rescale(1.0/255.0, 0.0) normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) changeswap_op = vision.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) @@ -48,21 +52,18 @@ def create_dataset(data_home, repeat_num=1, training=True): c_trans = [] if training: c_trans = [random_crop_op, random_horizontal_op] - c_trans += [resize_op, normalize_op, changeswap_op] + c_trans += [resize_op, rescale_op, normalize_op, changeswap_op] # apply map operations on images data_set = data_set.map(input_columns="label", operations=type_cast_op) data_set = data_set.map(input_columns="image", operations=c_trans) - # apply repeat operations - data_set = data_set.repeat(repeat_num) - - # apply shuffle operations - data_set = data_set.shuffle(buffer_size=10) - # apply batch operations data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True) + # apply repeat operations + data_set = data_set.repeat(repeat_num) + return data_set