forked from mindspore-Ecosystem/mindspore
!19835 AttGAN commit -2021/07/09
Merge pull request !19835 from MR.D/master
This commit is contained in:
commit
ab5b963495
|
@ -0,0 +1,223 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [AttGAN描述](#AttGAN描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [训练过程](#训练过程)
|
||||
- [训练](#训练)
|
||||
- [分布式训练](#分布式训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [推理过程](#推理过程)
|
||||
- [导出MindIR](#导出MindIR)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [CelebA上的AttGAN](#CelebA上的AttGAN)
|
||||
- [推理性能](#推理性能)
|
||||
- [CelebA上的AttGAN](#CelebA上的AttGAN)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# AttGAN描述
|
||||
|
||||
AttGAN指的是AttGAN: Facial Attribute Editing by Only Changing What You Want, 这个网络的特点是可以在不影响面部其它属性的情况下修改人脸属性。
|
||||
|
||||
[论文](https://arxiv.org/abs/1711.10678):[Zhenliang He](https://github.com/LynnHo/AttGAN-Tensorflow), [Wangmeng Zuo](https://github.com/LynnHo/AttGAN-Tensorflow), [Meina Kan](https://github.com/LynnHo/AttGAN-Tensorflow), [Shiguang Shan](https://github.com/LynnHo/AttGAN-Tensorflow), [Xilin Chen](https://github.com/LynnHo/AttGAN-Tensorflow), et al. AttGAN: Facial Attribute Editing by Only Changing What You Want[C]// 2017 CVPR. IEEE
|
||||
|
||||
# 模型架构
|
||||
|
||||
整个网络结构由一个生成器和一个判别器构成,生成器由编码器和解码器构成。该模型移除了严格的attribute-independent约束,仅需要通过attribute classification来保证正确地修改属性,同时整合了attribute classification constraint、adversarial learning和reconstruction learning,具有较好的修改面部属性的效果。
|
||||
|
||||
# 数据集
|
||||
|
||||
使用的数据集: [CelebA](<http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>)
|
||||
|
||||
CelebFaces Attributes Dataset (CelebA) 是一个大规模的人脸属性数据集,拥有超过 200K 的名人图像,每个图像有 40 个属性注释。 CelebA 多样性大、数量多、注释丰富,包括
|
||||
|
||||
- 10,177 number of identities,
|
||||
- 202,599 number of face images, and 5 landmark locations, 40 binary attributes annotations per image.
|
||||
|
||||
该数据集可用作以下计算机视觉任务的训练和测试集:人脸属性识别、人脸检测以及人脸编辑和合成。
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 使用Ascend来搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```python
|
||||
# 运行训练示例
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=1
|
||||
python train.py --experiment_name 128_shortcut1_inject1_none --data_path /path/data/img_align_celeba --attr_path /path/data/list_attr_celeba.txt
|
||||
OR
|
||||
bash run_single_train.sh experiment_name /path/data/img_align_celeba /path/data/list_attr_celeba
|
||||
|
||||
# 运行分布式训练示例
|
||||
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba
|
||||
|
||||
# 运行评估示例
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=1
|
||||
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --enc_ckpt_name encoder-119_84999.ckpt --dec_ckpt_name decoder-119_84999.ckpt
|
||||
OR
|
||||
bash run_eval.sh experiment_name /path/data/custom/ /path/data/list_attr_custom enc_ckpt_name dec_ckpt_name
|
||||
```
|
||||
|
||||
对于分布式训练,需要提前创建JSON格式的hccl配置文件。该配置文件的绝对路径作为运行分布式脚本的第一个参数。
|
||||
|
||||
请遵循以下链接中的说明:
|
||||
|
||||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
|
||||
|
||||
对于评估脚本,需要提前创建存放自定义图片(jpg)的目录以及属性编辑文件,关于属性编辑文件的说明见[脚本及样例代码](#脚本及样例代码)。目录以及属性编辑文件分别对应参数`custom_data`和`custom_attr`。checkpoint文件被训练脚本默认放置在
|
||||
`/output/{experiment_name}/checkpoint`目录下,执行脚本时需要将两个检查点文件(Encoder和Decoder)的名称作为参数传入。
|
||||
|
||||
[注意] 以上路径均应设置为绝对路径
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
||||
```text
|
||||
.
|
||||
└─ cv
|
||||
└─ AttGAN
|
||||
├─ data
|
||||
├─ custom # 自定义图像目录
|
||||
├─ list_attr_custom.txt # 属性控制文件
|
||||
├── scripts
|
||||
├──run_distribute_train.sh # 分布式训练的shell脚本
|
||||
├──run_single_train.sh # 单卡训练的shell脚本
|
||||
├──run_eval.sh # 推理脚本
|
||||
├─ src
|
||||
├─ __init__.py # 初始化文件
|
||||
├─ block.py # 基础cell
|
||||
├─ attgan.py # 生成网络和判别网络
|
||||
├─ utils.py # 辅助函数
|
||||
├─ cell.py # loss网络wrapper
|
||||
├─ data.py # 数据加载
|
||||
├─ helpers.py # 进度条显示
|
||||
├─ loss.py # loss计算
|
||||
├─ eval.py # 测试脚本
|
||||
├─ train.py # 训练脚本
|
||||
└─ README_CN.md # AttGAN的文件描述
|
||||
```
|
||||
|
||||
上述脚本目录中,custom用于存放用户想要修改属性的图片文件,list_attr_custom.txt用于设置想要修改的属性,该脚本可以修改13种属性,分别为:Bald Bangs Black_Hair Blond_Hair Brown_Hair Bushy_Eyebrows Eyeglasses Male Mouth_Slightly_Open Mustache No_Beard Pale_Skin Young。
|
||||
|
||||
list_attr_custom.txt文件中第一行参数为要评估的图片数量,第二行为13种属性,接下来的行表示对相应图片想要进行修改的属性,如果要修改为1,不要修改为-1,如(xxx.jpg -1 -1 -1 1 -1 1 -1 1 1 -1 1 -1 -1)。
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=1
|
||||
python train.py --img_size 128 --experiment_name 128_shortcut1_inject1_none --data_path /path/data/img_align_celeba --attr_path /path/data/list_attr_celeba.txt
|
||||
```
|
||||
|
||||
训练结束后,当前目录下会生成output目录,在该目录下会根据你设置的experiment_name参数生成相应的子目录,训练时的参数保存在该子目录下的setting.txt文件中,checkpoint文件保存在`output/experiment_name/rank0`下。
|
||||
|
||||
### 分布式训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
bash run_distribute_train.sh /path/hccl_config_file.json /path/data/img_align_celeba /path/data/list_attr_celeba
|
||||
```
|
||||
|
||||
上述shell脚本将在后台运行分布式训练。该脚本将在脚本目录下生成相应的LOG{RANK_ID}目录,每个进程的输出记录在相应LOG{RANK_ID}目录下的log.txt文件中。checkpoint文件保存在output/experiment_name/rank{RANK_ID}下。
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 评估
|
||||
|
||||
- 在Ascend环境运行时评估自定义数据集
|
||||
该网络可以用于修改面部属性,用户将希望修改的图片放在自定义的图片目录下,并根据自己期望修改的属性来编辑list_attr_custom.txt文件(文件的具体参数见[脚本及样例代码](#脚本及样例代码))。完成后,需要将自定义图片目录和属性编辑文件作为参数传入测试脚本,分别对应custom_data以及custom_attr。
|
||||
|
||||
评估时选择已经生成好的检查点文件,作为参数传入测试脚本,对应参数为`enc_ckpt_name`和`dec_ckpt_name`(分别保存了编码器和解码器的参数)
|
||||
|
||||
```bash
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=1
|
||||
python eval.py --experiment_name 128_shortcut1_inject1_none --test_int 1.0 --custom_data /path/data/custom/ --custom_attr /path/data/list_attr_custom.txt --custom_img --enc_ckpt_name encoder-119_84999.ckpt --dec_ckpt_name decoder-119_84999.ckpt
|
||||
```
|
||||
|
||||
测试脚本执行完成后,用户进入当前目录下的`output/{experiment_name}/custom_img`下查看修改好的图片。
|
||||
|
||||
## 推理过程
|
||||
|
||||
### 导出MindIR
|
||||
|
||||
```shell
|
||||
python export.py --experiment_name [EXPERIMENT_NAME] --enc_ckpt_name [ENCODER_CKPT_NAME] --dec_ckpt_name [DECODER_CKPT_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
|
||||
`file_format` 必须在 ["AIR", "MINDIR"]中选择。
|
||||
`experiment_name` 是output目录下的存放结果的文件夹的名称,此参数用于帮助export寻找参数
|
||||
|
||||
脚本会在当前目录下生成对应的MINDIR文件。
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 评估性能
|
||||
|
||||
#### CelebA上的AttGAN
|
||||
|
||||
| 参数 | Ascend 910 |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| 模型版本 | AttGAN |
|
||||
| 资源 | Ascend |
|
||||
| 上传日期 | 06/30/2021 (month/day/year) |
|
||||
| MindSpore版本 | 1.2.0 |
|
||||
| 数据集 | CelebA |
|
||||
| 训练参数 | batch_size=32, lr=0.0002 |
|
||||
| 优化器 | Adam |
|
||||
| 生成器输出 | image |
|
||||
| 速度 | 5.56 step/s |
|
||||
| 脚本 | [AttGAN script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/AttGAN) |
|
||||
|
||||
### 推理性能
|
||||
|
||||
#### CelebA上的AttGAN
|
||||
|
||||
| 参数 | Ascend 910 |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| 模型版本 | AttGAN |
|
||||
| 资源 | Ascend |
|
||||
| 上传日期 | 06/30/2021 (month/day/year) |
|
||||
| MindSpore版本 | 1.2.0 |
|
||||
| 数据集 | CelebA |
|
||||
| 推理参数 | batch_size=1 |
|
||||
| 输出 | image |
|
||||
|
||||
推理完成后可以获得对原图像进行属性编辑后的图片slide.
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
Binary file not shown.
After Width: | Height: | Size: 2.0 MiB |
Binary file not shown.
After Width: | Height: | Size: 1.4 MiB |
Binary file not shown.
After Width: | Height: | Size: 103 KiB |
Binary file not shown.
After Width: | Height: | Size: 34 KiB |
Binary file not shown.
After Width: | Height: | Size: 470 KiB |
Binary file not shown.
After Width: | Height: | Size: 91 KiB |
|
@ -0,0 +1,8 @@
|
|||
6
|
||||
Bald Bangs Black_Hair Blond_Hair Brown_Hair Bushy_Eyebrows Eyeglasses Male Mouth_Slightly_Open Mustache No_Beard Pale_Skin Young
|
||||
donald_trump.jpg -1 -1 -1 1 -1 1 -1 1 1 -1 1 -1 -1
|
||||
emma_watson.jpg -1 1 -1 -1 1 -1 -1 -1 -1 -1 1 -1 1
|
||||
jay_chou.jpeg -1 1 1 -1 -1 1 -1 1 1 -1 1 -1 1
|
||||
ji-eun_lee.jpg -1 1 -1 -1 1 -1 -1 -1 1 -1 1 -1 1
|
||||
tom_cruise.jpg -1 -1 -1 -1 1 1 -1 1 1 -1 1 -1 -1
|
||||
yui_aragaki.jpg -1 1 -1 -1 1 -1 -1 -1 -1 -1 1 -1 1
|
|
@ -0,0 +1,150 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Entry point for testing AttGAN network"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from os.path import join
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as de
|
||||
from mindspore import context, Tensor, ops
|
||||
from mindspore.train.serialization import load_param_into_net
|
||||
|
||||
from src.attgan import Genc, Gdec
|
||||
from src.cell import init_weights
|
||||
from src.data import check_attribute_conflict
|
||||
from src.data import get_loader, Custom
|
||||
from src.helpers import Progressbar
|
||||
from src.utils import resume_model, denorm
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
|
||||
|
||||
def parse(arg=None):
|
||||
"""Define configuration of Evaluation"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
|
||||
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
|
||||
parser.add_argument('--num_test', dest='num_test', type=int)
|
||||
parser.add_argument('--enc_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--dec_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--custom_img', action='store_true')
|
||||
parser.add_argument('--custom_data', type=str, default='../data/custom')
|
||||
parser.add_argument('--custom_attr', type=str, default='../data/list_attr_custom.txt')
|
||||
parser.add_argument('--shortcut_layers', dest='shortcut_layers', type=int, default=1)
|
||||
parser.add_argument('--inject_layers', dest='inject_layers', type=int, default=1)
|
||||
return parser.parse_args(arg)
|
||||
|
||||
|
||||
args_ = parse()
|
||||
print(args_)
|
||||
|
||||
with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
|
||||
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
|
||||
args.test_int = args_.test_int
|
||||
args.num_test = args_.num_test
|
||||
args.enc_ckpt_name = args_.enc_ckpt_name
|
||||
args.dec_ckpt_name = args_.dec_ckpt_name
|
||||
args.custom_img = args_.custom_img
|
||||
args.custom_data = args_.custom_data
|
||||
args.custom_attr = args_.custom_attr
|
||||
args.shortcut_layers = args_.shortcut_layers
|
||||
args.inject_layers = args_.inject_layers
|
||||
args.n_attrs = len(args.attrs)
|
||||
args.betas = (args.beta1, args.beta2)
|
||||
|
||||
print(args)
|
||||
|
||||
# Data loader
|
||||
if args.custom_img:
|
||||
output_path = join("output", args.experiment_name, "custom_testing")
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
test_dataset = Custom(args.custom_data, args.custom_attr, args.attrs)
|
||||
test_len = len(test_dataset)
|
||||
else:
|
||||
output_path = join("output", args.experiment_name, "sample_testing")
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
test_dataset = get_loader(args.data_path, args.attr_path,
|
||||
selected_attrs=args.attrs,
|
||||
mode="test"
|
||||
)
|
||||
test_len = len(test_dataset)
|
||||
dataset_column_names = ["image", "attr"]
|
||||
num_parallel_workers = 8
|
||||
ds = de.GeneratorDataset(test_dataset, column_names=dataset_column_names,
|
||||
num_parallel_workers=min(32, num_parallel_workers))
|
||||
ds = ds.batch(1, num_parallel_workers=min(8, num_parallel_workers), drop_remainder=False)
|
||||
test_dataset_iter = ds.create_dict_iterator()
|
||||
|
||||
if args.num_test is None:
|
||||
print('Testing images:', test_len)
|
||||
else:
|
||||
print('Testing images:', min(test_len, args.num_test))
|
||||
|
||||
# Model loader
|
||||
genc = Genc(mode='test')
|
||||
gdec = Gdec(shortcut_layers=args.shortcut_layers, inject_layers=args.inject_layers, mode='test')
|
||||
|
||||
# Initialize network
|
||||
init_weights(genc, 'KaimingUniform', math.sqrt(5))
|
||||
init_weights(gdec, 'KaimingUniform', math.sqrt(5))
|
||||
para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name)
|
||||
load_param_into_net(genc, para_genc)
|
||||
load_param_into_net(gdec, para_gdec)
|
||||
|
||||
progressbar = Progressbar()
|
||||
it = 0
|
||||
for data in test_dataset_iter:
|
||||
img_a = data["image"]
|
||||
att_a = data["attr"]
|
||||
if args.num_test is not None and it == args.num_test:
|
||||
break
|
||||
|
||||
att_a = Tensor(att_a, mstype.float32)
|
||||
att_b_list = [att_a]
|
||||
for i in range(args.n_attrs):
|
||||
clone = ops.Identity()
|
||||
tmp = clone(att_a)
|
||||
tmp[:, i] = 1 - tmp[:, i]
|
||||
tmp = check_attribute_conflict(tmp, args.attrs[i], args.attrs)
|
||||
att_b_list.append(tmp)
|
||||
|
||||
samples = [img_a]
|
||||
|
||||
for i, att_b in enumerate(att_b_list):
|
||||
att_b_ = (att_b * 2 - 1) * args.thres_int
|
||||
if i > 0:
|
||||
att_b_[..., i - 1] = att_b_[..., i - 1] * args.test_int / args.thres_int
|
||||
a_enc = genc(img_a)
|
||||
samples.append(gdec(a_enc, att_b_))
|
||||
cat = ops.Concat(axis=3)
|
||||
samples = cat(samples).asnumpy()
|
||||
result = denorm(samples)
|
||||
result = np.reshape(result, (128, -1, 3))
|
||||
im = Image.fromarray(np.uint8(result))
|
||||
if args.custom_img:
|
||||
out_file = test_dataset.images[it]
|
||||
else:
|
||||
out_file = "{:06d}.jpg".format(it + 182638)
|
||||
im.save(output_path + '/' + out_file)
|
||||
print('Successful save image in ' + output_path + '/' + out_file)
|
||||
it += 1
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export file."""
|
||||
import argparse
|
||||
import json
|
||||
from os.path import join
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import export, load_param_into_net
|
||||
|
||||
from src.utils import resume_model
|
||||
from src.attgan import Genc, Gdec
|
||||
|
||||
parser = argparse.ArgumentParser(description='Attribute Edit')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument('--enc_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--dec_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='AIR', help='file format')
|
||||
parser.add_argument('--experiment_name', dest='experiment_name', required=True)
|
||||
|
||||
args_ = parser.parse_args()
|
||||
print(args_)
|
||||
|
||||
with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
|
||||
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
|
||||
args.device_id = args_.device_id
|
||||
args.batch_size = args_.batch_size
|
||||
args.enc_ckpt_name = args_.enc_ckpt_name
|
||||
args.dec_ckpt_name = args_.dec_ckpt_name
|
||||
args.file_format = args_.file_format
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
genc = Genc(mode='test')
|
||||
gdec = Gdec(mode='test')
|
||||
|
||||
para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name)
|
||||
load_param_into_net(genc, para_genc)
|
||||
load_param_into_net(gdec, para_gdec)
|
||||
|
||||
enc_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128, 128)).astype(np.float32))
|
||||
dec_array = genc(enc_array)
|
||||
input_label = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 13)).astype(np.float32))
|
||||
|
||||
G_enc_file = f"AttGAN_Generator_Encoder"
|
||||
export(genc, enc_array, file_name=G_enc_file, file_format=args.file_format)
|
||||
|
||||
G_dec_file = f"AttGAN_Generator_Decoder"
|
||||
export(gdec, *(dec_array, input_label), file_name=G_dec_file, file_format=args.file_format)
|
|
@ -0,0 +1,66 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH] [ATTR_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$1
|
||||
export RANK_TABLE_FILE=$1
|
||||
export RANK_SIZE=8
|
||||
export HCCL_CONNECT_TIMEOUT=600
|
||||
echo "hccl connect time out has changed to 600 second"
|
||||
|
||||
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
data_path=$2
|
||||
attr_path=$3
|
||||
|
||||
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
|
||||
echo "the number of logical core" $cores
|
||||
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
|
||||
core_gap=`expr $avg_core_per_rank \- 1`
|
||||
echo "avg_core_per_rank" $avg_core_per_rank
|
||||
echo "core_gap" $core_gap
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
start=`expr $i \* $avg_core_per_rank`
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
export DEPLOY_MODE=0
|
||||
export GE_USE_STATIC_MEMORY=1
|
||||
end=`expr $start \+ $core_gap`
|
||||
cmdopt=$start"-"$end
|
||||
|
||||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
echo "Start training for rank $i, device $DEVICE_ID"
|
||||
|
||||
env > env.log
|
||||
cd ../../
|
||||
taskset -c $cmdopt python train.py \
|
||||
--img_size 128 \
|
||||
--shortcut_layers 1 \
|
||||
--inject_layers 1 \
|
||||
--experiment_name 128_shortcut1_inject1_none \
|
||||
--data_path $data_path \
|
||||
--attr_path $attr_path \
|
||||
--run_distribute 1 > ./scripts/LOG$i/log.txt 2>&1 &
|
||||
cd scripts
|
||||
done
|
|
@ -0,0 +1,51 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 5 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [EXPERIMENT_NAME] [CUSTOM_DATA_PATH] [CUSTOM_ATTR_PATH] [ENC_CKPT_NAME] [DEC_CKPT_NAME]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
experiment_name=$1
|
||||
data_path=$2
|
||||
attr_path=$3
|
||||
enc_name=$4
|
||||
dec_name=$5
|
||||
|
||||
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
|
||||
echo "The number of logical core" $cores
|
||||
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
rm -rf EVAL_LOG
|
||||
mkdir ./EVAL_LOG
|
||||
cd ./EVAL_LOG || exit
|
||||
echo "Start training for rank 0, device 0, directory is EVAL_LOG"
|
||||
|
||||
env > env.log
|
||||
cd ../../
|
||||
|
||||
python eval.py \
|
||||
--experiment_name $experiment_name \
|
||||
--test_int 1.0 \
|
||||
--custom_data $data_path \
|
||||
--custom_attr $attr_path \
|
||||
--custom_img \
|
||||
--enc_ckpt_name $enc_name \
|
||||
--dec_ckpt_name $dec_name > ./scripts/EVAL_LOG/log.txt 2>&1 &
|
|
@ -0,0 +1,48 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_single_train.sh [EXPERIMENT_NAME] [DATA_PATH] [ATTR_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
experiment_name=$1
|
||||
data_path=$2
|
||||
attr_path=$3
|
||||
|
||||
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
|
||||
echo "the number of logical core" $cores
|
||||
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
rm -rf LOG
|
||||
mkdir ./LOG
|
||||
cd ./LOG || exit
|
||||
echo "Start training for rank 0, device 0, directory is LOG"
|
||||
|
||||
env > env.log
|
||||
cd ../../
|
||||
|
||||
python train.py \
|
||||
--img_size 128 \
|
||||
--shortcut_layers 1 \
|
||||
--inject_layers 1 \
|
||||
--experiment_name $experiment_name \
|
||||
--data_path $data_path \
|
||||
--attr_path $attr_path > ./scripts/LOG/log.txt 2>&1 &
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""AttGAN Network Topology"""
|
||||
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import nn
|
||||
|
||||
from src.block import LinearBlock, Conv2dBlock, ConvTranspose2dBlock
|
||||
|
||||
# Image size 128 x 128
|
||||
MAX_DIM = 64 * 16
|
||||
|
||||
|
||||
class Genc(nn.Cell):
|
||||
"""Generator encoder"""
|
||||
def __init__(self, enc_dim=64, enc_layers=5, enc_norm_fn="batchnorm", enc_acti_fn="lrelu",
|
||||
mode='test'):
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
n_in = 3
|
||||
for i in range(enc_layers):
|
||||
n_out = min(enc_dim * 2 ** i, MAX_DIM)
|
||||
layers += [Conv2dBlock(
|
||||
n_in, n_out, (4, 4), stride=2, padding=1, norm_fn=enc_norm_fn, acti_fn=enc_acti_fn, mode=mode
|
||||
)]
|
||||
n_in = n_out
|
||||
self.enc_layers = nn.CellList(layers)
|
||||
|
||||
def construct(self, x):
|
||||
"""Encoder construct"""
|
||||
z = x
|
||||
zs = []
|
||||
for layer in self.enc_layers:
|
||||
z = layer(z)
|
||||
zs.append(z)
|
||||
return zs
|
||||
|
||||
|
||||
class Gdec(nn.Cell):
|
||||
"""Generator decoder"""
|
||||
def __init__(self, dec_dim=64, dec_layers=5, dec_norm_fn="batchnorm", dec_acti_fn="relu", n_attrs=13,
|
||||
shortcut_layers=1, inject_layers=1, img_size=128, mode='test'):
|
||||
super().__init__()
|
||||
self.shortcut_layers = min(shortcut_layers, dec_layers - 1)
|
||||
self.inject_layers = min(inject_layers, dec_layers - 1)
|
||||
self.f_size = img_size // 2 ** dec_layers # f_size = 4 for 128x128
|
||||
|
||||
layers = []
|
||||
n_in = 1024
|
||||
n_in = n_in + n_attrs # 1024 + 13
|
||||
for i in range(dec_layers):
|
||||
if i < dec_layers - 1:
|
||||
n_out = min(dec_dim * 2 ** (dec_layers - i - 1), MAX_DIM)
|
||||
layers += [ConvTranspose2dBlock(
|
||||
n_in, n_out, (4, 4), stride=2, padding=1, norm_fn=dec_norm_fn, acti_fn=dec_acti_fn, mode=mode
|
||||
)]
|
||||
n_in = n_out
|
||||
n_in = n_in + n_in // 2 if self.shortcut_layers > i else n_in
|
||||
n_in = n_in + n_attrs if self.inject_layers > i else n_in
|
||||
else:
|
||||
layers += [ConvTranspose2dBlock(
|
||||
n_in, 3, (4, 4), stride=2, padding=1, norm_fn='none', acti_fn='tanh', mode=mode
|
||||
)]
|
||||
self.dec_layers = nn.CellList(layers)
|
||||
|
||||
self.view = P.Reshape()
|
||||
self.repeat = P.Tile()
|
||||
self.cat = P.Concat(1)
|
||||
|
||||
def construct(self, zs, a):
|
||||
"""Decoder construct"""
|
||||
a_tile = self.view(a, (a.shape[0], -1, 1, 1))
|
||||
multiples = (1, 1, self.f_size, self.f_size)
|
||||
a_tile = self.repeat(a_tile, multiples)
|
||||
|
||||
z = self.cat((zs[-1], a_tile))
|
||||
i = 0
|
||||
for layer in self.dec_layers:
|
||||
z = layer(z)
|
||||
if self.shortcut_layers > i:
|
||||
z = self.cat((z, zs[len(self.dec_layers) - 2 - i]))
|
||||
if self.inject_layers > i:
|
||||
a_tile = self.view(a, (a.shape[0], -1, 1, 1))
|
||||
multiples = (1, 1, self.f_size * 2 ** (i + 1), self.f_size * 2 ** (i + 1))
|
||||
a_tile = self.repeat(a_tile, multiples)
|
||||
z = self.cat((z, a_tile))
|
||||
i = i + 1
|
||||
return z
|
||||
|
||||
|
||||
class Dis(nn.Cell):
|
||||
"""Discriminator"""
|
||||
def __init__(self, dim=64, norm_fn='none', acti_fn='lrelu',
|
||||
fc_dim=1024, fc_norm_fn='none', fc_acti_fn='lrelu', n_layers=5, img_size=128, mode='test'):
|
||||
super().__init__()
|
||||
self.f_size = img_size // 2 ** n_layers
|
||||
|
||||
layers = []
|
||||
n_in = 3
|
||||
for i in range(n_layers):
|
||||
n_out = min(dim * 2 ** i, MAX_DIM)
|
||||
layers += [Conv2dBlock(
|
||||
n_in, n_out, (4, 4), stride=2, padding=1, norm_fn=norm_fn, acti_fn=acti_fn, mode=mode
|
||||
)]
|
||||
n_in = n_out
|
||||
self.conv = nn.SequentialCell(layers)
|
||||
self.fc_adv = nn.SequentialCell(
|
||||
[LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn, mode),
|
||||
LinearBlock(fc_dim, 1, 'none', 'none', mode)])
|
||||
|
||||
self.fc_cls = nn.SequentialCell(
|
||||
[LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn, mode),
|
||||
LinearBlock(fc_dim, 13, 'none', 'none', mode)])
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
h = self.conv(x)
|
||||
view = P.Reshape()
|
||||
h = view(h, (h.shape[0], -1))
|
||||
return self.fc_adv(h), self.fc_cls(h)
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Network Component"""
|
||||
|
||||
from mindspore import nn
|
||||
|
||||
def add_normalization_1d(layers, fn, n_out, mode='test'):
|
||||
if fn == "none":
|
||||
pass
|
||||
elif fn == "batchnorm":
|
||||
layers.append(nn.BatchNorm1d(n_out, use_batch_statistics=(mode == 'train')))
|
||||
elif fn == "instancenorm":
|
||||
layers.append(nn.GroupNorm(n_out, n_out, affine=True))
|
||||
else:
|
||||
raise Exception('Unsupported normalization: ' + str(fn))
|
||||
return layers
|
||||
|
||||
|
||||
def add_normalization_2d(layers, fn, n_out, mode='test'):
|
||||
if fn == 'none':
|
||||
pass
|
||||
elif fn == 'batchnorm':
|
||||
layers.append(nn.BatchNorm2d(n_out, use_batch_statistics=(mode == 'train')))
|
||||
elif fn == "instancenorm":
|
||||
layers.append(nn.GroupNorm(n_out, n_out, affine=True))
|
||||
else:
|
||||
raise Exception('Unsupported normalization: ' + str(fn))
|
||||
return layers
|
||||
|
||||
|
||||
def add_activation(layers, fn):
|
||||
"""Add Activation"""
|
||||
if fn == "none":
|
||||
pass
|
||||
elif fn == "relu":
|
||||
layers.append(nn.ReLU())
|
||||
elif fn == "lrelu":
|
||||
layers.append(nn.LeakyReLU(alpha=0.01))
|
||||
elif fn == "sigmoid":
|
||||
layers.append(nn.Sigmoid())
|
||||
elif fn == "tanh":
|
||||
layers.append(nn.Tanh())
|
||||
else:
|
||||
raise Exception('Unsupported activation function: ' + str(fn))
|
||||
return layers
|
||||
|
||||
|
||||
class LinearBlock(nn.Cell):
|
||||
"""Linear Block"""
|
||||
def __init__(self, n_in, n_out, norm_fn="none", acti_fn="none", mode='test'):
|
||||
super().__init__()
|
||||
layers = [nn.Dense(n_in, n_out, has_bias=(norm_fn == 'none'))]
|
||||
layers = add_normalization_1d(layers, norm_fn, n_out, mode)
|
||||
layers = add_activation(layers, acti_fn)
|
||||
self.layers = nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class Conv2dBlock(nn.Cell):
|
||||
"""Convolution Block"""
|
||||
def __init__(self, n_in, n_out, kernel_size, stride=1, padding=0,
|
||||
norm_fn=None, acti_fn=None, mode='test'):
|
||||
super().__init__()
|
||||
layers = [nn.Conv2d(n_in, n_out, kernel_size, stride=stride, padding=padding, pad_mode='pad',
|
||||
has_bias=(norm_fn == 'none'))]
|
||||
layers = add_normalization_2d(layers, norm_fn, n_out, mode)
|
||||
layers = add_activation(layers, acti_fn)
|
||||
self.layers = nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class ConvTranspose2dBlock(nn.Cell):
|
||||
"""Transpose Convolution Block"""
|
||||
def __init__(self, n_in, n_out, kernel_size, stride=1, padding=0,
|
||||
norm_fn=None, acti_fn=None, mode='test'):
|
||||
super().__init__()
|
||||
layers = [nn.Conv2dTranspose(n_in, n_out, kernel_size, stride=stride, padding=padding, pad_mode='pad',
|
||||
has_bias=(norm_fn == 'none'))]
|
||||
layers = add_normalization_2d(layers, norm_fn, n_out, mode)
|
||||
layers = add_activation(layers, acti_fn)
|
||||
self.layers = nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
return self.layers(x)
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Cell Definition"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.ops.functional as F
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import nn, ops
|
||||
from mindspore.common import initializer as init, set_seed
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
|
||||
_get_parallel_mode)
|
||||
|
||||
set_seed(1)
|
||||
np.random.seed(1)
|
||||
|
||||
|
||||
def init_weights(net, init_type='normal', init_gain=0.02):
|
||||
"""
|
||||
Initialize network weights.
|
||||
|
||||
Parameters:
|
||||
net (Cell): Network to be initialized
|
||||
init_type (str): The name of an initialization method: normal | xavier.
|
||||
init_gain (float): Gain factor for normal and xavier.
|
||||
|
||||
"""
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
|
||||
if init_type == 'normal':
|
||||
cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
|
||||
elif init_type == 'xavier':
|
||||
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
|
||||
elif init_type == 'KaimingUniform':
|
||||
cell.weight.set_data(init.initializer(init.HeUniform(init_gain), cell.weight.shape))
|
||||
elif init_type == 'constant':
|
||||
cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
|
||||
else:
|
||||
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||||
elif isinstance(cell, (nn.GroupNorm, nn.BatchNorm2d)):
|
||||
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
|
||||
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
|
||||
|
||||
|
||||
class GenWithLossCell(nn.Cell):
|
||||
"""
|
||||
Wrap the network with loss function to return generator loss
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super().__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
|
||||
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
||||
_, g_loss, _, _, _, = self.network(img_a, att_a, att_a_, att_b, att_b_)
|
||||
return g_loss
|
||||
|
||||
|
||||
class DisWithLossCell(nn.Cell):
|
||||
"""
|
||||
Wrap the network with loss function to return discriminator loss
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super().__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
|
||||
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
||||
d_loss, _, _, _, _ = self.network(img_a, att_a, att_a_, att_b, att_b_)
|
||||
return d_loss
|
||||
|
||||
|
||||
class TrainOneStepCellGen(nn.Cell):
|
||||
"""Encapsulation class of AttGAN generator network training."""
|
||||
|
||||
def __init__(self, generator, optimizer, sens=1.0):
|
||||
super().__init__()
|
||||
self.optimizer = optimizer
|
||||
self.generator = generator
|
||||
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
|
||||
|
||||
self.sens = sens
|
||||
self.weights = optimizer.parameters
|
||||
self.network = GenWithLossCell(generator)
|
||||
self.network.add_flags(defer_inline=True)
|
||||
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = F.identity
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
|
||||
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
|
||||
|
||||
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
||||
weights = self.weights
|
||||
_, loss, gf_loss, gc_loss, gr_loss = self.generator(img_a, att_a, att_a_, att_b, att_b_)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(img_a, att_a, att_a_, att_b, att_b_, sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads)), gf_loss, gc_loss, gr_loss
|
||||
|
||||
|
||||
class TrainOneStepCellDis(nn.Cell):
|
||||
"""Encapsulation class of AttGAN discriminator network training."""
|
||||
|
||||
def __init__(self, discriminator, optimizer, sens=1.0):
|
||||
super().__init__()
|
||||
self.optimizer = optimizer
|
||||
self.discriminator = discriminator
|
||||
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
|
||||
|
||||
self.sens = sens
|
||||
self.weights = optimizer.parameters
|
||||
self.network = DisWithLossCell(discriminator)
|
||||
self.network.add_flags(defer_inline=True)
|
||||
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = F.identity
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
|
||||
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
|
||||
|
||||
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
||||
weights = self.weights
|
||||
loss, d_real_loss, d_fake_loss, dc_loss, df_gp = self.discriminator(img_a, att_a, att_a_, att_b, att_b_)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(img_a, att_a, att_a_, att_b, att_b_, sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
return F.depend(loss, self.optimizer(grads)), d_real_loss, d_fake_loss, dc_loss, df_gp
|
|
@ -0,0 +1,198 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" DataLoader: CelebA"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.py_transforms as py_vision
|
||||
from mindspore.dataset.transforms import py_transforms
|
||||
|
||||
from src.utils import DistributedSampler
|
||||
|
||||
|
||||
class Custom:
|
||||
"""
|
||||
Custom data loader
|
||||
"""
|
||||
|
||||
def __init__(self, data_path, attr_path, selected_attrs):
|
||||
self.data_path = data_path
|
||||
|
||||
att_list = open(attr_path, 'r', encoding='utf-8').readlines()[1].split()
|
||||
atts = [att_list.index(att) + 1 for att in selected_attrs]
|
||||
images = np.loadtxt(attr_path, skiprows=2, usecols=[0], dtype=np.str)
|
||||
labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int)
|
||||
|
||||
mean = [0.5, 0.5, 0.5]
|
||||
std = [0.5, 0.5, 0.5]
|
||||
transform = [py_vision.ToPIL()]
|
||||
transform.append(py_vision.Resize([128, 128]))
|
||||
transform.append(py_vision.ToTensor())
|
||||
transform.append(py_vision.Normalize(mean=mean, std=std))
|
||||
transform = py_transforms.Compose(transform)
|
||||
self.transform = transform
|
||||
self.images = np.array([images]) if images.size == 1 else images[0:]
|
||||
self.labels = np.array([labels]) if images.size == 1 else labels[0:]
|
||||
self.length = len(self.images)
|
||||
|
||||
def __getitem__(self, index):
|
||||
image = np.asarray(Image.open(os.path.join(self.data_path, self.images[index])))
|
||||
att = np.asarray((self.labels[index] + 1) // 2)
|
||||
image = np.squeeze(self.transform(image))
|
||||
return image, att
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
class CelebA:
|
||||
"""
|
||||
CelebA dataset
|
||||
Input:
|
||||
data_path: Image Path
|
||||
attr_path: Attr_list Path
|
||||
image_size: Image Size
|
||||
mode: Train, Valid or Test
|
||||
selected_attrs: selected attributes
|
||||
transform: Image Processing
|
||||
"""
|
||||
|
||||
def __init__(self, data_path, attr_path, image_size, mode, selected_attrs, transform):
|
||||
super().__init__()
|
||||
self.data_path = data_path
|
||||
self.transform = transform
|
||||
self.img_size = image_size
|
||||
|
||||
att_list = open(attr_path, 'r', encoding='utf-8').readlines()[1].split()
|
||||
atts = [att_list.index(att) + 1 for att in selected_attrs]
|
||||
|
||||
images = np.loadtxt(attr_path, skiprows=2, usecols=[0], dtype=np.str)
|
||||
labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int)
|
||||
|
||||
if mode == "train":
|
||||
self.images = images[:182000]
|
||||
self.labels = labels[:182000]
|
||||
if mode == "valid":
|
||||
self.images = images[182000:182637]
|
||||
self.labels = labels[182000:182637]
|
||||
if mode == "test":
|
||||
self.images = images[182637:]
|
||||
self.labels = labels[182637:]
|
||||
|
||||
self.length = len(self.images)
|
||||
|
||||
def __getitem__(self, index):
|
||||
image = np.asarray(Image.open(os.path.join(self.data_path, self.images[index])))
|
||||
att = np.asarray((self.labels[index] + 1) // 2)
|
||||
image = np.squeeze(self.transform(image))
|
||||
return image, att
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def get_loader(data_root, attr_path, selected_attrs, crop_size=170, image_size=128, mode="train"):
|
||||
"""Build and return dataloader"""
|
||||
|
||||
mean = [0.5, 0.5, 0.5]
|
||||
std = [0.5, 0.5, 0.5]
|
||||
transform = [py_vision.ToPIL()]
|
||||
transform.append(py_vision.CenterCrop((crop_size, crop_size)))
|
||||
transform.append(py_vision.Resize([image_size, image_size]))
|
||||
transform.append(py_vision.ToTensor())
|
||||
transform.append(py_vision.Normalize(mean=mean, std=std))
|
||||
transform = py_transforms.Compose(transform)
|
||||
|
||||
dataset = CelebA(data_root, attr_path, image_size, mode, selected_attrs, transform)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def data_loader(img_path, attr_path, selected_attrs, mode="train", batch_size=1, device_num=1, rank=0, shuffle=True):
|
||||
"""CelebA data loader"""
|
||||
num_parallel_workers = 8
|
||||
|
||||
dataset_loader = get_loader(img_path, attr_path, selected_attrs, mode=mode)
|
||||
length_dataset = len(dataset_loader)
|
||||
|
||||
distributed_sampler = DistributedSampler(length_dataset, device_num, rank, shuffle=shuffle)
|
||||
dataset_column_names = ["image", "attr"]
|
||||
|
||||
if device_num != 8:
|
||||
ds = de.GeneratorDataset(dataset_loader, column_names=dataset_column_names,
|
||||
num_parallel_workers=min(32, num_parallel_workers),
|
||||
sampler=distributed_sampler)
|
||||
ds = ds.batch(batch_size, num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True)
|
||||
else:
|
||||
ds = de.GeneratorDataset(dataset_loader, column_names=dataset_column_names, sampler=distributed_sampler)
|
||||
ds = ds.batch(batch_size, num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True)
|
||||
|
||||
# ds = ds.repeat(200)
|
||||
|
||||
return ds, length_dataset
|
||||
|
||||
|
||||
def check_attribute_conflict(att_batch, att_name, att_names):
|
||||
"""Check Attributes"""
|
||||
def _set(att, att_name):
|
||||
if att_name in att_names:
|
||||
att[att_names.index(att_name)] = 0.0
|
||||
|
||||
att_id = att_names.index(att_name)
|
||||
for att in att_batch:
|
||||
if att_name in ['Bald', 'Receding_Hairline'] and att[att_id] != 0:
|
||||
_set(att, 'Bangs')
|
||||
elif att_name == 'Bangs' and att[att_id] != 0:
|
||||
_set(att, 'Bald')
|
||||
_set(att, 'Receding_Hairline')
|
||||
elif att_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'] and att[att_id] != 0:
|
||||
for n in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
|
||||
if n != att_name:
|
||||
_set(att, n)
|
||||
elif att_name in ['Straight_Hair', 'Wavy_Hair'] and att[att_id] != 0:
|
||||
for n in ['Straight_Hair', 'Wavy_Hair']:
|
||||
if n != att_name:
|
||||
_set(att, n)
|
||||
elif att_name in ['Mustache', 'No_Beard'] and att[att_id] != 0:
|
||||
for n in ['Mustache', 'No_Beard']:
|
||||
if n != att_name:
|
||||
_set(att, n)
|
||||
return att_batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
attrs_default = [
|
||||
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
|
||||
'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'
|
||||
]
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--attrs', dest='attrs', default=attrs_default, nargs='+', help='attributes to test')
|
||||
parser.add_argument('--data_path', dest='data_path', type=str, required=True)
|
||||
parser.add_argument('--attr_path', dest='attr_path', type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
####### Test CelebA #######
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
dataset_ce, length_ce = data_loader(args.data_path, args.attr_path, attrs_default, mode="train")
|
||||
i = 0
|
||||
for data in dataset_ce.create_dict_iterator():
|
||||
print('Number:', i, 'Value:', data["attr"], 'Type:', type(data["attr"]))
|
||||
i += 1
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================s
|
||||
"""Helper functions for training"""
|
||||
|
||||
import datetime
|
||||
import platform
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def name_experiment(prefix="", suffix=""):
|
||||
experiment_name = datetime.datetime.now().strftime('%b%d_%H-%M-%S_') + platform.node()
|
||||
if prefix is not None and prefix != '':
|
||||
experiment_name = prefix + '_' + experiment_name
|
||||
if suffix is not None and suffix != '':
|
||||
experiment_name = experiment_name + '_' + suffix
|
||||
return experiment_name
|
||||
|
||||
|
||||
class Progressbar():
|
||||
"""Progress Bar"""
|
||||
|
||||
def __init__(self):
|
||||
self.p = None
|
||||
|
||||
def __call__(self, iterable, length):
|
||||
self.p = tqdm(iterable, total=length)
|
||||
return self.p
|
||||
|
||||
def say(self, **kwargs):
|
||||
if self.p is not None:
|
||||
self.p.set_postfix(**kwargs)
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================s
|
||||
"""Loss Computation of Generator and Discriminator"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore import nn, Tensor, ops
|
||||
from mindspore.ops import constexpr
|
||||
|
||||
|
||||
class ClassificationLoss(nn.Cell):
|
||||
"""Define classification loss for AttGAN"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bce_loss = P.BinaryCrossEntropy(reduction='sum')
|
||||
|
||||
def construct(self, pred, label):
|
||||
weight = ops.Ones()(pred.shape, mindspore.float32)
|
||||
pred_ = P.Sigmoid()(pred)
|
||||
x = self.bce_loss(pred_, label, weight) / pred.shape[0]
|
||||
return x
|
||||
|
||||
|
||||
@constexpr
|
||||
def generate_tensor(batch_size):
|
||||
np_array = np.random.randn(batch_size, 1, 1, 1)
|
||||
return Tensor(np_array, mindspore.float32)
|
||||
|
||||
|
||||
class GradientWithInput(nn.Cell):
|
||||
"""Get Discriminator Gradient with Input"""
|
||||
|
||||
def __init__(self, discriminator):
|
||||
super().__init__()
|
||||
self.reduce_sum = ops.ReduceSum()
|
||||
self.discriminator = discriminator
|
||||
self.discriminator.set_train(mode=True)
|
||||
self.discriminator.set_grad(True)
|
||||
|
||||
def construct(self, interpolates):
|
||||
decision_interpolate, _ = self.discriminator(interpolates)
|
||||
decision_interpolate = self.reduce_sum(decision_interpolate, 0)
|
||||
return decision_interpolate
|
||||
|
||||
|
||||
class WGANGPGradientPenalty(nn.Cell):
|
||||
"""Define WGAN loss for AttGAN"""
|
||||
|
||||
def __init__(self, discriminator):
|
||||
super().__init__()
|
||||
self.gradient_op = ops.GradOperation()
|
||||
|
||||
self.reduce_sum = ops.ReduceSum()
|
||||
self.reduce_sum_keep_dim = ops.ReduceSum(keep_dims=True)
|
||||
|
||||
self.sqrt = ops.Sqrt()
|
||||
self.discriminator = discriminator
|
||||
self.GradientWithInput = GradientWithInput(discriminator)
|
||||
self.GradientWithInput.set_grad(True)
|
||||
|
||||
def construct(self, x_real, x_fake):
|
||||
"""get gradient penalty"""
|
||||
batch_size = x_real.shape[0]
|
||||
alpha = generate_tensor(batch_size)
|
||||
alpha = alpha.expand_as(x_real)
|
||||
x_fake = ops.functional.stop_gradient(x_fake)
|
||||
x_hat = x_real + alpha * (x_fake - x_real)
|
||||
|
||||
gradient = self.gradient_op(self.GradientWithInput)(x_hat)
|
||||
gradient_1 = ops.reshape(gradient, (batch_size, -1))
|
||||
gradient_1 = self.sqrt(self.reduce_sum(gradient_1 * gradient_1, 1))
|
||||
gradient_penalty = self.reduce_sum((gradient_1 - 1.0) ** 2) / x_real.shape[0]
|
||||
return gradient_penalty
|
||||
|
||||
|
||||
class GenLoss(nn.Cell):
|
||||
"""Define total Generator loss"""
|
||||
|
||||
def __init__(self, args, encoder, decoder, discriminator):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.discriminator = discriminator
|
||||
|
||||
self.lambda_1 = Tensor(args.lambda_1, mstype.float32)
|
||||
self.lambda_2 = Tensor(args.lambda_2, mstype.float32)
|
||||
self.lambda_3 = Tensor(args.lambda_3, mstype.float32)
|
||||
self.lambda_gp = Tensor(args.lambda_gp, mstype.float32)
|
||||
|
||||
self.cyc_loss = P.ReduceMean()
|
||||
self.cls_loss = nn.BCEWithLogitsLoss()
|
||||
self.rec_loss = nn.L1Loss("mean")
|
||||
|
||||
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
||||
"""Get generator loss"""
|
||||
# generate
|
||||
zs_a = self.encoder(img_a)
|
||||
img_fake = self.decoder(zs_a, att_b_)
|
||||
img_recon = self.decoder(zs_a, att_a_)
|
||||
|
||||
# discriminate
|
||||
d_fake, dc_fake = self.discriminator(img_fake)
|
||||
|
||||
# generator loss
|
||||
gf_loss = - self.cyc_loss(d_fake)
|
||||
gc_loss = self.cls_loss(dc_fake, att_b)
|
||||
gr_loss = self.rec_loss(img_a, img_recon)
|
||||
|
||||
g_loss = gf_loss + self.lambda_2 * gc_loss + self.lambda_1 * gr_loss
|
||||
|
||||
return (img_fake, g_loss, gf_loss, gc_loss, gr_loss)
|
||||
|
||||
|
||||
class DisLoss(nn.Cell):
|
||||
"""Define total discriminator loss"""
|
||||
|
||||
def __init__(self, args, encoder, decoder, discriminator):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.discriminator = discriminator
|
||||
|
||||
self.cyc_loss = P.ReduceMean()
|
||||
self.cls_loss = nn.BCEWithLogitsLoss()
|
||||
self.WGANLoss = WGANGPGradientPenalty(discriminator)
|
||||
|
||||
self.lambda_1 = Tensor(args.lambda_1, mstype.float32)
|
||||
self.lambda_2 = Tensor(args.lambda_2, mstype.float32)
|
||||
self.lambda_3 = Tensor(args.lambda_3, mstype.float32)
|
||||
self.lambda_gp = Tensor(args.lambda_gp, mstype.float32)
|
||||
|
||||
def construct(self, img_a, att_a, att_a_, att_b, att_b_):
|
||||
"""Get discriminator loss"""
|
||||
# generate
|
||||
z = self.encoder(img_a)
|
||||
img_fake = self.decoder(z, att_b_)
|
||||
|
||||
# discriminate
|
||||
d_real, dc_real = self.discriminator(img_a)
|
||||
d_fake, _ = self.discriminator(img_fake)
|
||||
|
||||
# discriminator losses
|
||||
d_real_loss = - self.cyc_loss(d_real)
|
||||
d_fake_loss = self.cyc_loss(d_fake)
|
||||
df_loss = d_real_loss + d_fake_loss
|
||||
|
||||
df_gp = self.WGANLoss(img_a, img_fake)
|
||||
|
||||
dc_loss = self.cls_loss(dc_real, att_a)
|
||||
|
||||
d_loss = df_loss + self.lambda_gp * df_gp + self.lambda_3 * dc_loss
|
||||
|
||||
return (d_loss, d_real_loss, d_fake_loss, dc_loss, df_gp)
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Helper functions"""
|
||||
import math
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from mindspore import load_checkpoint
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
"""Distributed sampler."""
|
||||
|
||||
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=False):
|
||||
if num_replicas is None:
|
||||
print("***********Setting world_size to 1 since it is not passed in ******************")
|
||||
num_replicas = 1
|
||||
if rank is None:
|
||||
print("***********Setting rank to 0 since it is not passed in ******************")
|
||||
rank = 0
|
||||
|
||||
self.dataset_size = dataset_size
|
||||
self.num_replicas = num_replicas
|
||||
self.epoch = 0
|
||||
self.rank = rank
|
||||
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
# shuffle based on epoch
|
||||
if self.shuffle:
|
||||
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
|
||||
indices = indices.tolist()
|
||||
self.epoch += 1
|
||||
else:
|
||||
indices = list(range(self.dataset_size))
|
||||
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank: self.total_size: self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def resume_model(args, encoder, decoder, enc_ckpt_name, dec_ckpt_name):
|
||||
"""Restore the trained generator and discriminator"""
|
||||
print("Loading the trained models from step {}...".format(args.save_interval))
|
||||
encoder_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', enc_ckpt_name)
|
||||
decoder_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', dec_ckpt_name)
|
||||
|
||||
param_encoder = load_checkpoint(encoder_path, encoder)
|
||||
param_decoder = load_checkpoint(decoder_path, decoder)
|
||||
|
||||
return param_encoder, param_decoder
|
||||
|
||||
def resume_discriminator(args, discriminator, dis_ckpt_name):
|
||||
"""Restore the trained discriminator"""
|
||||
print("Loading the trained models from step {}...".format(args.save_interval))
|
||||
discriminator_path = os.path.join('output', args.experiment_name, 'checkpoint/rank0', dis_ckpt_name)
|
||||
param_discriminator = load_checkpoint(discriminator_path, discriminator)
|
||||
|
||||
return param_discriminator
|
||||
|
||||
def denorm(x):
|
||||
image_numpy = (np.transpose(x, (0, 2, 3, 1)) + 1) / 2.0 * 255.0
|
||||
image_numpy = np.clip(image_numpy, 0, 255)
|
||||
return image_numpy
|
|
@ -0,0 +1,265 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================sss
|
||||
"""Entry point for training AttGAN network"""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from os.path import join
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, context
|
||||
from mindspore import nn
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, _InternalCallbackParam, RunContext
|
||||
from mindspore.train.serialization import load_param_into_net
|
||||
|
||||
from src.attgan import Genc, Gdec, Dis
|
||||
from src.cell import TrainOneStepCellGen, TrainOneStepCellDis, init_weights
|
||||
from src.data import data_loader
|
||||
from src.helpers import Progressbar
|
||||
from src.loss import GenLoss, DisLoss
|
||||
from src.utils import resume_model, resume_discriminator
|
||||
|
||||
attrs_default = [
|
||||
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows',
|
||||
'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'
|
||||
]
|
||||
|
||||
|
||||
def parse(arg=None):
|
||||
"""Define configuration of Model"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--attrs', dest='attrs', default=attrs_default, nargs='+', help='attributes to learn')
|
||||
parser.add_argument('--data', dest='data', type=str, choices=['CelebA'], default='CelebA')
|
||||
parser.add_argument('--data_path', dest='data_path', type=str, default='./data/img_align_celeba')
|
||||
parser.add_argument('--attr_path', dest='attr_path', type=str, default='./data/list_attr_celeba.txt')
|
||||
|
||||
parser.add_argument('--img_size', dest='img_size', type=int, default=128)
|
||||
parser.add_argument('--shortcut_layers', dest='shortcut_layers', type=int, default=1)
|
||||
parser.add_argument('--inject_layers', dest='inject_layers', type=int, default=1)
|
||||
|
||||
parser.add_argument('--enc_dim', dest='enc_dim', type=int, default=64)
|
||||
parser.add_argument('--dec_dim', dest='dec_dim', type=int, default=64)
|
||||
parser.add_argument('--dis_dim', dest='dis_dim', type=int, default=64)
|
||||
parser.add_argument('--dis_fc_dim', dest='dis_fc_dim', type=int, default=1024)
|
||||
parser.add_argument('--enc_layers', dest='enc_layers', type=int, default=5)
|
||||
parser.add_argument('--dec_layers', dest='dec_layers', type=int, default=5)
|
||||
parser.add_argument('--dis_layers', dest='dis_layers', type=int, default=5)
|
||||
parser.add_argument('--enc_norm', dest='enc_norm', type=str, default='batchnorm')
|
||||
parser.add_argument('--dec_norm', dest='dec_norm', type=str, default='batchnorm')
|
||||
parser.add_argument('--dis_norm', dest='dis_norm', type=str, default='instancenorm')
|
||||
parser.add_argument('--dis_fc_norm', dest='dis_fc_norm', type=str, default='none')
|
||||
parser.add_argument('--enc_acti', dest='enc_acti', type=str, default='lrelu')
|
||||
parser.add_argument('--dec_acti', dest='dec_acti', type=str, default='relu')
|
||||
parser.add_argument('--dis_acti', dest='dis_acti', type=str, default='lrelu')
|
||||
parser.add_argument('--dis_fc_acti', dest='dis_fc_acti', type=str, default='relu')
|
||||
parser.add_argument('--lambda_1', dest='lambda_1', type=float, default=100.0)
|
||||
parser.add_argument('--lambda_2', dest='lambda_2', type=float, default=10.0)
|
||||
parser.add_argument('--lambda_3', dest='lambda_3', type=float, default=1.0)
|
||||
parser.add_argument('--lambda_gp', dest='lambda_gp', type=float, default=10.0)
|
||||
|
||||
parser.add_argument('--epochs', dest='epochs', type=int, default=200, help='# of epochs')
|
||||
parser.add_argument('--batch_size', dest='batch_size', type=int, default=32)
|
||||
parser.add_argument('--num_workers', dest='num_workers', type=int, default=16)
|
||||
parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='learning rate')
|
||||
parser.add_argument('--beta1', dest='beta1', type=float, default=0.5)
|
||||
parser.add_argument('--beta2', dest='beta2', type=float, default=0.999)
|
||||
parser.add_argument('--n_d', dest='n_d', type=int, default=5, help='# of d updates per g update')
|
||||
|
||||
parser.add_argument('--thres_int', dest='thres_int', type=float, default=0.5)
|
||||
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
|
||||
parser.add_argument('--save_interval', dest='save_interval', type=int, default=500)
|
||||
parser.add_argument('--experiment_name', dest='experiment_name',
|
||||
default=datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))
|
||||
parser.add_argument("--run_distribute", type=int, default=0, help="Run distribute, default: false.")
|
||||
parser.add_argument('--resume_model', action='store_true')
|
||||
parser.add_argument('--enc_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--dec_ckpt_name', type=str, default='')
|
||||
parser.add_argument('--dis_ckpt_name', type=str, default='')
|
||||
|
||||
return parser.parse_args(arg)
|
||||
|
||||
|
||||
args = parse()
|
||||
print(args)
|
||||
|
||||
args.lr_base = args.lr
|
||||
args.n_attrs = len(args.attrs)
|
||||
|
||||
# initialize environment
|
||||
set_seed(1)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
|
||||
if args.run_distribute:
|
||||
if os.getenv("DEVICE_ID", "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv("DEVICE_ID")))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
print(device_num)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=device_num)
|
||||
init()
|
||||
rank = get_rank()
|
||||
else:
|
||||
if os.getenv("DEVICE_ID", "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv("DEVICE_ID")))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
rank = 0
|
||||
|
||||
print("Initialize successful!")
|
||||
|
||||
os.makedirs(join('output', args.experiment_name), exist_ok=True)
|
||||
os.makedirs(join('output', args.experiment_name, 'checkpoint'), exist_ok=True)
|
||||
with open(join('output', args.experiment_name, 'setting.txt'), 'w') as f:
|
||||
f.write(json.dumps(vars(args), indent=4, separators=(',', ':')))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# Define dataloader
|
||||
train_dataset, train_length = data_loader(img_path=args.data_path,
|
||||
attr_path=args.attr_path,
|
||||
selected_attrs=args.attrs,
|
||||
mode="train",
|
||||
batch_size=args.batch_size,
|
||||
device_num=device_num,
|
||||
shuffle=True)
|
||||
train_loader = train_dataset.create_dict_iterator()
|
||||
|
||||
valid_dataset, valid_length = data_loader(img_path=args.data_path,
|
||||
attr_path=args.attr_path,
|
||||
selected_attrs=args.attrs,
|
||||
mode="valid",
|
||||
batch_size=args.batch_size,
|
||||
device_num=device_num,
|
||||
shuffle=False)
|
||||
valid_loader = valid_dataset.create_dict_iterator()
|
||||
|
||||
print('Training images:', train_length, '/', 'Validating images:', valid_length)
|
||||
|
||||
# Define network
|
||||
genc = Genc(args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, mode='train')
|
||||
gdec = Gdec(args.dec_dim, args.dec_layers, args.dec_norm, args.dec_acti, args.n_attrs, args.shortcut_layers,
|
||||
args.inject_layers, args.img_size, mode='train')
|
||||
dis = Dis(args.dis_dim, args.dis_norm, args.dis_acti, args.dis_fc_dim, args.dis_fc_norm, args.dis_fc_acti,
|
||||
args.dis_layers, args.img_size, mode='train')
|
||||
|
||||
# Initialize network
|
||||
init_weights(genc, 'KaimingUniform', math.sqrt(5))
|
||||
init_weights(gdec, 'KaimingUniform', math.sqrt(5))
|
||||
init_weights(dis, 'KaimingUniform', math.sqrt(5))
|
||||
|
||||
# Resume from checkpoint
|
||||
if args.resume_model:
|
||||
para_genc, para_gdec = resume_model(args, genc, gdec, args.enc_ckpt_name, args.dec_ckpt_name)
|
||||
para_dis = resume_discriminator(args, dis, args.dis_ckpt_name)
|
||||
load_param_into_net(genc, para_genc)
|
||||
load_param_into_net(gdec, para_gdec)
|
||||
load_param_into_net(dis, para_dis)
|
||||
|
||||
# Define network with loss
|
||||
G_loss_cell = GenLoss(args, genc, gdec, dis)
|
||||
D_loss_cell = DisLoss(args, genc, gdec, dis)
|
||||
|
||||
# Define Optimizer
|
||||
G_trainable_params = genc.trainable_params() + gdec.trainable_params()
|
||||
optimizer_G = nn.Adam(params=G_trainable_params, learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
|
||||
optimizer_D = nn.Adam(params=dis.trainable_params(), learning_rate=args.lr, beta1=args.beta1, beta2=args.beta2)
|
||||
|
||||
# Define One Step Train
|
||||
G_trainOneStep = TrainOneStepCellGen(G_loss_cell, optimizer_G)
|
||||
D_trainOneStep = TrainOneStepCellDis(D_loss_cell, optimizer_D)
|
||||
|
||||
# Train
|
||||
G_trainOneStep.set_train(True)
|
||||
D_trainOneStep.set_train(True)
|
||||
|
||||
print("Start Training")
|
||||
|
||||
train_iter = train_length // args.batch_size
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_interval)
|
||||
|
||||
if rank == 0:
|
||||
local_train_url = os.path.join('output', args.experiment_name, 'checkpoint/rank{}'.format(rank))
|
||||
ckpt_cb_genc = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='encoder')
|
||||
ckpt_cb_gdec = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='decoder')
|
||||
ckpt_cb_dis = ModelCheckpoint(config=ckpt_config, directory=local_train_url, prefix='discriminator')
|
||||
|
||||
cb_params_genc = _InternalCallbackParam()
|
||||
cb_params_genc.train_network = genc
|
||||
cb_params_genc.cur_epoch_num = 0
|
||||
genc_run_context = RunContext(cb_params_genc)
|
||||
ckpt_cb_genc.begin(genc_run_context)
|
||||
|
||||
cb_params_gdec = _InternalCallbackParam()
|
||||
cb_params_gdec.train_network = gdec
|
||||
cb_params_gdec.cur_epoch_num = 0
|
||||
gdec_run_context = RunContext(cb_params_gdec)
|
||||
ckpt_cb_gdec.begin(gdec_run_context)
|
||||
|
||||
cb_params_dis = _InternalCallbackParam()
|
||||
cb_params_dis.train_network = dis
|
||||
cb_params_dis.cur_epoch_num = 0
|
||||
dis_run_context = RunContext(cb_params_dis)
|
||||
ckpt_cb_dis.begin(dis_run_context)
|
||||
|
||||
# Initialize Progressbar
|
||||
progressbar = Progressbar()
|
||||
|
||||
it = 0
|
||||
for epoch in range(args.epochs):
|
||||
|
||||
for data in progressbar(train_loader, train_iter):
|
||||
img_a = data["image"]
|
||||
att_a = data["attr"]
|
||||
att_a = att_a.asnumpy()
|
||||
att_b = np.random.permutation(att_a)
|
||||
|
||||
att_a_ = (att_a * 2 - 1) * args.thres_int
|
||||
att_b_ = (att_b * 2 - 1) * args.thres_int
|
||||
|
||||
att_a = Tensor(att_a, mstype.float32)
|
||||
att_a_ = Tensor(att_a_, mstype.float32)
|
||||
att_b = Tensor(att_b, mstype.float32)
|
||||
att_b_ = Tensor(att_b_, mstype.float32)
|
||||
|
||||
if (it + 1) % (args.n_d + 1) != 0:
|
||||
d_out, d_real_loss, d_fake_loss, dc_loss, df_gp = D_trainOneStep(img_a, att_a, att_a_, att_b, att_b_)
|
||||
else:
|
||||
g_out, gf_loss, gc_loss, gr_loss = G_trainOneStep(img_a, att_a, att_a_, att_b, att_b_)
|
||||
progressbar.say(epoch=epoch, iter=it + 1, d_loss=d_out, g_loss=g_out, gf_loss=gf_loss, gc_loss=gc_loss,
|
||||
gr_loss=gr_loss, dc_loss=dc_loss, df_gp=df_gp)
|
||||
|
||||
if (epoch + 1) % 5 == 0 and (it + 1) % args.save_interval == 0 and rank == 0:
|
||||
cb_params_genc.cur_epoch_num = epoch + 1
|
||||
cb_params_gdec.cur_epoch_num = epoch + 1
|
||||
cb_params_dis.cur_epoch_num = epoch + 1
|
||||
cb_params_genc.cur_step_num = it + 1
|
||||
cb_params_gdec.cur_step_num = it + 1
|
||||
cb_params_dis.cur_step_num = it + 1
|
||||
cb_params_genc.batch_num = it + 2
|
||||
cb_params_gdec.batch_num = it + 2
|
||||
cb_params_dis.batch_num = it + 2
|
||||
ckpt_cb_genc.step_end(genc_run_context)
|
||||
ckpt_cb_gdec.step_end(gdec_run_context)
|
||||
ckpt_cb_dis.step_end(dis_run_context)
|
||||
it += 1
|
Loading…
Reference in New Issue