forked from mindspore-Ecosystem/mindspore
add use_random control options in cyclegan
This commit is contained in:
parent
f8e5bffe4d
commit
de9636cb16
|
@ -228,7 +228,7 @@ python export.py --platform [PLATFORM] --G_A_ckpt [G_A_CKPT] --G_B_ckpt [G_B_CKP
|
||||||
|
|
||||||
# [Description of Random Situation](#contents)
|
# [Description of Random Situation](#contents)
|
||||||
|
|
||||||
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
|
If you set --use_random=False, there are no random when training.
|
||||||
|
|
||||||
# [ModelZoo Homepage](#contents)
|
# [ModelZoo Homepage](#contents)
|
||||||
|
|
||||||
|
|
|
@ -21,29 +21,38 @@ import mindspore.dataset.vision.c_transforms as C
|
||||||
from .distributed_sampler import DistributedSampler
|
from .distributed_sampler import DistributedSampler
|
||||||
from .datasets import UnalignedDataset, ImageFolderDataset
|
from .datasets import UnalignedDataset, ImageFolderDataset
|
||||||
|
|
||||||
def create_dataset(args, shuffle=True, max_dataset_size=float("inf")):
|
def create_dataset(args):
|
||||||
"""Create dataset"""
|
"""Create dataset"""
|
||||||
dataroot = args.dataroot
|
dataroot = args.dataroot
|
||||||
phase = args.phase
|
phase = args.phase
|
||||||
batch_size = args.batch_size
|
batch_size = args.batch_size
|
||||||
device_num = args.device_num
|
device_num = args.device_num
|
||||||
rank = args.rank
|
rank = args.rank
|
||||||
|
shuffle = args.use_random
|
||||||
|
max_dataset_size = args.max_dataset_size
|
||||||
cores = multiprocessing.cpu_count()
|
cores = multiprocessing.cpu_count()
|
||||||
num_parallel_workers = min(8, int(cores / device_num))
|
num_parallel_workers = min(8, int(cores / device_num))
|
||||||
image_size = args.image_size
|
image_size = args.image_size
|
||||||
mean = [0.5 * 255] * 3
|
mean = [0.5 * 255] * 3
|
||||||
std = [0.5 * 255] * 3
|
std = [0.5 * 255] * 3
|
||||||
if phase == "train":
|
if phase == "train":
|
||||||
dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size)
|
dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size, use_random=args.use_random)
|
||||||
distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle)
|
distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle)
|
||||||
ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"],
|
ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"],
|
||||||
sampler=distributed_sampler, num_parallel_workers=num_parallel_workers)
|
sampler=distributed_sampler, num_parallel_workers=num_parallel_workers)
|
||||||
trans = [
|
if args.use_random:
|
||||||
C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
|
trans = [
|
||||||
C.RandomHorizontalFlip(prob=0.5),
|
C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
|
||||||
C.Normalize(mean=mean, std=std),
|
C.RandomHorizontalFlip(prob=0.5),
|
||||||
C.HWC2CHW()
|
C.Normalize(mean=mean, std=std),
|
||||||
]
|
C.HWC2CHW()
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
trans = [
|
||||||
|
C.Resize((image_size, image_size)),
|
||||||
|
C.Normalize(mean=mean, std=std),
|
||||||
|
C.HWC2CHW()
|
||||||
|
]
|
||||||
ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers)
|
ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers)
|
||||||
ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers)
|
ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers)
|
||||||
ds = ds.batch(batch_size, drop_remainder=True)
|
ds = ds.batch(batch_size, drop_remainder=True)
|
||||||
|
|
|
@ -52,7 +52,7 @@ class UnalignedDataset:
|
||||||
Two domain image path list.
|
Two domain image path list.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataroot, phase, max_dataset_size=float("inf")):
|
def __init__(self, dataroot, phase, max_dataset_size=float("inf"), use_random=True):
|
||||||
self.dir_A = os.path.join(dataroot, phase + 'A')
|
self.dir_A = os.path.join(dataroot, phase + 'A')
|
||||||
self.dir_B = os.path.join(dataroot, phase + 'B')
|
self.dir_B = os.path.join(dataroot, phase + 'B')
|
||||||
|
|
||||||
|
@ -60,12 +60,14 @@ class UnalignedDataset:
|
||||||
self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size)) # load images from '/path/to/data/trainB'
|
self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size)) # load images from '/path/to/data/trainB'
|
||||||
self.A_size = len(self.A_paths) # get the size of dataset A
|
self.A_size = len(self.A_paths) # get the size of dataset A
|
||||||
self.B_size = len(self.B_paths) # get the size of dataset B
|
self.B_size = len(self.B_paths) # get the size of dataset B
|
||||||
|
self.use_random = use_random
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
if index % max(self.A_size, self.B_size) == 0:
|
index_B = index % self.B_size
|
||||||
|
if index % max(self.A_size, self.B_size) == 0 and self.use_random:
|
||||||
random.shuffle(self.A_paths)
|
random.shuffle(self.A_paths)
|
||||||
|
index_B = random.randint(0, self.B_size - 1)
|
||||||
A_path = self.A_paths[index % self.A_size]
|
A_path = self.A_paths[index % self.A_size]
|
||||||
index_B = random.randint(0, self.B_size - 1)
|
|
||||||
B_path = self.B_paths[index_B]
|
B_path = self.B_paths[index_B]
|
||||||
A_img = np.array(Image.open(A_path).convert('RGB'))
|
A_img = np.array(Image.open(A_path).convert('RGB'))
|
||||||
B_img = np.array(Image.open(B_path).convert('RGB'))
|
B_img = np.array(Image.open(B_path).convert('RGB'))
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
"""Cycle GAN network."""
|
"""Cycle GAN network."""
|
||||||
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
|
from mindspore.common import initializer as init
|
||||||
|
|
||||||
def init_weights(net, init_type='normal', init_gain=0.02):
|
def init_weights(net, init_type='normal', init_gain=0.02):
|
||||||
"""
|
"""
|
||||||
|
@ -27,12 +27,14 @@ def init_weights(net, init_type='normal', init_gain=0.02):
|
||||||
init_gain (float): Gain factor for normal and xavier.
|
init_gain (float): Gain factor for normal and xavier.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
for cell in net.cells_and_names():
|
for _, cell in net.cells_and_names():
|
||||||
if isinstance(cell, nn.Conv2d):
|
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
|
||||||
if init_type == 'normal':
|
if init_type == 'normal':
|
||||||
cell.weight.set_data(init.initializer(init.Normal(init_gain)))
|
cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
|
||||||
elif init_type == 'xavier':
|
elif init_type == 'xavier':
|
||||||
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain)))
|
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
|
||||||
|
elif init_type == 'constant':
|
||||||
|
cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||||||
elif isinstance(cell, nn.BatchNorm2d):
|
elif isinstance(cell, nn.BatchNorm2d):
|
||||||
|
|
|
@ -105,6 +105,9 @@ def get_args(phase):
|
||||||
parser.add_argument('--save_imgs', type=ast.literal_eval, default=True, \
|
parser.add_argument('--save_imgs', type=ast.literal_eval, default=True, \
|
||||||
help='whether save imgs when epoch end, if True result images will generate in '
|
help='whether save imgs when epoch end, if True result images will generate in '
|
||||||
'`outputs_dir/imgs`, default is True.')
|
'`outputs_dir/imgs`, default is True.')
|
||||||
|
parser.add_argument('--use_random', type=ast.literal_eval, default=True, \
|
||||||
|
help='whether use random when training, default is True.')
|
||||||
|
parser.add_argument('--max_dataset_size', type=int, default=None, help='max images pre epoch, default is None.')
|
||||||
if phase == "export":
|
if phase == "export":
|
||||||
parser.add_argument("--file_name", type=str, default="cyclegan", help="output file name prefix.")
|
parser.add_argument("--file_name", type=str, default="cyclegan", help="output file name prefix.")
|
||||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \
|
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \
|
||||||
|
@ -140,6 +143,14 @@ def get_args(phase):
|
||||||
if args.dataroot is None and (phase in ["train", "predict"]):
|
if args.dataroot is None and (phase in ["train", "predict"]):
|
||||||
raise ValueError('Must set dataroot!')
|
raise ValueError('Must set dataroot!')
|
||||||
|
|
||||||
|
if not args.use_random:
|
||||||
|
args.need_dropout = False
|
||||||
|
args.init_type = "constant"
|
||||||
|
|
||||||
|
if args.max_dataset_size is None:
|
||||||
|
args.max_dataset_size = float("inf")
|
||||||
|
|
||||||
|
args.n_epochs = min(args.max_epoch, args.n_epochs)
|
||||||
args.n_epochs_decay = args.max_epoch - args.n_epochs
|
args.n_epochs_decay = args.max_epoch - args.n_epochs
|
||||||
args.phase = phase
|
args.phase = phase
|
||||||
return args
|
return args
|
||||||
|
|
Loading…
Reference in New Issue