parent
4e58f833be
commit
72b9b5125c
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,136 @@
|
|||
# Contents
|
||||
|
||||
- [StarGAN Description](#StarGAN-description)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Training Process](#training-process)
|
||||
- [Prediction Process](#prediction-process)
|
||||
- [Export MindIR](#export-mindir)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Inference Performance](#evaluation-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [StarGAN-description](#contents)
|
||||
|
||||
> **StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation**<br>
|
||||
> [Yunjey Choi](https://github.com/yunjey)<sup>1,2</sup>, [Minje Choi](https://github.com/mjc92)<sup>1,2</sup>, [Munyoung Kim](https://www.facebook.com/munyoung.kim.1291)<sup>2,3</sup>, [Jung-Woo Ha](https://www.facebook.com/jungwoo.ha.921)<sup>2</sup>, [Sung Kim](https://www.cse.ust.hk/~hunkim/)<sup>2,4</sup>, [Jaegul Choo](https://sites.google.com/site/jaegulchoo/)<sup>1,2</sup> <br/>
|
||||
> <sup>1</sup>Korea University, <sup>2</sup>Clova AI Research, NAVER Corp. <br>
|
||||
> <sup>3</sup>The College of New Jersey, <sup>4</sup>Hong Kong University of Science and Technology <br/>
|
||||
> https://arxiv.org/abs/1711.09020 <br>
|
||||
>
|
||||
> **Abstract:** *Recent studies have shown remarkable success in image-to-image translation for two domains. However, existing approaches have limited scalability and robustness in handling more than two domains, since different models should be built independently for every pair of image domains. To address this limitation, we propose StarGAN, a novel and scalable approach that can perform image-to-image translations for multiple domains using only a single model. Such a unified model architecture of StarGAN allows simultaneous training of multiple datasets with different domains within a single network. This leads to StarGAN's superior quality of translated images compared to existing models as well as the novel capability of flexibly translating an input image to any desired target domain. We empirically demonstrate the effectiveness of our approach on a facial attribute transfer and a facial expression synthesis tasks.*
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
Dataset used: [CelebA](<http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>)
|
||||
|
||||
CelebFaces Attributes Dataset (CelebA) is a large-scale face attributes dataset with more than 200K celebrity images, each with 40 attribute annotations. The images in this dataset cover large pose variations and background clutter. CelebA has large diversities, large quantities, and rich annotations, including
|
||||
|
||||
- 10,177 number of identities,
|
||||
|
||||
- 202,599 number of face images, and 5 landmark locations, 40 binary attributes annotations per image.
|
||||
|
||||
The dataset can be employed as the training and test sets for the following computer vision tasks: face attribute recognition, face detection, landmark (or facial part) localization, and face editing & synthesis.
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware Ascend
|
||||
- Prepare hardware environment with Ascend processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```text
|
||||
.
|
||||
└─ cv
|
||||
└─ StarGAN
|
||||
├─ src
|
||||
├─ __init__.py # init file
|
||||
├─ cell.py # StarGAN model define
|
||||
├─ model.py # define subnetwork about generator and discriminator
|
||||
├─ utils.py # utils for StarGAN
|
||||
├─ config.py # parse args
|
||||
├─ dataset.py # prepare celebA dataset to cyclegan format
|
||||
├─ reporter.py # Reporter class
|
||||
├─ loss.py # losses for StarGAN
|
||||
├─ cityscape_eval.py # cityscape dataset eval script
|
||||
├─ eval.py # translate attritubes from original images
|
||||
├─ train.py # train script
|
||||
├─ export.py # export mindir script
|
||||
└─ README.md # descriptions about StarGAN
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
When training the network, you should selected the attributes in config, then you should change the c_dim in config which is same as the number of selected attributes.
|
||||
|
||||
```bash
|
||||
python train.py
|
||||
```
|
||||
|
||||
## [Prediction Process](#contents)
|
||||
|
||||
```bash
|
||||
python eval.py
|
||||
```
|
||||
|
||||
**Note: the result will saved at `./results/`.**
|
||||
|
||||
## [Export MindIR](#contents)
|
||||
|
||||
```bash
|
||||
python export.py
|
||||
```
|
||||
|
||||
**Note: The file_name parameter is the prefix, the final file will as StarGAN_G.[FILE_FORMAT].**
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend 910 |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | StarGAN |
|
||||
| Resource | Ascend |
|
||||
| uploaded Date | 03/30/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.1.1 |
|
||||
| Dataset | CelebA |
|
||||
| Training Parameters | steps=200000, batch_size=1, lr=0.0001 |
|
||||
| Optimizer | Adam |
|
||||
| outputs | probability |
|
||||
| Speed | 1pc: 100 ms/step; |
|
||||
| Total time | 1pc: 10h; |
|
||||
| Parameters (M) | 8.423 M |
|
||||
| Checkpoint for Fine tuning | 32.15M (.ckpt file) |
|
||||
| Scripts | [StarGAN script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/StarGAN) |
|
||||
|
||||
### Inference Performance
|
||||
|
||||
| Parameters | Ascend 910 |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | StarGAN |
|
||||
| Resource | Ascend |
|
||||
| Uploaded Date | 03/30/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.1.1 |
|
||||
| Dataset | CelebA |
|
||||
| batch_size | 4 |
|
||||
| outputs | image |
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluation for StarGAN"""
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_param_into_net
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.ops as ops
|
||||
|
||||
from src.utils import resume_model, create_labels, denorm, get_network
|
||||
from src.config import get_config
|
||||
from src.dataset import dataloader
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=1)
|
||||
config = get_config()
|
||||
G, D = get_network(config)
|
||||
para_g, _ = resume_model(config, G, D)
|
||||
load_param_into_net(G, para_g)
|
||||
|
||||
if not os.path.exists(config.result_dir):
|
||||
os.mkdir(config.result_dir)
|
||||
# Define Dataset
|
||||
|
||||
data_path = config.celeba_image_dir
|
||||
attr_path = config.attr_path
|
||||
|
||||
dataset, length = dataloader(img_path=data_path,
|
||||
attr_path=attr_path,
|
||||
batch_size=4,
|
||||
selected_attr=config.selected_attrs,
|
||||
device_num=config.num_workers,
|
||||
dataset=config.dataset,
|
||||
mode=config.mode,
|
||||
shuffle=False)
|
||||
|
||||
op = ops.Concat(axis=3)
|
||||
ds = dataset.create_dict_iterator()
|
||||
print(length)
|
||||
print('Start Evaluating!')
|
||||
for i, data in enumerate(ds):
|
||||
result_list = ()
|
||||
img_real = denorm(data['image'].asnumpy())
|
||||
x_real = Tensor(data['image'], mstype.float32)
|
||||
result_list += (x_real,)
|
||||
c_trg_list = create_labels(data['attr'].asnumpy(), selected_attrs=config.selected_attrs)
|
||||
c_trg_list = Tensor(c_trg_list, mstype.float32)
|
||||
x_fake_list = []
|
||||
|
||||
for c_trg in c_trg_list:
|
||||
|
||||
x_fake = G(x_real, c_trg)
|
||||
x = Tensor(x_fake.asnumpy().copy())
|
||||
|
||||
result_list += (x,)
|
||||
|
||||
x_fake_list = op(result_list)
|
||||
|
||||
result = denorm(x_fake_list.asnumpy())
|
||||
result = np.reshape(result, (-1, 768, 3))
|
||||
|
||||
im = Image.fromarray(np.uint8(result))
|
||||
im.save(config.result_dir + '/test_{}.jpg'.format(i))
|
||||
print('Successful save image in ' + config.result_dir + '/test_{}.jpg'.format(i))
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export file."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import export, load_param_into_net
|
||||
from src.config import get_config
|
||||
from src.utils import get_network, resume_model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
config = get_config()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
|
||||
G, D = get_network(config)
|
||||
|
||||
# Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d
|
||||
# Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d
|
||||
|
||||
G.set_train(True)
|
||||
param_G, _ = resume_model(config, G, D)
|
||||
load_param_into_net(G, param_G)
|
||||
|
||||
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128, 128)).astype(np.float32))
|
||||
input_label = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 5)).astype(np.float32))
|
||||
G_file = f"StarGAN_Generator"
|
||||
export(G, input_array, file_name=G_file, file_format=config.file_format)
|
|
@ -0,0 +1,25 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export MODE='test'
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python eval.py --run_distribute=0 --device_num=$DEVICE_NUM --device_id=$DEVICE_ID --mode=$MODE> log_eval.txt 2>&1 &
|
||||
|
||||
cd ..
|
|
@ -0,0 +1,50 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [DISTRIBUTE] [RANK_TABLE_FILE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
export RANK_SIZE=$1
|
||||
DISTRIBUTE=$2
|
||||
export RANK_TABLE_FILE=$3
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cp ./*.json ./LOG$i
|
||||
cp ./*.py ./LOG$i
|
||||
cp -r ./src ./LOG$i
|
||||
cp -r ./scripts ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
export RANK_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
python train.py \
|
||||
--run_distribute=$DISTRIBUTE \
|
||||
--device_num=$RANK_SIZE \
|
||||
--device_id=$DEVICE_ID > log.txt 2>&1 &
|
||||
fi
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,25 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --run_distribute=0 --device_num=$DEVICE_NUM --device_id=$DEVICE_ID > log.txt 2>&1 &
|
||||
|
||||
cd ..
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""init file."""
|
|
@ -0,0 +1,166 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Train cell for StarGAN"""
|
||||
import numpy as np
|
||||
from mindspore import nn
|
||||
from mindspore import context
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.ops.functional as F
|
||||
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
|
||||
_get_parallel_mode)
|
||||
import mindspore.ops as ops
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.common import initializer as init, set_seed
|
||||
set_seed(1)
|
||||
np.random.seed(1)
|
||||
|
||||
def init_weights(net, init_type='normal', init_gain=0.02):
|
||||
"""
|
||||
Initialize network weights.
|
||||
|
||||
Parameters:
|
||||
net (Cell): Network to be initialized
|
||||
init_type (str): The name of an initialization method: normal | xavier.
|
||||
init_gain (float): Gain factor for normal and xavier.
|
||||
|
||||
"""
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
|
||||
if init_type == 'normal':
|
||||
cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
|
||||
elif init_type == 'xavier':
|
||||
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
|
||||
elif init_type == 'KaimingUniform':
|
||||
cell.weight.set_data(init.initializer(init.HeUniform(init_gain), cell.weight.shape))
|
||||
elif init_type == 'constant':
|
||||
cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
|
||||
else:
|
||||
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||||
elif isinstance(cell, nn.GroupNorm):
|
||||
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
|
||||
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
|
||||
|
||||
|
||||
class GeneratorWithLossCell(nn.Cell):
|
||||
"""
|
||||
Wrap the network with loss function to return generator loss.
|
||||
|
||||
Args:
|
||||
network (Cell): The target network to wrap.
|
||||
"""
|
||||
def __init__(self, network):
|
||||
super(GeneratorWithLossCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
|
||||
def construct(self, x_real, c_org, c_trg):
|
||||
_, G_Loss, _, _, _, = self.network(x_real, c_org, c_trg)
|
||||
return G_Loss
|
||||
|
||||
|
||||
class DiscriminatorWithLossCell(nn.Cell):
|
||||
"""
|
||||
Wrap the network with loss function to return generator loss.
|
||||
|
||||
Args:
|
||||
network (Cell): The target network to wrap.
|
||||
"""
|
||||
def __init__(self, network):
|
||||
super(DiscriminatorWithLossCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
|
||||
def construct(self, x_real, c_org, c_trg):
|
||||
D_Loss, _, _, _, _ = self.network(x_real, c_org, c_trg)
|
||||
return D_Loss
|
||||
|
||||
|
||||
class TrainOneStepCellGen(nn.Cell):
|
||||
"""Encapsulation class of StarGAN generator network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph."""
|
||||
def __init__(self, G, optimizer, sens=1.0):
|
||||
super(TrainOneStepCellGen, self).__init__()
|
||||
self.optimizer = optimizer
|
||||
self.G = G
|
||||
self.G.set_grad()
|
||||
self.G.set_train()
|
||||
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.weights = optimizer.parameters
|
||||
self.network = GeneratorWithLossCell(G)
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = F.identity
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
|
||||
|
||||
def construct(self, img_real, c_org, c_trg):
|
||||
weights = self.weights
|
||||
fake_image, loss, G_fake_loss, G_fake_cls_loss, G_rec_loss = self.G(img_real, c_org, c_trg)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(img_real, c_org, c_trg, sens)
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
return F.depend(loss, self.optimizer(grads)), fake_image, loss, G_fake_loss, G_fake_cls_loss, G_rec_loss
|
||||
|
||||
|
||||
class TrainOneStepCellDis(nn.Cell):
|
||||
"""Encapsulation class of StarGAN Discriminator network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph."""
|
||||
def __init__(self, D, optimizer, sens=1.0):
|
||||
super(TrainOneStepCellDis, self).__init__()
|
||||
self.optimizer = optimizer
|
||||
self.D = D
|
||||
self.D.set_grad()
|
||||
self.D.set_train()
|
||||
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.weights = optimizer.parameters
|
||||
self.network = DiscriminatorWithLossCell(D)
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = F.identity
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
if auto_parallel_context().get_device_num_is_set():
|
||||
degree = context.get_auto_parallel_context("device_num")
|
||||
else:
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, img_real, c_org, c_trg):
|
||||
weights = self.weights
|
||||
|
||||
loss, D_real_loss, D_fake_loss, D_real_cls_loss, D_gp_loss = self.D(img_real, c_org, c_trg)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(img_real, c_org, c_trg, sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
return F.depend(loss, self.optimizer(grads)), loss, D_real_loss, D_fake_loss, D_real_cls_loss, D_gp_loss
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Define Configuration for StarGAN"""
|
||||
import argparse
|
||||
|
||||
|
||||
def get_config():
|
||||
"""Define configuration of Model"""
|
||||
parser = argparse.ArgumentParser(description='StarGAN')
|
||||
|
||||
# Model configuration.
|
||||
parser.add_argument('--c_dim', type=int, default=5, help='dimension of domain labels (1st dataset)')
|
||||
parser.add_argument('--c2_dim', type=int, default=7, help='dimension of domain labels (2nd dataset)')
|
||||
parser.add_argument('--celeba_crop_size', type=int, default=178, help='crop size for the CelebA dataset')
|
||||
parser.add_argument('--rafd_crop_size', type=int, default=256, help='crop size for the RaFD dataset')
|
||||
parser.add_argument('--image_size', type=int, default=128, help='image resolution')
|
||||
parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G')
|
||||
parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D')
|
||||
parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G')
|
||||
parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D')
|
||||
parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss')
|
||||
parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss')
|
||||
parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')
|
||||
|
||||
# Training configuration.
|
||||
parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'Both'])
|
||||
parser.add_argument('--batch_size', type=int, default=4, help='mini-batch size')
|
||||
parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D')
|
||||
parser.add_argument('--epochs', type=int, default=59, help='number of epoch')
|
||||
parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr')
|
||||
parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G')
|
||||
parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D')
|
||||
parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update')
|
||||
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
|
||||
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
|
||||
parser.add_argument('--resume_iters', type=int, default=200000, help='resume training from this step')
|
||||
parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset',
|
||||
default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'])
|
||||
parser.add_argument('--init_type', type=str, default='normal', choices=("normal", "xavier"),
|
||||
help='network initialization, default is normal.')
|
||||
parser.add_argument('--init_gain', type=float, default=0.02,
|
||||
help='scaling factor for normal, xavier and orthogonal, default is 0.02.')
|
||||
|
||||
# Test configuration.
|
||||
parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step')
|
||||
|
||||
# Train Device.
|
||||
parser.add_argument('--num_workers', type=int, default=8)
|
||||
parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'])
|
||||
parser.add_argument('--device_target', type=str, default='Ascend')
|
||||
parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id, default: 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="number of device, default: 0.")
|
||||
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
|
||||
|
||||
|
||||
# Directories.
|
||||
parser.add_argument('--celeba_image_dir', type=str, default=r'/root/wcy/StarGAN_copy/celeba/images')
|
||||
parser.add_argument('--attr_path', type=str, default=r'/root/wcy/StarGAN_copy/celeba/list_attr_celeba.txt')
|
||||
parser.add_argument('--rafd_image_dir', type=str, default='data/RaFD/train')
|
||||
parser.add_argument('--log_dir', type=str, default='stargan/logs')
|
||||
parser.add_argument('--model_save_dir', type=str, default='./models/')
|
||||
parser.add_argument('--result_dir', type=str, default='./results')
|
||||
|
||||
# Step size.
|
||||
parser.add_argument('--log_step', type=int, default=10)
|
||||
parser.add_argument('--sample_step', type=int, default=5000)
|
||||
parser.add_argument('--model_save_step', type=int, default=5000)
|
||||
parser.add_argument('--lr_update_step', type=int, default=1000)
|
||||
|
||||
# export
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \
|
||||
help='file format')
|
||||
|
||||
|
||||
config = parser.parse_args()
|
||||
|
||||
return config
|
|
@ -0,0 +1,198 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Data Processing for StarGAN"""
|
||||
import os
|
||||
import random
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import mindspore.dataset.vision.py_transforms as py_vision
|
||||
import mindspore.dataset.transforms.py_transforms as py_transforms
|
||||
import mindspore.dataset as de
|
||||
|
||||
from src.utils import DistributedSampler
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
"""Judge whether it is an image"""
|
||||
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.tif', '.tiff']
|
||||
return any(filename.lower().endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def make_dataset(dir_path, max_dataset_size=float("inf")):
|
||||
"""Return image list in dir"""
|
||||
images = []
|
||||
assert os.path.isdir(dir_path), "%s is not a valid directory" % dir_path
|
||||
|
||||
for root, _, fnames in sorted(os.walk(dir_path)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
path = os.path.join(root, fname)
|
||||
images.append(path)
|
||||
return images[:min(max_dataset_size, len(images))]
|
||||
|
||||
|
||||
class CelebA:
|
||||
"""
|
||||
This dataset class helps load celebA dataset.
|
||||
"""
|
||||
def __init__(self, image_dir, attr_path, selected_attrs, transform, mode):
|
||||
"""Initialize and preprocess the CelebA dataset."""
|
||||
self.image_dir = image_dir
|
||||
self.attr_path = attr_path
|
||||
self.selected_attrs = selected_attrs
|
||||
self.transform = transform
|
||||
self.mode = mode
|
||||
self.train_dataset = []
|
||||
self.test_dataset = []
|
||||
self.attr2idx = {}
|
||||
self.idx2attr = {}
|
||||
self.preprocess()
|
||||
|
||||
if mode == 'train':
|
||||
self.num_images = len(self.train_dataset)
|
||||
else:
|
||||
self.num_images = len(self.test_dataset)
|
||||
|
||||
def preprocess(self):
|
||||
"""Preprocess the CelebA attribute file."""
|
||||
lines = [line.rstrip() for line in open(self.attr_path, 'r')]
|
||||
all_attr_names = lines[1].split()
|
||||
for i, attr_name in enumerate(all_attr_names):
|
||||
self.attr2idx[attr_name] = i
|
||||
self.idx2attr[i] = attr_name
|
||||
|
||||
lines = lines[2:]
|
||||
random.seed(1234)
|
||||
random.shuffle(lines)
|
||||
for i, line in enumerate(lines):
|
||||
split = line.split()
|
||||
filename = split[0]
|
||||
values = split[1:]
|
||||
|
||||
label = []
|
||||
for attr_name in self.selected_attrs:
|
||||
idx = self.attr2idx[attr_name]
|
||||
label.append(1.0 if values[idx] == '1' else 0.0)
|
||||
|
||||
if (i+1) < 2000:
|
||||
self.test_dataset.append([filename, label])
|
||||
else:
|
||||
self.train_dataset.append([filename, label])
|
||||
|
||||
print('Finished preprocessing the CelebA dataset...')
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Return one image and its corresponding attribute label."""
|
||||
dataset = self.train_dataset if self.mode == 'train' else self.test_dataset
|
||||
filename, label = dataset[idx]
|
||||
image = np.asarray(Image.open(os.path.join(self.image_dir, filename)))
|
||||
label = np.asarray(label)
|
||||
image = np.squeeze(self.transform(image))
|
||||
# image = Tensor(image, mstype.float32)
|
||||
# label = Tensor(label, mstype.float32)
|
||||
|
||||
return image, label
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
return self.num_images
|
||||
|
||||
|
||||
class ImageFolderDataset:
|
||||
"""
|
||||
This dataset class can load images from image folder.
|
||||
|
||||
Args:
|
||||
data_root (str): Images root directory.
|
||||
max_dataset_size (int): Maximum number of return image paths.
|
||||
|
||||
Returns:
|
||||
Image path list.
|
||||
"""
|
||||
|
||||
def __init__(self, data_root, transform, max_dataset_size=float("inf")):
|
||||
self.data_root = data_root
|
||||
self.transform = transform
|
||||
self.paths = sorted(make_dataset(data_root, max_dataset_size))
|
||||
self.size = len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path = self.paths[index % self.size]
|
||||
# image = np.array(Image.open(img_path).convert('RGB'))
|
||||
image = np.asarray(Image.open(img_path))
|
||||
return np.squeeze(self.transform(image)), os.path.split(img_path)[1]
|
||||
# return image, os.path.split(img_path)[1]
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
|
||||
def get_loader(data_root, attr_path, selected_attrs, crop_size=178, image_size=128,
|
||||
dataset='CelebA', mode='train'):
|
||||
"""Build and return a data loader."""
|
||||
mean = [0.5, 0.5, 0.5]
|
||||
std = [0.5, 0.5, 0.5]
|
||||
transform = [py_vision.ToPIL()]
|
||||
if mode == 'train':
|
||||
transform.append(py_vision.RandomHorizontalFlip())
|
||||
transform.append(py_vision.CenterCrop(crop_size))
|
||||
transform.append(py_vision.Resize(image_size))
|
||||
transform.append(py_vision.ToTensor())
|
||||
transform.append(py_vision.Normalize(mean=mean, std=std))
|
||||
transform = py_transforms.Compose(transform)
|
||||
|
||||
if dataset == 'CelebA':
|
||||
dataset = CelebA(data_root, attr_path, selected_attrs, transform, mode)
|
||||
|
||||
elif dataset == 'RaFD':
|
||||
dataset = ImageFolderDataset(data_root, transform)
|
||||
return dataset
|
||||
|
||||
|
||||
def dataloader(img_path, attr_path, selected_attr, dataset, mode='train',
|
||||
batch_size=1, device_num=1, rank=0, shuffle=True):
|
||||
"""Get dataloader"""
|
||||
assert dataset in ['CelebA', 'RaFD']
|
||||
|
||||
cores = multiprocessing.cpu_count()
|
||||
num_parallel_workers = int(cores / device_num)
|
||||
|
||||
if dataset == 'CelebA':
|
||||
dataset_loader = get_loader(img_path, attr_path, selected_attr, mode=mode)
|
||||
length_dataset = len(dataset_loader)
|
||||
distributed_sampler = DistributedSampler(length_dataset, device_num, rank, shuffle=shuffle)
|
||||
dataset_column_names = ["image", "attr"]
|
||||
|
||||
else:
|
||||
dataset_loader = get_loader(img_path, None, None, dataset='RaFD')
|
||||
length_dataset = len(dataset_loader)
|
||||
distributed_sampler = DistributedSampler(length_dataset, device_num, rank, shuffle=shuffle)
|
||||
dataset_column_names = ["image", "image_path"]
|
||||
|
||||
if device_num != 8:
|
||||
ds = de.GeneratorDataset(dataset_loader, column_names=dataset_column_names,
|
||||
num_parallel_workers=min(32, num_parallel_workers),
|
||||
sampler=distributed_sampler)
|
||||
ds = ds.batch(batch_size, num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True)
|
||||
|
||||
else:
|
||||
ds = de.GeneratorDataset(dataset_loader, column_names=dataset_column_names, sampler=distributed_sampler)
|
||||
ds = ds.batch(batch_size, num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True)
|
||||
if mode == 'train':
|
||||
ds = ds.repeat(200)
|
||||
|
||||
return ds, length_dataset
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Define loss function for StarGAN"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.ops as ops
|
||||
import mindspore.ops.operations as P
|
||||
|
||||
from mindspore import nn, Tensor
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import constexpr
|
||||
|
||||
|
||||
@constexpr
|
||||
def generate_tensor(batch_size):
|
||||
np_array = np.random.randn(batch_size, 1, 1, 1)
|
||||
return Tensor(np_array, mstype.float32)
|
||||
|
||||
|
||||
class ClassificationLoss(nn.Cell):
|
||||
"""Define classification loss for StarGAN"""
|
||||
def __init__(self, dataset='CelebA'):
|
||||
super().__init__()
|
||||
self.BCELoss = P.BinaryCrossEntropy(reduction='sum')
|
||||
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
|
||||
self.dataset = dataset
|
||||
self.bec = nn.BCELoss(reduction='sum')
|
||||
|
||||
def construct(self, pred, label):
|
||||
if self.dataset == 'CelebA':
|
||||
weight = ops.Ones()(pred.shape, mstype.float32)
|
||||
pred_ = P.Sigmoid()(pred)
|
||||
x = self.BCELoss(pred_, label, weight) / pred.shape[0]
|
||||
|
||||
else:
|
||||
x = self.cross_entropy(pred, label)
|
||||
return x
|
||||
|
||||
|
||||
class GradientWithInput(nn.Cell):
|
||||
"""Get Discriminator Gradient with Input"""
|
||||
def __init__(self, discrimator):
|
||||
super(GradientWithInput, self).__init__()
|
||||
self.reduce_sum = ops.ReduceSum()
|
||||
self.discrimator = discrimator
|
||||
|
||||
def construct(self, interpolates):
|
||||
decision_interpolate, _ = self.discrimator(interpolates)
|
||||
decision_interpolate = self.reduce_sum(decision_interpolate, 0)
|
||||
return decision_interpolate
|
||||
|
||||
|
||||
class WGANGPGradientPenalty(nn.Cell):
|
||||
"""Define WGAN loss for StarGAN"""
|
||||
def __init__(self, discrimator):
|
||||
super(WGANGPGradientPenalty, self).__init__()
|
||||
self.gradient_op = ops.GradOperation()
|
||||
|
||||
self.reduce_sum = ops.ReduceSum()
|
||||
self.reduce_sum_keep_dim = ops.ReduceSum(keep_dims=True)
|
||||
self.sqrt = ops.Sqrt()
|
||||
self.discrimator = discrimator
|
||||
self.gradientWithInput = GradientWithInput(discrimator)
|
||||
|
||||
def construct(self, x_real, x_fake):
|
||||
"""get gradient penalty"""
|
||||
batch_size = x_real.shape[0]
|
||||
alpha = generate_tensor(batch_size)
|
||||
alpha = alpha.expand_as(x_real)
|
||||
x_fake = ops.functional.stop_gradient(x_fake)
|
||||
x_hat = (alpha * x_real + (1 - alpha) * x_fake)
|
||||
|
||||
gradient = self.gradient_op(self.gradientWithInput)(x_hat)
|
||||
gradient_1 = ops.reshape(gradient, (batch_size, -1))
|
||||
gradient_1 = self.sqrt(self.reduce_sum(gradient_1*gradient_1, 1))
|
||||
gradient_penalty = self.reduce_sum((gradient_1 - 1.0)**2) / x_real.shape[0]
|
||||
return gradient_penalty
|
||||
|
||||
|
||||
class GeneratorLoss(nn.Cell):
|
||||
"""Define total Generator loss"""
|
||||
def __init__(self, args, generator, discriminator):
|
||||
super(GeneratorLoss, self).__init__()
|
||||
self.net_G = generator
|
||||
self.net_D = discriminator
|
||||
self.cyc_loss = P.ReduceMean()
|
||||
self.rec_loss = nn.L1Loss("mean")
|
||||
self.cls_loss = ClassificationLoss()
|
||||
|
||||
self.lambda_rec = args.lambda_rec
|
||||
self.lambda_cls = args.lambda_cls
|
||||
|
||||
def construct(self, x_real, c_org, c_trg):
|
||||
"""Get generator loss"""
|
||||
# Original to Target
|
||||
x_fake = self.net_G(x_real, c_trg)
|
||||
fake_src, fake_cls = self.net_D(x_fake)
|
||||
|
||||
G_fake_loss = - self.cyc_loss(fake_src)
|
||||
G_fake_cls_loss = self.cls_loss(fake_cls, c_trg)
|
||||
|
||||
# Target to Original
|
||||
x_rec = self.net_G(x_fake, c_org)
|
||||
G_rec_loss = self.rec_loss(x_real, x_rec)
|
||||
|
||||
g_loss = G_fake_loss + self.lambda_cls * G_fake_cls_loss + self.lambda_rec * G_rec_loss
|
||||
|
||||
return (x_fake, g_loss, G_fake_loss, G_fake_cls_loss, G_rec_loss)
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Cell):
|
||||
"""Define total discriminator loss"""
|
||||
def __init__(self, args, generator, discriminator):
|
||||
super(DiscriminatorLoss, self).__init__()
|
||||
self.net_G = generator
|
||||
self.net_D = discriminator
|
||||
self.cyc_loss = P.ReduceMean()
|
||||
self.cls_loss = ClassificationLoss()
|
||||
self.WGANLoss = WGANGPGradientPenalty(discriminator)
|
||||
|
||||
self.lambda_rec = Tensor(args.lambda_rec)
|
||||
self.lambda_cls = Tensor(args.lambda_cls)
|
||||
self.lambda_gp = Tensor(args.lambda_gp)
|
||||
|
||||
def construct(self, x_real, c_org, c_trg):
|
||||
"""Get discriminator loss"""
|
||||
# Compute loss with real images
|
||||
real_src, real_cls = self.net_D(x_real)
|
||||
|
||||
D_real_loss = - self.cyc_loss(real_src)
|
||||
D_real_cls_loss = self.cls_loss(real_cls, c_org)
|
||||
|
||||
# Compute loss with fake images
|
||||
x_fake = self.net_G(x_real, c_trg)
|
||||
fake_src, _ = self.net_D(x_fake)
|
||||
D_fake_loss = self.cyc_loss(fake_src)
|
||||
|
||||
D_gp_loss = self.WGANLoss(x_real, x_fake)
|
||||
|
||||
d_loss = D_real_loss + D_fake_loss + self.lambda_cls * D_real_cls_loss + self.lambda_gp *D_gp_loss
|
||||
|
||||
return (d_loss, D_real_loss, D_fake_loss, D_real_cls_loss, D_gp_loss)
|
|
@ -0,0 +1,142 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Define Generator and Discriminator for StarGAN"""
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as P
|
||||
from mindspore import set_seed, Tensor
|
||||
from mindspore.common import initializer as init
|
||||
|
||||
|
||||
set_seed(1)
|
||||
np.random.seed(1)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""Residual Block with instance normalization."""
|
||||
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.main = nn.SequentialCell(
|
||||
nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, pad_mode='pad', has_bias=False),
|
||||
nn.GroupNorm(num_groups=dim_out, num_channels=dim_out),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, pad_mode='pad', has_bias=False),
|
||||
nn.GroupNorm(num_groups=dim_out, num_channels=dim_out)
|
||||
)
|
||||
|
||||
def construct(self, x):
|
||||
return x + self.main(x)
|
||||
|
||||
|
||||
class Generator(nn.Cell):
|
||||
"""Generator network."""
|
||||
|
||||
def __init__(self, conv_dim=64, c_dim=4, repeat_num=6):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
layers = []
|
||||
layers.append((nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1,
|
||||
padding=3, pad_mode='pad', has_bias=False)))
|
||||
layers.append(nn.GroupNorm(num_groups=conv_dim, num_channels=conv_dim))
|
||||
layers.append(nn.ReLU())
|
||||
|
||||
# Down-sampling layers.
|
||||
curr_dim = conv_dim
|
||||
for _ in range(2):
|
||||
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2,
|
||||
padding=1, pad_mode='pad', has_bias=False))
|
||||
layers.append(nn.GroupNorm(num_groups=curr_dim*2, num_channels=curr_dim*2))
|
||||
layers.append(nn.ReLU())
|
||||
curr_dim = curr_dim*2
|
||||
|
||||
# Bottleneck layers.
|
||||
for _ in range(repeat_num):
|
||||
layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
|
||||
|
||||
# Up-sampling layers.
|
||||
for _ in range(2):
|
||||
layers.append(nn.Conv2dTranspose(curr_dim, int(curr_dim/2), kernel_size=4, stride=2,
|
||||
padding=1, pad_mode='pad', has_bias=False))
|
||||
layers.append(nn.GroupNorm(num_groups=int(curr_dim/2), num_channels=int(curr_dim/2)))
|
||||
layers.append(nn.ReLU())
|
||||
curr_dim = curr_dim // 2
|
||||
|
||||
layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, pad_mode='pad', has_bias=False))
|
||||
layers.append(nn.Tanh())
|
||||
self.main = nn.SequentialCell(*layers)
|
||||
|
||||
def construct(self, x, c):
|
||||
reshape = P.Reshape()
|
||||
c = reshape(c, (c.shape[0], c.shape[1], 1, 1))
|
||||
c = P.functional.reshape(c, (c.shape[0], c.shape[1], 1, 1))
|
||||
tile = P.Tile()
|
||||
c = tile(c, (1, 1, x.shape[2], x.shape[3]))
|
||||
op = P.Concat(1)
|
||||
x = op((x, c))
|
||||
return self.main(x)
|
||||
|
||||
|
||||
class ResidualBlock_2(nn.Cell):
|
||||
"""Residual Block with instance normalization."""
|
||||
|
||||
def __init__(self, weight, dim_in, dim_out):
|
||||
super(ResidualBlock_2, self).__init__()
|
||||
self.main = nn.SequentialCell(
|
||||
nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1,
|
||||
pad_mode='pad', has_bias=False, weight_init=Tensor(weight[0])),
|
||||
nn.GroupNorm(num_groups=dim_out, num_channels=dim_out),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1,
|
||||
pad_mode='pad', has_bias=False, weight_init=Tensor(weight[3])),
|
||||
nn.GroupNorm(num_groups=dim_out, num_channels=dim_out)
|
||||
)
|
||||
|
||||
def construct(self, x):
|
||||
return x + self.main(x)
|
||||
|
||||
|
||||
class Discriminator(nn.Cell):
|
||||
"""Discriminator network with PatchGAN."""
|
||||
|
||||
def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
|
||||
super(Discriminator, self).__init__()
|
||||
layers = []
|
||||
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1, has_bias=True,
|
||||
pad_mode='pad', bias_init=init.Uniform(1 / math.sqrt(3))))
|
||||
layers.append(nn.LeakyReLU(alpha=0.01))
|
||||
|
||||
curr_dim = conv_dim
|
||||
for _ in range(1, repeat_num):
|
||||
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, has_bias=True,
|
||||
pad_mode='pad', bias_init=init.Uniform(1 / math.sqrt(curr_dim))))
|
||||
layers.append(nn.LeakyReLU(alpha=0.01))
|
||||
curr_dim = curr_dim * 2
|
||||
|
||||
kernel_size = int(image_size / np.power(2, repeat_num))
|
||||
self.main = nn.SequentialCell(*layers)
|
||||
# Patch GAN输出结果
|
||||
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, pad_mode='pad', has_bias=False)
|
||||
self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, has_bias=False, pad_mode='valid')
|
||||
|
||||
def construct(self, x):
|
||||
h = self.main(x)
|
||||
out_src = self.conv1(h)
|
||||
out_cls = self.conv2(h)
|
||||
reshape = P.Reshape()
|
||||
out_cls = reshape(out_cls, (out_cls.shape[0], out_cls.shape[1]))
|
||||
return out_src, out_cls
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Reporter class."""
|
||||
import logging
|
||||
import time
|
||||
import datetime
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class Reporter(logging.Logger):
|
||||
"""
|
||||
This class includes several functions that can save images/checkpoints and print/save logging information.
|
||||
|
||||
Args:
|
||||
args (class): Option class.
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
super(Reporter, self).__init__("StarGAN")
|
||||
|
||||
self.epoch = 0
|
||||
self.step = 0
|
||||
self.print_iter = 50
|
||||
self.G_loss = []
|
||||
self.D_loss = []
|
||||
self.total_step = args.num_iters
|
||||
self.runs_step = 0
|
||||
|
||||
def epoch_start(self):
|
||||
self.step_start_time = time.time()
|
||||
self.epoch_start_time = time.time()
|
||||
self.step = 0
|
||||
self.epoch += 1
|
||||
self.G_loss = []
|
||||
self.D_loss = []
|
||||
|
||||
def print_info(self, start_time, step, lossG, lossD):
|
||||
"""print log after some steps."""
|
||||
resG, resD, _, _ = self.return_loss_array(lossG, lossD)
|
||||
if self.step % self.print_iter == 0:
|
||||
# step_cost = str(round(float(time.time() - start_time) * 1000 / self.print_iter,2))
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
losses = "D_loss: [{:.3f}], G_loss: [{:.3f}].\nD_real_loss: {:.3f}, " \
|
||||
"D_fake_loss: {:.3f}, D_real_cls_loss: {:.3f}, " \
|
||||
"D_gp_loss: {:.3f}, G_fake_loss: {:.3f}, " \
|
||||
"G_fake_cls_loss: {:.3f}, G_rec_loss: {:.3f}".format(
|
||||
resD[0], resG[0], resD[1], resD[2], resD[3], resD[4], resG[1], resG[2], resG[3])
|
||||
print("Step [{}/{}] Elapsed [{} s], {}".format(
|
||||
step + 1, self.total_step, elapsed[:-7], losses))
|
||||
|
||||
def return_loss_array(self, lossG, lossD):
|
||||
"""Transform output to loooooss array"""
|
||||
resG = []
|
||||
Glist = ['G_loss', 'G_fake_loss', 'G_fake_cls_loss', 'G_rec_loss']
|
||||
dict_G = {'G_loss': 0., 'G_fake_loss': 0., 'G_fake_cls_loss': 0., 'G_rec_loss': 0.}
|
||||
self.G_loss.append(float(lossG[2].asnumpy()))
|
||||
for i, item in enumerate(lossG[2:]):
|
||||
resG.append(float(item.asnumpy()))
|
||||
dict_G[Glist[i]] = Tensor(float(item.asnumpy()))
|
||||
resD = []
|
||||
Dlist = ['Dloss', 'D_real_loss', 'D_fake_loss', 'D_real_cls_loss', 'D_gp_loss']
|
||||
dict_D = {'Dloss': 0., 'D_real_loss': 0., 'D_fake_loss': 0., 'D_real_cls_loss': 0., 'D_gp_loss': 0.}
|
||||
self.D_loss.append(float(lossD[1].asnumpy()))
|
||||
for i, item in enumerate(lossD[1:]):
|
||||
resD.append(float(item.asnumpy()))
|
||||
dict_D[Dlist[i]] = Tensor(float(item.asnumpy()))
|
||||
|
||||
return resG, resD, dict_G, dict_D
|
||||
|
||||
def lr_decay_info(self, step, G_lr, D_lr):
|
||||
"""print log after learning rate decay"""
|
||||
print('Decayed learning rates in step {}, g_lr: {}, d_lr: {}.'.format(step, G_lr, D_lr))
|
||||
|
||||
def epoch_end(self):
|
||||
"""print log and save cgeckpoints when epoch end."""
|
||||
epoch_cost = (time.time() - self.epoch_start_time) * 1000
|
||||
pre_step_time = epoch_cost / self.step
|
||||
mean_loss_G = sum(self.G_loss) / self.step
|
||||
mean_loss_D = sum(self.D_loss) / self.step
|
||||
self.info("Epoch [{}] total cost: {:.2f} ms, pre step: {:.2f} ms, G_loss: {:.2f}, D_loss: {:.2f}".format(
|
||||
self.epoch, epoch_cost, pre_step_time, mean_loss_G, mean_loss_D))
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset distributed sampler."""
|
||||
from __future__ import division
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from mindspore import load_checkpoint
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
from src.cell import init_weights
|
||||
from src.model import Generator, Discriminator
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
"""Distributed sampler."""
|
||||
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=False):
|
||||
if num_replicas is None:
|
||||
print("***********Setting world_size to 1 since it is not passed in ******************")
|
||||
num_replicas = 1
|
||||
if rank is None:
|
||||
print("***********Setting rank to 0 since it is not passed in ******************")
|
||||
rank = 0
|
||||
|
||||
self.dataset_size = dataset_size
|
||||
self.num_replicas = num_replicas
|
||||
self.epoch = 0
|
||||
self.rank = rank
|
||||
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
# shuffle based on epoch
|
||||
if self.shuffle:
|
||||
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
|
||||
indices = indices.tolist()
|
||||
self.epoch += 1
|
||||
else:
|
||||
indices = list(range(self.dataset_size))
|
||||
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank: self.total_size: self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
def resume_model(config, G, D):
|
||||
"""Restore the trained generator and discriminator."""
|
||||
print('Loading the trained models from step {}...'.format(config.resume_iters))
|
||||
G_path = os.path.join(config.model_save_dir, f"Generator_2-0_%d.ckpt" % config.resume_iters)
|
||||
# D_path = os.path.join(config.model_save_dir, f"Net_D_%d.ckpt" % config.resume_iters)
|
||||
param_G = load_checkpoint(G_path, G)
|
||||
# param_D = load_checkpoint(D_path, D)
|
||||
|
||||
return param_G, D
|
||||
|
||||
|
||||
def print_network(model, name):
|
||||
"""Print out the network information."""
|
||||
num_params = 0
|
||||
for p in model.trainable_params():
|
||||
num_params += np.prod(p.shape)
|
||||
print(model)
|
||||
print(name)
|
||||
print('The number of parameters: {}'.format(num_params))
|
||||
|
||||
|
||||
def get_lr(init_lr, total_step, update_step, num_iters_decay):
|
||||
"""Get changed learning rate."""
|
||||
lr_each_step = []
|
||||
lr = init_lr
|
||||
for i in range(total_step):
|
||||
if (i+1) % update_step == 0 and (i+1) > total_step-num_iters_decay:
|
||||
lr = lr - (init_lr / float(num_iters_decay))
|
||||
if lr < 0:
|
||||
lr = 1e-6
|
||||
lr_each_step.append(lr)
|
||||
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def get_network(args):
|
||||
"""Create and initial a generator and a discriminator."""
|
||||
|
||||
G = Generator(args.g_conv_dim, args.c_dim, args.g_repeat_num)
|
||||
D = Discriminator(args.image_size, args.d_conv_dim, args.c_dim, args.d_repeat_num)
|
||||
|
||||
init_weights(G, 'KaimingUniform', math.sqrt(5))
|
||||
init_weights(D, 'KaimingUniform', math.sqrt(5))
|
||||
|
||||
print_network(G, 'Generator')
|
||||
print_network(D, 'Discriminator')
|
||||
|
||||
return G, D
|
||||
|
||||
|
||||
def create_labels(c_org, c_dim=5, selected_attrs=None):
|
||||
"""Generate target domain labels for debugging and testing"""
|
||||
# Get hair color indices.
|
||||
hair_color_indices = []
|
||||
for i, attr_name in enumerate(selected_attrs):
|
||||
if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
|
||||
hair_color_indices.append(i)
|
||||
c_trg_list = []
|
||||
for i in range(c_dim):
|
||||
c_trg = c_org.copy()
|
||||
if i in hair_color_indices:
|
||||
c_trg[:, i] = 1
|
||||
for j in hair_color_indices:
|
||||
if j != i:
|
||||
c_trg[:, j] = 0
|
||||
else:
|
||||
c_trg[:, i] = (c_trg[:, i] == 0)
|
||||
c_trg_list.append(c_trg)
|
||||
|
||||
c_trg_list = Tensor(c_trg_list, mstype.float16)
|
||||
return c_trg_list
|
||||
|
||||
def denorm(x):
|
||||
image_numpy = (np.transpose(x, (0, 2, 3, 1)) + 1) / 2.0 * 255.0
|
||||
image_numpy = np.clip(image_numpy, 0, 255)
|
||||
return image_numpy
|
|
@ -0,0 +1,219 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Train the model."""
|
||||
from time import time
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import nn
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, _InternalCallbackParam, RunContext
|
||||
|
||||
|
||||
from src.dataset import dataloader
|
||||
from src.config import get_config
|
||||
from src.utils import get_network
|
||||
from src.cell import TrainOneStepCellGen, TrainOneStepCellDis
|
||||
from src.loss import GeneratorLoss, DiscriminatorLoss, ClassificationLoss, WGANGPGradientPenalty
|
||||
from src.reporter import Reporter
|
||||
|
||||
set_seed(1)
|
||||
|
||||
# Modelarts
|
||||
parser = argparse.ArgumentParser(description='StarGAN_args')
|
||||
parser.add_argument('--modelarts', type=ast.literal_eval, default=False, help='Dataset path')
|
||||
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
|
||||
parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id, default: 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="number of device, default: 0.")
|
||||
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
config = get_config()
|
||||
if args_opt.modelarts:
|
||||
import moxing as mox
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
|
||||
context.set_context(device_id=device_id)
|
||||
local_data_url = './cache/data'
|
||||
local_train_url = '/cache/ckpt'
|
||||
|
||||
local_data_url = os.path.join(local_data_url, str(device_id))
|
||||
local_train_url = os.path.join(local_train_url, str(device_id))
|
||||
|
||||
# unzip data
|
||||
path = os.getcwd()
|
||||
print("cwd: %s" % path)
|
||||
data_url = 'obs://hit-wcy/data/CelebA/'
|
||||
|
||||
data_name = '/celeba.zip'
|
||||
print('listdir1: %s' % os.listdir('./'))
|
||||
|
||||
a1time = time()
|
||||
mox.file.copy_parallel(data_url, local_data_url)
|
||||
print('listdir2: %s' % os.listdir(local_data_url))
|
||||
b1time = time()
|
||||
print('time1:', b1time - a1time)
|
||||
|
||||
a2time = time()
|
||||
zip_command = "unzip -o %s -d %s" % (local_data_url + data_name, local_data_url)
|
||||
if os.system(zip_command) == 0:
|
||||
print('Successful backup')
|
||||
else:
|
||||
print('FAILED backup')
|
||||
b2time = time()
|
||||
print('time2:', b2time - a2time)
|
||||
print('listdir3: %s' % os.listdir(local_data_url))
|
||||
|
||||
# Device Environment
|
||||
if config.run_distribute:
|
||||
if config.device_target == "Ascend":
|
||||
rank = device_id
|
||||
# device_num = device_num
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
data_path = local_data_url + '/celeba/images'
|
||||
attr_path = local_data_url + '/celeba/list_attr_celeba.txt'
|
||||
dataset, length = dataloader(img_path=data_path,
|
||||
attr_path=attr_path,
|
||||
batch_size=config.batch_size,
|
||||
selected_attr=config.selected_attrs,
|
||||
device_num=config.num_workers,
|
||||
dataset=config.dataset,
|
||||
mode=config.mode,
|
||||
shuffle=True)
|
||||
|
||||
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target,
|
||||
device_id=config.device_id, save_graphs=False)
|
||||
if args_opt.run_distribute:
|
||||
if os.getenv("DEVICE_ID", "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv("DEVICE_ID")))
|
||||
device_num = config.device_num
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=device_num)
|
||||
init()
|
||||
|
||||
rank = get_rank()
|
||||
|
||||
data_path = config.celeba_image_dir
|
||||
attr_path = config.attr_path
|
||||
local_train_url = config.model_save_dir
|
||||
dataset, length = dataloader(img_path=data_path,
|
||||
attr_path=attr_path,
|
||||
batch_size=config.batch_size,
|
||||
selected_attr=config.selected_attrs,
|
||||
device_num=config.device_num,
|
||||
dataset=config.dataset,
|
||||
mode=config.mode,
|
||||
shuffle=True)
|
||||
print(length)
|
||||
dataset_iter = dataset.create_dict_iterator()
|
||||
|
||||
# Get and initial network
|
||||
generator, discriminator = get_network(config)
|
||||
|
||||
cls_loss = ClassificationLoss()
|
||||
wgan_loss = WGANGPGradientPenalty(discriminator)
|
||||
|
||||
# Define network with loss
|
||||
G_loss_cell = GeneratorLoss(config, generator, discriminator)
|
||||
D_loss_cell = DiscriminatorLoss(config, generator, discriminator)
|
||||
|
||||
# Define Optimizer
|
||||
star_iter = 0
|
||||
iter_sum = config.num_iters
|
||||
|
||||
Optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=config.g_lr,
|
||||
beta1=config.beta1, beta2=config.beta2)
|
||||
Optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=config.d_lr,
|
||||
beta1=config.beta1, beta2=config.beta2)
|
||||
|
||||
# Define One step train
|
||||
G_trainOneStep = TrainOneStepCellGen(G_loss_cell, Optimizer_G)
|
||||
D_trainOneStep = TrainOneStepCellDis(D_loss_cell, Optimizer_D)
|
||||
|
||||
# Train
|
||||
G_trainOneStep.set_train()
|
||||
D_trainOneStep.set_train()
|
||||
|
||||
print('Start Training')
|
||||
|
||||
reporter = Reporter(config)
|
||||
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.model_save_step)
|
||||
ckpt_cb_g = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='Generator')
|
||||
ckpt_cb_d = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='Discriminator')
|
||||
|
||||
cb_params_g = _InternalCallbackParam()
|
||||
cb_params_g.train_network = generator
|
||||
cb_params_g.cur_step_num = 0
|
||||
cb_params_g.batch_num = 4
|
||||
cb_params_g.cur_epoch_num = 0
|
||||
|
||||
cb_params_d = _InternalCallbackParam()
|
||||
cb_params_d.train_network = discriminator
|
||||
cb_params_d.cur_step_num = 0
|
||||
cb_params_d.batch_num = config.batch_size
|
||||
cb_params_d.cur_epoch_num = 0
|
||||
run_context_g = RunContext(cb_params_g)
|
||||
run_context_d = RunContext(cb_params_d)
|
||||
ckpt_cb_g.begin(run_context_g)
|
||||
ckpt_cb_d.begin(run_context_d)
|
||||
start = time()
|
||||
|
||||
for iterator in range(config.num_iters):
|
||||
data = next(dataset_iter)
|
||||
x_real = Tensor(data['image'], mstype.float32)
|
||||
c_trg = Tensor(data['attr'], mstype.float32)
|
||||
c_org = Tensor(data['attr'], mstype.float32)
|
||||
np.random.shuffle(c_trg)
|
||||
|
||||
d_out = D_trainOneStep(x_real, c_org, c_trg)
|
||||
|
||||
if (iterator + 1) % config.n_critic == 0:
|
||||
g_out = G_trainOneStep(x_real, c_org, c_trg)
|
||||
|
||||
if (iterator + 1) % config.log_step == 0:
|
||||
reporter.print_info(start, iterator, g_out, d_out)
|
||||
_, _, dict_G, dict_D = reporter.return_loss_array(g_out, d_out)
|
||||
|
||||
if (iterator + 1) % config.model_save_step == 0:
|
||||
cb_params_d.cur_step_num = iterator + 1
|
||||
cb_params_d.batch_num = iterator + 2
|
||||
cb_params_g.cur_step_num = iterator + 1
|
||||
cb_params_g.batch_num = iterator + 2
|
||||
ckpt_cb_g.step_end(run_context_g)
|
||||
ckpt_cb_d.step_end(run_context_d)
|
||||
|
||||
if args_opt.modelarts:
|
||||
mox.file.copy_parallel(local_train_url, config.train_url)
|
Loading…
Reference in New Issue