add use_random control options in cyclegan

This commit is contained in:
zhaoting 2020-12-21 19:53:50 +08:00
parent f8e5bffe4d
commit de9636cb16
5 changed files with 41 additions and 17 deletions

View File

@ -228,7 +228,7 @@ python export.py --platform [PLATFORM] --G_A_ckpt [G_A_CKPT] --G_B_ckpt [G_B_CKP
# [Description of Random Situation](#contents)
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
If you set --use_random=False, there are no random when training.
# [ModelZoo Homepage](#contents)

View File

@ -21,29 +21,38 @@ import mindspore.dataset.vision.c_transforms as C
from .distributed_sampler import DistributedSampler
from .datasets import UnalignedDataset, ImageFolderDataset
def create_dataset(args, shuffle=True, max_dataset_size=float("inf")):
def create_dataset(args):
"""Create dataset"""
dataroot = args.dataroot
phase = args.phase
batch_size = args.batch_size
device_num = args.device_num
rank = args.rank
shuffle = args.use_random
max_dataset_size = args.max_dataset_size
cores = multiprocessing.cpu_count()
num_parallel_workers = min(8, int(cores / device_num))
image_size = args.image_size
mean = [0.5 * 255] * 3
std = [0.5 * 255] * 3
if phase == "train":
dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size)
dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size, use_random=args.use_random)
distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle)
ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"],
sampler=distributed_sampler, num_parallel_workers=num_parallel_workers)
trans = [
C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
if args.use_random:
trans = [
C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
else:
trans = [
C.Resize((image_size, image_size)),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers)
ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)

View File

@ -52,7 +52,7 @@ class UnalignedDataset:
Two domain image path list.
"""
def __init__(self, dataroot, phase, max_dataset_size=float("inf")):
def __init__(self, dataroot, phase, max_dataset_size=float("inf"), use_random=True):
self.dir_A = os.path.join(dataroot, phase + 'A')
self.dir_B = os.path.join(dataroot, phase + 'B')
@ -60,12 +60,14 @@ class UnalignedDataset:
self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B
self.use_random = use_random
def __getitem__(self, index):
if index % max(self.A_size, self.B_size) == 0:
index_B = index % self.B_size
if index % max(self.A_size, self.B_size) == 0 and self.use_random:
random.shuffle(self.A_paths)
index_B = random.randint(0, self.B_size - 1)
A_path = self.A_paths[index % self.A_size]
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_img = np.array(Image.open(A_path).convert('RGB'))
B_img = np.array(Image.open(B_path).convert('RGB'))

View File

@ -15,7 +15,7 @@
"""Cycle GAN network."""
import mindspore.nn as nn
from mindspore.common import initializer as init
def init_weights(net, init_type='normal', init_gain=0.02):
"""
@ -27,12 +27,14 @@ def init_weights(net, init_type='normal', init_gain=0.02):
init_gain (float): Gain factor for normal and xavier.
"""
for cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d):
for _, cell in net.cells_and_names():
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
if init_type == 'normal':
cell.weight.set_data(init.initializer(init.Normal(init_gain)))
cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
elif init_type == 'xavier':
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain)))
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
elif init_type == 'constant':
cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif isinstance(cell, nn.BatchNorm2d):

View File

@ -105,6 +105,9 @@ def get_args(phase):
parser.add_argument('--save_imgs', type=ast.literal_eval, default=True, \
help='whether save imgs when epoch end, if True result images will generate in '
'`outputs_dir/imgs`, default is True.')
parser.add_argument('--use_random', type=ast.literal_eval, default=True, \
help='whether use random when training, default is True.')
parser.add_argument('--max_dataset_size', type=int, default=None, help='max images pre epoch, default is None.')
if phase == "export":
parser.add_argument("--file_name", type=str, default="cyclegan", help="output file name prefix.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \
@ -140,6 +143,14 @@ def get_args(phase):
if args.dataroot is None and (phase in ["train", "predict"]):
raise ValueError('Must set dataroot!')
if not args.use_random:
args.need_dropout = False
args.init_type = "constant"
if args.max_dataset_size is None:
args.max_dataset_size = float("inf")
args.n_epochs = min(args.max_epoch, args.n_epochs)
args.n_epochs_decay = args.max_epoch - args.n_epochs
args.phase = phase
return args