forked from mindspore-Ecosystem/mindspore
add use_random control options in cyclegan
This commit is contained in:
parent
f8e5bffe4d
commit
de9636cb16
|
@ -228,7 +228,7 @@ python export.py --platform [PLATFORM] --G_A_ckpt [G_A_CKPT] --G_B_ckpt [G_B_CKP
|
|||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
|
||||
If you set --use_random=False, there are no random when training.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
|
|
|
@ -21,29 +21,38 @@ import mindspore.dataset.vision.c_transforms as C
|
|||
from .distributed_sampler import DistributedSampler
|
||||
from .datasets import UnalignedDataset, ImageFolderDataset
|
||||
|
||||
def create_dataset(args, shuffle=True, max_dataset_size=float("inf")):
|
||||
def create_dataset(args):
|
||||
"""Create dataset"""
|
||||
dataroot = args.dataroot
|
||||
phase = args.phase
|
||||
batch_size = args.batch_size
|
||||
device_num = args.device_num
|
||||
rank = args.rank
|
||||
shuffle = args.use_random
|
||||
max_dataset_size = args.max_dataset_size
|
||||
cores = multiprocessing.cpu_count()
|
||||
num_parallel_workers = min(8, int(cores / device_num))
|
||||
image_size = args.image_size
|
||||
mean = [0.5 * 255] * 3
|
||||
std = [0.5 * 255] * 3
|
||||
if phase == "train":
|
||||
dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size)
|
||||
dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size, use_random=args.use_random)
|
||||
distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle)
|
||||
ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"],
|
||||
sampler=distributed_sampler, num_parallel_workers=num_parallel_workers)
|
||||
trans = [
|
||||
C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
if args.use_random:
|
||||
trans = [
|
||||
C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
trans = [
|
||||
C.Resize((image_size, image_size)),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
|
|
|
@ -52,7 +52,7 @@ class UnalignedDataset:
|
|||
Two domain image path list.
|
||||
"""
|
||||
|
||||
def __init__(self, dataroot, phase, max_dataset_size=float("inf")):
|
||||
def __init__(self, dataroot, phase, max_dataset_size=float("inf"), use_random=True):
|
||||
self.dir_A = os.path.join(dataroot, phase + 'A')
|
||||
self.dir_B = os.path.join(dataroot, phase + 'B')
|
||||
|
||||
|
@ -60,12 +60,14 @@ class UnalignedDataset:
|
|||
self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size)) # load images from '/path/to/data/trainB'
|
||||
self.A_size = len(self.A_paths) # get the size of dataset A
|
||||
self.B_size = len(self.B_paths) # get the size of dataset B
|
||||
self.use_random = use_random
|
||||
|
||||
def __getitem__(self, index):
|
||||
if index % max(self.A_size, self.B_size) == 0:
|
||||
index_B = index % self.B_size
|
||||
if index % max(self.A_size, self.B_size) == 0 and self.use_random:
|
||||
random.shuffle(self.A_paths)
|
||||
index_B = random.randint(0, self.B_size - 1)
|
||||
A_path = self.A_paths[index % self.A_size]
|
||||
index_B = random.randint(0, self.B_size - 1)
|
||||
B_path = self.B_paths[index_B]
|
||||
A_img = np.array(Image.open(A_path).convert('RGB'))
|
||||
B_img = np.array(Image.open(B_path).convert('RGB'))
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"""Cycle GAN network."""
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
||||
from mindspore.common import initializer as init
|
||||
|
||||
def init_weights(net, init_type='normal', init_gain=0.02):
|
||||
"""
|
||||
|
@ -27,12 +27,14 @@ def init_weights(net, init_type='normal', init_gain=0.02):
|
|||
init_gain (float): Gain factor for normal and xavier.
|
||||
|
||||
"""
|
||||
for cell in net.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
|
||||
if init_type == 'normal':
|
||||
cell.weight.set_data(init.initializer(init.Normal(init_gain)))
|
||||
cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
|
||||
elif init_type == 'xavier':
|
||||
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain)))
|
||||
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
|
||||
elif init_type == 'constant':
|
||||
cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
|
||||
else:
|
||||
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||||
elif isinstance(cell, nn.BatchNorm2d):
|
||||
|
|
|
@ -105,6 +105,9 @@ def get_args(phase):
|
|||
parser.add_argument('--save_imgs', type=ast.literal_eval, default=True, \
|
||||
help='whether save imgs when epoch end, if True result images will generate in '
|
||||
'`outputs_dir/imgs`, default is True.')
|
||||
parser.add_argument('--use_random', type=ast.literal_eval, default=True, \
|
||||
help='whether use random when training, default is True.')
|
||||
parser.add_argument('--max_dataset_size', type=int, default=None, help='max images pre epoch, default is None.')
|
||||
if phase == "export":
|
||||
parser.add_argument("--file_name", type=str, default="cyclegan", help="output file name prefix.")
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \
|
||||
|
@ -140,6 +143,14 @@ def get_args(phase):
|
|||
if args.dataroot is None and (phase in ["train", "predict"]):
|
||||
raise ValueError('Must set dataroot!')
|
||||
|
||||
if not args.use_random:
|
||||
args.need_dropout = False
|
||||
args.init_type = "constant"
|
||||
|
||||
if args.max_dataset_size is None:
|
||||
args.max_dataset_size = float("inf")
|
||||
|
||||
args.n_epochs = min(args.max_epoch, args.n_epochs)
|
||||
args.n_epochs_decay = args.max_epoch - args.n_epochs
|
||||
args.phase = phase
|
||||
return args
|
||||
|
|
Loading…
Reference in New Issue