update README
This commit is contained in:
chenyang_Marvin 2021-03-30 19:49:24 +08:00
parent 4e58f833be
commit 72b9b5125c
16 changed files with 2557 additions and 0 deletions

974
index.html Normal file

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,136 @@
# Contents
- [StarGAN Description](#StarGAN-description)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Prediction Process](#prediction-process)
- [Export MindIR](#export-mindir)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [StarGAN-description](#contents)
> **StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation**<br>
> [Yunjey Choi](https://github.com/yunjey)<sup>1,2</sup>, [Minje Choi](https://github.com/mjc92)<sup>1,2</sup>, [Munyoung Kim](https://www.facebook.com/munyoung.kim.1291)<sup>2,3</sup>, [Jung-Woo Ha](https://www.facebook.com/jungwoo.ha.921)<sup>2</sup>, [Sung Kim](https://www.cse.ust.hk/~hunkim/)<sup>2,4</sup>, [Jaegul Choo](https://sites.google.com/site/jaegulchoo/)<sup>1,2</sup>    <br/>
> <sup>1</sup>Korea University, <sup>2</sup>Clova AI Research, NAVER Corp. <br>
> <sup>3</sup>The College of New Jersey, <sup>4</sup>Hong Kong University of Science and Technology <br/>
> https://arxiv.org/abs/1711.09020 <br>
>
> **Abstract:** *Recent studies have shown remarkable success in image-to-image translation for two domains. However, existing approaches have limited scalability and robustness in handling more than two domains, since different models should be built independently for every pair of image domains. To address this limitation, we propose StarGAN, a novel and scalable approach that can perform image-to-image translations for multiple domains using only a single model. Such a unified model architecture of StarGAN allows simultaneous training of multiple datasets with different domains within a single network. This leads to StarGAN's superior quality of translated images compared to existing models as well as the novel capability of flexibly translating an input image to any desired target domain. We empirically demonstrate the effectiveness of our approach on a facial attribute transfer and a facial expression synthesis tasks.*
# [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: [CelebA](<http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>)
CelebFaces Attributes Dataset (CelebA) is a large-scale face attributes dataset with more than 200K celebrity images, each with 40 attribute annotations. The images in this dataset cover large pose variations and background clutter. CelebA has large diversities, large quantities, and rich annotations, including
- 10,177 number of identities,
- 202,599 number of face images, and 5 landmark locations, 40 binary attributes annotations per image.
The dataset can be employed as the training and test sets for the following computer vision tasks: face attribute recognition, face detection, landmark (or facial part) localization, and face editing & synthesis.
# [Environment Requirements](#contents)
- Hardware Ascend
- Prepare hardware environment with Ascend processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```text
.
└─ cv
└─ StarGAN
├─ src
├─ __init__.py # init file
├─ cell.py # StarGAN model define
├─ model.py # define subnetwork about generator and discriminator
├─ utils.py # utils for StarGAN
├─ config.py # parse args
├─ dataset.py # prepare celebA dataset to cyclegan format
├─ reporter.py # Reporter class
├─ loss.py # losses for StarGAN
├─ cityscape_eval.py # cityscape dataset eval script
├─ eval.py # translate attritubes from original images
├─ train.py # train script
├─ export.py # export mindir script
└─ README.md # descriptions about StarGAN
```
## [Training Process](#contents)
When training the network, you should selected the attributes in config, then you should change the c_dim in config which is same as the number of selected attributes.
```bash
python train.py
```
## [Prediction Process](#contents)
```bash
python eval.py
```
**Note: the result will saved at `./results/`.**
## [Export MindIR](#contents)
```bash
python export.py
```
**Note: The file_name parameter is the prefix, the final file will as StarGAN_G.[FILE_FORMAT].**
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
| Parameters | Ascend 910 |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | StarGAN |
| Resource | Ascend |
| uploaded Date | 03/30/2021 (month/day/year) |
| MindSpore Version | 1.1.1 |
| Dataset | CelebA |
| Training Parameters | steps=200000, batch_size=1, lr=0.0001 |
| Optimizer | Adam |
| outputs | probability |
| Speed | 1pc: 100 ms/step; |
| Total time | 1pc: 10h; |
| Parameters (M) | 8.423 M |
| Checkpoint for Fine tuning | 32.15M (.ckpt file) |
| Scripts | [StarGAN script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/StarGAN) |
### Inference Performance
| Parameters | Ascend 910 |
| ------------------- | --------------------------- |
| Model Version | StarGAN |
| Resource | Ascend |
| Uploaded Date | 03/30/2021 (month/day/year) |
| MindSpore Version | 1.1.1 |
| Dataset | CelebA |
| batch_size | 4 |
| outputs | image |
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,82 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation for StarGAN"""
import os
import numpy as np
from PIL import Image
from mindspore import context
from mindspore import Tensor
from mindspore.train.serialization import load_param_into_net
from mindspore.common import dtype as mstype
import mindspore.ops as ops
from src.utils import resume_model, create_labels, denorm, get_network
from src.config import get_config
from src.dataset import dataloader
if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=1)
config = get_config()
G, D = get_network(config)
para_g, _ = resume_model(config, G, D)
load_param_into_net(G, para_g)
if not os.path.exists(config.result_dir):
os.mkdir(config.result_dir)
# Define Dataset
data_path = config.celeba_image_dir
attr_path = config.attr_path
dataset, length = dataloader(img_path=data_path,
attr_path=attr_path,
batch_size=4,
selected_attr=config.selected_attrs,
device_num=config.num_workers,
dataset=config.dataset,
mode=config.mode,
shuffle=False)
op = ops.Concat(axis=3)
ds = dataset.create_dict_iterator()
print(length)
print('Start Evaluating!')
for i, data in enumerate(ds):
result_list = ()
img_real = denorm(data['image'].asnumpy())
x_real = Tensor(data['image'], mstype.float32)
result_list += (x_real,)
c_trg_list = create_labels(data['attr'].asnumpy(), selected_attrs=config.selected_attrs)
c_trg_list = Tensor(c_trg_list, mstype.float32)
x_fake_list = []
for c_trg in c_trg_list:
x_fake = G(x_real, c_trg)
x = Tensor(x_fake.asnumpy().copy())
result_list += (x,)
x_fake_list = op(result_list)
result = denorm(x_fake_list.asnumpy())
result = np.reshape(result, (-1, 768, 3))
im = Image.fromarray(np.uint8(result))
im.save(config.result_dir + '/test_{}.jpg'.format(i))
print('Successful save image in ' + config.result_dir + '/test_{}.jpg'.format(i))

View File

@ -0,0 +1,41 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export file."""
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import export, load_param_into_net
from src.config import get_config
from src.utils import get_network, resume_model
if __name__ == '__main__':
config = get_config()
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
G, D = get_network(config)
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
# Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d
G.set_train(True)
param_G, _ = resume_model(config, G, D)
load_param_into_net(G, param_G)
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128, 128)).astype(np.float32))
input_label = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 5)).astype(np.float32))
G_file = f"StarGAN_Generator"
export(G, input_array, file_name=G_file, file_format=config.file_format)

View File

@ -0,0 +1,25 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_NUM=1
export DEVICE_ID=0
export MODE='test'
echo "start training for device $DEVICE_ID"
env > env.log
python eval.py --run_distribute=0 --device_num=$DEVICE_NUM --device_id=$DEVICE_ID --mode=$MODE> log_eval.txt 2>&1 &
cd ..

View File

@ -0,0 +1,50 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [DISTRIBUTE] [RANK_TABLE_FILE]"
exit 1
fi
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
export RANK_SIZE=$1
DISTRIBUTE=$2
export RANK_TABLE_FILE=$3
for((i=0;i<RANK_SIZE;i++))
do
export DEVICE_ID=$i
rm -rf LOG$i
mkdir ./LOG$i
cp ./*.json ./LOG$i
cp ./*.py ./LOG$i
cp -r ./src ./LOG$i
cp -r ./scripts ./LOG$i
cd ./LOG$i || exit
export RANK_ID=$i
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
if [ $# == 3 ]
then
python train.py \
--run_distribute=$DISTRIBUTE \
--device_num=$RANK_SIZE \
--device_id=$DEVICE_ID > log.txt 2>&1 &
fi
cd ../
done

View File

@ -0,0 +1,25 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_NUM=1
export DEVICE_ID=0
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --run_distribute=0 --device_num=$DEVICE_NUM --device_id=$DEVICE_ID > log.txt 2>&1 &
cd ..

View File

@ -0,0 +1,15 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init file."""

View File

@ -0,0 +1,166 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train cell for StarGAN"""
import numpy as np
from mindspore import nn
from mindspore import context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
import mindspore.ops.operations as P
import mindspore.ops.functional as F
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
_get_parallel_mode)
import mindspore.ops as ops
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.common import initializer as init, set_seed
set_seed(1)
np.random.seed(1)
def init_weights(net, init_type='normal', init_gain=0.02):
"""
Initialize network weights.
Parameters:
net (Cell): Network to be initialized
init_type (str): The name of an initialization method: normal | xavier.
init_gain (float): Gain factor for normal and xavier.
"""
for _, cell in net.cells_and_names():
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
if init_type == 'normal':
cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
elif init_type == 'xavier':
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
elif init_type == 'KaimingUniform':
cell.weight.set_data(init.initializer(init.HeUniform(init_gain), cell.weight.shape))
elif init_type == 'constant':
cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif isinstance(cell, nn.GroupNorm):
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
class GeneratorWithLossCell(nn.Cell):
"""
Wrap the network with loss function to return generator loss.
Args:
network (Cell): The target network to wrap.
"""
def __init__(self, network):
super(GeneratorWithLossCell, self).__init__(auto_prefix=False)
self.network = network
def construct(self, x_real, c_org, c_trg):
_, G_Loss, _, _, _, = self.network(x_real, c_org, c_trg)
return G_Loss
class DiscriminatorWithLossCell(nn.Cell):
"""
Wrap the network with loss function to return generator loss.
Args:
network (Cell): The target network to wrap.
"""
def __init__(self, network):
super(DiscriminatorWithLossCell, self).__init__(auto_prefix=False)
self.network = network
def construct(self, x_real, c_org, c_trg):
D_Loss, _, _, _, _ = self.network(x_real, c_org, c_trg)
return D_Loss
class TrainOneStepCellGen(nn.Cell):
"""Encapsulation class of StarGAN generator network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph."""
def __init__(self, G, optimizer, sens=1.0):
super(TrainOneStepCellGen, self).__init__()
self.optimizer = optimizer
self.G = G
self.G.set_grad()
self.G.set_train()
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.weights = optimizer.parameters
self.network = GeneratorWithLossCell(G)
self.network.add_flags(defer_inline=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = F.identity
self.parallel_mode = _get_parallel_mode()
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
def construct(self, img_real, c_org, c_trg):
weights = self.weights
fake_image, loss, G_fake_loss, G_fake_cls_loss, G_rec_loss = self.G(img_real, c_org, c_trg)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(img_real, c_org, c_trg, sens)
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads)), fake_image, loss, G_fake_loss, G_fake_cls_loss, G_rec_loss
class TrainOneStepCellDis(nn.Cell):
"""Encapsulation class of StarGAN Discriminator network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph."""
def __init__(self, D, optimizer, sens=1.0):
super(TrainOneStepCellDis, self).__init__()
self.optimizer = optimizer
self.D = D
self.D.set_grad()
self.D.set_train()
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.weights = optimizer.parameters
self.network = DiscriminatorWithLossCell(D)
self.network.add_flags(defer_inline=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = F.identity
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num")
else:
degree = get_group_size()
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, img_real, c_org, c_trg):
weights = self.weights
loss, D_real_loss, D_fake_loss, D_real_cls_loss, D_gp_loss = self.D(img_real, c_org, c_trg)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(img_real, c_org, c_trg, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads)), loss, D_real_loss, D_fake_loss, D_real_cls_loss, D_gp_loss

View File

@ -0,0 +1,90 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define Configuration for StarGAN"""
import argparse
def get_config():
"""Define configuration of Model"""
parser = argparse.ArgumentParser(description='StarGAN')
# Model configuration.
parser.add_argument('--c_dim', type=int, default=5, help='dimension of domain labels (1st dataset)')
parser.add_argument('--c2_dim', type=int, default=7, help='dimension of domain labels (2nd dataset)')
parser.add_argument('--celeba_crop_size', type=int, default=178, help='crop size for the CelebA dataset')
parser.add_argument('--rafd_crop_size', type=int, default=256, help='crop size for the RaFD dataset')
parser.add_argument('--image_size', type=int, default=128, help='image resolution')
parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G')
parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D')
parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G')
parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D')
parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss')
parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss')
parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')
# Training configuration.
parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'Both'])
parser.add_argument('--batch_size', type=int, default=4, help='mini-batch size')
parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D')
parser.add_argument('--epochs', type=int, default=59, help='number of epoch')
parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr')
parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G')
parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D')
parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
parser.add_argument('--resume_iters', type=int, default=200000, help='resume training from this step')
parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset',
default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'])
parser.add_argument('--init_type', type=str, default='normal', choices=("normal", "xavier"),
help='network initialization, default is normal.')
parser.add_argument('--init_gain', type=float, default=0.02,
help='scaling factor for normal, xavier and orthogonal, default is 0.02.')
# Test configuration.
parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step')
# Train Device.
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'])
parser.add_argument('--device_target', type=str, default='Ascend')
parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.")
parser.add_argument("--device_id", type=int, default=0, help="device id, default: 0.")
parser.add_argument("--device_num", type=int, default=1, help="number of device, default: 0.")
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
# Directories.
parser.add_argument('--celeba_image_dir', type=str, default=r'/root/wcy/StarGAN_copy/celeba/images')
parser.add_argument('--attr_path', type=str, default=r'/root/wcy/StarGAN_copy/celeba/list_attr_celeba.txt')
parser.add_argument('--rafd_image_dir', type=str, default='data/RaFD/train')
parser.add_argument('--log_dir', type=str, default='stargan/logs')
parser.add_argument('--model_save_dir', type=str, default='./models/')
parser.add_argument('--result_dir', type=str, default='./results')
# Step size.
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=5000)
parser.add_argument('--model_save_step', type=int, default=5000)
parser.add_argument('--lr_update_step', type=int, default=1000)
# export
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \
help='file format')
config = parser.parse_args()
return config

View File

@ -0,0 +1,198 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data Processing for StarGAN"""
import os
import random
import multiprocessing
import numpy as np
from PIL import Image
import mindspore.dataset.vision.py_transforms as py_vision
import mindspore.dataset.transforms.py_transforms as py_transforms
import mindspore.dataset as de
from src.utils import DistributedSampler
def is_image_file(filename):
"""Judge whether it is an image"""
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.tif', '.tiff']
return any(filename.lower().endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir_path, max_dataset_size=float("inf")):
"""Return image list in dir"""
images = []
assert os.path.isdir(dir_path), "%s is not a valid directory" % dir_path
for root, _, fnames in sorted(os.walk(dir_path)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]
class CelebA:
"""
This dataset class helps load celebA dataset.
"""
def __init__(self, image_dir, attr_path, selected_attrs, transform, mode):
"""Initialize and preprocess the CelebA dataset."""
self.image_dir = image_dir
self.attr_path = attr_path
self.selected_attrs = selected_attrs
self.transform = transform
self.mode = mode
self.train_dataset = []
self.test_dataset = []
self.attr2idx = {}
self.idx2attr = {}
self.preprocess()
if mode == 'train':
self.num_images = len(self.train_dataset)
else:
self.num_images = len(self.test_dataset)
def preprocess(self):
"""Preprocess the CelebA attribute file."""
lines = [line.rstrip() for line in open(self.attr_path, 'r')]
all_attr_names = lines[1].split()
for i, attr_name in enumerate(all_attr_names):
self.attr2idx[attr_name] = i
self.idx2attr[i] = attr_name
lines = lines[2:]
random.seed(1234)
random.shuffle(lines)
for i, line in enumerate(lines):
split = line.split()
filename = split[0]
values = split[1:]
label = []
for attr_name in self.selected_attrs:
idx = self.attr2idx[attr_name]
label.append(1.0 if values[idx] == '1' else 0.0)
if (i+1) < 2000:
self.test_dataset.append([filename, label])
else:
self.train_dataset.append([filename, label])
print('Finished preprocessing the CelebA dataset...')
def __getitem__(self, idx):
"""Return one image and its corresponding attribute label."""
dataset = self.train_dataset if self.mode == 'train' else self.test_dataset
filename, label = dataset[idx]
image = np.asarray(Image.open(os.path.join(self.image_dir, filename)))
label = np.asarray(label)
image = np.squeeze(self.transform(image))
# image = Tensor(image, mstype.float32)
# label = Tensor(label, mstype.float32)
return image, label
def __len__(self):
"""Return the number of images."""
return self.num_images
class ImageFolderDataset:
"""
This dataset class can load images from image folder.
Args:
data_root (str): Images root directory.
max_dataset_size (int): Maximum number of return image paths.
Returns:
Image path list.
"""
def __init__(self, data_root, transform, max_dataset_size=float("inf")):
self.data_root = data_root
self.transform = transform
self.paths = sorted(make_dataset(data_root, max_dataset_size))
self.size = len(self.paths)
def __getitem__(self, index):
img_path = self.paths[index % self.size]
# image = np.array(Image.open(img_path).convert('RGB'))
image = np.asarray(Image.open(img_path))
return np.squeeze(self.transform(image)), os.path.split(img_path)[1]
# return image, os.path.split(img_path)[1]
def __len__(self):
return self.size
def get_loader(data_root, attr_path, selected_attrs, crop_size=178, image_size=128,
dataset='CelebA', mode='train'):
"""Build and return a data loader."""
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
transform = [py_vision.ToPIL()]
if mode == 'train':
transform.append(py_vision.RandomHorizontalFlip())
transform.append(py_vision.CenterCrop(crop_size))
transform.append(py_vision.Resize(image_size))
transform.append(py_vision.ToTensor())
transform.append(py_vision.Normalize(mean=mean, std=std))
transform = py_transforms.Compose(transform)
if dataset == 'CelebA':
dataset = CelebA(data_root, attr_path, selected_attrs, transform, mode)
elif dataset == 'RaFD':
dataset = ImageFolderDataset(data_root, transform)
return dataset
def dataloader(img_path, attr_path, selected_attr, dataset, mode='train',
batch_size=1, device_num=1, rank=0, shuffle=True):
"""Get dataloader"""
assert dataset in ['CelebA', 'RaFD']
cores = multiprocessing.cpu_count()
num_parallel_workers = int(cores / device_num)
if dataset == 'CelebA':
dataset_loader = get_loader(img_path, attr_path, selected_attr, mode=mode)
length_dataset = len(dataset_loader)
distributed_sampler = DistributedSampler(length_dataset, device_num, rank, shuffle=shuffle)
dataset_column_names = ["image", "attr"]
else:
dataset_loader = get_loader(img_path, None, None, dataset='RaFD')
length_dataset = len(dataset_loader)
distributed_sampler = DistributedSampler(length_dataset, device_num, rank, shuffle=shuffle)
dataset_column_names = ["image", "image_path"]
if device_num != 8:
ds = de.GeneratorDataset(dataset_loader, column_names=dataset_column_names,
num_parallel_workers=min(32, num_parallel_workers),
sampler=distributed_sampler)
ds = ds.batch(batch_size, num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True)
else:
ds = de.GeneratorDataset(dataset_loader, column_names=dataset_column_names, sampler=distributed_sampler)
ds = ds.batch(batch_size, num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True)
if mode == 'train':
ds = ds.repeat(200)
return ds, length_dataset

View File

@ -0,0 +1,154 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define loss function for StarGAN"""
import numpy as np
import mindspore.ops as ops
import mindspore.ops.operations as P
from mindspore import nn, Tensor
from mindspore import dtype as mstype
from mindspore.ops import constexpr
@constexpr
def generate_tensor(batch_size):
np_array = np.random.randn(batch_size, 1, 1, 1)
return Tensor(np_array, mstype.float32)
class ClassificationLoss(nn.Cell):
"""Define classification loss for StarGAN"""
def __init__(self, dataset='CelebA'):
super().__init__()
self.BCELoss = P.BinaryCrossEntropy(reduction='sum')
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
self.dataset = dataset
self.bec = nn.BCELoss(reduction='sum')
def construct(self, pred, label):
if self.dataset == 'CelebA':
weight = ops.Ones()(pred.shape, mstype.float32)
pred_ = P.Sigmoid()(pred)
x = self.BCELoss(pred_, label, weight) / pred.shape[0]
else:
x = self.cross_entropy(pred, label)
return x
class GradientWithInput(nn.Cell):
"""Get Discriminator Gradient with Input"""
def __init__(self, discrimator):
super(GradientWithInput, self).__init__()
self.reduce_sum = ops.ReduceSum()
self.discrimator = discrimator
def construct(self, interpolates):
decision_interpolate, _ = self.discrimator(interpolates)
decision_interpolate = self.reduce_sum(decision_interpolate, 0)
return decision_interpolate
class WGANGPGradientPenalty(nn.Cell):
"""Define WGAN loss for StarGAN"""
def __init__(self, discrimator):
super(WGANGPGradientPenalty, self).__init__()
self.gradient_op = ops.GradOperation()
self.reduce_sum = ops.ReduceSum()
self.reduce_sum_keep_dim = ops.ReduceSum(keep_dims=True)
self.sqrt = ops.Sqrt()
self.discrimator = discrimator
self.gradientWithInput = GradientWithInput(discrimator)
def construct(self, x_real, x_fake):
"""get gradient penalty"""
batch_size = x_real.shape[0]
alpha = generate_tensor(batch_size)
alpha = alpha.expand_as(x_real)
x_fake = ops.functional.stop_gradient(x_fake)
x_hat = (alpha * x_real + (1 - alpha) * x_fake)
gradient = self.gradient_op(self.gradientWithInput)(x_hat)
gradient_1 = ops.reshape(gradient, (batch_size, -1))
gradient_1 = self.sqrt(self.reduce_sum(gradient_1*gradient_1, 1))
gradient_penalty = self.reduce_sum((gradient_1 - 1.0)**2) / x_real.shape[0]
return gradient_penalty
class GeneratorLoss(nn.Cell):
"""Define total Generator loss"""
def __init__(self, args, generator, discriminator):
super(GeneratorLoss, self).__init__()
self.net_G = generator
self.net_D = discriminator
self.cyc_loss = P.ReduceMean()
self.rec_loss = nn.L1Loss("mean")
self.cls_loss = ClassificationLoss()
self.lambda_rec = args.lambda_rec
self.lambda_cls = args.lambda_cls
def construct(self, x_real, c_org, c_trg):
"""Get generator loss"""
# Original to Target
x_fake = self.net_G(x_real, c_trg)
fake_src, fake_cls = self.net_D(x_fake)
G_fake_loss = - self.cyc_loss(fake_src)
G_fake_cls_loss = self.cls_loss(fake_cls, c_trg)
# Target to Original
x_rec = self.net_G(x_fake, c_org)
G_rec_loss = self.rec_loss(x_real, x_rec)
g_loss = G_fake_loss + self.lambda_cls * G_fake_cls_loss + self.lambda_rec * G_rec_loss
return (x_fake, g_loss, G_fake_loss, G_fake_cls_loss, G_rec_loss)
class DiscriminatorLoss(nn.Cell):
"""Define total discriminator loss"""
def __init__(self, args, generator, discriminator):
super(DiscriminatorLoss, self).__init__()
self.net_G = generator
self.net_D = discriminator
self.cyc_loss = P.ReduceMean()
self.cls_loss = ClassificationLoss()
self.WGANLoss = WGANGPGradientPenalty(discriminator)
self.lambda_rec = Tensor(args.lambda_rec)
self.lambda_cls = Tensor(args.lambda_cls)
self.lambda_gp = Tensor(args.lambda_gp)
def construct(self, x_real, c_org, c_trg):
"""Get discriminator loss"""
# Compute loss with real images
real_src, real_cls = self.net_D(x_real)
D_real_loss = - self.cyc_loss(real_src)
D_real_cls_loss = self.cls_loss(real_cls, c_org)
# Compute loss with fake images
x_fake = self.net_G(x_real, c_trg)
fake_src, _ = self.net_D(x_fake)
D_fake_loss = self.cyc_loss(fake_src)
D_gp_loss = self.WGANLoss(x_real, x_fake)
d_loss = D_real_loss + D_fake_loss + self.lambda_cls * D_real_cls_loss + self.lambda_gp *D_gp_loss
return (d_loss, D_real_loss, D_fake_loss, D_real_cls_loss, D_gp_loss)

View File

@ -0,0 +1,142 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define Generator and Discriminator for StarGAN"""
import math
import numpy as np
import mindspore.nn as nn
import mindspore.ops as P
from mindspore import set_seed, Tensor
from mindspore.common import initializer as init
set_seed(1)
np.random.seed(1)
class ResidualBlock(nn.Cell):
"""Residual Block with instance normalization."""
def __init__(self, dim_in, dim_out):
super(ResidualBlock, self).__init__()
self.main = nn.SequentialCell(
nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, pad_mode='pad', has_bias=False),
nn.GroupNorm(num_groups=dim_out, num_channels=dim_out),
nn.ReLU(),
nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, pad_mode='pad', has_bias=False),
nn.GroupNorm(num_groups=dim_out, num_channels=dim_out)
)
def construct(self, x):
return x + self.main(x)
class Generator(nn.Cell):
"""Generator network."""
def __init__(self, conv_dim=64, c_dim=4, repeat_num=6):
super(Generator, self).__init__()
layers = []
layers.append((nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1,
padding=3, pad_mode='pad', has_bias=False)))
layers.append(nn.GroupNorm(num_groups=conv_dim, num_channels=conv_dim))
layers.append(nn.ReLU())
# Down-sampling layers.
curr_dim = conv_dim
for _ in range(2):
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2,
padding=1, pad_mode='pad', has_bias=False))
layers.append(nn.GroupNorm(num_groups=curr_dim*2, num_channels=curr_dim*2))
layers.append(nn.ReLU())
curr_dim = curr_dim*2
# Bottleneck layers.
for _ in range(repeat_num):
layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
# Up-sampling layers.
for _ in range(2):
layers.append(nn.Conv2dTranspose(curr_dim, int(curr_dim/2), kernel_size=4, stride=2,
padding=1, pad_mode='pad', has_bias=False))
layers.append(nn.GroupNorm(num_groups=int(curr_dim/2), num_channels=int(curr_dim/2)))
layers.append(nn.ReLU())
curr_dim = curr_dim // 2
layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, pad_mode='pad', has_bias=False))
layers.append(nn.Tanh())
self.main = nn.SequentialCell(*layers)
def construct(self, x, c):
reshape = P.Reshape()
c = reshape(c, (c.shape[0], c.shape[1], 1, 1))
c = P.functional.reshape(c, (c.shape[0], c.shape[1], 1, 1))
tile = P.Tile()
c = tile(c, (1, 1, x.shape[2], x.shape[3]))
op = P.Concat(1)
x = op((x, c))
return self.main(x)
class ResidualBlock_2(nn.Cell):
"""Residual Block with instance normalization."""
def __init__(self, weight, dim_in, dim_out):
super(ResidualBlock_2, self).__init__()
self.main = nn.SequentialCell(
nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1,
pad_mode='pad', has_bias=False, weight_init=Tensor(weight[0])),
nn.GroupNorm(num_groups=dim_out, num_channels=dim_out),
nn.ReLU(),
nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1,
pad_mode='pad', has_bias=False, weight_init=Tensor(weight[3])),
nn.GroupNorm(num_groups=dim_out, num_channels=dim_out)
)
def construct(self, x):
return x + self.main(x)
class Discriminator(nn.Cell):
"""Discriminator network with PatchGAN."""
def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
super(Discriminator, self).__init__()
layers = []
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1, has_bias=True,
pad_mode='pad', bias_init=init.Uniform(1 / math.sqrt(3))))
layers.append(nn.LeakyReLU(alpha=0.01))
curr_dim = conv_dim
for _ in range(1, repeat_num):
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, has_bias=True,
pad_mode='pad', bias_init=init.Uniform(1 / math.sqrt(curr_dim))))
layers.append(nn.LeakyReLU(alpha=0.01))
curr_dim = curr_dim * 2
kernel_size = int(image_size / np.power(2, repeat_num))
self.main = nn.SequentialCell(*layers)
# Patch GAN输出结果
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, pad_mode='pad', has_bias=False)
self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, has_bias=False, pad_mode='valid')
def construct(self, x):
h = self.main(x)
out_src = self.conv1(h)
out_cls = self.conv2(h)
reshape = P.Reshape()
out_cls = reshape(out_cls, (out_cls.shape[0], out_cls.shape[1]))
return out_src, out_cls

View File

@ -0,0 +1,94 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Reporter class."""
import logging
import time
import datetime
from mindspore import Tensor
class Reporter(logging.Logger):
"""
This class includes several functions that can save images/checkpoints and print/save logging information.
Args:
args (class): Option class.
"""
def __init__(self, args):
super(Reporter, self).__init__("StarGAN")
self.epoch = 0
self.step = 0
self.print_iter = 50
self.G_loss = []
self.D_loss = []
self.total_step = args.num_iters
self.runs_step = 0
def epoch_start(self):
self.step_start_time = time.time()
self.epoch_start_time = time.time()
self.step = 0
self.epoch += 1
self.G_loss = []
self.D_loss = []
def print_info(self, start_time, step, lossG, lossD):
"""print log after some steps."""
resG, resD, _, _ = self.return_loss_array(lossG, lossD)
if self.step % self.print_iter == 0:
# step_cost = str(round(float(time.time() - start_time) * 1000 / self.print_iter,2))
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
losses = "D_loss: [{:.3f}], G_loss: [{:.3f}].\nD_real_loss: {:.3f}, " \
"D_fake_loss: {:.3f}, D_real_cls_loss: {:.3f}, " \
"D_gp_loss: {:.3f}, G_fake_loss: {:.3f}, " \
"G_fake_cls_loss: {:.3f}, G_rec_loss: {:.3f}".format(
resD[0], resG[0], resD[1], resD[2], resD[3], resD[4], resG[1], resG[2], resG[3])
print("Step [{}/{}] Elapsed [{} s], {}".format(
step + 1, self.total_step, elapsed[:-7], losses))
def return_loss_array(self, lossG, lossD):
"""Transform output to loooooss array"""
resG = []
Glist = ['G_loss', 'G_fake_loss', 'G_fake_cls_loss', 'G_rec_loss']
dict_G = {'G_loss': 0., 'G_fake_loss': 0., 'G_fake_cls_loss': 0., 'G_rec_loss': 0.}
self.G_loss.append(float(lossG[2].asnumpy()))
for i, item in enumerate(lossG[2:]):
resG.append(float(item.asnumpy()))
dict_G[Glist[i]] = Tensor(float(item.asnumpy()))
resD = []
Dlist = ['Dloss', 'D_real_loss', 'D_fake_loss', 'D_real_cls_loss', 'D_gp_loss']
dict_D = {'Dloss': 0., 'D_real_loss': 0., 'D_fake_loss': 0., 'D_real_cls_loss': 0., 'D_gp_loss': 0.}
self.D_loss.append(float(lossD[1].asnumpy()))
for i, item in enumerate(lossD[1:]):
resD.append(float(item.asnumpy()))
dict_D[Dlist[i]] = Tensor(float(item.asnumpy()))
return resG, resD, dict_G, dict_D
def lr_decay_info(self, step, G_lr, D_lr):
"""print log after learning rate decay"""
print('Decayed learning rates in step {}, g_lr: {}, d_lr: {}.'.format(step, G_lr, D_lr))
def epoch_end(self):
"""print log and save cgeckpoints when epoch end."""
epoch_cost = (time.time() - self.epoch_start_time) * 1000
pre_step_time = epoch_cost / self.step
mean_loss_G = sum(self.G_loss) / self.step
mean_loss_D = sum(self.D_loss) / self.step
self.info("Epoch [{}] total cost: {:.2f} ms, pre step: {:.2f} ms, G_loss: {:.2f}, D_loss: {:.2f}".format(
self.epoch, epoch_cost, pre_step_time, mean_loss_G, mean_loss_D))

View File

@ -0,0 +1,146 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset distributed sampler."""
from __future__ import division
import os
import math
import numpy as np
from mindspore import load_checkpoint
from mindspore import Tensor
from mindspore import dtype as mstype
from src.cell import init_weights
from src.model import Generator, Discriminator
class DistributedSampler:
"""Distributed sampler."""
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=False):
if num_replicas is None:
print("***********Setting world_size to 1 since it is not passed in ******************")
num_replicas = 1
if rank is None:
print("***********Setting rank to 0 since it is not passed in ******************")
rank = 0
self.dataset_size = dataset_size
self.num_replicas = num_replicas
self.epoch = 0
self.rank = rank
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# shuffle based on epoch
if self.shuffle:
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
indices = indices.tolist()
self.epoch += 1
else:
indices = list(range(self.dataset_size))
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank: self.total_size: self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def resume_model(config, G, D):
"""Restore the trained generator and discriminator."""
print('Loading the trained models from step {}...'.format(config.resume_iters))
G_path = os.path.join(config.model_save_dir, f"Generator_2-0_%d.ckpt" % config.resume_iters)
# D_path = os.path.join(config.model_save_dir, f"Net_D_%d.ckpt" % config.resume_iters)
param_G = load_checkpoint(G_path, G)
# param_D = load_checkpoint(D_path, D)
return param_G, D
def print_network(model, name):
"""Print out the network information."""
num_params = 0
for p in model.trainable_params():
num_params += np.prod(p.shape)
print(model)
print(name)
print('The number of parameters: {}'.format(num_params))
def get_lr(init_lr, total_step, update_step, num_iters_decay):
"""Get changed learning rate."""
lr_each_step = []
lr = init_lr
for i in range(total_step):
if (i+1) % update_step == 0 and (i+1) > total_step-num_iters_decay:
lr = lr - (init_lr / float(num_iters_decay))
if lr < 0:
lr = 1e-6
lr_each_step.append(lr)
lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step
def get_network(args):
"""Create and initial a generator and a discriminator."""
G = Generator(args.g_conv_dim, args.c_dim, args.g_repeat_num)
D = Discriminator(args.image_size, args.d_conv_dim, args.c_dim, args.d_repeat_num)
init_weights(G, 'KaimingUniform', math.sqrt(5))
init_weights(D, 'KaimingUniform', math.sqrt(5))
print_network(G, 'Generator')
print_network(D, 'Discriminator')
return G, D
def create_labels(c_org, c_dim=5, selected_attrs=None):
"""Generate target domain labels for debugging and testing"""
# Get hair color indices.
hair_color_indices = []
for i, attr_name in enumerate(selected_attrs):
if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
hair_color_indices.append(i)
c_trg_list = []
for i in range(c_dim):
c_trg = c_org.copy()
if i in hair_color_indices:
c_trg[:, i] = 1
for j in hair_color_indices:
if j != i:
c_trg[:, j] = 0
else:
c_trg[:, i] = (c_trg[:, i] == 0)
c_trg_list.append(c_trg)
c_trg_list = Tensor(c_trg_list, mstype.float16)
return c_trg_list
def denorm(x):
image_numpy = (np.transpose(x, (0, 2, 3, 1)) + 1) / 2.0 * 255.0
image_numpy = np.clip(image_numpy, 0, 255)
return image_numpy

View File

@ -0,0 +1,219 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train the model."""
from time import time
import os
import argparse
import ast
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import nn
from mindspore import Tensor, context
from mindspore.common import set_seed
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, _InternalCallbackParam, RunContext
from src.dataset import dataloader
from src.config import get_config
from src.utils import get_network
from src.cell import TrainOneStepCellGen, TrainOneStepCellDis
from src.loss import GeneratorLoss, DiscriminatorLoss, ClassificationLoss, WGANGPGradientPenalty
from src.reporter import Reporter
set_seed(1)
# Modelarts
parser = argparse.ArgumentParser(description='StarGAN_args')
parser.add_argument('--modelarts', type=ast.literal_eval, default=False, help='Dataset path')
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.")
parser.add_argument("--device_id", type=int, default=0, help="device id, default: 0.")
parser.add_argument("--device_num", type=int, default=1, help="number of device, default: 0.")
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
args_opt = parser.parse_args()
if __name__ == '__main__':
config = get_config()
if args_opt.modelarts:
import moxing as mox
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
context.set_context(device_id=device_id)
local_data_url = './cache/data'
local_train_url = '/cache/ckpt'
local_data_url = os.path.join(local_data_url, str(device_id))
local_train_url = os.path.join(local_train_url, str(device_id))
# unzip data
path = os.getcwd()
print("cwd: %s" % path)
data_url = 'obs://hit-wcy/data/CelebA/'
data_name = '/celeba.zip'
print('listdir1: %s' % os.listdir('./'))
a1time = time()
mox.file.copy_parallel(data_url, local_data_url)
print('listdir2: %s' % os.listdir(local_data_url))
b1time = time()
print('time1:', b1time - a1time)
a2time = time()
zip_command = "unzip -o %s -d %s" % (local_data_url + data_name, local_data_url)
if os.system(zip_command) == 0:
print('Successful backup')
else:
print('FAILED backup')
b2time = time()
print('time2:', b2time - a2time)
print('listdir3: %s' % os.listdir(local_data_url))
# Device Environment
if config.run_distribute:
if config.device_target == "Ascend":
rank = device_id
# device_num = device_num
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
else:
rank = 0
device_num = 1
data_path = local_data_url + '/celeba/images'
attr_path = local_data_url + '/celeba/list_attr_celeba.txt'
dataset, length = dataloader(img_path=data_path,
attr_path=attr_path,
batch_size=config.batch_size,
selected_attr=config.selected_attrs,
device_num=config.num_workers,
dataset=config.dataset,
mode=config.mode,
shuffle=True)
else:
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target,
device_id=config.device_id, save_graphs=False)
if args_opt.run_distribute:
if os.getenv("DEVICE_ID", "not_set").isdigit():
context.set_context(device_id=int(os.getenv("DEVICE_ID")))
device_num = config.device_num
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
init()
rank = get_rank()
data_path = config.celeba_image_dir
attr_path = config.attr_path
local_train_url = config.model_save_dir
dataset, length = dataloader(img_path=data_path,
attr_path=attr_path,
batch_size=config.batch_size,
selected_attr=config.selected_attrs,
device_num=config.device_num,
dataset=config.dataset,
mode=config.mode,
shuffle=True)
print(length)
dataset_iter = dataset.create_dict_iterator()
# Get and initial network
generator, discriminator = get_network(config)
cls_loss = ClassificationLoss()
wgan_loss = WGANGPGradientPenalty(discriminator)
# Define network with loss
G_loss_cell = GeneratorLoss(config, generator, discriminator)
D_loss_cell = DiscriminatorLoss(config, generator, discriminator)
# Define Optimizer
star_iter = 0
iter_sum = config.num_iters
Optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=config.g_lr,
beta1=config.beta1, beta2=config.beta2)
Optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=config.d_lr,
beta1=config.beta1, beta2=config.beta2)
# Define One step train
G_trainOneStep = TrainOneStepCellGen(G_loss_cell, Optimizer_G)
D_trainOneStep = TrainOneStepCellDis(D_loss_cell, Optimizer_D)
# Train
G_trainOneStep.set_train()
D_trainOneStep.set_train()
print('Start Training')
reporter = Reporter(config)
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.model_save_step)
ckpt_cb_g = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='Generator')
ckpt_cb_d = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='Discriminator')
cb_params_g = _InternalCallbackParam()
cb_params_g.train_network = generator
cb_params_g.cur_step_num = 0
cb_params_g.batch_num = 4
cb_params_g.cur_epoch_num = 0
cb_params_d = _InternalCallbackParam()
cb_params_d.train_network = discriminator
cb_params_d.cur_step_num = 0
cb_params_d.batch_num = config.batch_size
cb_params_d.cur_epoch_num = 0
run_context_g = RunContext(cb_params_g)
run_context_d = RunContext(cb_params_d)
ckpt_cb_g.begin(run_context_g)
ckpt_cb_d.begin(run_context_d)
start = time()
for iterator in range(config.num_iters):
data = next(dataset_iter)
x_real = Tensor(data['image'], mstype.float32)
c_trg = Tensor(data['attr'], mstype.float32)
c_org = Tensor(data['attr'], mstype.float32)
np.random.shuffle(c_trg)
d_out = D_trainOneStep(x_real, c_org, c_trg)
if (iterator + 1) % config.n_critic == 0:
g_out = G_trainOneStep(x_real, c_org, c_trg)
if (iterator + 1) % config.log_step == 0:
reporter.print_info(start, iterator, g_out, d_out)
_, _, dict_G, dict_D = reporter.return_loss_array(g_out, d_out)
if (iterator + 1) % config.model_save_step == 0:
cb_params_d.cur_step_num = iterator + 1
cb_params_d.batch_num = iterator + 2
cb_params_g.cur_step_num = iterator + 1
cb_params_g.batch_num = iterator + 2
ckpt_cb_g.step_end(run_context_g)
ckpt_cb_d.step_end(run_context_d)
if args_opt.modelarts:
mox.file.copy_parallel(local_train_url, config.train_url)