forked from mindspore-Ecosystem/mindspore
!14637 CycleGAN
From: @xianzhu-liu Reviewed-by: @c_34,@oacjiewen Signed-off-by: @c_34
This commit is contained in:
commit
d93c23626a
|
@ -0,0 +1,173 @@
|
|||
# Contents
|
||||
|
||||
- [CycleGAN Description](#cyclegan-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training](#training-process)
|
||||
- [Evaluation](#evaluation-process)
|
||||
- [Prediction Process](#prediction-process)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#evaluation-performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [CycleGAN Description](#contents)
|
||||
|
||||
Image-to-image translation is a visual and image problem. Its goal is to use paired images as a training set and (let the machine) learn the mapping from input images to output images. However, in many tasks, paired training data cannot be obtained. CycleGAN does not require the training data to be paired. It only needs to provide images of different domains to successfully train the image mapping between different domains. CycleGAN shares two generators, and then each has a discriminator.
|
||||
|
||||
[Paper](https://arxiv.org/abs/1703.10593): Zhu J Y , Park T , Isola P , et al. Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks[J]. 2017.
|
||||
|
||||

|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
The CycleGAN contains two generation networks and two discriminant networks.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Download CycleGAN datasets and create your own datasets. We provide data/download_cyclegan_dataset.sh to download the datasets.
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend/GPU)
|
||||
- Prepare hardware environment with Ascend or GPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
## [Dependences](#contents)
|
||||
|
||||
- Python==3.7.5
|
||||
- Mindspore==1.1
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
The entire code structure is as following:
|
||||
|
||||
```markdown
|
||||
.CycleGAN
|
||||
├─ README.md # descriptions about CycleGAN
|
||||
├─ data
|
||||
└─download_cyclegan_dataset.sh.py # download dataset
|
||||
├── scripts
|
||||
└─run_train_ascend.sh # launch ascend training(1 pcs)
|
||||
└─run_train_gpu.sh # launch gpu training(1 pcs)
|
||||
└─run_eval_ascend.sh # launch ascend eval
|
||||
└─run_eval_gpu.sh # launch gpu eval
|
||||
├─ imgs
|
||||
└─objects-transfiguration.jpg # CycleGAN Imgs
|
||||
├─ src
|
||||
├─ __init__.py # init file
|
||||
├─ dataset
|
||||
├─ __init__.py # init file
|
||||
├─ cyclegan_dataset.py # create cyclegan dataset
|
||||
└─ distributed_sampler.py # iterator of dataset
|
||||
├─ models
|
||||
├─ __init__.py # init file
|
||||
├─ cycle_gan.py # cyclegan model define
|
||||
├─ losses.py # cyclegan losses function define
|
||||
├─ networks.py # cyclegan sub networks define
|
||||
├─ resnet.py # resnet generate network
|
||||
└─ depth_resnet.py # better generate network
|
||||
└─ utils
|
||||
├─ __init__.py # init file
|
||||
├─ args.py # parse args
|
||||
├─ reporter.py # Reporter class
|
||||
└─ tools.py # utils for cyclegan
|
||||
├─ eval.py # generate images from A->B and B->A
|
||||
├─ train.py # train script
|
||||
└─ export.py # export mindir script
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Major parameters in train.py and config.py as follows:
|
||||
|
||||
```python
|
||||
"platform": Ascend # run platform, only support GPU and Ascend.
|
||||
"device_id": 0 # device id, default is 0.
|
||||
"model": "resnet" # generator model.
|
||||
"pool_size": 50 # the size of image buffer that stores previously generated images, default is 50.
|
||||
"lr_policy": "linear" # learning rate policy, default is linear.
|
||||
"image_size": 256 # input image_size, default is 256.
|
||||
"batch_size": 1 # batch_size, default is 1.
|
||||
"max_epoch": 200 # epoch size for training, default is 200.
|
||||
"in_planes": 3 # input channels, default is 3.
|
||||
"ngf": 64 # generator model filter numbers, default is 64.
|
||||
"gl_num": 9 # generator model residual block numbers, default is 9.
|
||||
"ndf": 64 # discriminator model filter numbers, default is 64.
|
||||
"dl_num": 3 # discriminator model residual block numbers, default is 3.
|
||||
"outputs_dir": "outputs" # models are saved here, default is ./outputs.
|
||||
"dataroot": None # path of images (should have subfolders trainA, trainB, testA, testB, etc).
|
||||
"load_ckpt": False # whether load pretrained ckpt.
|
||||
"G_A_ckpt": None # pretrained checkpoint file path of G_A.
|
||||
"G_B_ckpt": None # pretrained checkpoint file path of G_B.
|
||||
"D_A_ckpt": None # pretrained checkpoint file path of D_A.
|
||||
"D_B_ckpt": None # pretrained checkpoint file path of D_B.
|
||||
```
|
||||
|
||||
## [Training](#contents)
|
||||
|
||||
- running on Ascend with default parameters
|
||||
|
||||
```bash
|
||||
sh ./scripts/run_train_ascend.sh
|
||||
```
|
||||
|
||||
- running on GPU with default parameters
|
||||
|
||||
```bash
|
||||
sh ./scripts/run_train_gpu.sh
|
||||
```
|
||||
|
||||
## [Evaluation](#contents)
|
||||
|
||||
```bash
|
||||
python eval.py --platform [PLATFORM] --dataroot [DATA_PATH] --G_A_ckpt [G_A_CKPT] --G_B_ckpt [G_B_CKPT]
|
||||
```
|
||||
|
||||
**Note: You will get the result as following in "./outputs_dir/predict".**
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Training Performance
|
||||
|
||||
| Parameters | single Ascend/GPU |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | CycleGAN |
|
||||
| Resource | Ascend 910/NV SMX2 V100-32G |
|
||||
| MindSpore Version | 1.1 |
|
||||
| Dataset | horse2zebra |
|
||||
| Training Parameters | epoch=200, steps=1334, batch_size=1, lr=0.002 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | Mean Sqare Loss & L1 Loss |
|
||||
| outputs | probability |
|
||||
| Speed | 1pc(Ascend): 123 ms/step; 1pc(GPU): 264 ms/step |
|
||||
| Total time | 1pc(Ascend): 9.6h; 1pc(GPU): 19.1h; |
|
||||
| Checkpoint for Fine tuning | 44M (.ckpt file) |
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | single Ascend/GPU |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | CycleGAN |
|
||||
| Resource | Ascend 910/NV SMX2 V100-32G |
|
||||
| MindSpore Version | 1.1 |
|
||||
| Dataset | horse2zebra |
|
||||
| batch_size | 1 |
|
||||
| outputs | probability |
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,23 @@
|
|||
#!/bin/bash
|
||||
|
||||
FILE=$1
|
||||
|
||||
if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "mini" && $FILE != "mini_pix2pix" && $FILE != "mini_colorization" ]]; then
|
||||
echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ $FILE == "cityscapes" ]]; then
|
||||
echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py."
|
||||
echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Specified [$FILE]"
|
||||
URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
|
||||
ZIP_FILE=./datasets/$FILE.zip
|
||||
TARGET_DIR=./datasets/$FILE/
|
||||
wget -N $URL -O $ZIP_FILE
|
||||
mkdir $TARGET_DIR
|
||||
unzip $ZIP_FILE -d ./datasets/
|
||||
rm $ZIP_FILE
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,25 +12,26 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Cycle GAN predict."""
|
||||
|
||||
"""Cycle GAN test."""
|
||||
|
||||
import os
|
||||
from mindspore import Tensor
|
||||
from src.models.cycle_gan import get_generator
|
||||
from src.utils.args import get_args
|
||||
from src.dataset.cyclegan_dataset import create_dataset
|
||||
from src.utils.reporter import Reporter
|
||||
from src.utils.tools import save_image, load_ckpt
|
||||
|
||||
from src.models import get_generator
|
||||
from src.utils import get_args, load_ckpt, save_image, Reporter
|
||||
from src.dataset import create_dataset
|
||||
|
||||
def predict():
|
||||
"""Predict function."""
|
||||
args = get_args("predict")
|
||||
G_A = get_generator(args)
|
||||
G_B = get_generator(args)
|
||||
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
|
||||
# Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d
|
||||
G_A.set_train(True)
|
||||
G_B.set_train(True)
|
||||
load_ckpt(args, G_A, G_B)
|
||||
|
||||
imgs_out = os.path.join(args.outputs_dir, "predict")
|
||||
if not os.path.exists(imgs_out):
|
||||
os.makedirs(imgs_out)
|
||||
|
@ -45,8 +46,10 @@ def predict():
|
|||
for data in ds.create_dict_iterator(output_numpy=True):
|
||||
img_A = Tensor(data["image"])
|
||||
path_A = str(data["image_name"][0], encoding="utf-8")
|
||||
path_B = path_A[0:-4] + "_fake_B.jpg"
|
||||
fake_B = G_A(img_A)
|
||||
save_image(fake_B, os.path.join(imgs_out, "fake_B", path_A))
|
||||
save_image(fake_B, os.path.join(imgs_out, "fake_B", path_B))
|
||||
save_image(img_A, os.path.join(imgs_out, "fake_B", path_A))
|
||||
reporter.info('save fake_B at %s', os.path.join(imgs_out, "fake_B", path_A))
|
||||
reporter.end_predict()
|
||||
args.data_dir = 'testB'
|
||||
|
@ -56,10 +59,13 @@ def predict():
|
|||
for data in ds.create_dict_iterator(output_numpy=True):
|
||||
img_B = Tensor(data["image"])
|
||||
path_B = str(data["image_name"][0], encoding="utf-8")
|
||||
path_A = path_B[0:-4] + "_fake_A.jpg"
|
||||
fake_A = G_B(img_B)
|
||||
save_image(fake_A, os.path.join(imgs_out, "fake_A", path_B))
|
||||
save_image(fake_A, os.path.join(imgs_out, "fake_A", path_A))
|
||||
save_image(img_B, os.path.join(imgs_out, "fake_A", path_B))
|
||||
reporter.info('save fake_A at %s', os.path.join(imgs_out, "fake_A", path_B))
|
||||
reporter.end_predict()
|
||||
|
||||
if __name__ == "__main__":
|
||||
predict()
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,28 +12,36 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export file."""
|
||||
import numpy as np
|
||||
|
||||
"""export file."""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import export
|
||||
from src.models import get_generator
|
||||
from src.utils import get_args, load_ckpt
|
||||
from src.models.cycle_gan import get_generator
|
||||
from src.utils.args import get_args
|
||||
from src.utils.tools import load_ckpt
|
||||
|
||||
args = get_args("export")
|
||||
model_args = get_args("export")
|
||||
parser = argparse.ArgumentParser(description="openpose export")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--file_name", type=str, default="CycleGAN", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=model_args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
G_A = get_generator(args)
|
||||
G_B = get_generator(args)
|
||||
G_A = get_generator(model_args)
|
||||
G_B = get_generator(model_args)
|
||||
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
|
||||
# Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d
|
||||
G_A.set_train(True)
|
||||
G_B.set_train(True)
|
||||
load_ckpt(args, G_A, G_B)
|
||||
load_ckpt(model_args, G_A, G_B)
|
||||
|
||||
input_shp = [1, 3, args.image_size, args.image_size]
|
||||
input_shp = [args.batch_size, 3, model_args.image_size, model_args.image_size]
|
||||
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
|
||||
G_A_file = f"{args.file_name}_BtoA"
|
||||
export(G_A, input_array, file_name=G_A_file, file_format=args.file_format)
|
Binary file not shown.
After Width: | Height: | Size: 681 KiB |
|
@ -1,4 +1,5 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -11,7 +12,5 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""init file."""
|
||||
from .datasets import UnalignedDataset, ImageFolderDataset, make_dataset
|
||||
from .cyclegan_dataset import create_dataset
|
||||
|
||||
python eval.py --platform Ascend --device_id 0 --model DepthResNet --G_A_ckpt ./outputs/ckpt/G_A_200.ckpt --G_B_ckpt ./outputs/ckpt/G_B_200.ckpt
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -11,8 +12,5 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""init file."""
|
||||
from .cycle_gan import get_generator, get_discriminator, Generator, TrainOneStepG, TrainOneStepD
|
||||
from .losses import DiscriminatorLoss, GeneratorLoss, GANLoss
|
||||
from .networks import init_weights
|
||||
|
||||
python eval.py --platform GPU --device_id 0 --model ResNet --G_A_ckpt ./outputs/ckpt/G_A_200.ckpt --G_B_ckpt ./outputs/ckpt/G_B_200.ckpt
|
|
@ -0,0 +1,22 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the train as: "
|
||||
echo "python train.py device_id platform model max_epoch dataroot outputs_dir"
|
||||
echo "for example: python train.py --platform Ascend --device_id 0 --model ResNet --max_epoch 200 --dataroot ./data/horse2zebra/ --outputs_dir ./outputs"
|
||||
echo "================================================================================================================="
|
||||
|
||||
python train.py --platform Ascend --device_id 0 --model DepthResNet --max_epoch 200 --dataroot ./data/horse2zebra/ --outputs_dir ./outputs
|
|
@ -0,0 +1,22 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the train as: "
|
||||
echo "python train.py device_id platform model max_epoch dataroot outputs_dir"
|
||||
echo "for example: python train.py --platform GPU --device_id 0 --model ResNet --max_epoch 200 --dataroot ./data/horse2zebra/ --outputs_dir ./outputs"
|
||||
echo "================================================================================================================="
|
||||
|
||||
python train.py --platform GPU --device_id 0 --model ResNet --max_epoch 200 --dataroot ./data/horse2zebra/ --outputs_dir ./outputs
|
|
@ -0,0 +1,183 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Cycle GAN dataset."""
|
||||
|
||||
import os
|
||||
import random
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
from .distributed_sampler import DistributedSampler
|
||||
|
||||
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.tif', '.tiff']
|
||||
|
||||
def is_image_file(filename):
|
||||
"""Judge whether it is a picture."""
|
||||
return any(filename.lower().endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def make_dataset(dir_path, max_dataset_size=float("inf")):
|
||||
"""Return image list in dir."""
|
||||
images = []
|
||||
assert os.path.isdir(dir_path), '%s is not a valid directory' % dir_path
|
||||
|
||||
for root, _, fnames in sorted(os.walk(dir_path)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
path = os.path.join(root, fname)
|
||||
images.append(path)
|
||||
return images[:min(max_dataset_size, len(images))]
|
||||
|
||||
|
||||
class UnalignedDataset:
|
||||
"""
|
||||
This dataset class can load unaligned/unpaired datasets.
|
||||
It requires two directories to host training images from domain A '/path/to/data/trainA'
|
||||
and from domain B '/path/to/data/trainB' respectively.
|
||||
You can train the model with the dataset flag '--dataroot /path/to/data'.
|
||||
Similarly, you need to prepare two directories:
|
||||
'/path/to/data/testA' and '/path/to/data/testB' during test time.
|
||||
Returns:
|
||||
Two domain image path list.
|
||||
"""
|
||||
def __init__(self, dataroot, phase, max_dataset_size=float("inf"), use_random=True):
|
||||
self.dir_A = os.path.join(dataroot, phase + 'A')
|
||||
self.dir_B = os.path.join(dataroot, phase + 'B')
|
||||
|
||||
self.A_paths = sorted(make_dataset(self.dir_A, max_dataset_size)) # load images from '/path/to/data/trainA'
|
||||
self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size)) # load images from '/path/to/data/trainB'
|
||||
self.A_size = len(self.A_paths) # get the size of dataset A
|
||||
self.B_size = len(self.B_paths) # get the size of dataset B
|
||||
self.use_random = use_random
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return a data point and its metadata information.
|
||||
|
||||
Parameters:
|
||||
index (int) -- a random integer for data indexing
|
||||
|
||||
Returns a dictionary that contains A, B, A_paths and B_paths
|
||||
A (tensor) -- an image in the input domain
|
||||
B (tensor) -- its corresponding image in the target domain
|
||||
A_paths (str) -- image paths
|
||||
B_paths (str) -- image paths
|
||||
"""
|
||||
index_B = index % self.B_size
|
||||
if index % max(self.A_size, self.B_size) == 0 and self.use_random:
|
||||
random.shuffle(self.A_paths)
|
||||
index_B = random.randint(0, self.B_size - 1)
|
||||
A_path = self.A_paths[index % self.A_size]
|
||||
B_path = self.B_paths[index_B]
|
||||
A_img = np.array(Image.open(A_path).convert('RGB'))
|
||||
B_img = np.array(Image.open(B_path).convert('RGB'))
|
||||
|
||||
return A_img, B_img
|
||||
|
||||
def __len__(self):
|
||||
"""Return the total number of images in the dataset.
|
||||
"""
|
||||
return max(self.A_size, self.B_size)
|
||||
|
||||
|
||||
class ImageFolderDataset:
|
||||
"""
|
||||
This dataset class can load images from image folder.
|
||||
Args:
|
||||
dataroot (str): Images root directory.
|
||||
max_dataset_size (int): Maximum number of return image paths.
|
||||
Returns:
|
||||
Image path list.
|
||||
"""
|
||||
def __init__(self, dataroot, max_dataset_size=float("inf")):
|
||||
self.dataroot = dataroot
|
||||
self.paths = sorted(make_dataset(dataroot, max_dataset_size))
|
||||
self.size = len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path = self.paths[index % self.size]
|
||||
img = np.array(Image.open(img_path).convert('RGB'))
|
||||
|
||||
return img, os.path.split(img_path)[1]
|
||||
|
||||
def __len__(self):
|
||||
"""Return the total number of images in the dataset.
|
||||
As we have two datasets with potentially different number of images,
|
||||
we take a maximum of
|
||||
"""
|
||||
return self.size
|
||||
|
||||
|
||||
def create_dataset(args):
|
||||
"""
|
||||
Create dataset
|
||||
This dataset class can load images for train or test.
|
||||
Args:
|
||||
dataroot (str): Images root directory.
|
||||
Returns:
|
||||
RGB Image list.
|
||||
"""
|
||||
dataroot = args.dataroot
|
||||
phase = args.phase
|
||||
batch_size = args.batch_size
|
||||
device_num = args.device_num
|
||||
rank = args.rank
|
||||
shuffle = args.use_random
|
||||
max_dataset_size = args.max_dataset_size
|
||||
cores = multiprocessing.cpu_count()
|
||||
num_parallel_workers = min(8, int(cores / device_num))
|
||||
image_size = args.image_size
|
||||
mean = [0.5 * 255] * 3
|
||||
std = [0.5 * 255] * 3
|
||||
if phase == "train":
|
||||
dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size, use_random=args.use_random)
|
||||
distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle)
|
||||
ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"],
|
||||
sampler=distributed_sampler, num_parallel_workers=num_parallel_workers)
|
||||
if args.use_random:
|
||||
trans = [
|
||||
C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
trans = [
|
||||
C.Resize((image_size, image_size)),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(1)
|
||||
else:
|
||||
datadir = os.path.join(dataroot, args.data_dir)
|
||||
dataset = ImageFolderDataset(datadir, max_dataset_size=max_dataset_size)
|
||||
ds = de.GeneratorDataset(dataset, column_names=["image", "image_name"],
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
trans = [
|
||||
C.Resize((image_size, image_size)),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
ds = ds.map(operations=trans, input_columns=["image"], num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.batch(1, drop_remainder=True)
|
||||
ds = ds.repeat(1)
|
||||
args.dataset_size = len(dataset)
|
||||
return ds
|
||||
|
|
@ -1,60 +1,61 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset distributed sampler."""
|
||||
from __future__ import division
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
"""Distributed sampler."""
|
||||
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
|
||||
if num_replicas is None:
|
||||
print("***********Setting world_size to 1 since it is not passed in ******************")
|
||||
num_replicas = 1
|
||||
if rank is None:
|
||||
print("***********Setting rank to 0 since it is not passed in ******************")
|
||||
rank = 0
|
||||
self.dataset_size = dataset_size
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
if self.shuffle:
|
||||
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
|
||||
# np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
|
||||
indices = indices.tolist()
|
||||
self.epoch += 1
|
||||
# change to list type
|
||||
else:
|
||||
indices = list(range(self.dataset_size))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Dataset distributed sampler."""
|
||||
from __future__ import division
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
"""Distributed sampler."""
|
||||
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
|
||||
if num_replicas is None:
|
||||
print("***********Setting world_size to 1 since it is not passed in ******************")
|
||||
num_replicas = 1
|
||||
if rank is None:
|
||||
print("***********Setting rank to 0 since it is not passed in ******************")
|
||||
rank = 0
|
||||
self.dataset_size = dataset_size
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
if self.shuffle:
|
||||
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
|
||||
# np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
|
||||
indices = indices.tolist()
|
||||
self.epoch += 1
|
||||
# change to list type
|
||||
else:
|
||||
indices = list(range(self.dataset_size))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Cycle GAN network."""
|
||||
|
||||
import mindspore as ms
|
||||
|
@ -21,40 +22,40 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.communication.management import get_group_size
|
||||
import mindspore.ops as ops
|
||||
from .resnet import ResNetGenerator
|
||||
from .networks import ConvNormReLU, init_weights
|
||||
from .resnet import ResNetGenerator
|
||||
from .depth_resnet import DepthResNetGenerator
|
||||
from .unet import UnetGenerator
|
||||
|
||||
def get_generator(args, teacher_net=False):
|
||||
"""Return generator by args."""
|
||||
if teacher_net:
|
||||
if args.model == "resnet":
|
||||
net = ResNetGenerator(in_planes=args.in_planes, ngf=args.t_ngf, n_layers=args.t_gl_num,
|
||||
alpha=args.t_slope, norm_mode=args.t_norm_mode, dropout=False,
|
||||
pad_mode=args.pad_mode)
|
||||
init_weights(net, args.init_type, args.init_gain)
|
||||
elif args.model == "unet":
|
||||
net = UnetGenerator(in_planes=args.in_planes, out_planes=args.in_planes, ngf=args.t_ngf,
|
||||
n_layers=args.t_gl_num, alpha=args.t_slope, norm_mode=args.t_norm_mode,
|
||||
dropout=False)
|
||||
init_weights(net, args.init_type, args.init_gain)
|
||||
else:
|
||||
raise NotImplementedError(f'Model {args.model} not recognized.')
|
||||
def get_generator(args):
|
||||
"""
|
||||
This class implements the CycleGAN model, for learning image-to-image translation without paired data.
|
||||
|
||||
The model training requires '--dataset_mode unaligned' dataset.
|
||||
By default, it uses a '--netG resnet_9blocks' ResNet generator,
|
||||
a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
|
||||
and a least-square GANs objective ('--gan_mode lsgan').
|
||||
"""
|
||||
if args.model == "ResNet":
|
||||
net = ResNetGenerator(in_planes=args.in_planes, ngf=args.ngf, n_layers=args.gl_num,
|
||||
alpha=args.slope, norm_mode=args.norm_mode, dropout=args.need_dropout,
|
||||
pad_mode=args.pad_mode)
|
||||
init_weights(net, args.init_type, args.init_gain)
|
||||
elif args.model == "DepthResNet":
|
||||
net = DepthResNetGenerator(in_planes=args.in_planes, ngf=args.ngf, n_layers=args.gl_num,
|
||||
alpha=args.slope, norm_mode=args.norm_mode, dropout=args.need_dropout,
|
||||
pad_mode=args.pad_mode)
|
||||
init_weights(net, args.init_type, args.init_gain)
|
||||
elif args.model == "UNet":
|
||||
net = UnetGenerator(in_planes=args.in_planes, out_planes=args.in_planes, ngf=args.ngf, n_layers=args.gl_num,
|
||||
alpha=args.slope, norm_mode=args.norm_mode, dropout=args.need_dropout)
|
||||
init_weights(net, args.init_type, args.init_gain)
|
||||
else:
|
||||
if args.model == "resnet":
|
||||
net = ResNetGenerator(in_planes=args.in_planes, ngf=args.ngf, n_layers=args.gl_num,
|
||||
alpha=args.slope, norm_mode=args.norm_mode, dropout=args.need_dropout,
|
||||
pad_mode=args.pad_mode)
|
||||
init_weights(net, args.init_type, args.init_gain)
|
||||
elif args.model == "unet":
|
||||
net = UnetGenerator(in_planes=args.in_planes, out_planes=args.in_planes, ngf=args.ngf, n_layers=args.gl_num,
|
||||
alpha=args.slope, norm_mode=args.norm_mode, dropout=args.need_dropout)
|
||||
init_weights(net, args.init_type, args.init_gain)
|
||||
else:
|
||||
raise NotImplementedError(f'Model {args.model} not recognized.')
|
||||
raise NotImplementedError(f'Model {args.model} not recognized.')
|
||||
return net
|
||||
|
||||
def get_discriminator(args, teacher_net=False):
|
||||
|
||||
def get_discriminator(args):
|
||||
"""Return discriminator by args."""
|
||||
net = Discriminator(in_planes=args.in_planes, ndf=args.ndf, n_layers=args.dl_num,
|
||||
alpha=args.slope, norm_mode=args.norm_mode)
|
||||
|
@ -65,17 +66,14 @@ def get_discriminator(args, teacher_net=False):
|
|||
class Discriminator(nn.Cell):
|
||||
"""
|
||||
Discriminator of GAN.
|
||||
|
||||
Args:
|
||||
in_planes (int): Input channel.
|
||||
ndf (int): Output channel.
|
||||
n_layers (int): The number of ConvNormReLU blocks.
|
||||
alpha (float): LeakyRelu slope. Default: 0.2.
|
||||
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> Discriminator(3, 64, 3)
|
||||
"""
|
||||
|
@ -105,15 +103,12 @@ class Discriminator(nn.Cell):
|
|||
class Generator(nn.Cell):
|
||||
"""
|
||||
Generator of CycleGAN, return fake_A, fake_B, rec_A, rec_B, identity_A and identity_B.
|
||||
|
||||
Args:
|
||||
G_A (Cell): The generator network of domain A to domain B.
|
||||
G_B (Cell): The generator network of domain B to domain A.
|
||||
use_identity (bool): Use identity loss or not. Default: True.
|
||||
|
||||
Returns:
|
||||
Tensors, fake_A, fake_B, rec_A, rec_B, identity_A and identity_B.
|
||||
|
||||
Examples:
|
||||
>>> Generator(G_A, G_B)
|
||||
"""
|
||||
|
@ -142,7 +137,6 @@ class Generator(nn.Cell):
|
|||
class WithLossCell(nn.Cell):
|
||||
"""
|
||||
Wrap the network with loss function to return generator loss.
|
||||
|
||||
Args:
|
||||
network (Cell): The target network to wrap.
|
||||
"""
|
||||
|
@ -158,10 +152,8 @@ class WithLossCell(nn.Cell):
|
|||
class TrainOneStepG(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of Cycle GAN generator network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
G (Cell): Generator with loss Cell. Note that loss function should have been added.
|
||||
generator (Cell): Generator of CycleGAN.
|
||||
|
@ -210,10 +202,8 @@ class TrainOneStepG(nn.Cell):
|
|||
class TrainOneStepD(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of Cycle GAN discriminator network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
G (Cell): Generator with loss Cell. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
|
@ -0,0 +1,113 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ResNet Generator."""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from .networks import ConvNormReLU, ConvTransposeNormReLU
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""
|
||||
ResNet residual block definition.
|
||||
We construct a conv block with build_conv_block function,
|
||||
and implement skip connections in <forward> function..
|
||||
Args:
|
||||
dim (int): Input and output channel.
|
||||
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
|
||||
dropout (bool): Use dropout or not. Default: False.
|
||||
pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
|
||||
Default: "CONSTANT".
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
def __init__(self, dim, norm_mode='batch', dropout=False, pad_mode="CONSTANT"):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0.2, norm_mode, pad_mode)
|
||||
self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0.2, norm_mode, pad_mode)
|
||||
self.dropout = dropout
|
||||
if dropout:
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv1(x)
|
||||
if self.dropout:
|
||||
out = self.dropout(out)
|
||||
out = self.conv2(out)
|
||||
return x + out
|
||||
|
||||
|
||||
class DepthResNetGenerator(nn.Cell):
|
||||
"""
|
||||
ResNet Generator of GAN.
|
||||
Args:
|
||||
in_planes (int): Input channel.
|
||||
ngf (int): Output channel.
|
||||
n_layers (int): The number of ConvNormReLU blocks.
|
||||
alpha (float): LeakyRelu slope. Default: 0.2.
|
||||
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
|
||||
dropout (bool): Use dropout or not. Default: False.
|
||||
pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
|
||||
Default: "CONSTANT".
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
def __init__(self, in_planes=3, ngf=64, n_layers=9, alpha=0.2, norm_mode='batch', dropout=False,
|
||||
pad_mode="CONSTANT"):
|
||||
super(DepthResNetGenerator, self).__init__()
|
||||
conv_in1 = nn.Conv2d(in_planes, ngf, kernel_size=3, stride=1, has_bias=True)
|
||||
conv_in2 = ConvNormReLU(ngf, ngf, 7, 1, alpha, norm_mode, pad_mode=pad_mode)
|
||||
self.conv_in = nn.SequentialCell([conv_in1, conv_in2])
|
||||
down_1 = ConvNormReLU(ngf, ngf * 2, 3, 2, alpha, norm_mode)
|
||||
Res1 = ResidualBlock(ngf * 2, norm_mode, dropout=dropout, pad_mode=pad_mode)
|
||||
self.down_1 = nn.SequentialCell([down_1, Res1])
|
||||
down_2 = ConvNormReLU(ngf * 2, ngf * 3, 3, 2, alpha, norm_mode)
|
||||
Res2 = ResidualBlock(ngf * 3, norm_mode, dropout=dropout, pad_mode=pad_mode)
|
||||
self.down_2 = nn.SequentialCell([down_2, Res2])
|
||||
self.down_3 = ConvNormReLU(ngf * 3, ngf * 4, 3, 2, alpha, norm_mode)
|
||||
layers = [ResidualBlock(ngf * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * (n_layers-5)
|
||||
self.residuals = nn.SequentialCell(layers)
|
||||
up_3 = ConvTransposeNormReLU(ngf * 4, ngf * 3, 3, 2, alpha, norm_mode)
|
||||
Res3 = ResidualBlock(ngf * 3, norm_mode, dropout=dropout, pad_mode=pad_mode)
|
||||
self.up_3 = nn.SequentialCell([up_3, Res3])
|
||||
up_2 = ConvTransposeNormReLU(ngf * 3, ngf * 2, 3, 2, alpha, norm_mode)
|
||||
Res4 = ResidualBlock(ngf * 2, norm_mode, dropout=dropout, pad_mode=pad_mode)
|
||||
self.up_2 = nn.SequentialCell([up_2, Res4])
|
||||
up_1 = ConvTransposeNormReLU(ngf * 2, ngf, 3, 2, alpha, norm_mode)
|
||||
Res5 = ResidualBlock(ngf, norm_mode, dropout=dropout, pad_mode=pad_mode)
|
||||
self.up_1 = nn.SequentialCell([up_1, Res5])
|
||||
tanh = nn.Tanh()
|
||||
if pad_mode == "CONSTANT":
|
||||
conv_out1 = nn.Conv2d(ngf, 3, kernel_size=7, stride=1, has_bias=True, pad_mode='pad', padding=3)
|
||||
conv_out2 = nn.Conv2d(3, 3, kernel_size=3, stride=1, has_bias=True)
|
||||
self.conv_out = nn.SequentialCell([conv_out1, tanh, conv_out2, tanh])
|
||||
else:
|
||||
pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)
|
||||
conv = nn.Conv2d(ngf, 3, kernel_size=7, stride=1, pad_mode='pad')
|
||||
self.conv_out = nn.SequentialCell([pad, conv, tanh])
|
||||
|
||||
def construct(self, x):
|
||||
""" construct network """
|
||||
x = self.conv_in(x)
|
||||
x = self.down_1(x)
|
||||
x = self.down_2(x)
|
||||
x = self.down_3(x)
|
||||
x = self.residuals(x)
|
||||
x = self.up_3(x)
|
||||
x = self.up_2(x)
|
||||
x = self.up_1(x)
|
||||
output = self.conv_out(x)
|
||||
return output
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,24 +12,22 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Cycle GAN losses"""
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
from .cycle_gan import get_generator
|
||||
from ..utils import load_teacher_ckpt
|
||||
|
||||
|
||||
class BCEWithLogits(nn.Cell):
|
||||
"""
|
||||
BCEWithLogits creates a criterion to measure the Binary Cross Entropy between the true labels and
|
||||
predicted labels with sigmoid logits.
|
||||
|
||||
Args:
|
||||
reduction (str): Specifies the reduction to be applied to the output.
|
||||
Its value must be one of 'none', 'mean', 'sum'. Default: 'none'.
|
||||
|
||||
Outputs:
|
||||
Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `inputs`.
|
||||
Otherwise, the output is a scalar.
|
||||
|
@ -58,13 +56,19 @@ class BCEWithLogits(nn.Cell):
|
|||
|
||||
class GANLoss(nn.Cell):
|
||||
"""
|
||||
Cycle GAN loss factory.
|
||||
|
||||
The GANLoss class abstracts away the need to create the target label tensor
|
||||
that has the same size as the input.
|
||||
Args:
|
||||
mode (str): The type of GAN objective. It currently supports 'vanilla', 'lsgan'. Default: 'lsgan'.
|
||||
reduction (str): Specifies the reduction to be applied to the output.
|
||||
Its value must be one of 'none', 'mean', 'sum'. Default: 'none'.
|
||||
Parameters:
|
||||
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
|
||||
target_real_label (bool) - - label for a real image
|
||||
target_fake_label (bool) - - label of a fake image
|
||||
|
||||
Note: Do not use sigmoid as the last layer of Discriminator.
|
||||
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
|
||||
Outputs:
|
||||
Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `inputs`.
|
||||
Otherwise, the output is a scalar.
|
||||
|
@ -90,13 +94,11 @@ class GANLoss(nn.Cell):
|
|||
class GeneratorLoss(nn.Cell):
|
||||
"""
|
||||
Cycle GAN generator loss.
|
||||
|
||||
Args:
|
||||
args (class): Option class.
|
||||
generator (Cell): Generator of CycleGAN.
|
||||
D_A (Cell): The discriminator network of domain A to domain B.
|
||||
D_B (Cell): The discriminator network of domain B to domain A.
|
||||
|
||||
Outputs:
|
||||
Tuple Tensor, the losses of generator.
|
||||
"""
|
||||
|
@ -112,14 +114,6 @@ class GeneratorLoss(nn.Cell):
|
|||
self.D_A = D_A
|
||||
self.D_B = D_B
|
||||
self.true = Tensor(True, mstype.bool_)
|
||||
self.kd = args.kd
|
||||
if self.kd:
|
||||
self.GT_A = get_generator(args, True)
|
||||
load_teacher_ckpt(self.GT_A, args.GT_A_ckpt, "GT_A", "G_A")
|
||||
self.GT_B = get_generator(args, True)
|
||||
load_teacher_ckpt(self.GT_B, args.GT_B_ckpt, "GT_B", "G_B")
|
||||
self.GT_A.set_train(True)
|
||||
self.GT_B.set_train(True)
|
||||
|
||||
def construct(self, img_A, img_B):
|
||||
"""If use_identity, identity loss will be used."""
|
||||
|
@ -135,23 +129,18 @@ class GeneratorLoss(nn.Cell):
|
|||
loss_idt_A = 0
|
||||
loss_idt_B = 0
|
||||
loss_G = loss_G_A + loss_G_B + loss_C_A + loss_C_B + loss_idt_A + loss_idt_B
|
||||
if self.kd:
|
||||
teacher_A = self.GT_B(img_B)
|
||||
teacher_B = self.GT_A(img_A)
|
||||
kd_loss_A = self.rec_loss(teacher_A, fake_A) * self.lambda_A * 5
|
||||
kd_loss_B = self.rec_loss(teacher_B, fake_B) * self.lambda_A * 5
|
||||
loss_G += kd_loss_A + kd_loss_B
|
||||
return (fake_A, fake_B, loss_G, loss_G_A, loss_G_B, loss_C_A, loss_C_B, loss_idt_A, loss_idt_B)
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Cell):
|
||||
"""
|
||||
Cycle GAN discriminator loss.
|
||||
|
||||
Args:
|
||||
args (class): option class.
|
||||
D_A (Cell): The discriminator network of domain A to domain B.
|
||||
D_B (Cell): The discriminator network of domain B to domain A.
|
||||
|
||||
real (tensor array) -- real images
|
||||
fake (tensor array) -- images generated by a generator
|
||||
Outputs:
|
||||
Tuple Tensor, the loss of discriminator.
|
||||
"""
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,20 +12,20 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Cycle GAN network."""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import initializer as init
|
||||
|
||||
|
||||
def init_weights(net, init_type='normal', init_gain=0.02):
|
||||
"""
|
||||
Initialize network weights.
|
||||
|
||||
Parameters:
|
||||
net (Cell): Network to be initialized
|
||||
init_type (str): The name of an initialization method: normal | xavier.
|
||||
init_gain (float): Gain factor for normal and xavier.
|
||||
|
||||
"""
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
|
||||
|
@ -45,7 +45,6 @@ def init_weights(net, init_type='normal', init_gain=0.02):
|
|||
class ConvNormReLU(nn.Cell):
|
||||
"""
|
||||
Convolution fused with BatchNorm/InstanceNorm and ReLU/LackyReLU block definition.
|
||||
|
||||
Args:
|
||||
in_planes (int): Input channel.
|
||||
out_planes (int): Output channel.
|
||||
|
@ -57,7 +56,6 @@ class ConvNormReLU(nn.Cell):
|
|||
Default: "CONSTANT".
|
||||
use_relu (bool): Use relu or not. Default: True.
|
||||
padding (int): Pad size, if it is None, it will calculate by kernel_size. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
|
@ -103,7 +101,6 @@ class ConvNormReLU(nn.Cell):
|
|||
class ConvTransposeNormReLU(nn.Cell):
|
||||
"""
|
||||
ConvTranspose2d fused with BatchNorm/InstanceNorm and ReLU/LackyReLU block definition.
|
||||
|
||||
Args:
|
||||
in_planes (int): Input channel.
|
||||
out_planes (int): Output channel.
|
||||
|
@ -115,7 +112,6 @@ class ConvTransposeNormReLU(nn.Cell):
|
|||
Default: "CONSTANT".
|
||||
use_relu (bool): use relu or not. Default: True.
|
||||
padding (int): pad size, if it is None, it will calculate by kernel_size. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,26 +12,28 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ResNet Generator."""
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from .networks import ConvNormReLU, ConvTransposeNormReLU
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""
|
||||
ResNet residual block definition.
|
||||
|
||||
A resnet block is a conv block with skip connections
|
||||
We construct a conv block with build_conv_block function,
|
||||
and implement skip connections in <forward> function..
|
||||
Args:
|
||||
dim (int): Input and output channel.
|
||||
norm_mode (str): Specifies norm method. The optional values are "batch", "instance".
|
||||
dropout (bool): Use dropout or not. Default: False.
|
||||
pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
|
||||
Default: "CONSTANT".
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, norm_mode='batch', dropout=False, pad_mode="CONSTANT"):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)
|
||||
|
@ -51,7 +53,6 @@ class ResidualBlock(nn.Cell):
|
|||
class ResNetGenerator(nn.Cell):
|
||||
"""
|
||||
ResNet Generator of GAN.
|
||||
|
||||
Args:
|
||||
in_planes (int): Input channel.
|
||||
ngf (int): Output channel.
|
||||
|
@ -61,7 +62,6 @@ class ResNetGenerator(nn.Cell):
|
|||
dropout (bool): Use dropout or not. Default: False.
|
||||
pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
|
||||
Default: "CONSTANT".
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,37 +12,40 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""get args."""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init, get_rank
|
||||
|
||||
|
||||
def get_args(phase):
|
||||
"""Define the common options that are used in both training and test."""
|
||||
parser = argparse.ArgumentParser(description='Cycle GAN.')
|
||||
parser = argparse.ArgumentParser(description='Cycle GAN')
|
||||
# basic parameters
|
||||
parser.add_argument('--model', type=str, default="resnet", choices=("resnet", "unet"), \
|
||||
help='generator model, should be in [resnet, unet].')
|
||||
parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \
|
||||
help='run platform, only support GPU, CPU and Ascend')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id, default is 0.")
|
||||
parser.add_argument("--lr", type=float, default=0.0002, help="learning rate, default is 0.0002.")
|
||||
parser.add_argument('--pool_size', type=int, default=50, \
|
||||
help='the size of image buffer that stores previously generated images, default is 50.')
|
||||
parser.add_argument('--lr_policy', type=str, default='linear', choices=("linear", "constant"), \
|
||||
help='learning rate policy, default is linear')
|
||||
parser.add_argument("--image_size", type=int, default=256, help="input image_size, default is 256.")
|
||||
parser.add_argument('--batch_size', type=int, default=1, help='batch_size, default is 1.')
|
||||
parser.add_argument('--max_epoch', type=int, default=200, help='epoch size for training, default is 200.')
|
||||
parser.add_argument('--n_epochs', type=int, default=100, \
|
||||
help='number of epochs with the initial learning rate, default is 100')
|
||||
parser.add_argument("--beta1", type=float, default=0.5, help="Adam beta1, default is 0.5.")
|
||||
parser.add_argument('--init_type', type=str, default='normal', choices=("normal", "xavier"), \
|
||||
parser.add_argument('--platform', type=str, default='Ascend', help='only support GPU and Ascend')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id, default is 0.')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='device num, default is 1.')
|
||||
parser.add_argument('--model', type=str, default='DepthResNet', choices=('DepthResNet', 'ResNet', 'UNet'), \
|
||||
help='generator model')
|
||||
parser.add_argument('--init_type', type=str, default='normal', choices=('normal', 'xavier'), \
|
||||
help='network initialization, default is normal.')
|
||||
parser.add_argument('--init_gain', type=float, default=0.02, \
|
||||
help='scaling factor for normal, xavier and orthogonal, default is 0.02.')
|
||||
parser.add_argument('--image_size', type=int, default=256, help='input image_size, default is 256.')
|
||||
parser.add_argument('--batch_size', type=int, default=1, help='batch_size, default is 1.')
|
||||
parser.add_argument('--pool_size', type=int, default=50, \
|
||||
help='the size of image buffer that stores previously generated images')
|
||||
parser.add_argument('--beta1', type=float, default=0.5, help='Adam beta1, default is 0.5.')
|
||||
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default is 0.0002.')
|
||||
parser.add_argument('--lr_policy', type=str, default='linear', choices=('linear', 'constant'), \
|
||||
help='learning rate policy, default is linear')
|
||||
parser.add_argument('--max_epoch', type=int, default=200, help='epoch size for training, default is 200.')
|
||||
parser.add_argument('--n_epochs', type=int, default=100, \
|
||||
help='number of epochs with the initial learning rate, default is 100')
|
||||
|
||||
# model parameters
|
||||
parser.add_argument('--in_planes', type=int, default=3, help='input channels, default is 3.')
|
||||
|
@ -52,8 +55,8 @@ def get_args(phase):
|
|||
parser.add_argument('--dl_num', type=int, default=3, \
|
||||
help='discriminator model residual block numbers, default is 3.')
|
||||
parser.add_argument('--slope', type=float, default=0.2, help='leakyrelu slope, default is 0.2.')
|
||||
parser.add_argument('--norm_mode', type=str, default="instance", choices=("batch", "instance"), \
|
||||
help='norm mode, default is instance.')
|
||||
parser.add_argument('--norm_mode', type=str, default='batch', choices=('batch', 'instance'), \
|
||||
help='norm mode, default is batch.')
|
||||
parser.add_argument('--lambda_A', type=float, default=10.0, \
|
||||
help='weight for cycle loss (A -> B -> A), default is 10.')
|
||||
parser.add_argument('--lambda_B', type=float, default=10.0, \
|
||||
|
@ -63,58 +66,44 @@ def get_args(phase):
|
|||
'weight of the identity mapping loss. For example, if the weight of the identity loss '
|
||||
'should be 10 times smaller than the weight of the reconstruction loss,'
|
||||
'please set lambda_identity = 0.1, default is 0.5.')
|
||||
parser.add_argument('--gan_mode', type=str, default='lsgan', choices=("lsgan", "vanilla"), \
|
||||
parser.add_argument('--gan_mode', type=str, default='lsgan', choices=('lsgan', 'vanilla'), \
|
||||
help='the type of GAN loss, default is lsgan.')
|
||||
parser.add_argument('--pad_mode', type=str, default='REFLECT', choices=("CONSTANT", "REFLECT", "SYMMETRIC"), \
|
||||
help='the type of Pad, default is REFLECT.')
|
||||
parser.add_argument('--need_dropout', type=ast.literal_eval, default=True, \
|
||||
help='whether need dropout, default is True.')
|
||||
|
||||
# distillation learning parameters
|
||||
parser.add_argument('--kd', type=ast.literal_eval, default=False, \
|
||||
help='knowledge distillation learning or not, default is False.')
|
||||
parser.add_argument('--t_ngf', type=int, default=64, \
|
||||
help='teacher network generator model filter numbers when `kd` is True, default is 64.')
|
||||
parser.add_argument('--t_gl_num', type=int, default=9, \
|
||||
help='teacher network generator model residual block numbers when `kd` is True, default is 9.')
|
||||
parser.add_argument('--t_slope', type=float, default=0.2, \
|
||||
help='teacher network leakyrelu slope when `kd` is True, default is 0.2.')
|
||||
parser.add_argument('--t_norm_mode', type=str, default="instance", choices=("batch", "instance"), \
|
||||
help='teacher network norm mode when `kd` is True, default is instance.')
|
||||
parser.add_argument("--GT_A_ckpt", type=str, default=None, \
|
||||
help="teacher network pretrained checkpoint file path of G_A when `kd` is True.")
|
||||
parser.add_argument("--GT_B_ckpt", type=str, default=None, \
|
||||
help="teacher network pretrained checkpoint file path of G_B when `kd` is True.")
|
||||
parser.add_argument('--pad_mode', type=str, default='CONSTANT', choices=('CONSTANT', 'REFLECT', 'SYMMETRIC'), \
|
||||
help='the type of Pad, default is CONSTANT.')
|
||||
|
||||
# additional parameters
|
||||
parser.add_argument('--device_num', type=int, default=1, help='device num, default is 1.')
|
||||
parser.add_argument("--G_A_ckpt", type=str, default=None, help="pretrained checkpoint file path of G_A.")
|
||||
parser.add_argument("--G_B_ckpt", type=str, default=None, help="pretrained checkpoint file path of G_B.")
|
||||
parser.add_argument("--D_A_ckpt", type=str, default=None, help="pretrained checkpoint file path of D_A.")
|
||||
parser.add_argument("--D_B_ckpt", type=str, default=None, help="pretrained checkpoint file path of D_B.")
|
||||
parser.add_argument("--save_checkpoint_epochs", type=int, default=1, help="Save checkpoint epochs, default is 10.")
|
||||
parser.add_argument("--print_iter", type=int, default=100, help="log print iter, default is 100.")
|
||||
parser.add_argument('--dataroot', default='./data/horse2zebra/', \
|
||||
help='path of images (should have subfolders trainA, trainB, testA, testB, etc).')
|
||||
parser.add_argument('--outputs_dir', type=str, default='./outputs', \
|
||||
help='models are saved here, default is ./outputs.')
|
||||
parser.add_argument('--load_ckpt', type=ast.literal_eval, default=False, \
|
||||
help='whether load pretrained ckpt')
|
||||
parser.add_argument('--G_A_ckpt', type=str, default='./outputs/ckpt/G_A_200.ckpt', \
|
||||
help='checkpoint file path of G_A.')
|
||||
parser.add_argument('--G_B_ckpt', type=str, default='./outputs/ckpt/G_B_200.ckpt', \
|
||||
help='checkpoint file path of G_B.')
|
||||
parser.add_argument('--D_A_ckpt', type=str, default='./outputs/ckpt/D_A_200.ckpt', \
|
||||
help='checkpoint file path of D_A.')
|
||||
parser.add_argument('--D_B_ckpt', type=str, default='./outputs/ckpt/D_B_200.ckpt', \
|
||||
help='checkpoint file path of D_B.')
|
||||
parser.add_argument('--save_checkpoint_epochs', type=int, default=10, \
|
||||
help='Save checkpoint epochs, default is 10.')
|
||||
parser.add_argument('--print_iter', type=int, default=100, help='log print iter, default is 100.')
|
||||
parser.add_argument('--need_profiler', type=ast.literal_eval, default=False, \
|
||||
help='whether need profiler, default is False.')
|
||||
parser.add_argument('--save_graphs', type=ast.literal_eval, default=False, \
|
||||
help='whether save graphs, default is False.')
|
||||
parser.add_argument('--outputs_dir', type=str, default='./outputs', \
|
||||
help='models are saved here, default is ./outputs.')
|
||||
parser.add_argument('--dataroot', default=None, \
|
||||
help='path of images (should have subfolders trainA, trainB, testA, testB, etc).')
|
||||
parser.add_argument('--save_imgs', type=ast.literal_eval, default=True, \
|
||||
help='whether save imgs when epoch end, if True result images will generate in '
|
||||
'`outputs_dir/imgs`, default is True.')
|
||||
help='whether save imgs when epoch end')
|
||||
parser.add_argument('--use_random', type=ast.literal_eval, default=True, \
|
||||
help='whether use random when training, default is True.')
|
||||
parser.add_argument('--max_dataset_size', type=int, default=None, help='max images pre epoch, default is None.')
|
||||
if phase == "export":
|
||||
parser.add_argument("--file_name", type=str, default="cyclegan", help="output file name prefix.")
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \
|
||||
help='file format')
|
||||
|
||||
parser.add_argument('--need_dropout', type=ast.literal_eval, default=False, \
|
||||
help='whether need dropout, default is True.')
|
||||
parser.add_argument('--max_dataset_size', type=int, default=None, \
|
||||
help='max images pre epoch, default is None.')
|
||||
args = parser.parse_args()
|
||||
if args.device_num > 1 and args.platform != "CPU":
|
||||
|
||||
if args.device_num > 1:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=args.save_graphs)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
|
@ -127,26 +116,18 @@ def get_args(phase):
|
|||
args.rank = 0
|
||||
args.device_num = 1
|
||||
|
||||
if args.platform != "GPU":
|
||||
if args.platform == "Ascend":
|
||||
args.pad_mode = "CONSTANT"
|
||||
|
||||
if phase != "train" and (args.G_A_ckpt is None or args.G_B_ckpt is None):
|
||||
raise ValueError('Must set G_A_ckpt and G_B_ckpt in predict phase!')
|
||||
|
||||
if args.kd:
|
||||
if args.GT_A_ckpt is None or args.GT_B_ckpt is None:
|
||||
raise ValueError('Must set GT_A_ckpt, GT_B_ckpt in knowledge distillation!')
|
||||
if args.batch_size == 1:
|
||||
args.norm_mode = "instance"
|
||||
|
||||
if args.norm_mode == "instance" or (args.kd and args.t_norm_mode == "instance"):
|
||||
args.batch_size = 1
|
||||
|
||||
if args.dataroot is None and (phase in ["train", "predict"]):
|
||||
if args.dataroot is None:
|
||||
raise ValueError('Must set dataroot!')
|
||||
|
||||
if not args.use_random:
|
||||
args.need_dropout = False
|
||||
args.init_type = "constant"
|
||||
|
||||
if args.max_dataset_size is None:
|
||||
args.max_dataset_size = float("inf")
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,7 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Reporter class."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
@ -20,10 +22,10 @@ from datetime import datetime
|
|||
from mindspore.train.serialization import save_checkpoint
|
||||
from .tools import save_image
|
||||
|
||||
|
||||
class Reporter(logging.Logger):
|
||||
"""
|
||||
This class includes several functions that can save images/checkpoints and print/save logging information.
|
||||
|
||||
Args:
|
||||
args (class): Option class.
|
||||
"""
|
||||
|
@ -73,17 +75,6 @@ class Reporter(logging.Logger):
|
|||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.logger.isEnabledFor(logging.INFO) and self.rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += '*'*line_width + ' '*8 + msg + '\n'
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
def epoch_start(self):
|
||||
self.step_start_time = time.time()
|
||||
self.epoch_start_time = time.time()
|
||||
|
@ -119,14 +110,14 @@ class Reporter(logging.Logger):
|
|||
self.info("Epoch [{}] total cost: {:.2f} ms, pre step: {:.2f} ms, G_loss: {:.2f}, D_loss: {:.2f}".format(
|
||||
self.epoch, epoch_cost, pre_step_time, mean_loss_G, mean_loss_D))
|
||||
|
||||
if self.epoch % self.save_checkpoint_epochs == 0 and self.rank == 0:
|
||||
if self.epoch % self.save_checkpoint_epochs == 0:
|
||||
save_checkpoint(net.G.generator.G_A, os.path.join(self.ckpts_dir, f"G_A_{self.epoch}.ckpt"))
|
||||
save_checkpoint(net.G.generator.G_B, os.path.join(self.ckpts_dir, f"G_B_{self.epoch}.ckpt"))
|
||||
save_checkpoint(net.G.D_A, os.path.join(self.ckpts_dir, f"D_A_{self.epoch}.ckpt"))
|
||||
save_checkpoint(net.G.D_B, os.path.join(self.ckpts_dir, f"D_B_{self.epoch}.ckpt"))
|
||||
|
||||
def visualizer(self, img_A, img_B, fake_A, fake_B):
|
||||
if self.save_imgs and self.step % self.dataset_size == 0 and self.rank == 0:
|
||||
if self.save_imgs and self.step % self.dataset_size == 0:
|
||||
save_image(img_A, os.path.join(self.imgs_dir, f"{self.epoch}_img_A.jpg"))
|
||||
save_image(img_B, os.path.join(self.imgs_dir, f"{self.epoch}_img_B.jpg"))
|
||||
save_image(fake_A, os.path.join(self.imgs_dir, f"{self.epoch}_fake_A.jpg"))
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,7 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Utils for cyclegan."""
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
@ -23,15 +25,12 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|||
class ImagePool():
|
||||
"""
|
||||
This class implements an image buffer that stores previously generated images.
|
||||
|
||||
This buffer enables us to update discriminators using a history of generated images
|
||||
rather than the ones produced by the latest generators.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_size):
|
||||
"""
|
||||
Initialize the ImagePool class
|
||||
|
||||
Args:
|
||||
pool_size (int): the size of image buffer, if pool_size=0, no buffer will be created.
|
||||
"""
|
||||
|
@ -43,12 +42,9 @@ class ImagePool():
|
|||
def query(self, images):
|
||||
"""
|
||||
Return an image from the pool.
|
||||
|
||||
Args:
|
||||
images: the latest generated images from the generator
|
||||
|
||||
Returns images Tensor from the buffer.
|
||||
|
||||
By 50/100, the buffer will return input images.
|
||||
By 50/100, the buffer will return images previously stored in the buffer,
|
||||
and insert the current images to the buffer.
|
||||
|
@ -80,7 +76,6 @@ class ImagePool():
|
|||
|
||||
def save_image(img, img_path):
|
||||
"""Save a numpy image to the disk
|
||||
|
||||
Parameters:
|
||||
img (numpy array / Tensor): image to save.
|
||||
image_path (str): the path of the image.
|
||||
|
@ -101,7 +96,11 @@ def decode_image(img):
|
|||
|
||||
|
||||
def get_lr(args):
|
||||
"""Learning rate generator."""
|
||||
"""
|
||||
Learning rate generator.
|
||||
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
|
||||
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
|
||||
"""
|
||||
if args.lr_policy == 'linear':
|
||||
lrs = [args.lr] * args.dataset_size * args.n_epochs
|
||||
lr_epoch = 0
|
||||
|
@ -127,15 +126,3 @@ def load_ckpt(args, G_A, G_B, D_A=None, D_B=None):
|
|||
if D_B is not None and args.D_B_ckpt is not None:
|
||||
param_DB = load_checkpoint(args.D_B_ckpt)
|
||||
load_param_into_net(D_B, param_DB)
|
||||
|
||||
|
||||
def load_teacher_ckpt(net, ckpt_path, teacher, student):
|
||||
"""Replace parameter name to teacher net and load parameter from checkpoint."""
|
||||
param = load_checkpoint(ckpt_path)
|
||||
new_param = {}
|
||||
for k, v in param.items():
|
||||
new_name = k.replace(student, teacher)
|
||||
new_param_name = v.name.replace(student, teacher)
|
||||
v.name = new_param_name
|
||||
new_param[new_name] = v
|
||||
load_param_into_net(net, new_param)
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,17 +12,22 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Cycle GAN train."""
|
||||
|
||||
"""General-purpose training script for image-to-image translation.
|
||||
You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model').
|
||||
Example:
|
||||
Train a resnet model:
|
||||
python train.py --dataroot ./data/horse2zebra --model ResNet
|
||||
"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import set_seed
|
||||
from src.utils.args import get_args
|
||||
from src.utils.reporter import Reporter
|
||||
from src.utils.tools import get_lr, ImagePool, load_ckpt
|
||||
from src.dataset.cyclegan_dataset import create_dataset
|
||||
from src.models.losses import DiscriminatorLoss, GeneratorLoss
|
||||
from src.models.cycle_gan import get_generator, get_discriminator, Generator, TrainOneStepG, TrainOneStepD
|
||||
|
||||
from src.models import get_generator, get_discriminator, Generator, TrainOneStepG, TrainOneStepD, \
|
||||
DiscriminatorLoss, GeneratorLoss
|
||||
from src.utils import get_lr, get_args, Reporter, ImagePool, load_ckpt
|
||||
from src.dataset import create_dataset
|
||||
|
||||
set_seed(1)
|
||||
|
||||
def train():
|
||||
"""Train function."""
|
||||
|
@ -35,7 +40,8 @@ def train():
|
|||
G_B = get_generator(args)
|
||||
D_A = get_discriminator(args)
|
||||
D_B = get_discriminator(args)
|
||||
load_ckpt(args, G_A, G_B, D_A, D_B)
|
||||
if args.load_ckpt:
|
||||
load_ckpt(args, G_A, G_B, D_A, D_B)
|
||||
imgae_pool_A = ImagePool(args.pool_size)
|
||||
imgae_pool_B = ImagePool(args.pool_size)
|
||||
generator = Generator(G_A, G_B, args.lambda_idt > 0)
|
|
@ -1,235 +0,0 @@
|
|||
# Contents
|
||||
|
||||
- [CycleGAN Description](#cyclegan-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Knowledge Distillation Process](#knowledge-distillation-process)
|
||||
- [Prediction Process](#prediction-process)
|
||||
- [Evaluation with cityscape dataset](#evaluation-with-cityscape-dataset)
|
||||
- [Export MindIR](#export-mindir)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Inference Performance](#evaluation-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [CycleGAN Description](#contents)
|
||||
|
||||
Generative Adversarial Network (referred to as GAN) is an unsupervised learning method that learns by letting two neural networks play against each other. CycleGAN is a kind of GAN, which consists of two generation networks and two discriminant networks. It converts a certain type of pictures into another type of pictures through unpaired pictures, which can be used for style transfer.
|
||||
|
||||
[Paper](https://arxiv.org/abs/1703.10593): Zhu J Y , Park T , Isola P , et al. Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks[J]. 2017.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
The CycleGAN contains two generation networks and two discriminant networks. We support two architectures for generation networks: resnet and unet. Resnet architecture contains three convolutions, several residual blocks, two fractionally-strided convlutions with stride 1/2, and one convolution that maps features to RGB. Unet architecture contains three unet block to downsample and upsample, several unet blocks unet block and one convolution that maps features to RGB. For the discriminator networks we use 70 × 70 PatchGANs, which aim to classify whether 70 × 70 overlapping image patches are real or fake.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
Dataset used: [CityScape](<https://cityscapes-dataset.com>)
|
||||
|
||||
Please download the datasets [gtFine_trainvaltest.zip] and [leftImg8bit_trainvaltest.zip] and unzip them. We provide `src/utils/prepare_cityscapes_dataset.py` to process images. gtFine contains the semantics segmentations. Use --gtFine_dir to specify the path to the unzipped gtFine_trainvaltest directory. leftImg8bit contains the dashcam photographs. Use --leftImg8bit_dir to specify the path to the unzipped leftImg8bit_trainvaltest directory.
|
||||
The processed images will be placed at --output_dir.
|
||||
|
||||
Example usage:
|
||||
|
||||
```bash
|
||||
python src/utils/prepare_cityscapes_dataset.py --gitFine_dir ./cityscapes/gtFine/ --leftImg8bit_dir ./cityscapes/leftImg8bit --output_dir ./cityscapes/
|
||||
```
|
||||
|
||||
The directory structure is as follows:
|
||||
|
||||
```path
|
||||
.
|
||||
└─cityscapes
|
||||
├─trainA
|
||||
├─trainB
|
||||
├─testA
|
||||
└─testB
|
||||
```
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware GPU
|
||||
- Prepare hardware environment with GPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```path
|
||||
.
|
||||
└─ cv
|
||||
└─ cyclegan
|
||||
├─ src
|
||||
├─ __init__.py # init file
|
||||
├─ dataset
|
||||
├─ __init__.py # init file
|
||||
├─ cyclegan_dataset.py # create cyclegan dataset
|
||||
├─ datasets.py # UnalignedDataset and ImageFolderDataset class and some image utils
|
||||
└─ distributed_sampler.py # iterator of dataset
|
||||
├─ models
|
||||
├─ __init__.py # init file
|
||||
├─ cycle_gan.py # cyclegan model define
|
||||
├─ losses.py # cyclegan losses function define
|
||||
├─ networks.py # cyclegan sub networks define
|
||||
├─ resnet.py # resnet generate network
|
||||
└─ unet.py # unet generate network
|
||||
└─ utils
|
||||
├─ __init__.py # init file
|
||||
├─ args.py # parse args
|
||||
├─ prepare_cityscapes_dataset.py # prepare cityscapes dataset to cyclegan format
|
||||
├─ cityscapes_utils.py # cityscapes dataset evaluation utils
|
||||
├─ reporter.py # Reporter class
|
||||
└─ tools.py # utils for cyclegan
|
||||
├─ cityscape_eval.py # cityscape dataset eval script
|
||||
├─ predict.py # generate images from A->B and B->A
|
||||
├─ train.py # train script
|
||||
├─ export.py # export mindir script
|
||||
├─ README.md # descriptions about CycleGAN
|
||||
└─ mindspore_hub_conf.py # mindspore hub interface
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
```python
|
||||
Major parameters in train.py and config.py as follows:
|
||||
|
||||
"model": "resnet" # generator model, should be in [resnet, unet].
|
||||
"platform": "GPU" # run platform, support GPU, CPU and Ascend.
|
||||
"device_id": 0 # device id, default is 0.
|
||||
"lr": 0.0002 # init learning rate, default is 0.0002.
|
||||
"pool_size": 50 # the size of image buffer that stores previously generated images, default is 50.
|
||||
"lr_policy": "linear" # learning rate policy, default is linear.
|
||||
"image_size": 256 # input image_size, default is 256.
|
||||
"batch_size": 1 # batch_size, default is 1.
|
||||
"max_epoch": 200 # epoch size for training, default is 200.
|
||||
"n_epochs": 100 # number of epochs with the initial learning rate, default is 100
|
||||
"beta1": 0.5 # Adam beta1, default is 0.5.
|
||||
"init_type": normal # network initialization, default is normal.
|
||||
"init_gain": 0.02 # scaling factor for normal, xavier and orthogonal, default is 0.02.
|
||||
"in_planes": 3 # input channels, default is 3.
|
||||
"ngf": 64 # generator model filter numbers, default is 64.
|
||||
"gl_num": 9 # generator model residual block numbers, default is 9.
|
||||
"ndf": 64 # discriminator model filter numbers, default is 64.
|
||||
"dl_num": 3 # discriminator model residual block numbers, default is 3.
|
||||
"slope": 0.2 # leakyrelu slope, default is 0.2.
|
||||
"norm_mode":"instance" # norm mode, should be [batch, instance], default is instance.
|
||||
"lambda_A": 10 # weight for cycle loss (A -> B -> A), default is 10.
|
||||
"lambda_B": 10 # weight for cycle loss (B -> A -> B), default is 10.
|
||||
"lambda_idt": 0.5 # if lambda_idt > 0 use identity mapping.
|
||||
"gan_mode": lsgan # the type of GAN loss, should be [lsgan, vanilla], default is lsgan.
|
||||
"pad_mode": REFLECT # the type of Pad, should be [CONSTANT, REFLECT, SYMMETRIC], default is REFLECT.
|
||||
"need_dropout": True # whether need dropout, default is True.
|
||||
"kd": False # knowledge distillation learning or not, default is False.
|
||||
"t_ngf": 64 # teacher network generator model filter numbers when `kd` is True, default is 64.
|
||||
"t_gl_num":9 # teacher network generator model residual block numbers when `kd` is True, default is 9.
|
||||
"t_slope": 0.2 # teacher network leakyrelu slope when `kd` is True, default is 0.2.
|
||||
"t_norm_mode": "instance" #teacher network norm mode when `kd` is True, defaultis instance.
|
||||
"print_iter": 100 # log print iter, default is 100.
|
||||
"outputs_dir": "outputs" # models are saved here, default is ./outputs.
|
||||
"dataroot": None # path of images (should have subfolders trainA, trainB, testA, testB, etc).
|
||||
"save_imgs": True # whether save imgs when epoch end, if True result images will generate in `outputs_dir/imgs`, default is True.
|
||||
"GT_A_ckpt": None # teacher network pretrained checkpoint file path of G_A when `kd` is True.
|
||||
"GT_B_ckpt": None # teacher network pretrained checkpoint file path of G_B when `kd` is True.
|
||||
"G_A_ckpt": None # pretrained checkpoint file path of G_A.
|
||||
"G_B_ckpt": None # pretrained checkpoint file path of G_B.
|
||||
"D_A_ckpt": None # pretrained checkpoint file path of D_A.
|
||||
"D_B_ckpt": None # pretrained checkpoint file path of D_B.
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
```bash
|
||||
python train.py --platform [PLATFORM] --dataroot [DATA_PATH]
|
||||
```
|
||||
|
||||
**Note: pad_mode should be CONSTANT when use Ascend and CPU. When using unet as generate network, the gl_num should less than 7.**
|
||||
|
||||
## [Knowledge Distillation Process](#contents)
|
||||
|
||||
```bash
|
||||
python train.py --platform [PLATFORM] --dataroot [DATA_PATH] --ngf [NGF] --kd True --GT_A_ckpt [G_A_CKPT] --GT_B_ckpt [G_B_CKPT]
|
||||
```
|
||||
|
||||
**Note: the student network ngf should be 1/2 or 1/4 of teacher network ngf, if you change default args when training teacher generate networks, please change t_xx in knowledge distillation process.**
|
||||
|
||||
## [Prediction Process](#contents)
|
||||
|
||||
```bash
|
||||
python predict.py --platform [PLATFORM] --dataroot [DATA_PATH] --G_A_ckpt [G_A_CKPT] --G_B_ckpt [G_B_CKPT]
|
||||
```
|
||||
|
||||
**Note: the result will saved at `outputs_dir/predict`.**
|
||||
|
||||
## [Evaluation with cityscape dataset](#contents)
|
||||
|
||||
```bash
|
||||
python cityscape_eval.py --cityscapes_dir [LABEL_PATH] --result_dir [FAKEB_PATH]
|
||||
```
|
||||
|
||||
**Note: Please run cityscape_eval.py after prediction process.**
|
||||
|
||||
## [Export MindIR](#contents)
|
||||
|
||||
```bash
|
||||
python export.py --platform [PLATFORM] --G_A_ckpt [G_A_CKPT] --G_B_ckpt [G_B_CKPT] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
|
||||
**Note: The file_name parameter is the prefix, the final file will as [FILE_NAME]_AtoB.[FILE_FORMAT] and [FILE_NAME]_BtoA.[FILE_FORMAT].**
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | GPU |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | CycleGAN |
|
||||
| Resource | NV SMX2 V100-32G |
|
||||
| uploaded Date | 12/10/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.1.0 |
|
||||
| Dataset | Cityscapes |
|
||||
| Training Parameters | epoch=200, steps=2975, batch_size=1, lr=0.002 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | Mean Sqare Loss & L1 Loss |
|
||||
| outputs | probability |
|
||||
| Speed | 1pc: 264 ms/step; |
|
||||
| Total time | 1pc: 43.6h; |
|
||||
| Parameters (M) | 11.378 M |
|
||||
| Checkpoint for Fine tuning | 44M (.ckpt file) |
|
||||
| Scripts | [CycleGAN script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/cycle_gan) |
|
||||
|
||||
### Inference Performance
|
||||
|
||||
| Parameters | GPU |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | CycleGAN |
|
||||
| Resource | GPU |
|
||||
| Uploaded Date | 12/10/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.1.0 |
|
||||
| Dataset | Cityscapes |
|
||||
| batch_size | 1 |
|
||||
| outputs | probability |
|
||||
| Accuracy | mean_pixel_acc: 54.8, mean_class_acc: 21.3, mean_class_iou: 16.1 |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
If you set --use_random=False, there are no random when training.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -1,54 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Eval use cityscape dataset."""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from src.dataset import make_dataset
|
||||
from src.utils import CityScapes, fast_hist, get_scores
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--cityscapes_dir", type=str, required=True, help="Path to the original cityscapes dataset")
|
||||
parser.add_argument("--result_dir", type=str, required=True, help="Path to the generated images to be evaluated")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
CS = CityScapes()
|
||||
cityscapes = make_dataset(args.cityscapes_dir)
|
||||
hist_perframe = np.zeros((CS.class_num, CS.class_num))
|
||||
for i, img_path in enumerate(cityscapes):
|
||||
if i % 100 == 0:
|
||||
print('Evaluating: %d/%d' % (i, len(cityscapes)))
|
||||
img_name = os.path.split(img_path)[1]
|
||||
ids1 = CS.get_id(os.path.join(args.cityscapes_dir, img_name))
|
||||
ids2 = CS.get_id(os.path.join(args.result_dir, img_name))
|
||||
hist_perframe += fast_hist(ids1.flatten(), ids2.flatten(), CS.class_num)
|
||||
|
||||
mean_pixel_acc, mean_class_acc, mean_class_iou, per_class_acc, per_class_iou = get_scores(hist_perframe)
|
||||
print(f"mean_pixel_acc: {mean_pixel_acc}, mean_class_acc: {mean_class_acc}, mean_class_iou: {mean_class_iou}")
|
||||
with open('./evaluation_results.txt', 'w') as f:
|
||||
f.write('Mean pixel accuracy: %f\n' % mean_pixel_acc)
|
||||
f.write('Mean class accuracy: %f\n' % mean_class_acc)
|
||||
f.write('Mean class IoU: %f\n' % mean_class_iou)
|
||||
f.write('************ Per class numbers below ************\n')
|
||||
for i, cl in enumerate(CS.classes):
|
||||
while len(cl) < 15:
|
||||
cl = cl + ' '
|
||||
f.write('%s: acc = %f, iou = %f\n' % (cl, per_class_acc[i], per_class_iou[i]))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,27 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
from src.models import get_generator
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == "cyclegan":
|
||||
G_A = get_generator(*args, **kwargs)
|
||||
G_B = get_generator(*args, **kwargs)
|
||||
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
|
||||
# Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d
|
||||
G_A.set_train(True)
|
||||
G_B.set_train(True)
|
||||
return G_A, G_B
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -1,74 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Cycle GAN dataset."""
|
||||
import os
|
||||
import multiprocessing
|
||||
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
from .distributed_sampler import DistributedSampler
|
||||
from .datasets import UnalignedDataset, ImageFolderDataset
|
||||
|
||||
def create_dataset(args):
|
||||
"""Create dataset"""
|
||||
dataroot = args.dataroot
|
||||
phase = args.phase
|
||||
batch_size = args.batch_size
|
||||
device_num = args.device_num
|
||||
rank = args.rank
|
||||
shuffle = args.use_random
|
||||
max_dataset_size = args.max_dataset_size
|
||||
cores = multiprocessing.cpu_count()
|
||||
num_parallel_workers = min(8, int(cores / device_num))
|
||||
image_size = args.image_size
|
||||
mean = [0.5 * 255] * 3
|
||||
std = [0.5 * 255] * 3
|
||||
if phase == "train":
|
||||
dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size, use_random=args.use_random)
|
||||
distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle)
|
||||
ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"],
|
||||
sampler=distributed_sampler, num_parallel_workers=num_parallel_workers)
|
||||
if args.use_random:
|
||||
trans = [
|
||||
C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
trans = [
|
||||
C.Resize((image_size, image_size)),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(1)
|
||||
else:
|
||||
datadir = os.path.join(dataroot, args.data_dir)
|
||||
dataset = ImageFolderDataset(datadir, max_dataset_size=max_dataset_size)
|
||||
ds = de.GeneratorDataset(dataset, column_names=["image", "image_name"],
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
trans = [
|
||||
C.Resize((image_size, image_size)),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
ds = ds.map(operations=trans, input_columns=["image"], num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.batch(1, drop_remainder=True)
|
||||
ds = ds.repeat(1)
|
||||
args.dataset_size = len(dataset)
|
||||
return ds
|
|
@ -1,105 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Cycle GAN datasets."""
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
random.seed(1)
|
||||
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.tif', '.tiff']
|
||||
|
||||
def is_image_file(filename):
|
||||
"""Judge whether it is a picture."""
|
||||
return any(filename.lower().endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def make_dataset(dir_path, max_dataset_size=float("inf")):
|
||||
"""Return image list in dir."""
|
||||
images = []
|
||||
assert os.path.isdir(dir_path), '%s is not a valid directory' % dir_path
|
||||
|
||||
for root, _, fnames in sorted(os.walk(dir_path)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
path = os.path.join(root, fname)
|
||||
images.append(path)
|
||||
return images[:min(max_dataset_size, len(images))]
|
||||
|
||||
|
||||
class UnalignedDataset:
|
||||
"""
|
||||
This dataset class can load unaligned/unpaired datasets.
|
||||
|
||||
Args:
|
||||
dataroot (str): Images root directory.
|
||||
phase (str): Train or test. It requires two directories in dataroot, like trainA and trainB to
|
||||
host training images from domain A '{dataroot}/trainA' and from domain B '{dataroot}/trainB' respectively.
|
||||
max_dataset_size (int): Maximum number of return image paths.
|
||||
|
||||
Returns:
|
||||
Two domain image path list.
|
||||
"""
|
||||
|
||||
def __init__(self, dataroot, phase, max_dataset_size=float("inf"), use_random=True):
|
||||
self.dir_A = os.path.join(dataroot, phase + 'A')
|
||||
self.dir_B = os.path.join(dataroot, phase + 'B')
|
||||
|
||||
self.A_paths = sorted(make_dataset(self.dir_A, max_dataset_size)) # load images from '/path/to/data/trainA'
|
||||
self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size)) # load images from '/path/to/data/trainB'
|
||||
self.A_size = len(self.A_paths) # get the size of dataset A
|
||||
self.B_size = len(self.B_paths) # get the size of dataset B
|
||||
self.use_random = use_random
|
||||
|
||||
def __getitem__(self, index):
|
||||
index_B = index % self.B_size
|
||||
if index % max(self.A_size, self.B_size) == 0 and self.use_random:
|
||||
random.shuffle(self.A_paths)
|
||||
index_B = random.randint(0, self.B_size - 1)
|
||||
A_path = self.A_paths[index % self.A_size]
|
||||
B_path = self.B_paths[index_B]
|
||||
A_img = np.array(Image.open(A_path).convert('RGB'))
|
||||
B_img = np.array(Image.open(B_path).convert('RGB'))
|
||||
|
||||
return A_img, B_img
|
||||
|
||||
def __len__(self):
|
||||
return max(self.A_size, self.B_size)
|
||||
|
||||
class ImageFolderDataset:
|
||||
"""
|
||||
This dataset class can load images from image folder.
|
||||
|
||||
Args:
|
||||
dataroot (str): Images root directory.
|
||||
max_dataset_size (int): Maximum number of return image paths.
|
||||
|
||||
Returns:
|
||||
Image path list.
|
||||
"""
|
||||
|
||||
def __init__(self, dataroot, max_dataset_size=float("inf")):
|
||||
self.dataroot = dataroot
|
||||
self.paths = sorted(make_dataset(dataroot, max_dataset_size))
|
||||
self.size = len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path = self.paths[index % self.size]
|
||||
img = np.array(Image.open(img_path).convert('RGB'))
|
||||
|
||||
return img, os.path.split(img_path)[1]
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
|
@ -1,19 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""init file."""
|
||||
from .args import get_args
|
||||
from .reporter import Reporter
|
||||
from .tools import get_lr, load_teacher_ckpt, ImagePool, load_ckpt, save_image
|
||||
from .cityscapes_utils import CityScapes, fast_hist, get_scores
|
|
@ -1,95 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""cityscape utils."""
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
# label name and RGB color map.
|
||||
label2color = {
|
||||
'unlabeled': (0, 0, 0),
|
||||
'ego vehicle': (0, 0, 0),
|
||||
'rectification border': (0, 0, 0),
|
||||
'out of roi': (0, 0, 0),
|
||||
'static': (0, 0, 0),
|
||||
'dynamic': (111, 74, 0),
|
||||
'ground': (81, 0, 81),
|
||||
'road': (128, 64, 128),
|
||||
'sidewalk': (244, 35, 232),
|
||||
'parking': (250, 170, 160),
|
||||
'rail track': (230, 150, 140),
|
||||
'building': (70, 70, 70),
|
||||
'wall': (102, 102, 156),
|
||||
'fence': (190, 153, 153),
|
||||
'guard rail': (180, 165, 180),
|
||||
'bridge': (150, 100, 100),
|
||||
'tunnel': (150, 120, 90),
|
||||
'pole': (153, 153, 153),
|
||||
'polegroup': (153, 153, 153),
|
||||
'traffic light': (250, 170, 30),
|
||||
'traffic sign': (220, 220, 0),
|
||||
'vegetation': (107, 142, 35),
|
||||
'terrain': (152, 251, 152),
|
||||
'sky': (70, 130, 180),
|
||||
'person': (220, 20, 60),
|
||||
'rider': (255, 0, 0),
|
||||
'car': (0, 0, 142),
|
||||
'truck': (0, 0, 70),
|
||||
'bus': (0, 60, 100),
|
||||
'caravan': (0, 0, 90),
|
||||
'trailer': (0, 0, 110),
|
||||
'train': (0, 80, 100),
|
||||
'motorcycle': (0, 0, 230),
|
||||
'bicycle': (119, 11, 32),
|
||||
'license plate': (0, 0, 142)
|
||||
}
|
||||
|
||||
def fast_hist(a, b, n):
|
||||
k = np.where((a >= 0) & (a < n))[0]
|
||||
bc = np.bincount(n * a[k].astype(int) + b[k], minlength=n**2)
|
||||
if len(bc) != n**2:
|
||||
# ignore this example if dimension mismatch
|
||||
return 0
|
||||
return bc.reshape(n, n)
|
||||
|
||||
def get_scores(hist):
|
||||
# Mean pixel accuracy
|
||||
acc = np.diag(hist).sum() / (hist.sum() + 1e-12)
|
||||
# Per class accuracy
|
||||
cl_acc = np.diag(hist) / (hist.sum(1) + 1e-12)
|
||||
# Per class IoU
|
||||
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + 1e-12)
|
||||
return acc, np.nanmean(cl_acc), np.nanmean(iu), cl_acc, iu
|
||||
|
||||
class CityScapes:
|
||||
"""CityScapes util class."""
|
||||
def __init__(self):
|
||||
self.classes = ['road', 'sidewalk', 'building', 'wall', 'fence',
|
||||
'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain',
|
||||
'sky', 'person', 'rider', 'car', 'truck',
|
||||
'bus', 'train', 'motorcycle', 'bicycle', 'unlabeled']
|
||||
self.color_list = []
|
||||
for name in self.classes:
|
||||
self.color_list.append(label2color[name])
|
||||
self.class_num = len(self.classes)
|
||||
|
||||
def get_id(self, img_path):
|
||||
"""Get train id by img"""
|
||||
img = np.array(Image.open(img_path).convert("RGB"))
|
||||
w, h, _ = img.shape
|
||||
img_tile = np.tile(img, (1, 1, self.class_num)).reshape(w, h, self.class_num, 3)
|
||||
diff = np.abs(img_tile - self.color_list).sum(axis=-1)
|
||||
ids = diff.argmin(axis=-1)
|
||||
return ids
|
|
@ -1,84 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""prepare cityscapes dataset to cyclegan format"""
|
||||
import os
|
||||
import argparse
|
||||
import glob
|
||||
from PIL import Image
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--gtFine_dir', type=str, required=True,
|
||||
help='Path to the Cityscapes gtFine directory.')
|
||||
parser.add_argument('--leftImg8bit_dir', type=str, required=True,
|
||||
help='Path to the Cityscapes leftImg8bit_trainvaltest directory.')
|
||||
parser.add_argument('--output_dir', type=str, required=True,
|
||||
default='./cityscapes',
|
||||
help='Directory the output images will be written to.')
|
||||
opt = parser.parse_args()
|
||||
|
||||
def load_resized_img(path):
|
||||
"""Load image with RGB and resize to (256, 256)"""
|
||||
return Image.open(path).convert('RGB').resize((256, 256))
|
||||
|
||||
def check_matching_pair(segmap_path, photo_path):
|
||||
"""Check the segment images and photo images are matched or not."""
|
||||
segmap_identifier = os.path.basename(segmap_path).replace('_gtFine_color', '')
|
||||
photo_identifier = os.path.basename(photo_path).replace('_leftImg8bit', '')
|
||||
|
||||
assert segmap_identifier == photo_identifier, \
|
||||
f"[{segmap_path}] and [{photo_path}] don't seem to be matching. Aborting."
|
||||
|
||||
|
||||
def process_cityscapes(gtFine_dir, leftImg8bit_dir, output_dir, phase):
|
||||
"""Process citycapes dataset to cyclegan dataset format."""
|
||||
save_phase = 'test' if phase == 'val' else 'train'
|
||||
savedir = os.path.join(output_dir, save_phase)
|
||||
os.makedirs(savedir + 'A', exist_ok=True)
|
||||
os.makedirs(savedir + 'B', exist_ok=True)
|
||||
print(f"Directory structure prepared at {output_dir}")
|
||||
|
||||
segmap_expr = os.path.join(gtFine_dir, phase) + "/*/*_color.png"
|
||||
segmap_paths = glob.glob(segmap_expr)
|
||||
segmap_paths = sorted(segmap_paths)
|
||||
|
||||
photo_expr = os.path.join(leftImg8bit_dir, phase) + "/*/*_leftImg8bit.png"
|
||||
photo_paths = glob.glob(photo_expr)
|
||||
photo_paths = sorted(photo_paths)
|
||||
|
||||
assert len(segmap_paths) == len(photo_paths), \
|
||||
"{} images that match [{}], and {} images that match [{}]. Aborting.".format(
|
||||
len(segmap_paths), segmap_expr, len(photo_paths), photo_expr)
|
||||
|
||||
for i, (segmap_path, photo_path) in enumerate(zip(segmap_paths, photo_paths)):
|
||||
check_matching_pair(segmap_path, photo_path)
|
||||
segmap = load_resized_img(segmap_path)
|
||||
photo = load_resized_img(photo_path)
|
||||
|
||||
# data for cyclegan where the two images are stored at two distinct directories
|
||||
savepath = os.path.join(savedir + 'A', f"{i + 1}.jpg")
|
||||
photo.save(savepath)
|
||||
savepath = os.path.join(savedir + 'B', f"{i + 1}.jpg")
|
||||
segmap.save(savepath)
|
||||
|
||||
if i % (len(segmap_paths) // 10) == 0:
|
||||
print("%d / %d: last image saved at %s, " % (i, len(segmap_paths), savepath))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('Preparing Cityscapes Dataset for val phase')
|
||||
process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "val")
|
||||
print('Preparing Cityscapes Dataset for train phase')
|
||||
process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "train")
|
||||
print('Done')
|
Loading…
Reference in New Issue