From: @xianzhu-liu
Reviewed-by: @c_34,@oacjiewen
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-05-11 21:02:50 +08:00 committed by Gitee
commit d93c23626a
33 changed files with 780 additions and 985 deletions

View File

@ -0,0 +1,173 @@
# Contents
- [CycleGAN Description](#cyclegan-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training](#training-process)
- [Evaluation](#evaluation-process)
- [Prediction Process](#prediction-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Evaluation Performance](#evaluation-performance)
- [ModelZoo Homepage](#modelzoo-homepage)
# [CycleGAN Description](#contents)
Image-to-image translation is a visual and image problem. Its goal is to use paired images as a training set and (let the machine) learn the mapping from input images to output images. However, in many tasks, paired training data cannot be obtained. CycleGAN does not require the training data to be paired. It only needs to provide images of different domains to successfully train the image mapping between different domains. CycleGAN shares two generators, and then each has a discriminator.
[Paper](https://arxiv.org/abs/1703.10593): Zhu J Y , Park T , Isola P , et al. Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks[J]. 2017.
![CycleGAN Imgs](imgs/objects-transfiguration.jpg)
# [Model Architecture](#contents)
The CycleGAN contains two generation networks and two discriminant networks.
# [Dataset](#contents)
Download CycleGAN datasets and create your own datasets. We provide data/download_cyclegan_dataset.sh to download the datasets.
# [Environment Requirements](#contents)
- HardwareAscend/GPU
- Prepare hardware environment with Ascend or GPU processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
## [Dependences](#contents)
- Python==3.7.5
- Mindspore==1.1
# [Script Description](#contents)
## [Script and Sample Code](#contents)
The entire code structure is as following:
```markdown
.CycleGAN
├─ README.md # descriptions about CycleGAN
├─ data
└─download_cyclegan_dataset.sh.py # download dataset
├── scripts
└─run_train_ascend.sh # launch ascend training(1 pcs)
└─run_train_gpu.sh # launch gpu training(1 pcs)
└─run_eval_ascend.sh # launch ascend eval
└─run_eval_gpu.sh # launch gpu eval
├─ imgs
└─objects-transfiguration.jpg # CycleGAN Imgs
├─ src
├─ __init__.py # init file
├─ dataset
├─ __init__.py # init file
├─ cyclegan_dataset.py # create cyclegan dataset
└─ distributed_sampler.py # iterator of dataset
├─ models
├─ __init__.py # init file
├─ cycle_gan.py # cyclegan model define
├─ losses.py # cyclegan losses function define
├─ networks.py # cyclegan sub networks define
├─ resnet.py # resnet generate network
└─ depth_resnet.py # better generate network
└─ utils
├─ __init__.py # init file
├─ args.py # parse args
├─ reporter.py # Reporter class
└─ tools.py # utils for cyclegan
├─ eval.py # generate images from A->B and B->A
├─ train.py # train script
└─ export.py # export mindir script
```
## [Script Parameters](#contents)
Major parameters in train.py and config.py as follows:
```python
"platform": Ascend # run platform, only support GPU and Ascend.
"device_id": 0 # device id, default is 0.
"model": "resnet" # generator model.
"pool_size": 50 # the size of image buffer that stores previously generated images, default is 50.
"lr_policy": "linear" # learning rate policy, default is linear.
"image_size": 256 # input image_size, default is 256.
"batch_size": 1 # batch_size, default is 1.
"max_epoch": 200 # epoch size for training, default is 200.
"in_planes": 3 # input channels, default is 3.
"ngf": 64 # generator model filter numbers, default is 64.
"gl_num": 9 # generator model residual block numbers, default is 9.
"ndf": 64 # discriminator model filter numbers, default is 64.
"dl_num": 3 # discriminator model residual block numbers, default is 3.
"outputs_dir": "outputs" # models are saved here, default is ./outputs.
"dataroot": None # path of images (should have subfolders trainA, trainB, testA, testB, etc).
"load_ckpt": False # whether load pretrained ckpt.
"G_A_ckpt": None # pretrained checkpoint file path of G_A.
"G_B_ckpt": None # pretrained checkpoint file path of G_B.
"D_A_ckpt": None # pretrained checkpoint file path of D_A.
"D_B_ckpt": None # pretrained checkpoint file path of D_B.
```
## [Training](#contents)
- running on Ascend with default parameters
```bash
sh ./scripts/run_train_ascend.sh
```
- running on GPU with default parameters
```bash
sh ./scripts/run_train_gpu.sh
```
## [Evaluation](#contents)
```bash
python eval.py --platform [PLATFORM] --dataroot [DATA_PATH] --G_A_ckpt [G_A_CKPT] --G_B_ckpt [G_B_CKPT]
```
**Note: You will get the result as following in "./outputs_dir/predict".**
# [Model Description](#contents)
## [Performance](#contents)
### Training Performance
| Parameters | single Ascend/GPU |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | CycleGAN |
| Resource | Ascend 910/NV SMX2 V100-32G |
| MindSpore Version | 1.1 |
| Dataset | horse2zebra |
| Training Parameters | epoch=200, steps=1334, batch_size=1, lr=0.002 |
| Optimizer | Adam |
| Loss Function | Mean Sqare Loss & L1 Loss |
| outputs | probability |
| Speed | 1pc(Ascend): 123 ms/step; 1pc(GPU): 264 ms/step |
| Total time | 1pc(Ascend): 9.6h; 1pc(GPU): 19.1h; |
| Checkpoint for Fine tuning | 44M (.ckpt file) |
### Evaluation Performance
| Parameters | single Ascend/GPU |
| ------------------- | --------------------------- |
| Model Version | CycleGAN |
| Resource | Ascend 910/NV SMX2 V100-32G |
| MindSpore Version | 1.1 |
| Dataset | horse2zebra |
| batch_size | 1 |
| outputs | probability |
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,23 @@
#!/bin/bash
FILE=$1
if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "mini" && $FILE != "mini_pix2pix" && $FILE != "mini_colorization" ]]; then
echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
exit 1
fi
if [[ $FILE == "cityscapes" ]]; then
echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py."
echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py"
exit 1
fi
echo "Specified [$FILE]"
URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
ZIP_FILE=./datasets/$FILE.zip
TARGET_DIR=./datasets/$FILE/
wget -N $URL -O $ZIP_FILE
mkdir $TARGET_DIR
unzip $ZIP_FILE -d ./datasets/
rm $ZIP_FILE

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,25 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cycle GAN predict."""
"""Cycle GAN test."""
import os
from mindspore import Tensor
from src.models.cycle_gan import get_generator
from src.utils.args import get_args
from src.dataset.cyclegan_dataset import create_dataset
from src.utils.reporter import Reporter
from src.utils.tools import save_image, load_ckpt
from src.models import get_generator
from src.utils import get_args, load_ckpt, save_image, Reporter
from src.dataset import create_dataset
def predict():
"""Predict function."""
args = get_args("predict")
G_A = get_generator(args)
G_B = get_generator(args)
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
# Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d
G_A.set_train(True)
G_B.set_train(True)
load_ckpt(args, G_A, G_B)
imgs_out = os.path.join(args.outputs_dir, "predict")
if not os.path.exists(imgs_out):
os.makedirs(imgs_out)
@ -45,8 +46,10 @@ def predict():
for data in ds.create_dict_iterator(output_numpy=True):
img_A = Tensor(data["image"])
path_A = str(data["image_name"][0], encoding="utf-8")
path_B = path_A[0:-4] + "_fake_B.jpg"
fake_B = G_A(img_A)
save_image(fake_B, os.path.join(imgs_out, "fake_B", path_A))
save_image(fake_B, os.path.join(imgs_out, "fake_B", path_B))
save_image(img_A, os.path.join(imgs_out, "fake_B", path_A))
reporter.info('save fake_B at %s', os.path.join(imgs_out, "fake_B", path_A))
reporter.end_predict()
args.data_dir = 'testB'
@ -56,10 +59,13 @@ def predict():
for data in ds.create_dict_iterator(output_numpy=True):
img_B = Tensor(data["image"])
path_B = str(data["image_name"][0], encoding="utf-8")
path_A = path_B[0:-4] + "_fake_A.jpg"
fake_A = G_B(img_B)
save_image(fake_A, os.path.join(imgs_out, "fake_A", path_B))
save_image(fake_A, os.path.join(imgs_out, "fake_A", path_A))
save_image(img_B, os.path.join(imgs_out, "fake_A", path_B))
reporter.info('save fake_A at %s', os.path.join(imgs_out, "fake_A", path_B))
reporter.end_predict()
if __name__ == "__main__":
predict()

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,28 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export file."""
import numpy as np
"""export file."""
import argparse
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import export
from src.models import get_generator
from src.utils import get_args, load_ckpt
from src.models.cycle_gan import get_generator
from src.utils.args import get_args
from src.utils.tools import load_ckpt
args = get_args("export")
model_args = get_args("export")
parser = argparse.ArgumentParser(description="openpose export")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--file_name", type=str, default="CycleGAN", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=model_args.device_id)
if __name__ == '__main__':
G_A = get_generator(args)
G_B = get_generator(args)
G_A = get_generator(model_args)
G_B = get_generator(model_args)
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
# Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d
G_A.set_train(True)
G_B.set_train(True)
load_ckpt(args, G_A, G_B)
load_ckpt(model_args, G_A, G_B)
input_shp = [1, 3, args.image_size, args.image_size]
input_shp = [args.batch_size, 3, model_args.image_size, model_args.image_size]
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
G_A_file = f"{args.file_name}_BtoA"
export(G_A, input_array, file_name=G_A_file, file_format=args.file_format)

Binary file not shown.

After

Width:  |  Height:  |  Size: 681 KiB

View File

@ -1,4 +1,5 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,7 +12,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init file."""
from .datasets import UnalignedDataset, ImageFolderDataset, make_dataset
from .cyclegan_dataset import create_dataset
python eval.py --platform Ascend --device_id 0 --model DepthResNet --G_A_ckpt ./outputs/ckpt/G_A_200.ckpt --G_B_ckpt ./outputs/ckpt/G_B_200.ckpt

View File

@ -1,4 +1,5 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,8 +12,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init file."""
from .cycle_gan import get_generator, get_discriminator, Generator, TrainOneStepG, TrainOneStepD
from .losses import DiscriminatorLoss, GeneratorLoss, GANLoss
from .networks import init_weights
python eval.py --platform GPU --device_id 0 --model ResNet --G_A_ckpt ./outputs/ckpt/G_A_200.ckpt --G_B_ckpt ./outputs/ckpt/G_B_200.ckpt

View File

@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
echo "=============================================================================================================="
echo "Please run the train as: "
echo "python train.py device_id platform model max_epoch dataroot outputs_dir"
echo "for example: python train.py --platform Ascend --device_id 0 --model ResNet --max_epoch 200 --dataroot ./data/horse2zebra/ --outputs_dir ./outputs"
echo "================================================================================================================="
python train.py --platform Ascend --device_id 0 --model DepthResNet --max_epoch 200 --dataroot ./data/horse2zebra/ --outputs_dir ./outputs

View File

@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
echo "=============================================================================================================="
echo "Please run the train as: "
echo "python train.py device_id platform model max_epoch dataroot outputs_dir"
echo "for example: python train.py --platform GPU --device_id 0 --model ResNet --max_epoch 200 --dataroot ./data/horse2zebra/ --outputs_dir ./outputs"
echo "================================================================================================================="
python train.py --platform GPU --device_id 0 --model ResNet --max_epoch 200 --dataroot ./data/horse2zebra/ --outputs_dir ./outputs

View File

@ -0,0 +1,183 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cycle GAN dataset."""
import os
import random
import multiprocessing
import numpy as np
from PIL import Image
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as C
from .distributed_sampler import DistributedSampler
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.tif', '.tiff']
def is_image_file(filename):
"""Judge whether it is a picture."""
return any(filename.lower().endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir_path, max_dataset_size=float("inf")):
"""Return image list in dir."""
images = []
assert os.path.isdir(dir_path), '%s is not a valid directory' % dir_path
for root, _, fnames in sorted(os.walk(dir_path)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]
class UnalignedDataset:
"""
This dataset class can load unaligned/unpaired datasets.
It requires two directories to host training images from domain A '/path/to/data/trainA'
and from domain B '/path/to/data/trainB' respectively.
You can train the model with the dataset flag '--dataroot /path/to/data'.
Similarly, you need to prepare two directories:
'/path/to/data/testA' and '/path/to/data/testB' during test time.
Returns:
Two domain image path list.
"""
def __init__(self, dataroot, phase, max_dataset_size=float("inf"), use_random=True):
self.dir_A = os.path.join(dataroot, phase + 'A')
self.dir_B = os.path.join(dataroot, phase + 'B')
self.A_paths = sorted(make_dataset(self.dir_A, max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B
self.use_random = use_random
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) -- an image in the input domain
B (tensor) -- its corresponding image in the target domain
A_paths (str) -- image paths
B_paths (str) -- image paths
"""
index_B = index % self.B_size
if index % max(self.A_size, self.B_size) == 0 and self.use_random:
random.shuffle(self.A_paths)
index_B = random.randint(0, self.B_size - 1)
A_path = self.A_paths[index % self.A_size]
B_path = self.B_paths[index_B]
A_img = np.array(Image.open(A_path).convert('RGB'))
B_img = np.array(Image.open(B_path).convert('RGB'))
return A_img, B_img
def __len__(self):
"""Return the total number of images in the dataset.
"""
return max(self.A_size, self.B_size)
class ImageFolderDataset:
"""
This dataset class can load images from image folder.
Args:
dataroot (str): Images root directory.
max_dataset_size (int): Maximum number of return image paths.
Returns:
Image path list.
"""
def __init__(self, dataroot, max_dataset_size=float("inf")):
self.dataroot = dataroot
self.paths = sorted(make_dataset(dataroot, max_dataset_size))
self.size = len(self.paths)
def __getitem__(self, index):
img_path = self.paths[index % self.size]
img = np.array(Image.open(img_path).convert('RGB'))
return img, os.path.split(img_path)[1]
def __len__(self):
"""Return the total number of images in the dataset.
As we have two datasets with potentially different number of images,
we take a maximum of
"""
return self.size
def create_dataset(args):
"""
Create dataset
This dataset class can load images for train or test.
Args:
dataroot (str): Images root directory.
Returns:
RGB Image list.
"""
dataroot = args.dataroot
phase = args.phase
batch_size = args.batch_size
device_num = args.device_num
rank = args.rank
shuffle = args.use_random
max_dataset_size = args.max_dataset_size
cores = multiprocessing.cpu_count()
num_parallel_workers = min(8, int(cores / device_num))
image_size = args.image_size
mean = [0.5 * 255] * 3
std = [0.5 * 255] * 3
if phase == "train":
dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size, use_random=args.use_random)
distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle)
ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"],
sampler=distributed_sampler, num_parallel_workers=num_parallel_workers)
if args.use_random:
trans = [
C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
else:
trans = [
C.Resize((image_size, image_size)),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers)
ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(1)
else:
datadir = os.path.join(dataroot, args.data_dir)
dataset = ImageFolderDataset(datadir, max_dataset_size=max_dataset_size)
ds = de.GeneratorDataset(dataset, column_names=["image", "image_name"],
num_parallel_workers=num_parallel_workers)
trans = [
C.Resize((image_size, image_size)),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
ds = ds.map(operations=trans, input_columns=["image"], num_parallel_workers=num_parallel_workers)
ds = ds.batch(1, drop_remainder=True)
ds = ds.repeat(1)
args.dataset_size = len(dataset)
return ds

View File

@ -1,60 +1,61 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset distributed sampler."""
from __future__ import division
import math
import numpy as np
class DistributedSampler:
"""Distributed sampler."""
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
print("***********Setting world_size to 1 since it is not passed in ******************")
num_replicas = 1
if rank is None:
print("***********Setting rank to 0 since it is not passed in ******************")
rank = 0
self.dataset_size = dataset_size
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
# np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
indices = indices.tolist()
self.epoch += 1
# change to list type
else:
indices = list(range(self.dataset_size))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset distributed sampler."""
from __future__ import division
import math
import numpy as np
class DistributedSampler:
"""Distributed sampler."""
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
print("***********Setting world_size to 1 since it is not passed in ******************")
num_replicas = 1
if rank is None:
print("***********Setting rank to 0 since it is not passed in ******************")
rank = 0
self.dataset_size = dataset_size
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
# np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
indices = indices.tolist()
self.epoch += 1
# change to list type
else:
indices = list(range(self.dataset_size))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cycle GAN network."""
import mindspore as ms
@ -21,40 +22,40 @@ from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size
import mindspore.ops as ops
from .resnet import ResNetGenerator
from .networks import ConvNormReLU, init_weights
from .resnet import ResNetGenerator
from .depth_resnet import DepthResNetGenerator
from .unet import UnetGenerator
def get_generator(args, teacher_net=False):
"""Return generator by args."""
if teacher_net:
if args.model == "resnet":
net = ResNetGenerator(in_planes=args.in_planes, ngf=args.t_ngf, n_layers=args.t_gl_num,
alpha=args.t_slope, norm_mode=args.t_norm_mode, dropout=False,
pad_mode=args.pad_mode)
init_weights(net, args.init_type, args.init_gain)
elif args.model == "unet":
net = UnetGenerator(in_planes=args.in_planes, out_planes=args.in_planes, ngf=args.t_ngf,
n_layers=args.t_gl_num, alpha=args.t_slope, norm_mode=args.t_norm_mode,
dropout=False)
init_weights(net, args.init_type, args.init_gain)
else:
raise NotImplementedError(f'Model {args.model} not recognized.')
def get_generator(args):
"""
This class implements the CycleGAN model, for learning image-to-image translation without paired data.
The model training requires '--dataset_mode unaligned' dataset.
By default, it uses a '--netG resnet_9blocks' ResNet generator,
a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
and a least-square GANs objective ('--gan_mode lsgan').
"""
if args.model == "ResNet":
net = ResNetGenerator(in_planes=args.in_planes, ngf=args.ngf, n_layers=args.gl_num,
alpha=args.slope, norm_mode=args.norm_mode, dropout=args.need_dropout,
pad_mode=args.pad_mode)
init_weights(net, args.init_type, args.init_gain)
elif args.model == "DepthResNet":
net = DepthResNetGenerator(in_planes=args.in_planes, ngf=args.ngf, n_layers=args.gl_num,
alpha=args.slope, norm_mode=args.norm_mode, dropout=args.need_dropout,
pad_mode=args.pad_mode)
init_weights(net, args.init_type, args.init_gain)
elif args.model == "UNet":
net = UnetGenerator(in_planes=args.in_planes, out_planes=args.in_planes, ngf=args.ngf, n_layers=args.gl_num,
alpha=args.slope, norm_mode=args.norm_mode, dropout=args.need_dropout)
init_weights(net, args.init_type, args.init_gain)
else:
if args.model == "resnet":
net = ResNetGenerator(in_planes=args.in_planes, ngf=args.ngf, n_layers=args.gl_num,
alpha=args.slope, norm_mode=args.norm_mode, dropout=args.need_dropout,
pad_mode=args.pad_mode)
init_weights(net, args.init_type, args.init_gain)
elif args.model == "unet":
net = UnetGenerator(in_planes=args.in_planes, out_planes=args.in_planes, ngf=args.ngf, n_layers=args.gl_num,
alpha=args.slope, norm_mode=args.norm_mode, dropout=args.need_dropout)
init_weights(net, args.init_type, args.init_gain)
else:
raise NotImplementedError(f'Model {args.model} not recognized.')
raise NotImplementedError(f'Model {args.model} not recognized.')
return net
def get_discriminator(args, teacher_net=False):
def get_discriminator(args):
"""Return discriminator by args."""
net = Discriminator(in_planes=args.in_planes, ndf=args.ndf, n_layers=args.dl_num,
alpha=args.slope, norm_mode=args.norm_mode)
@ -65,17 +66,14 @@ def get_discriminator(args, teacher_net=False):
class Discriminator(nn.Cell):
"""
Discriminator of GAN.
Args:
in_planes (int): Input channel.
ndf (int): Output channel.
n_layers (int): The number of ConvNormReLU blocks.
alpha (float): LeakyRelu slope. Default: 0.2.
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
Returns:
Tensor, output tensor.
Examples:
>>> Discriminator(3, 64, 3)
"""
@ -105,15 +103,12 @@ class Discriminator(nn.Cell):
class Generator(nn.Cell):
"""
Generator of CycleGAN, return fake_A, fake_B, rec_A, rec_B, identity_A and identity_B.
Args:
G_A (Cell): The generator network of domain A to domain B.
G_B (Cell): The generator network of domain B to domain A.
use_identity (bool): Use identity loss or not. Default: True.
Returns:
Tensors, fake_A, fake_B, rec_A, rec_B, identity_A and identity_B.
Examples:
>>> Generator(G_A, G_B)
"""
@ -142,7 +137,6 @@ class Generator(nn.Cell):
class WithLossCell(nn.Cell):
"""
Wrap the network with loss function to return generator loss.
Args:
network (Cell): The target network to wrap.
"""
@ -158,10 +152,8 @@ class WithLossCell(nn.Cell):
class TrainOneStepG(nn.Cell):
"""
Encapsulation class of Cycle GAN generator network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
G (Cell): Generator with loss Cell. Note that loss function should have been added.
generator (Cell): Generator of CycleGAN.
@ -210,10 +202,8 @@ class TrainOneStepG(nn.Cell):
class TrainOneStepD(nn.Cell):
"""
Encapsulation class of Cycle GAN discriminator network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
G (Cell): Generator with loss Cell. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.

View File

@ -0,0 +1,113 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet Generator."""
import mindspore.nn as nn
from .networks import ConvNormReLU, ConvTransposeNormReLU
class ResidualBlock(nn.Cell):
"""
ResNet residual block definition.
We construct a conv block with build_conv_block function,
and implement skip connections in <forward> function..
Args:
dim (int): Input and output channel.
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
dropout (bool): Use dropout or not. Default: False.
pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
Default: "CONSTANT".
Returns:
Tensor, output tensor.
"""
def __init__(self, dim, norm_mode='batch', dropout=False, pad_mode="CONSTANT"):
super(ResidualBlock, self).__init__()
self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0.2, norm_mode, pad_mode)
self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0.2, norm_mode, pad_mode)
self.dropout = dropout
if dropout:
self.dropout = nn.Dropout(0.5)
def construct(self, x):
out = self.conv1(x)
if self.dropout:
out = self.dropout(out)
out = self.conv2(out)
return x + out
class DepthResNetGenerator(nn.Cell):
"""
ResNet Generator of GAN.
Args:
in_planes (int): Input channel.
ngf (int): Output channel.
n_layers (int): The number of ConvNormReLU blocks.
alpha (float): LeakyRelu slope. Default: 0.2.
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
dropout (bool): Use dropout or not. Default: False.
pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
Default: "CONSTANT".
Returns:
Tensor, output tensor.
"""
def __init__(self, in_planes=3, ngf=64, n_layers=9, alpha=0.2, norm_mode='batch', dropout=False,
pad_mode="CONSTANT"):
super(DepthResNetGenerator, self).__init__()
conv_in1 = nn.Conv2d(in_planes, ngf, kernel_size=3, stride=1, has_bias=True)
conv_in2 = ConvNormReLU(ngf, ngf, 7, 1, alpha, norm_mode, pad_mode=pad_mode)
self.conv_in = nn.SequentialCell([conv_in1, conv_in2])
down_1 = ConvNormReLU(ngf, ngf * 2, 3, 2, alpha, norm_mode)
Res1 = ResidualBlock(ngf * 2, norm_mode, dropout=dropout, pad_mode=pad_mode)
self.down_1 = nn.SequentialCell([down_1, Res1])
down_2 = ConvNormReLU(ngf * 2, ngf * 3, 3, 2, alpha, norm_mode)
Res2 = ResidualBlock(ngf * 3, norm_mode, dropout=dropout, pad_mode=pad_mode)
self.down_2 = nn.SequentialCell([down_2, Res2])
self.down_3 = ConvNormReLU(ngf * 3, ngf * 4, 3, 2, alpha, norm_mode)
layers = [ResidualBlock(ngf * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * (n_layers-5)
self.residuals = nn.SequentialCell(layers)
up_3 = ConvTransposeNormReLU(ngf * 4, ngf * 3, 3, 2, alpha, norm_mode)
Res3 = ResidualBlock(ngf * 3, norm_mode, dropout=dropout, pad_mode=pad_mode)
self.up_3 = nn.SequentialCell([up_3, Res3])
up_2 = ConvTransposeNormReLU(ngf * 3, ngf * 2, 3, 2, alpha, norm_mode)
Res4 = ResidualBlock(ngf * 2, norm_mode, dropout=dropout, pad_mode=pad_mode)
self.up_2 = nn.SequentialCell([up_2, Res4])
up_1 = ConvTransposeNormReLU(ngf * 2, ngf, 3, 2, alpha, norm_mode)
Res5 = ResidualBlock(ngf, norm_mode, dropout=dropout, pad_mode=pad_mode)
self.up_1 = nn.SequentialCell([up_1, Res5])
tanh = nn.Tanh()
if pad_mode == "CONSTANT":
conv_out1 = nn.Conv2d(ngf, 3, kernel_size=7, stride=1, has_bias=True, pad_mode='pad', padding=3)
conv_out2 = nn.Conv2d(3, 3, kernel_size=3, stride=1, has_bias=True)
self.conv_out = nn.SequentialCell([conv_out1, tanh, conv_out2, tanh])
else:
pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)
conv = nn.Conv2d(ngf, 3, kernel_size=7, stride=1, pad_mode='pad')
self.conv_out = nn.SequentialCell([pad, conv, tanh])
def construct(self, x):
""" construct network """
x = self.conv_in(x)
x = self.down_1(x)
x = self.down_2(x)
x = self.down_3(x)
x = self.residuals(x)
x = self.up_3(x)
x = self.up_2(x)
x = self.up_1(x)
output = self.conv_out(x)
return output

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,24 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cycle GAN losses"""
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from .cycle_gan import get_generator
from ..utils import load_teacher_ckpt
class BCEWithLogits(nn.Cell):
"""
BCEWithLogits creates a criterion to measure the Binary Cross Entropy between the true labels and
predicted labels with sigmoid logits.
Args:
reduction (str): Specifies the reduction to be applied to the output.
Its value must be one of 'none', 'mean', 'sum'. Default: 'none'.
Outputs:
Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `inputs`.
Otherwise, the output is a scalar.
@ -58,13 +56,19 @@ class BCEWithLogits(nn.Cell):
class GANLoss(nn.Cell):
"""
Cycle GAN loss factory.
The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input.
Args:
mode (str): The type of GAN objective. It currently supports 'vanilla', 'lsgan'. Default: 'lsgan'.
reduction (str): Specifies the reduction to be applied to the output.
Its value must be one of 'none', 'mean', 'sum'. Default: 'none'.
Parameters:
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
target_real_label (bool) - - label for a real image
target_fake_label (bool) - - label of a fake image
Note: Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
Outputs:
Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `inputs`.
Otherwise, the output is a scalar.
@ -90,13 +94,11 @@ class GANLoss(nn.Cell):
class GeneratorLoss(nn.Cell):
"""
Cycle GAN generator loss.
Args:
args (class): Option class.
generator (Cell): Generator of CycleGAN.
D_A (Cell): The discriminator network of domain A to domain B.
D_B (Cell): The discriminator network of domain B to domain A.
Outputs:
Tuple Tensor, the losses of generator.
"""
@ -112,14 +114,6 @@ class GeneratorLoss(nn.Cell):
self.D_A = D_A
self.D_B = D_B
self.true = Tensor(True, mstype.bool_)
self.kd = args.kd
if self.kd:
self.GT_A = get_generator(args, True)
load_teacher_ckpt(self.GT_A, args.GT_A_ckpt, "GT_A", "G_A")
self.GT_B = get_generator(args, True)
load_teacher_ckpt(self.GT_B, args.GT_B_ckpt, "GT_B", "G_B")
self.GT_A.set_train(True)
self.GT_B.set_train(True)
def construct(self, img_A, img_B):
"""If use_identity, identity loss will be used."""
@ -135,23 +129,18 @@ class GeneratorLoss(nn.Cell):
loss_idt_A = 0
loss_idt_B = 0
loss_G = loss_G_A + loss_G_B + loss_C_A + loss_C_B + loss_idt_A + loss_idt_B
if self.kd:
teacher_A = self.GT_B(img_B)
teacher_B = self.GT_A(img_A)
kd_loss_A = self.rec_loss(teacher_A, fake_A) * self.lambda_A * 5
kd_loss_B = self.rec_loss(teacher_B, fake_B) * self.lambda_A * 5
loss_G += kd_loss_A + kd_loss_B
return (fake_A, fake_B, loss_G, loss_G_A, loss_G_B, loss_C_A, loss_C_B, loss_idt_A, loss_idt_B)
class DiscriminatorLoss(nn.Cell):
"""
Cycle GAN discriminator loss.
Args:
args (class): option class.
D_A (Cell): The discriminator network of domain A to domain B.
D_B (Cell): The discriminator network of domain B to domain A.
real (tensor array) -- real images
fake (tensor array) -- images generated by a generator
Outputs:
Tuple Tensor, the loss of discriminator.
"""

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,20 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cycle GAN network."""
import mindspore.nn as nn
from mindspore.common import initializer as init
def init_weights(net, init_type='normal', init_gain=0.02):
"""
Initialize network weights.
Parameters:
net (Cell): Network to be initialized
init_type (str): The name of an initialization method: normal | xavier.
init_gain (float): Gain factor for normal and xavier.
"""
for _, cell in net.cells_and_names():
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
@ -45,7 +45,6 @@ def init_weights(net, init_type='normal', init_gain=0.02):
class ConvNormReLU(nn.Cell):
"""
Convolution fused with BatchNorm/InstanceNorm and ReLU/LackyReLU block definition.
Args:
in_planes (int): Input channel.
out_planes (int): Output channel.
@ -57,7 +56,6 @@ class ConvNormReLU(nn.Cell):
Default: "CONSTANT".
use_relu (bool): Use relu or not. Default: True.
padding (int): Pad size, if it is None, it will calculate by kernel_size. Default: None.
Returns:
Tensor, output tensor.
"""
@ -103,7 +101,6 @@ class ConvNormReLU(nn.Cell):
class ConvTransposeNormReLU(nn.Cell):
"""
ConvTranspose2d fused with BatchNorm/InstanceNorm and ReLU/LackyReLU block definition.
Args:
in_planes (int): Input channel.
out_planes (int): Output channel.
@ -115,7 +112,6 @@ class ConvTransposeNormReLU(nn.Cell):
Default: "CONSTANT".
use_relu (bool): use relu or not. Default: True.
padding (int): pad size, if it is None, it will calculate by kernel_size. Default: None.
Returns:
Tensor, output tensor.
"""

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,26 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet Generator."""
import mindspore.nn as nn
import mindspore.ops as ops
from .networks import ConvNormReLU, ConvTransposeNormReLU
class ResidualBlock(nn.Cell):
"""
ResNet residual block definition.
A resnet block is a conv block with skip connections
We construct a conv block with build_conv_block function,
and implement skip connections in <forward> function..
Args:
dim (int): Input and output channel.
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
dropout (bool): Use dropout or not. Default: False.
pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
Default: "CONSTANT".
Returns:
Tensor, output tensor.
"""
def __init__(self, dim, norm_mode='batch', dropout=False, pad_mode="CONSTANT"):
super(ResidualBlock, self).__init__()
self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)
@ -51,7 +53,6 @@ class ResidualBlock(nn.Cell):
class ResNetGenerator(nn.Cell):
"""
ResNet Generator of GAN.
Args:
in_planes (int): Input channel.
ngf (int): Output channel.
@ -61,7 +62,6 @@ class ResNetGenerator(nn.Cell):
dropout (bool): Use dropout or not. Default: False.
pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
Default: "CONSTANT".
Returns:
Tensor, output tensor.
"""

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,37 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""get args."""
import argparse
import ast
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.communication.management import init, get_rank
def get_args(phase):
"""Define the common options that are used in both training and test."""
parser = argparse.ArgumentParser(description='Cycle GAN.')
parser = argparse.ArgumentParser(description='Cycle GAN')
# basic parameters
parser.add_argument('--model', type=str, default="resnet", choices=("resnet", "unet"), \
help='generator model, should be in [resnet, unet].')
parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \
help='run platform, only support GPU, CPU and Ascend')
parser.add_argument("--device_id", type=int, default=0, help="device id, default is 0.")
parser.add_argument("--lr", type=float, default=0.0002, help="learning rate, default is 0.0002.")
parser.add_argument('--pool_size', type=int, default=50, \
help='the size of image buffer that stores previously generated images, default is 50.')
parser.add_argument('--lr_policy', type=str, default='linear', choices=("linear", "constant"), \
help='learning rate policy, default is linear')
parser.add_argument("--image_size", type=int, default=256, help="input image_size, default is 256.")
parser.add_argument('--batch_size', type=int, default=1, help='batch_size, default is 1.')
parser.add_argument('--max_epoch', type=int, default=200, help='epoch size for training, default is 200.')
parser.add_argument('--n_epochs', type=int, default=100, \
help='number of epochs with the initial learning rate, default is 100')
parser.add_argument("--beta1", type=float, default=0.5, help="Adam beta1, default is 0.5.")
parser.add_argument('--init_type', type=str, default='normal', choices=("normal", "xavier"), \
parser.add_argument('--platform', type=str, default='Ascend', help='only support GPU and Ascend')
parser.add_argument('--device_id', type=int, default=0, help='device id, default is 0.')
parser.add_argument('--device_num', type=int, default=1, help='device num, default is 1.')
parser.add_argument('--model', type=str, default='DepthResNet', choices=('DepthResNet', 'ResNet', 'UNet'), \
help='generator model')
parser.add_argument('--init_type', type=str, default='normal', choices=('normal', 'xavier'), \
help='network initialization, default is normal.')
parser.add_argument('--init_gain', type=float, default=0.02, \
help='scaling factor for normal, xavier and orthogonal, default is 0.02.')
parser.add_argument('--image_size', type=int, default=256, help='input image_size, default is 256.')
parser.add_argument('--batch_size', type=int, default=1, help='batch_size, default is 1.')
parser.add_argument('--pool_size', type=int, default=50, \
help='the size of image buffer that stores previously generated images')
parser.add_argument('--beta1', type=float, default=0.5, help='Adam beta1, default is 0.5.')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default is 0.0002.')
parser.add_argument('--lr_policy', type=str, default='linear', choices=('linear', 'constant'), \
help='learning rate policy, default is linear')
parser.add_argument('--max_epoch', type=int, default=200, help='epoch size for training, default is 200.')
parser.add_argument('--n_epochs', type=int, default=100, \
help='number of epochs with the initial learning rate, default is 100')
# model parameters
parser.add_argument('--in_planes', type=int, default=3, help='input channels, default is 3.')
@ -52,8 +55,8 @@ def get_args(phase):
parser.add_argument('--dl_num', type=int, default=3, \
help='discriminator model residual block numbers, default is 3.')
parser.add_argument('--slope', type=float, default=0.2, help='leakyrelu slope, default is 0.2.')
parser.add_argument('--norm_mode', type=str, default="instance", choices=("batch", "instance"), \
help='norm mode, default is instance.')
parser.add_argument('--norm_mode', type=str, default='batch', choices=('batch', 'instance'), \
help='norm mode, default is batch.')
parser.add_argument('--lambda_A', type=float, default=10.0, \
help='weight for cycle loss (A -> B -> A), default is 10.')
parser.add_argument('--lambda_B', type=float, default=10.0, \
@ -63,58 +66,44 @@ def get_args(phase):
'weight of the identity mapping loss. For example, if the weight of the identity loss '
'should be 10 times smaller than the weight of the reconstruction loss,'
'please set lambda_identity = 0.1, default is 0.5.')
parser.add_argument('--gan_mode', type=str, default='lsgan', choices=("lsgan", "vanilla"), \
parser.add_argument('--gan_mode', type=str, default='lsgan', choices=('lsgan', 'vanilla'), \
help='the type of GAN loss, default is lsgan.')
parser.add_argument('--pad_mode', type=str, default='REFLECT', choices=("CONSTANT", "REFLECT", "SYMMETRIC"), \
help='the type of Pad, default is REFLECT.')
parser.add_argument('--need_dropout', type=ast.literal_eval, default=True, \
help='whether need dropout, default is True.')
# distillation learning parameters
parser.add_argument('--kd', type=ast.literal_eval, default=False, \
help='knowledge distillation learning or not, default is False.')
parser.add_argument('--t_ngf', type=int, default=64, \
help='teacher network generator model filter numbers when `kd` is True, default is 64.')
parser.add_argument('--t_gl_num', type=int, default=9, \
help='teacher network generator model residual block numbers when `kd` is True, default is 9.')
parser.add_argument('--t_slope', type=float, default=0.2, \
help='teacher network leakyrelu slope when `kd` is True, default is 0.2.')
parser.add_argument('--t_norm_mode', type=str, default="instance", choices=("batch", "instance"), \
help='teacher network norm mode when `kd` is True, default is instance.')
parser.add_argument("--GT_A_ckpt", type=str, default=None, \
help="teacher network pretrained checkpoint file path of G_A when `kd` is True.")
parser.add_argument("--GT_B_ckpt", type=str, default=None, \
help="teacher network pretrained checkpoint file path of G_B when `kd` is True.")
parser.add_argument('--pad_mode', type=str, default='CONSTANT', choices=('CONSTANT', 'REFLECT', 'SYMMETRIC'), \
help='the type of Pad, default is CONSTANT.')
# additional parameters
parser.add_argument('--device_num', type=int, default=1, help='device num, default is 1.')
parser.add_argument("--G_A_ckpt", type=str, default=None, help="pretrained checkpoint file path of G_A.")
parser.add_argument("--G_B_ckpt", type=str, default=None, help="pretrained checkpoint file path of G_B.")
parser.add_argument("--D_A_ckpt", type=str, default=None, help="pretrained checkpoint file path of D_A.")
parser.add_argument("--D_B_ckpt", type=str, default=None, help="pretrained checkpoint file path of D_B.")
parser.add_argument("--save_checkpoint_epochs", type=int, default=1, help="Save checkpoint epochs, default is 10.")
parser.add_argument("--print_iter", type=int, default=100, help="log print iter, default is 100.")
parser.add_argument('--dataroot', default='./data/horse2zebra/', \
help='path of images (should have subfolders trainA, trainB, testA, testB, etc).')
parser.add_argument('--outputs_dir', type=str, default='./outputs', \
help='models are saved here, default is ./outputs.')
parser.add_argument('--load_ckpt', type=ast.literal_eval, default=False, \
help='whether load pretrained ckpt')
parser.add_argument('--G_A_ckpt', type=str, default='./outputs/ckpt/G_A_200.ckpt', \
help='checkpoint file path of G_A.')
parser.add_argument('--G_B_ckpt', type=str, default='./outputs/ckpt/G_B_200.ckpt', \
help='checkpoint file path of G_B.')
parser.add_argument('--D_A_ckpt', type=str, default='./outputs/ckpt/D_A_200.ckpt', \
help='checkpoint file path of D_A.')
parser.add_argument('--D_B_ckpt', type=str, default='./outputs/ckpt/D_B_200.ckpt', \
help='checkpoint file path of D_B.')
parser.add_argument('--save_checkpoint_epochs', type=int, default=10, \
help='Save checkpoint epochs, default is 10.')
parser.add_argument('--print_iter', type=int, default=100, help='log print iter, default is 100.')
parser.add_argument('--need_profiler', type=ast.literal_eval, default=False, \
help='whether need profiler, default is False.')
parser.add_argument('--save_graphs', type=ast.literal_eval, default=False, \
help='whether save graphs, default is False.')
parser.add_argument('--outputs_dir', type=str, default='./outputs', \
help='models are saved here, default is ./outputs.')
parser.add_argument('--dataroot', default=None, \
help='path of images (should have subfolders trainA, trainB, testA, testB, etc).')
parser.add_argument('--save_imgs', type=ast.literal_eval, default=True, \
help='whether save imgs when epoch end, if True result images will generate in '
'`outputs_dir/imgs`, default is True.')
help='whether save imgs when epoch end')
parser.add_argument('--use_random', type=ast.literal_eval, default=True, \
help='whether use random when training, default is True.')
parser.add_argument('--max_dataset_size', type=int, default=None, help='max images pre epoch, default is None.')
if phase == "export":
parser.add_argument("--file_name", type=str, default="cyclegan", help="output file name prefix.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \
help='file format')
parser.add_argument('--need_dropout', type=ast.literal_eval, default=False, \
help='whether need dropout, default is True.')
parser.add_argument('--max_dataset_size', type=int, default=None, \
help='max images pre epoch, default is None.')
args = parser.parse_args()
if args.device_num > 1 and args.platform != "CPU":
if args.device_num > 1:
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=args.save_graphs)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
@ -127,26 +116,18 @@ def get_args(phase):
args.rank = 0
args.device_num = 1
if args.platform != "GPU":
if args.platform == "Ascend":
args.pad_mode = "CONSTANT"
if phase != "train" and (args.G_A_ckpt is None or args.G_B_ckpt is None):
raise ValueError('Must set G_A_ckpt and G_B_ckpt in predict phase!')
if args.kd:
if args.GT_A_ckpt is None or args.GT_B_ckpt is None:
raise ValueError('Must set GT_A_ckpt, GT_B_ckpt in knowledge distillation!')
if args.batch_size == 1:
args.norm_mode = "instance"
if args.norm_mode == "instance" or (args.kd and args.t_norm_mode == "instance"):
args.batch_size = 1
if args.dataroot is None and (phase in ["train", "predict"]):
if args.dataroot is None:
raise ValueError('Must set dataroot!')
if not args.use_random:
args.need_dropout = False
args.init_type = "constant"
if args.max_dataset_size is None:
args.max_dataset_size = float("inf")

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Reporter class."""
import logging
import os
import time
@ -20,10 +22,10 @@ from datetime import datetime
from mindspore.train.serialization import save_checkpoint
from .tools import save_image
class Reporter(logging.Logger):
"""
This class includes several functions that can save images/checkpoints and print/save logging information.
Args:
args (class): Option class.
"""
@ -73,17 +75,6 @@ class Reporter(logging.Logger):
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.logger.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def epoch_start(self):
self.step_start_time = time.time()
self.epoch_start_time = time.time()
@ -119,14 +110,14 @@ class Reporter(logging.Logger):
self.info("Epoch [{}] total cost: {:.2f} ms, pre step: {:.2f} ms, G_loss: {:.2f}, D_loss: {:.2f}".format(
self.epoch, epoch_cost, pre_step_time, mean_loss_G, mean_loss_D))
if self.epoch % self.save_checkpoint_epochs == 0 and self.rank == 0:
if self.epoch % self.save_checkpoint_epochs == 0:
save_checkpoint(net.G.generator.G_A, os.path.join(self.ckpts_dir, f"G_A_{self.epoch}.ckpt"))
save_checkpoint(net.G.generator.G_B, os.path.join(self.ckpts_dir, f"G_B_{self.epoch}.ckpt"))
save_checkpoint(net.G.D_A, os.path.join(self.ckpts_dir, f"D_A_{self.epoch}.ckpt"))
save_checkpoint(net.G.D_B, os.path.join(self.ckpts_dir, f"D_B_{self.epoch}.ckpt"))
def visualizer(self, img_A, img_B, fake_A, fake_B):
if self.save_imgs and self.step % self.dataset_size == 0 and self.rank == 0:
if self.save_imgs and self.step % self.dataset_size == 0:
save_image(img_A, os.path.join(self.imgs_dir, f"{self.epoch}_img_A.jpg"))
save_image(img_B, os.path.join(self.imgs_dir, f"{self.epoch}_img_B.jpg"))
save_image(fake_A, os.path.join(self.imgs_dir, f"{self.epoch}_fake_A.jpg"))

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils for cyclegan."""
import random
import numpy as np
from PIL import Image
@ -23,15 +25,12 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
class ImagePool():
"""
This class implements an image buffer that stores previously generated images.
This buffer enables us to update discriminators using a history of generated images
rather than the ones produced by the latest generators.
"""
def __init__(self, pool_size):
"""
Initialize the ImagePool class
Args:
pool_size (int): the size of image buffer, if pool_size=0, no buffer will be created.
"""
@ -43,12 +42,9 @@ class ImagePool():
def query(self, images):
"""
Return an image from the pool.
Args:
images: the latest generated images from the generator
Returns images Tensor from the buffer.
By 50/100, the buffer will return input images.
By 50/100, the buffer will return images previously stored in the buffer,
and insert the current images to the buffer.
@ -80,7 +76,6 @@ class ImagePool():
def save_image(img, img_path):
"""Save a numpy image to the disk
Parameters:
img (numpy array / Tensor): image to save.
image_path (str): the path of the image.
@ -101,7 +96,11 @@ def decode_image(img):
def get_lr(args):
"""Learning rate generator."""
"""
Learning rate generator.
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
"""
if args.lr_policy == 'linear':
lrs = [args.lr] * args.dataset_size * args.n_epochs
lr_epoch = 0
@ -127,15 +126,3 @@ def load_ckpt(args, G_A, G_B, D_A=None, D_B=None):
if D_B is not None and args.D_B_ckpt is not None:
param_DB = load_checkpoint(args.D_B_ckpt)
load_param_into_net(D_B, param_DB)
def load_teacher_ckpt(net, ckpt_path, teacher, student):
"""Replace parameter name to teacher net and load parameter from checkpoint."""
param = load_checkpoint(ckpt_path)
new_param = {}
for k, v in param.items():
new_name = k.replace(student, teacher)
new_param_name = v.name.replace(student, teacher)
v.name = new_param_name
new_param[new_name] = v
load_param_into_net(net, new_param)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,17 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cycle GAN train."""
"""General-purpose training script for image-to-image translation.
You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model').
Example:
Train a resnet model:
python train.py --dataroot ./data/horse2zebra --model ResNet
"""
import mindspore.nn as nn
from mindspore.common import set_seed
from src.utils.args import get_args
from src.utils.reporter import Reporter
from src.utils.tools import get_lr, ImagePool, load_ckpt
from src.dataset.cyclegan_dataset import create_dataset
from src.models.losses import DiscriminatorLoss, GeneratorLoss
from src.models.cycle_gan import get_generator, get_discriminator, Generator, TrainOneStepG, TrainOneStepD
from src.models import get_generator, get_discriminator, Generator, TrainOneStepG, TrainOneStepD, \
DiscriminatorLoss, GeneratorLoss
from src.utils import get_lr, get_args, Reporter, ImagePool, load_ckpt
from src.dataset import create_dataset
set_seed(1)
def train():
"""Train function."""
@ -35,7 +40,8 @@ def train():
G_B = get_generator(args)
D_A = get_discriminator(args)
D_B = get_discriminator(args)
load_ckpt(args, G_A, G_B, D_A, D_B)
if args.load_ckpt:
load_ckpt(args, G_A, G_B, D_A, D_B)
imgae_pool_A = ImagePool(args.pool_size)
imgae_pool_B = ImagePool(args.pool_size)
generator = Generator(G_A, G_B, args.lambda_idt > 0)

View File

@ -1,235 +0,0 @@
# Contents
- [CycleGAN Description](#cyclegan-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Knowledge Distillation Process](#knowledge-distillation-process)
- [Prediction Process](#prediction-process)
- [Evaluation with cityscape dataset](#evaluation-with-cityscape-dataset)
- [Export MindIR](#export-mindir)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [CycleGAN Description](#contents)
Generative Adversarial Network (referred to as GAN) is an unsupervised learning method that learns by letting two neural networks play against each other. CycleGAN is a kind of GAN, which consists of two generation networks and two discriminant networks. It converts a certain type of pictures into another type of pictures through unpaired pictures, which can be used for style transfer.
[Paper](https://arxiv.org/abs/1703.10593): Zhu J Y , Park T , Isola P , et al. Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks[J]. 2017.
# [Model Architecture](#contents)
The CycleGAN contains two generation networks and two discriminant networks. We support two architectures for generation networks: resnet and unet. Resnet architecture contains three convolutions, several residual blocks, two fractionally-strided convlutions with stride 1/2, and one convolution that maps features to RGB. Unet architecture contains three unet block to downsample and upsample, several unet blocks unet block and one convolution that maps features to RGB. For the discriminator networks we use 70 × 70 PatchGANs, which aim to classify whether 70 × 70 overlapping image patches are real or fake.
# [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: [CityScape](<https://cityscapes-dataset.com>)
Please download the datasets [gtFine_trainvaltest.zip] and [leftImg8bit_trainvaltest.zip] and unzip them. We provide `src/utils/prepare_cityscapes_dataset.py` to process images. gtFine contains the semantics segmentations. Use --gtFine_dir to specify the path to the unzipped gtFine_trainvaltest directory. leftImg8bit contains the dashcam photographs. Use --leftImg8bit_dir to specify the path to the unzipped leftImg8bit_trainvaltest directory.
The processed images will be placed at --output_dir.
Example usage:
```bash
python src/utils/prepare_cityscapes_dataset.py --gitFine_dir ./cityscapes/gtFine/ --leftImg8bit_dir ./cityscapes/leftImg8bit --output_dir ./cityscapes/
```
The directory structure is as follows:
```path
.
└─cityscapes
├─trainA
├─trainB
├─testA
└─testB
```
# [Environment Requirements](#contents)
- Hardware GPU
- Prepare hardware environment with GPU processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```path
.
└─ cv
└─ cyclegan
├─ src
├─ __init__.py # init file
├─ dataset
├─ __init__.py # init file
├─ cyclegan_dataset.py # create cyclegan dataset
├─ datasets.py # UnalignedDataset and ImageFolderDataset class and some image utils
└─ distributed_sampler.py # iterator of dataset
├─ models
├─ __init__.py # init file
├─ cycle_gan.py # cyclegan model define
├─ losses.py # cyclegan losses function define
├─ networks.py # cyclegan sub networks define
├─ resnet.py # resnet generate network
└─ unet.py # unet generate network
└─ utils
├─ __init__.py # init file
├─ args.py # parse args
├─ prepare_cityscapes_dataset.py # prepare cityscapes dataset to cyclegan format
├─ cityscapes_utils.py # cityscapes dataset evaluation utils
├─ reporter.py # Reporter class
└─ tools.py # utils for cyclegan
├─ cityscape_eval.py # cityscape dataset eval script
├─ predict.py # generate images from A->B and B->A
├─ train.py # train script
├─ export.py # export mindir script
├─ README.md # descriptions about CycleGAN
└─ mindspore_hub_conf.py # mindspore hub interface
```
## [Script Parameters](#contents)
```python
Major parameters in train.py and config.py as follows:
"model": "resnet" # generator model, should be in [resnet, unet].
"platform": "GPU" # run platform, support GPU, CPU and Ascend.
"device_id": 0 # device id, default is 0.
"lr": 0.0002 # init learning rate, default is 0.0002.
"pool_size": 50 # the size of image buffer that stores previously generated images, default is 50.
"lr_policy": "linear" # learning rate policy, default is linear.
"image_size": 256 # input image_size, default is 256.
"batch_size": 1 # batch_size, default is 1.
"max_epoch": 200 # epoch size for training, default is 200.
"n_epochs": 100 # number of epochs with the initial learning rate, default is 100
"beta1": 0.5 # Adam beta1, default is 0.5.
"init_type": normal # network initialization, default is normal.
"init_gain": 0.02 # scaling factor for normal, xavier and orthogonal, default is 0.02.
"in_planes": 3 # input channels, default is 3.
"ngf": 64 # generator model filter numbers, default is 64.
"gl_num": 9 # generator model residual block numbers, default is 9.
"ndf": 64 # discriminator model filter numbers, default is 64.
"dl_num": 3 # discriminator model residual block numbers, default is 3.
"slope": 0.2 # leakyrelu slope, default is 0.2.
"norm_mode":"instance" # norm mode, should be [batch, instance], default is instance.
"lambda_A": 10 # weight for cycle loss (A -> B -> A), default is 10.
"lambda_B": 10 # weight for cycle loss (B -> A -> B), default is 10.
"lambda_idt": 0.5 # if lambda_idt > 0 use identity mapping.
"gan_mode": lsgan # the type of GAN loss, should be [lsgan, vanilla], default is lsgan.
"pad_mode": REFLECT # the type of Pad, should be [CONSTANT, REFLECT, SYMMETRIC], default is REFLECT.
"need_dropout": True # whether need dropout, default is True.
"kd": False # knowledge distillation learning or not, default is False.
"t_ngf": 64 # teacher network generator model filter numbers when `kd` is True, default is 64.
"t_gl_num":9 # teacher network generator model residual block numbers when `kd` is True, default is 9.
"t_slope": 0.2 # teacher network leakyrelu slope when `kd` is True, default is 0.2.
"t_norm_mode": "instance" #teacher network norm mode when `kd` is True, defaultis instance.
"print_iter": 100 # log print iter, default is 100.
"outputs_dir": "outputs" # models are saved here, default is ./outputs.
"dataroot": None # path of images (should have subfolders trainA, trainB, testA, testB, etc).
"save_imgs": True # whether save imgs when epoch end, if True result images will generate in `outputs_dir/imgs`, default is True.
"GT_A_ckpt": None # teacher network pretrained checkpoint file path of G_A when `kd` is True.
"GT_B_ckpt": None # teacher network pretrained checkpoint file path of G_B when `kd` is True.
"G_A_ckpt": None # pretrained checkpoint file path of G_A.
"G_B_ckpt": None # pretrained checkpoint file path of G_B.
"D_A_ckpt": None # pretrained checkpoint file path of D_A.
"D_B_ckpt": None # pretrained checkpoint file path of D_B.
```
## [Training Process](#contents)
```bash
python train.py --platform [PLATFORM] --dataroot [DATA_PATH]
```
**Note: pad_mode should be CONSTANT when use Ascend and CPU. When using unet as generate network, the gl_num should less than 7.**
## [Knowledge Distillation Process](#contents)
```bash
python train.py --platform [PLATFORM] --dataroot [DATA_PATH] --ngf [NGF] --kd True --GT_A_ckpt [G_A_CKPT] --GT_B_ckpt [G_B_CKPT]
```
**Note: the student network ngf should be 1/2 or 1/4 of teacher network ngf, if you change default args when training teacher generate networks, please change t_xx in knowledge distillation process.**
## [Prediction Process](#contents)
```bash
python predict.py --platform [PLATFORM] --dataroot [DATA_PATH] --G_A_ckpt [G_A_CKPT] --G_B_ckpt [G_B_CKPT]
```
**Note: the result will saved at `outputs_dir/predict`.**
## [Evaluation with cityscape dataset](#contents)
```bash
python cityscape_eval.py --cityscapes_dir [LABEL_PATH] --result_dir [FAKEB_PATH]
```
**Note: Please run cityscape_eval.py after prediction process.**
## [Export MindIR](#contents)
```bash
python export.py --platform [PLATFORM] --G_A_ckpt [G_A_CKPT] --G_B_ckpt [G_B_CKPT] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
```
**Note: The file_name parameter is the prefix, the final file will as [FILE_NAME]_AtoB.[FILE_FORMAT] and [FILE_NAME]_BtoA.[FILE_FORMAT].**
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
| Parameters | GPU |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | CycleGAN |
| Resource | NV SMX2 V100-32G |
| uploaded Date | 12/10/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | Cityscapes |
| Training Parameters | epoch=200, steps=2975, batch_size=1, lr=0.002 |
| Optimizer | Adam |
| Loss Function | Mean Sqare Loss & L1 Loss |
| outputs | probability |
| Speed | 1pc: 264 ms/step; |
| Total time | 1pc: 43.6h; |
| Parameters (M) | 11.378 M |
| Checkpoint for Fine tuning | 44M (.ckpt file) |
| Scripts | [CycleGAN script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/cycle_gan) |
### Inference Performance
| Parameters | GPU |
| ------------------- | --------------------------- |
| Model Version | CycleGAN |
| Resource | GPU |
| Uploaded Date | 12/10/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | Cityscapes |
| batch_size | 1 |
| outputs | probability |
| Accuracy | mean_pixel_acc: 54.8, mean_class_acc: 21.3, mean_class_iou: 16.1 |
# [Description of Random Situation](#contents)
If you set --use_random=False, there are no random when training.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -1,54 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Eval use cityscape dataset."""
import os
import argparse
import numpy as np
from src.dataset import make_dataset
from src.utils import CityScapes, fast_hist, get_scores
parser = argparse.ArgumentParser()
parser.add_argument("--cityscapes_dir", type=str, required=True, help="Path to the original cityscapes dataset")
parser.add_argument("--result_dir", type=str, required=True, help="Path to the generated images to be evaluated")
args = parser.parse_args()
def main():
CS = CityScapes()
cityscapes = make_dataset(args.cityscapes_dir)
hist_perframe = np.zeros((CS.class_num, CS.class_num))
for i, img_path in enumerate(cityscapes):
if i % 100 == 0:
print('Evaluating: %d/%d' % (i, len(cityscapes)))
img_name = os.path.split(img_path)[1]
ids1 = CS.get_id(os.path.join(args.cityscapes_dir, img_name))
ids2 = CS.get_id(os.path.join(args.result_dir, img_name))
hist_perframe += fast_hist(ids1.flatten(), ids2.flatten(), CS.class_num)
mean_pixel_acc, mean_class_acc, mean_class_iou, per_class_acc, per_class_iou = get_scores(hist_perframe)
print(f"mean_pixel_acc: {mean_pixel_acc}, mean_class_acc: {mean_class_acc}, mean_class_iou: {mean_class_iou}")
with open('./evaluation_results.txt', 'w') as f:
f.write('Mean pixel accuracy: %f\n' % mean_pixel_acc)
f.write('Mean class accuracy: %f\n' % mean_class_acc)
f.write('Mean class IoU: %f\n' % mean_class_iou)
f.write('************ Per class numbers below ************\n')
for i, cl in enumerate(CS.classes):
while len(cl) < 15:
cl = cl + ' '
f.write('%s: acc = %f, iou = %f\n' % (cl, per_class_acc[i], per_class_iou[i]))
if __name__ == '__main__':
main()

View File

@ -1,27 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.models import get_generator
def create_network(name, *args, **kwargs):
if name == "cyclegan":
G_A = get_generator(*args, **kwargs)
G_B = get_generator(*args, **kwargs)
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
# Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d
G_A.set_train(True)
G_B.set_train(True)
return G_A, G_B
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -1,74 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cycle GAN dataset."""
import os
import multiprocessing
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as C
from .distributed_sampler import DistributedSampler
from .datasets import UnalignedDataset, ImageFolderDataset
def create_dataset(args):
"""Create dataset"""
dataroot = args.dataroot
phase = args.phase
batch_size = args.batch_size
device_num = args.device_num
rank = args.rank
shuffle = args.use_random
max_dataset_size = args.max_dataset_size
cores = multiprocessing.cpu_count()
num_parallel_workers = min(8, int(cores / device_num))
image_size = args.image_size
mean = [0.5 * 255] * 3
std = [0.5 * 255] * 3
if phase == "train":
dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size, use_random=args.use_random)
distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle)
ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"],
sampler=distributed_sampler, num_parallel_workers=num_parallel_workers)
if args.use_random:
trans = [
C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
else:
trans = [
C.Resize((image_size, image_size)),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers)
ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(1)
else:
datadir = os.path.join(dataroot, args.data_dir)
dataset = ImageFolderDataset(datadir, max_dataset_size=max_dataset_size)
ds = de.GeneratorDataset(dataset, column_names=["image", "image_name"],
num_parallel_workers=num_parallel_workers)
trans = [
C.Resize((image_size, image_size)),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
ds = ds.map(operations=trans, input_columns=["image"], num_parallel_workers=num_parallel_workers)
ds = ds.batch(1, drop_remainder=True)
ds = ds.repeat(1)
args.dataset_size = len(dataset)
return ds

View File

@ -1,105 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cycle GAN datasets."""
import os
import random
import numpy as np
from PIL import Image
random.seed(1)
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.tif', '.tiff']
def is_image_file(filename):
"""Judge whether it is a picture."""
return any(filename.lower().endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir_path, max_dataset_size=float("inf")):
"""Return image list in dir."""
images = []
assert os.path.isdir(dir_path), '%s is not a valid directory' % dir_path
for root, _, fnames in sorted(os.walk(dir_path)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]
class UnalignedDataset:
"""
This dataset class can load unaligned/unpaired datasets.
Args:
dataroot (str): Images root directory.
phase (str): Train or test. It requires two directories in dataroot, like trainA and trainB to
host training images from domain A '{dataroot}/trainA' and from domain B '{dataroot}/trainB' respectively.
max_dataset_size (int): Maximum number of return image paths.
Returns:
Two domain image path list.
"""
def __init__(self, dataroot, phase, max_dataset_size=float("inf"), use_random=True):
self.dir_A = os.path.join(dataroot, phase + 'A')
self.dir_B = os.path.join(dataroot, phase + 'B')
self.A_paths = sorted(make_dataset(self.dir_A, max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B
self.use_random = use_random
def __getitem__(self, index):
index_B = index % self.B_size
if index % max(self.A_size, self.B_size) == 0 and self.use_random:
random.shuffle(self.A_paths)
index_B = random.randint(0, self.B_size - 1)
A_path = self.A_paths[index % self.A_size]
B_path = self.B_paths[index_B]
A_img = np.array(Image.open(A_path).convert('RGB'))
B_img = np.array(Image.open(B_path).convert('RGB'))
return A_img, B_img
def __len__(self):
return max(self.A_size, self.B_size)
class ImageFolderDataset:
"""
This dataset class can load images from image folder.
Args:
dataroot (str): Images root directory.
max_dataset_size (int): Maximum number of return image paths.
Returns:
Image path list.
"""
def __init__(self, dataroot, max_dataset_size=float("inf")):
self.dataroot = dataroot
self.paths = sorted(make_dataset(dataroot, max_dataset_size))
self.size = len(self.paths)
def __getitem__(self, index):
img_path = self.paths[index % self.size]
img = np.array(Image.open(img_path).convert('RGB'))
return img, os.path.split(img_path)[1]
def __len__(self):
return self.size

View File

@ -1,19 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init file."""
from .args import get_args
from .reporter import Reporter
from .tools import get_lr, load_teacher_ckpt, ImagePool, load_ckpt, save_image
from .cityscapes_utils import CityScapes, fast_hist, get_scores

View File

@ -1,95 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""cityscape utils."""
import numpy as np
from PIL import Image
# label name and RGB color map.
label2color = {
'unlabeled': (0, 0, 0),
'ego vehicle': (0, 0, 0),
'rectification border': (0, 0, 0),
'out of roi': (0, 0, 0),
'static': (0, 0, 0),
'dynamic': (111, 74, 0),
'ground': (81, 0, 81),
'road': (128, 64, 128),
'sidewalk': (244, 35, 232),
'parking': (250, 170, 160),
'rail track': (230, 150, 140),
'building': (70, 70, 70),
'wall': (102, 102, 156),
'fence': (190, 153, 153),
'guard rail': (180, 165, 180),
'bridge': (150, 100, 100),
'tunnel': (150, 120, 90),
'pole': (153, 153, 153),
'polegroup': (153, 153, 153),
'traffic light': (250, 170, 30),
'traffic sign': (220, 220, 0),
'vegetation': (107, 142, 35),
'terrain': (152, 251, 152),
'sky': (70, 130, 180),
'person': (220, 20, 60),
'rider': (255, 0, 0),
'car': (0, 0, 142),
'truck': (0, 0, 70),
'bus': (0, 60, 100),
'caravan': (0, 0, 90),
'trailer': (0, 0, 110),
'train': (0, 80, 100),
'motorcycle': (0, 0, 230),
'bicycle': (119, 11, 32),
'license plate': (0, 0, 142)
}
def fast_hist(a, b, n):
k = np.where((a >= 0) & (a < n))[0]
bc = np.bincount(n * a[k].astype(int) + b[k], minlength=n**2)
if len(bc) != n**2:
# ignore this example if dimension mismatch
return 0
return bc.reshape(n, n)
def get_scores(hist):
# Mean pixel accuracy
acc = np.diag(hist).sum() / (hist.sum() + 1e-12)
# Per class accuracy
cl_acc = np.diag(hist) / (hist.sum(1) + 1e-12)
# Per class IoU
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + 1e-12)
return acc, np.nanmean(cl_acc), np.nanmean(iu), cl_acc, iu
class CityScapes:
"""CityScapes util class."""
def __init__(self):
self.classes = ['road', 'sidewalk', 'building', 'wall', 'fence',
'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain',
'sky', 'person', 'rider', 'car', 'truck',
'bus', 'train', 'motorcycle', 'bicycle', 'unlabeled']
self.color_list = []
for name in self.classes:
self.color_list.append(label2color[name])
self.class_num = len(self.classes)
def get_id(self, img_path):
"""Get train id by img"""
img = np.array(Image.open(img_path).convert("RGB"))
w, h, _ = img.shape
img_tile = np.tile(img, (1, 1, self.class_num)).reshape(w, h, self.class_num, 3)
diff = np.abs(img_tile - self.color_list).sum(axis=-1)
ids = diff.argmin(axis=-1)
return ids

View File

@ -1,84 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""prepare cityscapes dataset to cyclegan format"""
import os
import argparse
import glob
from PIL import Image
parser = argparse.ArgumentParser()
parser.add_argument('--gtFine_dir', type=str, required=True,
help='Path to the Cityscapes gtFine directory.')
parser.add_argument('--leftImg8bit_dir', type=str, required=True,
help='Path to the Cityscapes leftImg8bit_trainvaltest directory.')
parser.add_argument('--output_dir', type=str, required=True,
default='./cityscapes',
help='Directory the output images will be written to.')
opt = parser.parse_args()
def load_resized_img(path):
"""Load image with RGB and resize to (256, 256)"""
return Image.open(path).convert('RGB').resize((256, 256))
def check_matching_pair(segmap_path, photo_path):
"""Check the segment images and photo images are matched or not."""
segmap_identifier = os.path.basename(segmap_path).replace('_gtFine_color', '')
photo_identifier = os.path.basename(photo_path).replace('_leftImg8bit', '')
assert segmap_identifier == photo_identifier, \
f"[{segmap_path}] and [{photo_path}] don't seem to be matching. Aborting."
def process_cityscapes(gtFine_dir, leftImg8bit_dir, output_dir, phase):
"""Process citycapes dataset to cyclegan dataset format."""
save_phase = 'test' if phase == 'val' else 'train'
savedir = os.path.join(output_dir, save_phase)
os.makedirs(savedir + 'A', exist_ok=True)
os.makedirs(savedir + 'B', exist_ok=True)
print(f"Directory structure prepared at {output_dir}")
segmap_expr = os.path.join(gtFine_dir, phase) + "/*/*_color.png"
segmap_paths = glob.glob(segmap_expr)
segmap_paths = sorted(segmap_paths)
photo_expr = os.path.join(leftImg8bit_dir, phase) + "/*/*_leftImg8bit.png"
photo_paths = glob.glob(photo_expr)
photo_paths = sorted(photo_paths)
assert len(segmap_paths) == len(photo_paths), \
"{} images that match [{}], and {} images that match [{}]. Aborting.".format(
len(segmap_paths), segmap_expr, len(photo_paths), photo_expr)
for i, (segmap_path, photo_path) in enumerate(zip(segmap_paths, photo_paths)):
check_matching_pair(segmap_path, photo_path)
segmap = load_resized_img(segmap_path)
photo = load_resized_img(photo_path)
# data for cyclegan where the two images are stored at two distinct directories
savepath = os.path.join(savedir + 'A', f"{i + 1}.jpg")
photo.save(savepath)
savepath = os.path.join(savedir + 'B', f"{i + 1}.jpg")
segmap.save(savepath)
if i % (len(segmap_paths) // 10) == 0:
print("%d / %d: last image saved at %s, " % (i, len(segmap_paths), savepath))
if __name__ == '__main__':
print('Preparing Cityscapes Dataset for val phase')
process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "val")
print('Preparing Cityscapes Dataset for train phase')
process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "train")
print('Done')