diff --git a/model_zoo/official/cv/resnet/eval.py b/model_zoo/official/cv/resnet/eval.py index 1536a36f898..756d76257cf 100755 --- a/model_zoo/official/cv/resnet/eval.py +++ b/model_zoo/official/cv/resnet/eval.py @@ -49,7 +49,7 @@ if args_opt.net in ("resnet18", "resnet50"): elif args_opt.net == "resnet34": from src.resnet import resnet34 as resnet from src.config import config_resnet34 as config - from src.dataset import create_dataset_resnet34 as create_dataset + from src.dataset import create_dataset2 as create_dataset elif args_opt.net == "resnet101": from src.resnet import resnet101 as resnet from src.config import config3 as config diff --git a/model_zoo/official/cv/resnet/src/dataset.py b/model_zoo/official/cv/resnet/src/dataset.py index 4aaeb1da594..98b4fa7d376 100755 --- a/model_zoo/official/cv/resnet/src/dataset.py +++ b/model_zoo/official/cv/resnet/src/dataset.py @@ -400,65 +400,6 @@ def create_dataset4(dataset_path, do_train, repeat_num=1, batch_size=32, target= return data_set -def create_dataset_resnet34(dataset_path, do_train, repeat_num=1, batch_size=32): - """ - create a train or eval imagenet2012 dataset for resnet34 - - Args: - dataset_path(string): the path of dataset. - do_train(bool): whether dataset is used for train or eval. - repeat_num(int): the repeat times of dataset. Default: 1 - batch_size(int): the batch size of dataset. Default: 32 - - Returns: - data_set - """ - device_id = int(os.getenv('DEVICE_ID')) - device_num = int(os.getenv('RANK_SIZE')) - - if device_num == 1: - data_set = ds.ImageFolderDataset(dataset_path) - else: - if do_train: - data_set = ds.ImageFolderDataset(dataset_path, shuffle=True, - num_shards=device_num, shard_id=device_id) - else: - data_set = ds.ImageFolderDataset(dataset_path) - - image_size = 224 - mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] - std = [0.229 * 255, 0.224 * 255, 0.225 * 255] - - # define map operations - if do_train: - trans = [ - C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), - C.RandomHorizontalFlip(prob=0.5), - C.Normalize(mean=mean, std=std), - C.HWC2CHW() - ] - else: - trans = [ - C.Decode(), - C.Resize(256), - C.CenterCrop(image_size), - C.Normalize(mean=mean, std=std), - C.HWC2CHW() - ] - - type_cast_op = C2.TypeCast(mstype.int32) - - data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8) - data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8) - - # apply batch operations - data_set = data_set.batch(batch_size, drop_remainder=True) - - # apply dataset repeat operation - data_set = data_set.repeat(repeat_num) - - return data_set - def _get_rank_info(): """ get rank size and rank id diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index dc3250a9532..642a7426392 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -86,7 +86,7 @@ if args_opt.net in ("resnet18", "resnet50"): elif args_opt.net == "resnet34": from src.resnet import resnet34 as resnet from src.config import config_resnet34 as config - from src.dataset import create_dataset_resnet34 as create_dataset + from src.dataset import create_dataset2 as create_dataset elif args_opt.net == "resnet101": from src.resnet import resnet101 as resnet from src.config import config3 as config