!19835 AttGAN commit -2021/07/09

Merge pull request !19835 from MR.D/master
This commit is contained in:
i-robot 2021-07-13 07:02:08 +00:00 committed by Gitee
commit ab5b963495
22 changed files with 1760 additions and 0 deletions

View File

@ -0,0 +1,223 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [AttGAN描述](#AttGAN描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [训练过程](#训练过程)
- [训练](#训练)
- [分布式训练](#分布式训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [推理过程](#推理过程)
- [导出MindIR](#导出MindIR)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [CelebA上的AttGAN](#CelebA上的AttGAN)
- [推理性能](#推理性能)
- [CelebA上的AttGAN](#CelebA上的AttGAN)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# AttGAN描述
AttGAN指的是AttGAN: Facial Attribute Editing by Only Changing What You Want, 这个网络的特点是可以在不影响面部其它属性的情况下修改人脸属性。
[论文](https://arxiv.org/abs/1711.10678)[Zhenliang He](https://github.com/LynnHo/AttGAN-Tensorflow), [Wangmeng Zuo](https://github.com/LynnHo/AttGAN-Tensorflow), [Meina Kan](https://github.com/LynnHo/AttGAN-Tensorflow), [Shiguang Shan](https://github.com/LynnHo/AttGAN-Tensorflow), [Xilin Chen](https://github.com/LynnHo/AttGAN-Tensorflow), et al. AttGAN: Facial Attribute Editing by Only Changing What You Want[C]// 2017 CVPR. IEEE
# 模型架构
整个网络结构由一个生成器和一个判别器构成生成器由编码器和解码器构成。该模型移除了严格的attribute-independent约束仅需要通过attribute classification来保证正确地修改属性同时整合了attribute classification constraint、adversarial learning和reconstruction learning具有较好的修改面部属性的效果。
# 数据集
使用的数据集: [CelebA](<http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>)
CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据集,拥有超过 200K 的名人图像,每个图像有 40 个属性注释。 CelebA 多样性大、数量多、注释丰富,包括
- 10,177 number of identities,
- 202,599 number of face images, and 5 landmark locations, 40 binary attributes annotations per image.
该数据集可用作以下计算机视觉任务的训练和测试集:人脸属性识别、人脸检测以及人脸编辑和合成。
# 环境要求
- 硬件Ascend
- 使用Ascend来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- Ascend处理器环境运行
```python
# 运行训练示例
export DEVICE_ID=0
export RANK_SIZE=1
python train.py --experiment_name 128_shortcut1_inject1_none --data_path /path/data/img_align_celeba --attr_path /path/data/list_attr_celeba.txt
OR
bash run_single_train.sh experiment_name /path/data/img_align_celeba /path/data/list_attr_celeba
# 运行分布式训练示例
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba
# 运行评估示例
export DEVICE_ID=0
export RANK_SIZE=1
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --enc_ckpt_name encoder-119_84999.ckpt --dec_ckpt_name decoder-119_84999.ckpt
OR
bash run_eval.sh experiment_name /path/data/custom/ /path/data/list_attr_custom enc_ckpt_name dec_ckpt_name
```
对于分布式训练需要提前创建JSON格式的hccl配置文件。该配置文件的绝对路径作为运行分布式脚本的第一个参数。
请遵循以下链接中的说明:
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
对于评估脚本,需要提前创建存放自定义图片(jpg)的目录以及属性编辑文件,关于属性编辑文件的说明见[脚本及样例代码](#脚本及样例代码)。目录以及属性编辑文件分别对应参数`custom_data`和`custom_attr`。checkpoint文件被训练脚本默认放置在
`/output/{experiment_name}/checkpoint`目录下执行脚本时需要将两个检查点文件Encoder和Decoder的名称作为参数传入。
[注意] 以上路径均应设置为绝对路径
# 脚本说明
## 脚本及样例代码
```text
.
└─ cv
└─ AttGAN
├─ data
├─ custom # 自定义图像目录
├─ list_attr_custom.txt # 属性控制文件
├── scripts
├──run_distribute_train.sh # 分布式训练的shell脚本
├──run_single_train.sh # 单卡训练的shell脚本
├──run_eval.sh # 推理脚本
├─ src
├─ __init__.py # 初始化文件
├─ block.py # 基础cell
├─ attgan.py # 生成网络和判别网络
├─ utils.py # 辅助函数
├─ cell.py # loss网络wrapper
├─ data.py # 数据加载
├─ helpers.py # 进度条显示
├─ loss.py # loss计算
├─ eval.py # 测试脚本
├─ train.py # 训练脚本
└─ README_CN.md # AttGAN的文件描述
```
上述脚本目录中custom用于存放用户想要修改属性的图片文件list_attr_custom.txt用于设置想要修改的属性该脚本可以修改13种属性分别为Bald Bangs Black_Hair Blond_Hair Brown_Hair Bushy_Eyebrows Eyeglasses Male Mouth_Slightly_Open Mustache No_Beard Pale_Skin Young。
list_attr_custom.txt文件中第一行参数为要评估的图片数量第二行为13种属性接下来的行表示对相应图片想要进行修改的属性如果要修改为1不要修改为-1如(xxx.jpg -1 -1 -1 1 -1 1 -1 1 1 -1 1 -1 -1)。
## 训练过程
### 训练
- Ascend处理器环境运行
```bash
export DEVICE_ID=0
export RANK_SIZE=1
python train.py --img_size 128 --experiment_name 128_shortcut1_inject1_none --data_path /path/data/img_align_celeba --attr_path /path/data/list_attr_celeba.txt
```
训练结束后当前目录下会生成output目录在该目录下会根据你设置的experiment_name参数生成相应的子目录训练时的参数保存在该子目录下的setting.txt文件中checkpoint文件保存在`output/experiment_name/rank0`下。
### 分布式训练
- Ascend处理器环境运行
```bash
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba
```
上述shell脚本将在后台运行分布式训练。该脚本将在脚本目录下生成相应的LOG{RANK_ID}目录每个进程的输出记录在相应LOG{RANK_ID}目录下的log.txt文件中。checkpoint文件保存在output/experiment_name/rank{RANK_ID}下。
## 评估过程
### 评估
- 在Ascend环境运行时评估自定义数据集
该网络可以用于修改面部属性用户将希望修改的图片放在自定义的图片目录下并根据自己期望修改的属性来编辑list_attr_custom.txt文件(文件的具体参数见[脚本及样例代码](#脚本及样例代码))。完成后需要将自定义图片目录和属性编辑文件作为参数传入测试脚本分别对应custom_data以及custom_attr。
评估时选择已经生成好的检查点文件,作为参数传入测试脚本,对应参数为`enc_ckpt_name`和`dec_ckpt_name`(分别保存了编码器和解码器的参数)
```bash
export DEVICE_ID=0
export RANK_SIZE=1
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --enc_ckpt_name encoder-119_84999.ckpt --dec_ckpt_name decoder-119_84999.ckpt
```
测试脚本执行完成后,用户进入当前目录下的`output/{experiment_name}/custom_img`下查看修改好的图片。
## 推理过程
### 导出MindIR
```shell
python export.py --experiment_name [EXPERIMENT_NAME] --enc_ckpt_name [ENCODER_CKPT_NAME] --dec_ckpt_name [DECODER_CKPT_NAME] --file_format [FILE_FORMAT]
```
`file_format` 必须在 ["AIR", "MINDIR"]中选择。
`experiment_name` 是output目录下的存放结果的文件夹的名称此参数用于帮助export寻找参数
脚本会在当前目录下生成对应的MINDIR文件。
# 模型描述
## 性能
### 评估性能
#### CelebA上的AttGAN
| 参数 | Ascend 910 |
| -------------------------- | ----------------------------------------------------------- |
| 模型版本 | AttGAN |
| 资源 | Ascend |
| 上传日期 | 06/30/2021 (month/day/year) |
| MindSpore版本 | 1.2.0 |
| 数据集 | CelebA |
| 训练参数 | batch_size=32, lr=0.0002 |
| 优化器 | Adam |
| 生成器输出 | image |
| 速度 | 5.56 step/s |
| 脚本 | [AttGAN script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/AttGAN) |
### 推理性能
#### CelebA上的AttGAN
| 参数 | Ascend 910 |
| -------------------------- | ----------------------------------------------------------- |
| 模型版本 | AttGAN |
| 资源 | Ascend |
| 上传日期 | 06/30/2021 (month/day/year) |
| MindSpore版本 | 1.2.0 |
| 数据集 | CelebA |
| 推理参数 | batch_size=1 |
| 输出 | image |
推理完成后可以获得对原图像进行属性编辑后的图片slide.
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 470 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

View File

@ -0,0 +1,8 @@
6
Bald Bangs Black_Hair Blond_Hair Brown_Hair Bushy_Eyebrows Eyeglasses Male Mouth_Slightly_Open Mustache No_Beard Pale_Skin Young
donald_trump.jpg -1 -1 -1 1 -1 1 -1 1 1 -1 1 -1 -1
emma_watson.jpg -1 1 -1 -1 1 -1 -1 -1 -1 -1 1 -1 1
jay_chou.jpeg -1 1 1 -1 -1 1 -1 1 1 -1 1 -1 1
ji-eun_lee.jpg -1 1 -1 -1 1 -1 -1 -1 1 -1 1 -1 1
tom_cruise.jpg -1 -1 -1 -1 1 1 -1 1 1 -1 1 -1 -1
yui_aragaki.jpg -1 1 -1 -1 1 -1 -1 -1 -1 -1 1 -1 1

View File

@ -0,0 +1,150 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Entry point for testing AttGAN network"""
import argparse
import json
import math
import os
from os.path import join
import numpy as np
from PIL import Image
import mindspore.common.dtype as mstype
import mindspore.dataset as de
from mindspore import context, Tensor, ops
from mindspore.train.serialization import load_param_into_net
from src.attgan import Genc, Gdec
from src.cell import init_weights
from src.data import check_attribute_conflict
from src.data import get_loader, Custom
from src.helpers import Progressbar
from src.utils import resume_model, denorm
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
def parse(arg=None):
"""Define configuration of Evaluation"""
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
parser.add_argument('--num_test', dest='num_test', type=int)
parser.add_argument('--enc_ckpt_name', type=str, default='')
parser.add_argument('--dec_ckpt_name', type=str, default='')
parser.add_argument('--custom_img', action='store_true')
parser.add_argument('--custom_data', type=str, default='../data/custom')
parser.add_argument('--custom_attr', type=str, default='../data/list_attr_custom.txt')
parser.add_argument('--shortcut_layers', dest='shortcut_layers', type=int, default=1)
parser.add_argument('--inject_layers', dest='inject_layers', type=int, default=1)
return parser.parse_args(arg)
args_ = parse()
print(args_)
with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
args.test_int = args_.test_int
args.num_test = args_.num_test
args.enc_ckpt_name = args_.enc_ckpt_name
args.dec_ckpt_name = args_.dec_ckpt_name
args.custom_img = args_.custom_img
args.custom_data = args_.custom_data
args.custom_attr = args_.custom_attr
args.shortcut_layers = args_.shortcut_layers
args.inject_layers = args_.inject_layers
args.n_attrs = len(args.attrs)
args.betas = (args.beta1, args.beta2)
print(args)
# Data loader
if args.custom_img:
output_path = join("output", args.experiment_name, "custom_testing")
os.makedirs(output_path, exist_ok=True)
test_dataset = Custom(args.custom_data, args.custom_attr, args.attrs)
test_len = len(test_dataset)
else:
output_path = join("output", args.experiment_name, "sample_testing")
os.makedirs(output_path, exist_ok=True)
test_dataset = get_loader(args.data_path, args.attr_path,
selected_attrs=args.attrs,
mode="test"
)
test_len = len(test_dataset)
dataset_column_names = ["image", "attr"]
num_parallel_workers = 8
ds = de.GeneratorDataset(test_dataset, column_names=dataset_column_names,
num_parallel_workers=min(32, num_parallel_workers))
ds = ds.batch(1, num_parallel_workers=min(8, num_parallel_workers), drop_remainder=False)
test_dataset_iter = ds.create_dict_iterator()
if args.num_test is None:
print('Testing images:', test_len)
else:
print('Testing images:', min(test_len, args.num_test))
# Model loader
genc = Genc(mode='test')
gdec = Gdec(shortcut_layers=args.shortcut_layers, inject_layers=args.inject_layers, mode='test')
# Initialize network
init_weights(genc, 'KaimingUniform', math.sqrt(5))
init_weights(gdec, 'KaimingUniform', math.sqrt(5))
para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name)
load_param_into_net(genc, para_genc)
load_param_into_net(gdec, para_gdec)
progressbar = Progressbar()
it = 0
for data in test_dataset_iter:
img_a = data["image"]
att_a = data["attr"]
if args.num_test is not None and it == args.num_test:
break
att_a = Tensor(att_a, mstype.float32)
att_b_list = [att_a]
for i in range(args.n_attrs):
clone = ops.Identity()
tmp = clone(att_a)
tmp[:, i] = 1 - tmp[:, i]
tmp = check_attribute_conflict(tmp, args.attrs[i], args.attrs)
att_b_list.append(tmp)
samples = [img_a]
for i, att_b in enumerate(att_b_list):
att_b_ = (att_b * 2 - 1) * args.thres_int
if i > 0:
att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int
a_enc = genc(img_a)
samples.append(gdec(a_enc, att_b_))
cat = ops.Concat(axis=3)
samples = cat(samples).asnumpy()
result = denorm(samples)
result = np.reshape(result, (128, -1, 3))
im = Image.fromarray(np.uint8(result))
if args.custom_img:
out_file = test_dataset.images[it]
else:
out_file = "{:06d}.jpg".format(it + 182638)
im.save(output_path + '/' + out_file)
print('Successful save image in ' + output_path + '/' + out_file)
it += 1

View File

@ -0,0 +1,65 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export file."""
import argparse
import json
from os.path import join
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import export, load_param_into_net
from src.utils import resume_model
from src.attgan import Genc, Gdec
parser = argparse.ArgumentParser(description='Attribute Edit')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument('--enc_ckpt_name', type=str, default='')
parser.add_argument('--dec_ckpt_name', type=str, default='')
parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='AIR', help='file format')
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
args_ = parser.parse_args()
print(args_)
with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
args.device_id = args_.device_id
args.batch_size = args_.batch_size
args.enc_ckpt_name = args_.enc_ckpt_name
args.dec_ckpt_name = args_.dec_ckpt_name
args.file_format = args_.file_format
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
if __name__ == '__main__':
genc = Genc(mode='test')
gdec = Gdec(mode='test')
para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name)
load_param_into_net(genc, para_genc)
load_param_into_net(gdec, para_gdec)
enc_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128, 128)).astype(np.float32))
dec_array = genc(enc_array)
input_label = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 13)).astype(np.float32))
G_enc_file = f"AttGAN_Generator_Encoder"
export(genc, enc_array, file_name=G_enc_file, file_format=args.file_format)
G_dec_file = f"AttGAN_Generator_Decoder"
export(gdec, *(dec_array, input_label), file_name=G_dec_file, file_format=args.file_format)

View File

@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH] [ATTR_PATH]"
exit 1
fi
export MINDSPORE_HCCL_CONFIG_PATH=$1
export RANK_TABLE_FILE=$1
export RANK_SIZE=8
export HCCL_CONNECT_TIMEOUT=600
echo "hccl connect time out has changed to 600 second"
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
data_path=$2
attr_path=$3
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "the number of logical core" $cores
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
core_gap=`expr $avg_core_per_rank \- 1`
echo "avg_core_per_rank" $avg_core_per_rank
echo "core_gap" $core_gap
for((i=0;i<RANK_SIZE;i++))
do
start=`expr $i \* $avg_core_per_rank`
export DEVICE_ID=$i
export RANK_ID=$i
export DEPLOY_MODE=0
export GE_USE_STATIC_MEMORY=1
end=`expr $start \+ $core_gap`
cmdopt=$start"-"$end
rm -rf LOG$i
mkdir ./LOG$i
cd ./LOG$i || exit
echo "Start training for rank $i, device $DEVICE_ID"
env > env.log
cd ../../
taskset -c $cmdopt python train.py \
--img_size 128 \
--shortcut_layers 1 \
--inject_layers 1 \
--experiment_name 128_shortcut1_inject1_none \
--data_path $data_path \
--attr_path $attr_path \
--run_distribute 1 > ./scripts/LOG$i/log.txt 2>&1 &
cd scripts
done

View File

@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 5 ]
then
echo "Usage: sh run_eval.sh [EXPERIMENT_NAME] [CUSTOM_DATA_PATH] [CUSTOM_ATTR_PATH] [ENC_CKPT_NAME] [DEC_CKPT_NAME]"
exit 1
fi
experiment_name=$1
data_path=$2
attr_path=$3
enc_name=$4
dec_name=$5
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "The number of logical core" $cores
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
rm -rf EVAL_LOG
mkdir ./EVAL_LOG
cd ./EVAL_LOG || exit
echo "Start training for rank 0, device 0, directory is EVAL_LOG"
env > env.log
cd ../../
python eval.py \
--experiment_name $experiment_name \
--test_int 1.0 \
--custom_data $data_path \
--custom_attr $attr_path \
--custom_img \
--enc_ckpt_name $enc_name \
--dec_ckpt_name $dec_name > ./scripts/EVAL_LOG/log.txt 2>&1 &

View File

@ -0,0 +1,48 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_single_train.sh [EXPERIMENT_NAME] [DATA_PATH] [ATTR_PATH]"
exit 1
fi
experiment_name=$1
data_path=$2
attr_path=$3
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "the number of logical core" $cores
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
rm -rf LOG
mkdir ./LOG
cd ./LOG || exit
echo "Start training for rank 0, device 0, directory is LOG"
env > env.log
cd ../../
python train.py \
--img_size 128 \
--shortcut_layers 1 \
--inject_layers 1 \
--experiment_name $experiment_name \
--data_path $data_path \
--attr_path $attr_path > ./scripts/LOG/log.txt 2>&1 &

View File

@ -0,0 +1,133 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AttGAN Network Topology"""
import mindspore.ops.operations as P
from mindspore import nn
from src.block import LinearBlock, Conv2dBlock, ConvTranspose2dBlock
# Image size 128 x 128
MAX_DIM = 64 * 16
class Genc(nn.Cell):
"""Generator encoder"""
def __init__(self, enc_dim=64, enc_layers=5, enc_norm_fn="batchnorm", enc_acti_fn="lrelu",
mode='test'):
super().__init__()
layers = []
n_in = 3
for i in range(enc_layers):
n_out = min(enc_dim * 2 ** i, MAX_DIM)
layers += [Conv2dBlock(
n_in, n_out, (4, 4), stride=2, padding=1, norm_fn=enc_norm_fn, acti_fn=enc_acti_fn, mode=mode
)]
n_in = n_out
self.enc_layers = nn.CellList(layers)
def construct(self, x):
"""Encoder construct"""
z = x
zs = []
for layer in self.enc_layers:
z = layer(z)
zs.append(z)
return zs
class Gdec(nn.Cell):
"""Generator decoder"""
def __init__(self, dec_dim=64, dec_layers=5, dec_norm_fn="batchnorm", dec_acti_fn="relu", n_attrs=13,
shortcut_layers=1, inject_layers=1, img_size=128, mode='test'):
super().__init__()
self.shortcut_layers = min(shortcut_layers, dec_layers - 1)
self.inject_layers = min(inject_layers, dec_layers - 1)
self.f_size = img_size // 2 ** dec_layers # f_size = 4 for 128x128
layers = []
n_in = 1024
n_in = n_in + n_attrs # 1024 + 13
for i in range(dec_layers):
if i < dec_layers - 1:
n_out = min(dec_dim * 2 ** (dec_layers - i - 1), MAX_DIM)
layers += [ConvTranspose2dBlock(
n_in, n_out, (4, 4), stride=2, padding=1, norm_fn=dec_norm_fn, acti_fn=dec_acti_fn, mode=mode
)]
n_in = n_out
n_in = n_in + n_in // 2 if self.shortcut_layers > i else n_in
n_in = n_in + n_attrs if self.inject_layers > i else n_in
else:
layers += [ConvTranspose2dBlock(
n_in, 3, (4, 4), stride=2, padding=1, norm_fn='none', acti_fn='tanh', mode=mode
)]
self.dec_layers = nn.CellList(layers)
self.view = P.Reshape()
self.repeat = P.Tile()
self.cat = P.Concat(1)
def construct(self, zs, a):
"""Decoder construct"""
a_tile = self.view(a, (a.shape[0], -1, 1, 1))
multiples = (1, 1, self.f_size, self.f_size)
a_tile = self.repeat(a_tile, multiples)
z = self.cat((zs[-1], a_tile))
i = 0
for layer in self.dec_layers:
z = layer(z)
if self.shortcut_layers > i:
z = self.cat((z, zs[len(self.dec_layers) - 2 - i]))
if self.inject_layers > i:
a_tile = self.view(a, (a.shape[0], -1, 1, 1))
multiples = (1, 1, self.f_size * 2 ** (i + 1), self.f_size * 2 ** (i + 1))
a_tile = self.repeat(a_tile, multiples)
z = self.cat((z, a_tile))
i = i + 1
return z
class Dis(nn.Cell):
"""Discriminator"""
def __init__(self, dim=64, norm_fn='none', acti_fn='lrelu',
fc_dim=1024, fc_norm_fn='none', fc_acti_fn='lrelu', n_layers=5, img_size=128, mode='test'):
super().__init__()
self.f_size = img_size // 2 ** n_layers
layers = []
n_in = 3
for i in range(n_layers):
n_out = min(dim * 2 ** i, MAX_DIM)
layers += [Conv2dBlock(
n_in, n_out, (4, 4), stride=2, padding=1, norm_fn=norm_fn, acti_fn=acti_fn, mode=mode
)]
n_in = n_out
self.conv = nn.SequentialCell(layers)
self.fc_adv = nn.SequentialCell(
[LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn, mode),
LinearBlock(fc_dim, 1, 'none', 'none', mode)])
self.fc_cls = nn.SequentialCell(
[LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn, mode),
LinearBlock(fc_dim, 13, 'none', 'none', mode)])
def construct(self, x):
"""construct"""
h = self.conv(x)
view = P.Reshape()
h = view(h, (h.shape[0], -1))
return self.fc_adv(h), self.fc_cls(h)

View File

@ -0,0 +1,100 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Network Component"""
from mindspore import nn
def add_normalization_1d(layers, fn, n_out, mode='test'):
if fn == "none":
pass
elif fn == "batchnorm":
layers.append(nn.BatchNorm1d(n_out, use_batch_statistics=(mode == 'train')))
elif fn == "instancenorm":
layers.append(nn.GroupNorm(n_out, n_out, affine=True))
else:
raise Exception('Unsupported normalization: ' + str(fn))
return layers
def add_normalization_2d(layers, fn, n_out, mode='test'):
if fn == 'none':
pass
elif fn == 'batchnorm':
layers.append(nn.BatchNorm2d(n_out, use_batch_statistics=(mode == 'train')))
elif fn == "instancenorm":
layers.append(nn.GroupNorm(n_out, n_out, affine=True))
else:
raise Exception('Unsupported normalization: ' + str(fn))
return layers
def add_activation(layers, fn):
"""Add Activation"""
if fn == "none":
pass
elif fn == "relu":
layers.append(nn.ReLU())
elif fn == "lrelu":
layers.append(nn.LeakyReLU(alpha=0.01))
elif fn == "sigmoid":
layers.append(nn.Sigmoid())
elif fn == "tanh":
layers.append(nn.Tanh())
else:
raise Exception('Unsupported activation function: ' + str(fn))
return layers
class LinearBlock(nn.Cell):
"""Linear Block"""
def __init__(self, n_in, n_out, norm_fn="none", acti_fn="none", mode='test'):
super().__init__()
layers = [nn.Dense(n_in, n_out, has_bias=(norm_fn == 'none'))]
layers = add_normalization_1d(layers, norm_fn, n_out, mode)
layers = add_activation(layers, acti_fn)
self.layers = nn.SequentialCell(layers)
def construct(self, x):
return self.layers(x)
class Conv2dBlock(nn.Cell):
"""Convolution Block"""
def __init__(self, n_in, n_out, kernel_size, stride=1, padding=0,
norm_fn=None, acti_fn=None, mode='test'):
super().__init__()
layers = [nn.Conv2d(n_in, n_out, kernel_size, stride=stride, padding=padding, pad_mode='pad',
has_bias=(norm_fn == 'none'))]
layers = add_normalization_2d(layers, norm_fn, n_out, mode)
layers = add_activation(layers, acti_fn)
self.layers = nn.SequentialCell(layers)
def construct(self, x):
return self.layers(x)
class ConvTranspose2dBlock(nn.Cell):
"""Transpose Convolution Block"""
def __init__(self, n_in, n_out, kernel_size, stride=1, padding=0,
norm_fn=None, acti_fn=None, mode='test'):
super().__init__()
layers = [nn.Conv2dTranspose(n_in, n_out, kernel_size, stride=stride, padding=padding, pad_mode='pad',
has_bias=(norm_fn == 'none'))]
layers = add_normalization_2d(layers, norm_fn, n_out, mode)
layers = add_activation(layers, acti_fn)
self.layers = nn.SequentialCell(layers)
def construct(self, x):
return self.layers(x)

View File

@ -0,0 +1,155 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cell Definition"""
import numpy as np
import mindspore.ops.functional as F
import mindspore.ops.operations as P
from mindspore import nn, ops
from mindspore.common import initializer as init, set_seed
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
_get_parallel_mode)
set_seed(1)
np.random.seed(1)
def init_weights(net, init_type='normal', init_gain=0.02):
"""
Initialize network weights.
Parameters:
net (Cell): Network to be initialized
init_type (str): The name of an initialization method: normal | xavier.
init_gain (float): Gain factor for normal and xavier.
"""
for _, cell in net.cells_and_names():
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
if init_type == 'normal':
cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
elif init_type == 'xavier':
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
elif init_type == 'KaimingUniform':
cell.weight.set_data(init.initializer(init.HeUniform(init_gain), cell.weight.shape))
elif init_type == 'constant':
cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif isinstance(cell, (nn.GroupNorm, nn.BatchNorm2d)):
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
class GenWithLossCell(nn.Cell):
"""
Wrap the network with loss function to return generator loss
"""
def __init__(self, network):
super().__init__(auto_prefix=False)
self.network = network
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
_, g_loss, _, _, _, = self.network(img_a, att_a, att_a_, att_b, att_b_)
return g_loss
class DisWithLossCell(nn.Cell):
"""
Wrap the network with loss function to return discriminator loss
"""
def __init__(self, network):
super().__init__(auto_prefix=False)
self.network = network
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
d_loss, _, _, _, _ = self.network(img_a, att_a, att_a_, att_b, att_b_)
return d_loss
class TrainOneStepCellGen(nn.Cell):
"""Encapsulation class of AttGAN generator network training."""
def __init__(self, generator, optimizer, sens=1.0):
super().__init__()
self.optimizer = optimizer
self.generator = generator
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.weights = optimizer.parameters
self.network = GenWithLossCell(generator)
self.network.add_flags(defer_inline=True)
self.reducer_flag = False
self.grad_reducer = F.identity
self.parallel_mode = _get_parallel_mode()
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
weights = self.weights
_, loss, gf_loss, gc_loss, gr_loss = self.generator(img_a, att_a, att_a_, att_b, att_b_)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(img_a, att_a, att_a_, att_b, att_b_, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads)), gf_loss, gc_loss, gr_loss
class TrainOneStepCellDis(nn.Cell):
"""Encapsulation class of AttGAN discriminator network training."""
def __init__(self, discriminator, optimizer, sens=1.0):
super().__init__()
self.optimizer = optimizer
self.discriminator = discriminator
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.weights = optimizer.parameters
self.network = DisWithLossCell(discriminator)
self.network.add_flags(defer_inline=True)
self.reducer_flag = False
self.grad_reducer = F.identity
self.parallel_mode = _get_parallel_mode()
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
weights = self.weights
loss, d_real_loss, d_fake_loss, dc_loss, df_gp = self.discriminator(img_a, att_a, att_a_, att_b, att_b_)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(img_a, att_a, att_a_, att_b, att_b_, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads)), d_real_loss, d_fake_loss, dc_loss, df_gp

View File

@ -0,0 +1,198 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" DataLoader: CelebA"""
import os
import numpy as np
from PIL import Image
import mindspore.dataset as de
import mindspore.dataset.vision.py_transforms as py_vision
from mindspore.dataset.transforms import py_transforms
from src.utils import DistributedSampler
class Custom:
"""
Custom data loader
"""
def __init__(self, data_path, attr_path, selected_attrs):
self.data_path = data_path
att_list = open(attr_path, 'r', encoding='utf-8').readlines()[1].split()
atts = [att_list.index(att) + 1 for att in selected_attrs]
images = np.loadtxt(attr_path, skiprows=2, usecols=[0], dtype=np.str)
labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int)
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
transform = [py_vision.ToPIL()]
transform.append(py_vision.Resize([128, 128]))
transform.append(py_vision.ToTensor())
transform.append(py_vision.Normalize(mean=mean, std=std))
transform = py_transforms.Compose(transform)
self.transform = transform
self.images = np.array([images]) if images.size == 1 else images[0:]
self.labels = np.array([labels]) if images.size == 1 else labels[0:]
self.length = len(self.images)
def __getitem__(self, index):
image = np.asarray(Image.open(os.path.join(self.data_path, self.images[index])))
att = np.asarray((self.labels[index] + 1) // 2)
image = np.squeeze(self.transform(image))
return image, att
def __len__(self):
return self.length
class CelebA:
"""
CelebA dataset
Input:
data_path: Image Path
attr_path: Attr_list Path
image_size: Image Size
mode: Train, Valid or Test
selected_attrs: selected attributes
transform: Image Processing
"""
def __init__(self, data_path, attr_path, image_size, mode, selected_attrs, transform):
super().__init__()
self.data_path = data_path
self.transform = transform
self.img_size = image_size
att_list = open(attr_path, 'r', encoding='utf-8').readlines()[1].split()
atts = [att_list.index(att) + 1 for att in selected_attrs]
images = np.loadtxt(attr_path, skiprows=2, usecols=[0], dtype=np.str)
labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int)
if mode == "train":
self.images = images[:182000]
self.labels = labels[:182000]
if mode == "valid":
self.images = images[182000:182637]
self.labels = labels[182000:182637]
if mode == "test":
self.images = images[182637:]
self.labels = labels[182637:]
self.length = len(self.images)
def __getitem__(self, index):
image = np.asarray(Image.open(os.path.join(self.data_path, self.images[index])))
att = np.asarray((self.labels[index] + 1) // 2)
image = np.squeeze(self.transform(image))
return image, att
def __len__(self):
return self.length
def get_loader(data_root, attr_path, selected_attrs, crop_size=170, image_size=128, mode="train"):
"""Build and return dataloader"""
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
transform = [py_vision.ToPIL()]
transform.append(py_vision.CenterCrop((crop_size, crop_size)))
transform.append(py_vision.Resize([image_size, image_size]))
transform.append(py_vision.ToTensor())
transform.append(py_vision.Normalize(mean=mean, std=std))
transform = py_transforms.Compose(transform)
dataset = CelebA(data_root, attr_path, image_size, mode, selected_attrs, transform)
return dataset
def data_loader(img_path, attr_path, selected_attrs, mode="train", batch_size=1, device_num=1, rank=0, shuffle=True):
"""CelebA data loader"""
num_parallel_workers = 8
dataset_loader = get_loader(img_path, attr_path, selected_attrs, mode=mode)
length_dataset = len(dataset_loader)
distributed_sampler = DistributedSampler(length_dataset, device_num, rank, shuffle=shuffle)
dataset_column_names = ["image", "attr"]
if device_num != 8:
ds = de.GeneratorDataset(dataset_loader, column_names=dataset_column_names,
num_parallel_workers=min(32, num_parallel_workers),
sampler=distributed_sampler)
ds = ds.batch(batch_size, num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True)
else:
ds = de.GeneratorDataset(dataset_loader, column_names=dataset_column_names, sampler=distributed_sampler)
ds = ds.batch(batch_size, num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True)
# ds = ds.repeat(200)
return ds, length_dataset
def check_attribute_conflict(att_batch, att_name, att_names):
"""Check Attributes"""
def _set(att, att_name):
if att_name in att_names:
att[att_names.index(att_name)] = 0.0
att_id = att_names.index(att_name)
for att in att_batch:
if att_name in ['Bald', 'Receding_Hairline'] and att[att_id] != 0:
_set(att, 'Bangs')
elif att_name == 'Bangs' and att[att_id] != 0:
_set(att, 'Bald')
_set(att, 'Receding_Hairline')
elif att_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'] and att[att_id] != 0:
for n in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
if n != att_name:
_set(att, n)
elif att_name in ['Straight_Hair', 'Wavy_Hair'] and att[att_id] != 0:
for n in ['Straight_Hair', 'Wavy_Hair']:
if n != att_name:
_set(att, n)
elif att_name in ['Mustache', 'No_Beard'] and att[att_id] != 0:
for n in ['Mustache', 'No_Beard']:
if n != att_name:
_set(att, n)
return att_batch
if __name__ == "__main__":
attrs_default = [
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'
]
parser = argparse.ArgumentParser()
parser.add_argument('--attrs', dest='attrs', default=attrs_default, nargs='+', help='attributes to test')
parser.add_argument('--data_path', dest='data_path', type=str, required=True)
parser.add_argument('--attr_path', dest='attr_path', type=str, required=True)
args = parser.parse_args()
####### Test CelebA #######
context.set_context(device_target="Ascend")
dataset_ce, length_ce = data_loader(args.data_path, args.attr_path, attrs_default, mode="train")
i = 0
for data in dataset_ce.create_dict_iterator():
print('Number:', i, 'Value:', data["attr"], 'Type:', type(data["attr"]))
i += 1

View File

@ -0,0 +1,44 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================s
"""Helper functions for training"""
import datetime
import platform
from tqdm import tqdm
def name_experiment(prefix="", suffix=""):
experiment_name = datetime.datetime.now().strftime('%b%d_%H-%M-%S_') + platform.node()
if prefix is not None and prefix != '':
experiment_name = prefix + '_' + experiment_name
if suffix is not None and suffix != '':
experiment_name = experiment_name + '_' + suffix
return experiment_name
class Progressbar():
"""Progress Bar"""
def __init__(self):
self.p = None
def __call__(self, iterable, length):
self.p = tqdm(iterable, total=length)
return self.p
def say(self, **kwargs):
if self.p is not None:
self.p.set_postfix(**kwargs)

View File

@ -0,0 +1,169 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================s
"""Loss Computation of Generator and Discriminator"""
import numpy as np
import mindspore
import mindspore.ops.operations as P
from mindspore import dtype as mstype
from mindspore import nn, Tensor, ops
from mindspore.ops import constexpr
class ClassificationLoss(nn.Cell):
"""Define classification loss for AttGAN"""
def __init__(self):
super().__init__()
self.bce_loss = P.BinaryCrossEntropy(reduction='sum')
def construct(self, pred, label):
weight = ops.Ones()(pred.shape, mindspore.float32)
pred_ = P.Sigmoid()(pred)
x = self.bce_loss(pred_, label, weight) / pred.shape[0]
return x
@constexpr
def generate_tensor(batch_size):
np_array = np.random.randn(batch_size, 1, 1, 1)
return Tensor(np_array, mindspore.float32)
class GradientWithInput(nn.Cell):
"""Get Discriminator Gradient with Input"""
def __init__(self, discriminator):
super().__init__()
self.reduce_sum = ops.ReduceSum()
self.discriminator = discriminator
self.discriminator.set_train(mode=True)
self.discriminator.set_grad(True)
def construct(self, interpolates):
decision_interpolate, _ = self.discriminator(interpolates)
decision_interpolate = self.reduce_sum(decision_interpolate, 0)
return decision_interpolate
class WGANGPGradientPenalty(nn.Cell):
"""Define WGAN loss for AttGAN"""
def __init__(self, discriminator):
super().__init__()
self.gradient_op = ops.GradOperation()
self.reduce_sum = ops.ReduceSum()
self.reduce_sum_keep_dim = ops.ReduceSum(keep_dims=True)
self.sqrt = ops.Sqrt()
self.discriminator = discriminator
self.GradientWithInput = GradientWithInput(discriminator)
self.GradientWithInput.set_grad(True)
def construct(self, x_real, x_fake):
"""get gradient penalty"""
batch_size = x_real.shape[0]
alpha = generate_tensor(batch_size)
alpha = alpha.expand_as(x_real)
x_fake = ops.functional.stop_gradient(x_fake)
x_hat = x_real + alpha * (x_fake - x_real)
gradient = self.gradient_op(self.GradientWithInput)(x_hat)
gradient_1 = ops.reshape(gradient, (batch_size, -1))
gradient_1 = self.sqrt(self.reduce_sum(gradient_1 * gradient_1, 1))
gradient_penalty = self.reduce_sum((gradient_1 - 1.0) ** 2) / x_real.shape[0]
return gradient_penalty
class GenLoss(nn.Cell):
"""Define total Generator loss"""
def __init__(self, args, encoder, decoder, discriminator):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.discriminator = discriminator
self.lambda_1 = Tensor(args.lambda_1, mstype.float32)
self.lambda_2 = Tensor(args.lambda_2, mstype.float32)
self.lambda_3 = Tensor(args.lambda_3, mstype.float32)
self.lambda_gp = Tensor(args.lambda_gp, mstype.float32)
self.cyc_loss = P.ReduceMean()
self.cls_loss = nn.BCEWithLogitsLoss()
self.rec_loss = nn.L1Loss("mean")
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
"""Get generator loss"""
# generate
zs_a = self.encoder(img_a)
img_fake = self.decoder(zs_a, att_b_)
img_recon = self.decoder(zs_a, att_a_)
# discriminate
d_fake, dc_fake = self.discriminator(img_fake)
# generator loss
gf_loss = - self.cyc_loss(d_fake)
gc_loss = self.cls_loss(dc_fake, att_b)
gr_loss = self.rec_loss(img_a, img_recon)
g_loss = gf_loss + self.lambda_2 * gc_loss + self.lambda_1 * gr_loss
return (img_fake, g_loss, gf_loss, gc_loss, gr_loss)
class DisLoss(nn.Cell):
"""Define total discriminator loss"""
def __init__(self, args, encoder, decoder, discriminator):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.discriminator = discriminator
self.cyc_loss = P.ReduceMean()
self.cls_loss = nn.BCEWithLogitsLoss()
self.WGANLoss = WGANGPGradientPenalty(discriminator)
self.lambda_1 = Tensor(args.lambda_1, mstype.float32)
self.lambda_2 = Tensor(args.lambda_2, mstype.float32)
self.lambda_3 = Tensor(args.lambda_3, mstype.float32)
self.lambda_gp = Tensor(args.lambda_gp, mstype.float32)
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
"""Get discriminator loss"""
# generate
z = self.encoder(img_a)
img_fake = self.decoder(z, att_b_)
# discriminate
d_real, dc_real = self.discriminator(img_a)
d_fake, _ = self.discriminator(img_fake)
# discriminator losses
d_real_loss = - self.cyc_loss(d_real)
d_fake_loss = self.cyc_loss(d_fake)
df_loss = d_real_loss + d_fake_loss
df_gp = self.WGANLoss(img_a, img_fake)
dc_loss = self.cls_loss(dc_real, att_a)
d_loss = df_loss + self.lambda_gp * df_gp + self.lambda_3 * dc_loss
return (d_loss, d_real_loss, d_fake_loss, dc_loss, df_gp)

View File

@ -0,0 +1,85 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Helper functions"""
import math
import os
import numpy as np
from mindspore import load_checkpoint
class DistributedSampler:
"""Distributed sampler."""
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=False):
if num_replicas is None:
print("***********Setting world_size to 1 since it is not passed in ******************")
num_replicas = 1
if rank is None:
print("***********Setting rank to 0 since it is not passed in ******************")
rank = 0
self.dataset_size = dataset_size
self.num_replicas = num_replicas
self.epoch = 0
self.rank = rank
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# shuffle based on epoch
if self.shuffle:
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
indices = indices.tolist()
self.epoch += 1
else:
indices = list(range(self.dataset_size))
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank: self.total_size: self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def resume_model(args, encoder, decoder, enc_ckpt_name, dec_ckpt_name):
"""Restore the trained generator and discriminator"""
print("Loading the trained models from step {}...".format(args.save_interval))
encoder_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', enc_ckpt_name)
decoder_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', dec_ckpt_name)
param_encoder = load_checkpoint(encoder_path, encoder)
param_decoder = load_checkpoint(decoder_path, decoder)
return param_encoder, param_decoder
def resume_discriminator(args, discriminator, dis_ckpt_name):
"""Restore the trained discriminator"""
print("Loading the trained models from step {}...".format(args.save_interval))
discriminator_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', dis_ckpt_name)
param_discriminator = load_checkpoint(discriminator_path, discriminator)
return param_discriminator
def denorm(x):
image_numpy = (np.transpose(x, (0, 2, 3, 1)) + 1) / 2.0 * 255.0
image_numpy = np.clip(image_numpy, 0, 255)
return image_numpy

View File

@ -0,0 +1,265 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================sss
"""Entry point for training AttGAN network"""
import argparse
import datetime
import json
import math
import os
from os.path import join
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, context
from mindspore import nn
from mindspore.common import set_seed
from mindspore.communication.management import init, get_rank
from mindspore.context import ParallelMode
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, _InternalCallbackParam, RunContext
from mindspore.train.serialization import load_param_into_net
from src.attgan import Genc, Gdec, Dis
from src.cell import TrainOneStepCellGen, TrainOneStepCellDis, init_weights
from src.data import data_loader
from src.helpers import Progressbar
from src.loss import GenLoss, DisLoss
from src.utils import resume_model, resume_discriminator
attrs_default = [
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'
]
def parse(arg=None):
"""Define configuration of Model"""
parser = argparse.ArgumentParser()
parser.add_argument('--attrs', dest='attrs', default=attrs_default, nargs='+', help='attributes to learn')
parser.add_argument('--data', dest='data', type=str, choices=['CelebA'], default='CelebA')
parser.add_argument('--data_path', dest='data_path', type=str, default='./data/img_align_celeba')
parser.add_argument('--attr_path', dest='attr_path', type=str, default='./data/list_attr_celeba.txt')
parser.add_argument('--img_size', dest='img_size', type=int, default=128)
parser.add_argument('--shortcut_layers', dest='shortcut_layers', type=int, default=1)
parser.add_argument('--inject_layers', dest='inject_layers', type=int, default=1)
parser.add_argument('--enc_dim', dest='enc_dim', type=int, default=64)
parser.add_argument('--dec_dim', dest='dec_dim', type=int, default=64)
parser.add_argument('--dis_dim', dest='dis_dim', type=int, default=64)
parser.add_argument('--dis_fc_dim', dest='dis_fc_dim', type=int, default=1024)
parser.add_argument('--enc_layers', dest='enc_layers', type=int, default=5)
parser.add_argument('--dec_layers', dest='dec_layers', type=int, default=5)
parser.add_argument('--dis_layers', dest='dis_layers', type=int, default=5)
parser.add_argument('--enc_norm', dest='enc_norm', type=str, default='batchnorm')
parser.add_argument('--dec_norm', dest='dec_norm', type=str, default='batchnorm')
parser.add_argument('--dis_norm', dest='dis_norm', type=str, default='instancenorm')
parser.add_argument('--dis_fc_norm', dest='dis_fc_norm', type=str, default='none')
parser.add_argument('--enc_acti', dest='enc_acti', type=str, default='lrelu')
parser.add_argument('--dec_acti', dest='dec_acti', type=str, default='relu')
parser.add_argument('--dis_acti', dest='dis_acti', type=str, default='lrelu')
parser.add_argument('--dis_fc_acti', dest='dis_fc_acti', type=str, default='relu')
parser.add_argument('--lambda_1', dest='lambda_1', type=float, default=100.0)
parser.add_argument('--lambda_2', dest='lambda_2', type=float, default=10.0)
parser.add_argument('--lambda_3', dest='lambda_3', type=float, default=1.0)
parser.add_argument('--lambda_gp', dest='lambda_gp', type=float, default=10.0)
parser.add_argument('--epochs', dest='epochs', type=int, default=200, help='# of epochs')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=32)
parser.add_argument('--num_workers', dest='num_workers', type=int, default=16)
parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='learning rate')
parser.add_argument('--beta1', dest='beta1', type=float, default=0.5)
parser.add_argument('--beta2', dest='beta2', type=float, default=0.999)
parser.add_argument('--n_d', dest='n_d', type=int, default=5, help='# of d updates per g update')
parser.add_argument('--thres_int', dest='thres_int', type=float, default=0.5)
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
parser.add_argument('--save_interval', dest='save_interval', type=int, default=500)
parser.add_argument('--experiment_name', dest='experiment_name',
default=datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))
parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.")
parser.add_argument('--resume_model', action='store_true')
parser.add_argument('--enc_ckpt_name', type=str, default='')
parser.add_argument('--dec_ckpt_name', type=str, default='')
parser.add_argument('--dis_ckpt_name', type=str, default='')
return parser.parse_args(arg)
args = parse()
print(args)
args.lr_base = args.lr
args.n_attrs = len(args.attrs)
# initialize environment
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
if args.run_distribute:
if os.getenv("DEVICE_ID", "not_set").isdigit():
context.set_context(device_id=int(os.getenv("DEVICE_ID")))
device_num = int(os.getenv('RANK_SIZE'))
print(device_num)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
init()
rank = get_rank()
else:
if os.getenv("DEVICE_ID", "not_set").isdigit():
context.set_context(device_id=int(os.getenv("DEVICE_ID")))
device_num = int(os.getenv('RANK_SIZE'))
rank = 0
print("Initialize successful!")
os.makedirs(join('output', args.experiment_name), exist_ok=True)
os.makedirs(join('output', args.experiment_name, 'checkpoint'), exist_ok=True)
with open(join('output', args.experiment_name, 'setting.txt'), 'w') as f:
f.write(json.dumps(vars(args), indent=4, separators=(',', ':')))
if __name__ == '__main__':
# Define dataloader
train_dataset, train_length = data_loader(img_path=args.data_path,
attr_path=args.attr_path,
selected_attrs=args.attrs,
mode="train",
batch_size=args.batch_size,
device_num=device_num,
shuffle=True)
train_loader = train_dataset.create_dict_iterator()
valid_dataset, valid_length = data_loader(img_path=args.data_path,
attr_path=args.attr_path,
selected_attrs=args.attrs,
mode="valid",
batch_size=args.batch_size,
device_num=device_num,
shuffle=False)
valid_loader = valid_dataset.create_dict_iterator()
print('Training images:', train_length, '/', 'Validating images:', valid_length)
# Define network
genc = Genc(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, mode='train')
gdec = Gdec(args.dec_dim, args.dec_layers, args.dec_norm, args.dec_acti, args.n_attrs, args.shortcut_layers,
args.inject_layers, args.img_size, mode='train')
dis = Dis(args.dis_dim, args.dis_norm, args.dis_acti, args.dis_fc_dim, args.dis_fc_norm, args.dis_fc_acti,
args.dis_layers, args.img_size, mode='train')
# Initialize network
init_weights(genc, 'KaimingUniform', math.sqrt(5))
init_weights(gdec, 'KaimingUniform', math.sqrt(5))
init_weights(dis, 'KaimingUniform', math.sqrt(5))
# Resume from checkpoint
if args.resume_model:
para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name)
para_dis = resume_discriminator(args, dis, args.dis_ckpt_name)
load_param_into_net(genc, para_genc)
load_param_into_net(gdec, para_gdec)
load_param_into_net(dis, para_dis)
# Define network with loss
G_loss_cell = GenLoss(args, genc, gdec, dis)
D_loss_cell = DisLoss(args, genc, gdec, dis)
# Define Optimizer
G_trainable_params = genc.trainable_params() + gdec.trainable_params()
optimizer_G = nn.Adam(params=G_trainable_params, learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
optimizer_D = nn.Adam(params=dis.trainable_params(), learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
# Define One Step Train
G_trainOneStep = TrainOneStepCellGen(G_loss_cell, optimizer_G)
D_trainOneStep = TrainOneStepCellDis(D_loss_cell, optimizer_D)
# Train
G_trainOneStep.set_train(True)
D_trainOneStep.set_train(True)
print("Start Training")
train_iter = train_length // args.batch_size
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_interval)
if rank == 0:
local_train_url = os.path.join('output', args.experiment_name, 'checkpoint/rank{}'.format(rank))
ckpt_cb_genc = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='encoder')
ckpt_cb_gdec = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='decoder')
ckpt_cb_dis = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='discriminator')
cb_params_genc = _InternalCallbackParam()
cb_params_genc.train_network = genc
cb_params_genc.cur_epoch_num = 0
genc_run_context = RunContext(cb_params_genc)
ckpt_cb_genc.begin(genc_run_context)
cb_params_gdec = _InternalCallbackParam()
cb_params_gdec.train_network = gdec
cb_params_gdec.cur_epoch_num = 0
gdec_run_context = RunContext(cb_params_gdec)
ckpt_cb_gdec.begin(gdec_run_context)
cb_params_dis = _InternalCallbackParam()
cb_params_dis.train_network = dis
cb_params_dis.cur_epoch_num = 0
dis_run_context = RunContext(cb_params_dis)
ckpt_cb_dis.begin(dis_run_context)
# Initialize Progressbar
progressbar = Progressbar()
it = 0
for epoch in range(args.epochs):
for data in progressbar(train_loader, train_iter):
img_a = data["image"]
att_a = data["attr"]
att_a = att_a.asnumpy()
att_b = np.random.permutation(att_a)
att_a_ = (att_a * 2 - 1) * args.thres_int
att_b_ = (att_b * 2 - 1) * args.thres_int
att_a = Tensor(att_a, mstype.float32)
att_a_ = Tensor(att_a_, mstype.float32)
att_b = Tensor(att_b, mstype.float32)
att_b_ = Tensor(att_b_, mstype.float32)
if (it + 1) % (args.n_d + 1) != 0:
d_out, d_real_loss, d_fake_loss, dc_loss, df_gp = D_trainOneStep(img_a, att_a, att_a_, att_b, att_b_)
else:
g_out, gf_loss, gc_loss, gr_loss = G_trainOneStep(img_a, att_a, att_a_, att_b, att_b_)
progressbar.say(epoch=epoch, iter=it + 1, d_loss=d_out, g_loss=g_out, gf_loss=gf_loss, gc_loss=gc_loss,
gr_loss=gr_loss, dc_loss=dc_loss, df_gp=df_gp)
if (epoch + 1) % 5 == 0 and (it + 1) % args.save_interval == 0 and rank == 0:
cb_params_genc.cur_epoch_num = epoch + 1
cb_params_gdec.cur_epoch_num = epoch + 1
cb_params_dis.cur_epoch_num = epoch + 1
cb_params_genc.cur_step_num = it + 1
cb_params_gdec.cur_step_num = it + 1
cb_params_dis.cur_step_num = it + 1
cb_params_genc.batch_num = it + 2
cb_params_gdec.batch_num = it + 2
cb_params_dis.batch_num = it + 2
ckpt_cb_genc.step_end(genc_run_context)
ckpt_cb_gdec.step_end(gdec_run_context)
ckpt_cb_dis.step_end(dis_run_context)
it += 1