forked from mindspore-Ecosystem/mindspore
add new model: wgan
This commit is contained in:
parent
5ed742a893
commit
283672992c
|
@ -0,0 +1,269 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [WGAN描述](#WGAN描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [单机训练](#单机训练)
|
||||
- [推理过程](#推理过程)
|
||||
- [推理](#推理)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [训练性能](#训练性能)
|
||||
- [第一种情况(选用标准卷积DCGAN的生成器结构)](#第一种情况(选用标准卷积DCGAN的生成器结构))
|
||||
- [第二种情况(选用没有BatchNorm的卷积DCGAN的生成器结构)](#第二种情况(选用没有BatchNorm的卷积DCGAN的生成器结构))
|
||||
- [推理性能](#推理性能)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# WGAN描述
|
||||
|
||||
WGAN(Wasserstein GAN的简称)是一种基于Wasserstein距离的生成对抗网络(GAN),包括生成器网络和判别器网络,它通过改进原始GAN的算法流程,彻底解决了GAN训练不稳定的问题,确保了生成样本的多样性,并且训练过程中终于有一个像交叉熵、准确率这样的数值来指示训练的进程,即-loss_D,这个数值越小代表GAN训练得越好,代表生成器产生的图像质量越高。
|
||||
|
||||
[论文](https://arxiv.org/abs/1701.07875):Martin Arjovsky, Soumith Chintala, Léon Bottou. "Wasserstein GAN"*In International Conference on Machine Learning(ICML 2017).
|
||||
|
||||
# 模型架构
|
||||
|
||||
WGAN网络包含两部分,生成器网络和判别器网络。判别器网络采用卷积DCGAN的架构,即多层二维卷积相连。生成器网络分别采用卷积DCGAN生成器结构、没有BatchNorm的卷积DCGAN生成器结构。输入数据包括真实图片数据和噪声数据,真实图片resize到64*64,噪声数据随机生成。
|
||||
|
||||
# 数据集
|
||||
|
||||
[LSUN-Bedrooms](<http://dl.yf.io/lsun/scenes/bedroom_train_lmdb.zip>)
|
||||
|
||||
- 数据集大小:42.8G
|
||||
- 训练集:42.8G,共3033044张图像。
|
||||
- 注:对于生成对抗网络,推理部分是传入噪声数据生成图片,故无需使用测试集数据。
|
||||
- 数据格式:原始数据格式为lmdb格式,需要使用LSUN官网格式转换脚本把lmdb数据export所有图片,并将Bedrooms这一类图片放到同一文件夹下。
|
||||
- 注:LSUN数据集官网的数据格式转换脚本地址:(https://github.com/fyu/lsun)
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 使用Ascend来搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```python
|
||||
# 运行单机训练示例(包括以下两种情况):
|
||||
sh run_train.sh [DATASET] [DATAROOT] [DEVICE_ID] [NOBN]
|
||||
|
||||
# 第一种情况(选用标准卷积DCGAN的生成器结构):
|
||||
sh run_train.sh [DATASET] [DATAROOT] [DEVICE_ID] False
|
||||
|
||||
# 第二种情况(选用没有BatchNorm的卷积DCGAN的生成器结构):
|
||||
sh run_train.sh [DATASET] [DATAROOT] [DEVICE_ID] True
|
||||
|
||||
|
||||
# 运行评估示例
|
||||
sh run_eval.sh [DEVICE_ID] [CONFIG_PATH] [CKPT_FILE_PATH] [OUTPUT_DIR] [NIMAGES]
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
||||
```bash
|
||||
├── model_zoo
|
||||
├── README.md // 所有模型相关说明
|
||||
├── WGAN
|
||||
├── README.md // WGAN相关说明
|
||||
├── scripts
|
||||
│ ├── run_train.sh // 单机到Ascend处理器的shell脚本
|
||||
│ ├── run_eval.sh // Ascend评估的shell脚本
|
||||
├── src
|
||||
│ ├── dataset.py // 创建数据集及数据预处理
|
||||
│ ├── dcgan_model.py // WGAN架构,标准的DCGAN架构
|
||||
│ ├── dcgannobn_model.py // WGAN架构,没有BatchNorm的DCGAN架构
|
||||
│ ├── args.py // 参数配置文件
|
||||
│ ├── cell.py // 模型单步训练文件
|
||||
├── train.py // 训练脚本
|
||||
├── eval.py // 评估脚本
|
||||
├── export.py // 将checkpoint文件导出到mindir下
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
在args.py中可以同时配置训练参数、评估参数及模型导出参数。
|
||||
|
||||
```python
|
||||
# common_config
|
||||
'device_target': 'Ascend', # 运行设备
|
||||
'device_id': 0, # 用于训练或评估数据集的设备ID
|
||||
|
||||
# train_config
|
||||
'dataset': 'lsun', # 数据集名称
|
||||
'dataroot': None, # 数据集路径,必须输入,不能为空
|
||||
'workers': 8, # 数据加载线程数
|
||||
'batchSize': 64, # 批处理大小
|
||||
'imageSize': 64, # 图片尺寸大小
|
||||
'nc': 3, # 传入图片的通道数
|
||||
'nz': 100, # 初始噪声向量大小
|
||||
'ndf': 64, # 判别器网络基础特征数目
|
||||
'ngf': 64, # 生成器网络基础特征数目
|
||||
'niter': 25, # 网络训练的epoch数
|
||||
'lrD': 0.00005, # 判别器初始学习率
|
||||
'lrG': 0.00005, # 生成器初始学习率
|
||||
'netG': '', # 恢复训练的生成器的ckpt文件路径
|
||||
'netD': '', # 恢复训练的判别器的ckpt文件路径
|
||||
'clamp_lower': -0.01, # 将优化器参数限定在某一范围的下界
|
||||
'clamp_upper': 0.01, # 将优化器参数限定在某一范围的上界
|
||||
'Diters': 5, # 每训练一次生成器需要训练判别器的次数
|
||||
'noBN': False, # 卷积生成器网络中是否使用BatchNorm,默认是使用
|
||||
'n_extra_layers': 0, # 生成器和判别器网络中附加层的数目,默认是0
|
||||
'experiment': None, # 保存模型和生成图片的路径,若不指定,则使用默认路径
|
||||
'adam': False, # 是否使用Adam优化器,默认是不使用,使用的是RMSprop优化器
|
||||
|
||||
# eval_config
|
||||
'config': None, # 训练生成的生成器的配置文件.json文件路径,必须指定
|
||||
'ckpt_file': None, # 训练时保存的生成器的权重文件.ckpt的路径,必须指定
|
||||
'output_dir': None, # 生成图片的输出路径,必须指定
|
||||
'nimages': 1, # 生成图片的数量,默认是1
|
||||
|
||||
# export_config
|
||||
'config': None, # 训练生成的生成器的配置文件.json文件路径,必须指定
|
||||
'ckpt_file': None, # 训练时保存的生成器的权重文件.ckpt的路径,必须指定
|
||||
'file_name': 'WGAN', # 输出文件名字的前缀,默认是'WGAN'
|
||||
'file_format': 'AIR', # 模型输出格式,可选["AIR", "ONNX", "MINDIR"],默认是'AIR'
|
||||
'nimages': 1, # 生成图片的数量,默认是1
|
||||
|
||||
```
|
||||
|
||||
更多配置细节请参考脚本`args.py`。
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 单机训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
sh run_train.sh [DATASET] [DATAROOT] [DEVICE_ID] [NOBN]
|
||||
```
|
||||
|
||||
第一种情况(选用标准卷积DCGAN的生成器结构):
|
||||
|
||||
```bash
|
||||
sh run_train.sh [DATASET] [DATAROOT] [DEVICE_ID] False
|
||||
```
|
||||
|
||||
第二种情况(选用没有BatchNorm的卷积DCGAN的生成器结构):
|
||||
|
||||
```bash
|
||||
sh run_train.sh [DATASET] [DATAROOT] [DEVICE_ID] True
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过train.log文件查看结果。
|
||||
|
||||
训练结束后,您可在存储的文件夹(默认是./samples)下找到生成的图片、检查点文件和.json文件。采用以下方式得到损失值:
|
||||
|
||||
```bash
|
||||
[0/25][2300/47391][23] Loss_D: -1.555344 Loss_G: 0.761238
|
||||
[0/25][2400/47391][24] Loss_D: -1.557617 Loss_G: 0.762344
|
||||
...
|
||||
```
|
||||
|
||||
## 推理过程
|
||||
|
||||
### 推理
|
||||
|
||||
- 在Ascend环境下评估
|
||||
|
||||
在运行以下命令之前,请检查用于推理的检查点和json文件路径,并设置输出图片的路径。
|
||||
|
||||
```bash
|
||||
sh run_eval.sh [DEVICE_ID] [CONFIG_PATH] [CKPT_FILE_PATH] [OUTPUT_DIR] [NIMAGES]
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过eval/eval.log文件查看日志信息,在输出图片的路径下查看生成的图片。
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 训练性能
|
||||
|
||||
#### 第一种情况(选用标准卷积DCGAN的生成器结构)
|
||||
|
||||
| 参数 | Ascend |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| 资源 | Ascend 910 ;CPU 2.60GHz,192核;内存:755G |
|
||||
| 上传日期 | 2021-05-14 |
|
||||
| MindSpore版本 | 1.2.0-alpha |
|
||||
| 数据集 | LSUN-Bedrooms |
|
||||
| 训练参数 | max_epoch=25, batch_size=64, lr_init=0.00005 |
|
||||
| 优化器 | RMSProp |
|
||||
| 损失函数 | 自定义损失函数 |
|
||||
| 输出 | 生成的图片 |
|
||||
| 速度 | 单卡:190毫秒/步 |
|
||||
| 总时长 | 单卡12小时10分钟 |
|
||||
| 参数(M) | 6.57 |
|
||||
| 微调检查点 | 13.98M (.ckpt文件) |
|
||||
| 推理模型 | 14.00M (.mindir文件) |
|
||||
| 脚本 | [WGAN脚本](https://gitee.com/mindspore/mindspore/tree/r1.2/model_zoo/research/cv/WGAN) |
|
||||
|
||||
生成图片效果如下:
|
||||
|
||||
![GenSample1](imgs/WGAN_1.png "第一种情况生成的图片样本")
|
||||
|
||||
#### 第二种情况(选用没有BatchNorm的卷积DCGAN的生成器结构)
|
||||
|
||||
| 参数 | Ascend |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| 资源 | Ascend 910 ;CPU 2.60GHz,192核;内存:755G |
|
||||
| 上传日期 | 2021-05-14 |
|
||||
| MindSpore版本 | 1.2.0-alpha |
|
||||
| 数据集 | LSUN-Bedrooms |
|
||||
| 训练参数 | max_epoch=25, batch_size=64, lr_init=0.00005 |
|
||||
| 优化器 | RMSProp |
|
||||
| 损失函数 | 自定义损失函数 |
|
||||
| 输出 | 生成的图片 |
|
||||
| 速度 | 单卡:180毫秒/步 |
|
||||
| 总时长 | 单卡:11小时40分钟 |
|
||||
| 参数(M) | 6.45 |
|
||||
| 微调检查点 | 13.98M (.ckpt文件) |
|
||||
| 推理模型 | 14.00M (.mindir文件) |
|
||||
| 脚本 | [WGAN脚本](https://gitee.com/mindspore/mindspore/tree/r1.2/model_zoo/research/cv/WGAN) |
|
||||
|
||||
生成图片效果如下:
|
||||
|
||||
![GenSample2](imgs/WGAN_2.png "第二种情况生成的图片样本")
|
||||
|
||||
### 推理性能
|
||||
|
||||
#### 推理
|
||||
|
||||
| 参数 | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| 资源 | Ascend 910 |
|
||||
| 上传日期 | 2021-05-14 |
|
||||
| MindSpore 版本 | 1.2.0-alpha |
|
||||
| 数据集 | LSUN-Bedrooms |
|
||||
| batch_size | 1 |
|
||||
| 输出 | 生成的图片 |
|
||||
|
||||
# 随机情况说明
|
||||
|
||||
在train.py中,我们设置了随机种子。
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test WGAN """
|
||||
import os
|
||||
import json
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore import context
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from src.dcgan_model import DcganG
|
||||
from src.dcgannobn_model import DcgannobnG
|
||||
from src.args import get_args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
args_opt = get_args('eval')
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
|
||||
with open(args_opt.config, 'r') as gencfg:
|
||||
generator_config = json.loads(gencfg.read())
|
||||
|
||||
imageSize = generator_config["imageSize"]
|
||||
nz = generator_config["nz"]
|
||||
nc = generator_config["nc"]
|
||||
ngf = generator_config["ngf"]
|
||||
noBN = generator_config["noBN"]
|
||||
n_extra_layers = generator_config["n_extra_layers"]
|
||||
|
||||
# generator
|
||||
if noBN:
|
||||
netG = DcgannobnG(imageSize, nz, nc, ngf, n_extra_layers)
|
||||
else:
|
||||
netG = DcganG(imageSize, nz, nc, ngf, n_extra_layers)
|
||||
|
||||
# load weights
|
||||
load_param_into_net(netG, load_checkpoint(args_opt.ckpt_file))
|
||||
|
||||
# initialize noise
|
||||
fixed_noise = Tensor(np.random.normal(size=[args_opt.nimages, nz, 1, 1]), dtype=mstype.float32)
|
||||
|
||||
fake = netG(fixed_noise)
|
||||
mul = ops.Mul()
|
||||
add = ops.Add()
|
||||
reshape = ops.Reshape()
|
||||
fake = mul(fake, 0.5*255)
|
||||
fake = add(fake, 0.5*255)
|
||||
|
||||
for i in range(args_opt.nimages):
|
||||
img_pil = reshape(fake[i, ...], (1, nc, imageSize, imageSize))
|
||||
img_pil = img_pil.asnumpy()[0].astype(np.uint8).transpose((1, 2, 0))
|
||||
img_pil = Image.fromarray(img_pil)
|
||||
img_pil.save(os.path.join(args_opt.output_dir, "generated_%02d.png" % i))
|
||||
|
||||
print("Generate images success!")
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
##############export checkpoint file into air, onnx, mindir models#################
|
||||
python export.py
|
||||
"""
|
||||
import json
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||
|
||||
from src.args import get_args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args_opt = get_args('export')
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
|
||||
with open(args_opt.config, 'r') as gencfg:
|
||||
generator_config = json.loads(gencfg.read())
|
||||
|
||||
imageSize = generator_config["imageSize"]
|
||||
nz = generator_config["nz"]
|
||||
nc = generator_config["nc"]
|
||||
ngf = generator_config["ngf"]
|
||||
noBN = generator_config["noBN"]
|
||||
n_extra_layers = generator_config["n_extra_layers"]
|
||||
|
||||
# generator
|
||||
if noBN:
|
||||
netG = DcgannobnG(imageSize, nz, nc, ngf, n_extra_layers)
|
||||
else:
|
||||
netG = DcganG(imageSize, nz, nc, ngf, n_extra_layers)
|
||||
|
||||
# load weights
|
||||
load_param_into_net(netG, load_checkpoint(args_opt.ckpt_file))
|
||||
|
||||
# initialize noise
|
||||
fixed_noise = Tensor(np.random.normal(size=[args_opt.nimages, nz, 1, 1]), dtype=mstype.float32)
|
||||
|
||||
export(netG, fixed_noise, file_name=args_opt.file_name, file_format=args_opt.file_format)
|
Binary file not shown.
Binary file not shown.
After Width: | Height: | Size: 531 KiB |
Binary file not shown.
After Width: | Height: | Size: 536 KiB |
|
@ -0,0 +1,3 @@
|
|||
mindspore
|
||||
numpy
|
||||
Pillow
|
|
@ -0,0 +1,49 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash bash run_eval.sh device_id config ckpt_file output_dir nimages"
|
||||
echo "For example: bash run_eval.sh DEVICE_ID CONFIG_PATH CKPT_FILE_PATH OUTPUT_DIR NIMAGES"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
echo "$EXEC_PATH"
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
cd ../
|
||||
rm -rf eval
|
||||
mkdir eval
|
||||
cd ./eval
|
||||
mkdir src
|
||||
cd ../
|
||||
cp ./*.py ./eval
|
||||
cp ./src/*.py ./eval/src
|
||||
cd ./eval
|
||||
|
||||
env > env0.log
|
||||
|
||||
echo "train begin."
|
||||
python eval.py --device_id $1 --config $2 --ckpt_file $3 --output_dir $4 --nimages $5 > ./eval.log 2>&1 &
|
||||
|
||||
if [ $? -eq 0 ];then
|
||||
echo "eval success"
|
||||
else
|
||||
echo "eval failed"
|
||||
exit 2
|
||||
fi
|
||||
echo "finish"
|
||||
cd ../
|
|
@ -0,0 +1,49 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash bash run_train.sh dataset dataroot device_id noBN"
|
||||
echo "For example: bash run_train.sh lsun /opt_data/xidian_wks/lsun 3 False"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
echo "$EXEC_PATH"
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
cd ../
|
||||
rm -rf train
|
||||
mkdir train
|
||||
cd ./train
|
||||
mkdir src
|
||||
cd ../
|
||||
cp ./*.py ./train
|
||||
cp ./src/*.py ./train/src
|
||||
cd ./train
|
||||
|
||||
env > env0.log
|
||||
|
||||
echo "train begin."
|
||||
python train.py --dataset $1 --dataroot $2 --device_id $3 --noBN $4 > ./train.log 2>&1 &
|
||||
|
||||
if [ $? -eq 0 ];then
|
||||
echo "training success"
|
||||
else
|
||||
echo "training failed"
|
||||
exit 2
|
||||
fi
|
||||
echo "finish"
|
||||
cd ../
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""get args"""
|
||||
import ast
|
||||
import argparse
|
||||
|
||||
def get_args(phase):
|
||||
"""Define the common options that are used in training."""
|
||||
parser = argparse.ArgumentParser(description='WGAN')
|
||||
parser.add_argument('--device_target', default='Ascend', help='enables npu')
|
||||
parser.add_argument('--device_id', type=int, default=0)
|
||||
|
||||
if phase == 'train':
|
||||
parser.add_argument('--dataset', default='lsun', help='cifar10 | lsun')
|
||||
parser.add_argument('--dataroot', default=None, help='path to dataset')
|
||||
parser.add_argument('--is_modelarts', type=ast.literal_eval, default=False, help='train in Modelarts or not')
|
||||
parser.add_argument('--data_url', default=None, help='Location of data.')
|
||||
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
|
||||
|
||||
parser.add_argument('--workers', type=int, help='number of data loading workers', default=8)
|
||||
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
|
||||
parser.add_argument('--imageSize', type=int, default=64, help='the height/width of the input image to network')
|
||||
parser.add_argument('--nc', type=int, default=3, help='input image channels')
|
||||
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
|
||||
parser.add_argument('--ngf', type=int, default=64)
|
||||
parser.add_argument('--ndf', type=int, default=64)
|
||||
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
|
||||
parser.add_argument('--lrD', type=float, default=0.00005, help='learning rate for Critic, default=0.00005')
|
||||
parser.add_argument('--lrG', type=float, default=0.00005, help='learning rate for Generator, default=0.00005')
|
||||
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
|
||||
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
|
||||
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
|
||||
parser.add_argument('--clamp_lower', type=float, default=-0.01)
|
||||
parser.add_argument('--clamp_upper', type=float, default=0.01)
|
||||
parser.add_argument('--Diters', type=int, default=5, help='number of D iters per each G iter')
|
||||
parser.add_argument('--noBN', type=ast.literal_eval, default=False, help='use batchnorm or not (for DCGAN)')
|
||||
parser.add_argument('--n_extra_layers', type=int, default=0, help='Number of extra layers on gen and disc')
|
||||
parser.add_argument('--experiment', default=None, help='Where to store samples and models')
|
||||
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)')
|
||||
|
||||
elif phase == 'export':
|
||||
parser.add_argument('--config', required=True, type=str, help='path to generator config .json file')
|
||||
parser.add_argument('--ckpt_file', type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument('--file_name', type=str, default="WGAN", help="output file name prefix.")
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], \
|
||||
default='AIR', help='file format')
|
||||
parser.add_argument('--nimages', required=True, type=int, help="number of images to generate", default=1)
|
||||
|
||||
elif phase == 'eval':
|
||||
parser.add_argument('--config', required=True, type=str, help='path to generator config .json file')
|
||||
parser.add_argument('--ckpt_file', required=True, type=str, help='path to generator weights .ckpt file')
|
||||
parser.add_argument('--output_dir', required=True, type=str, help="path to to output directory")
|
||||
parser.add_argument('--nimages', required=True, type=int, help="number of images to generate", default=1)
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
|
@ -0,0 +1,170 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
""" Train one step """
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.composite as C
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.ops.functional as F
|
||||
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
|
||||
_get_parallel_mode)
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
||||
|
||||
class GenWithLossCell(nn.Cell):
|
||||
"""Generator with loss(wrapped)"""
|
||||
def __init__(self, netG, netD):
|
||||
super(GenWithLossCell, self).__init__()
|
||||
self.netG = netG
|
||||
self.netD = netD
|
||||
|
||||
def construct(self, noise):
|
||||
"""construct"""
|
||||
fake = self.netG(noise)
|
||||
errG = self.netD(fake)
|
||||
loss_G = errG
|
||||
return loss_G
|
||||
|
||||
|
||||
class DisWithLossCell(nn.Cell):
|
||||
""" Discriminator with loss(wrapped) """
|
||||
def __init__(self, netG, netD):
|
||||
super(DisWithLossCell, self).__init__()
|
||||
self.netG = netG
|
||||
self.netD = netD
|
||||
|
||||
def construct(self, real, noise):
|
||||
"""construct"""
|
||||
errD_real = self.netD(real)
|
||||
fake = self.netG(noise)
|
||||
errD_fake = self.netD(fake)
|
||||
loss_D = errD_real - errD_fake
|
||||
return loss_D
|
||||
|
||||
|
||||
class ClipParameter(nn.Cell):
|
||||
""" Clip the parameter """
|
||||
def __init__(self):
|
||||
super(ClipParameter, self).__init__()
|
||||
self.cast = P.Cast()
|
||||
self.dtype = P.DType()
|
||||
|
||||
def construct(self, params, clip_lower, clip_upper):
|
||||
"""construct"""
|
||||
new_params = ()
|
||||
for param in params:
|
||||
dt = self.dtype(param)
|
||||
t = C.clip_by_value(param, self.cast(F.tuple_to_array((clip_lower,)), dt),
|
||||
self.cast(F.tuple_to_array((clip_upper,)), dt))
|
||||
new_params = new_params + (t,)
|
||||
|
||||
return new_params
|
||||
|
||||
|
||||
class GenTrainOneStepCell(nn.Cell):
|
||||
""" Generator TrainOneStepCell """
|
||||
def __init__(self, netG, netD,
|
||||
optimizerG: nn.Optimizer,
|
||||
sens=1.0):
|
||||
super(GenTrainOneStepCell, self).__init__()
|
||||
self.netD = netD
|
||||
self.netD.set_train(False)
|
||||
self.netD.set_grad(False)
|
||||
self.weights_G = optimizerG.parameters
|
||||
self.optimizerG = optimizerG
|
||||
self.net = GenWithLossCell(netG, netD)
|
||||
self.net.set_train()
|
||||
self.net.set_grad()
|
||||
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
|
||||
# parallel process
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer_G = F.identity
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer_G = DistributedGradReducer(self.weights_G, mean, degree) # A distributed optimizer.
|
||||
|
||||
def construct(self, noise):
|
||||
""" construct """
|
||||
loss_G = self.net(noise)
|
||||
sens = P.Fill()(P.DType()(loss_G), P.Shape()(loss_G), self.sens)
|
||||
grads = self.grad(self.net, self.weights_G)(noise, sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer_G(grads)
|
||||
return F.depend(loss_G, self.optimizerG(grads))
|
||||
|
||||
|
||||
_my_adam_opt = C.MultitypeFuncGraph("_my_adam_opt")
|
||||
|
||||
|
||||
@_my_adam_opt.register("Tensor", "Tensor")
|
||||
def _update_run_op(param, param_clipped):
|
||||
param_clipped = F.depend(param_clipped, F.assign(param, param_clipped))
|
||||
return param_clipped
|
||||
|
||||
|
||||
class DisTrainOneStepCell(nn.Cell):
|
||||
""" Discriminator TrainOneStepCell """
|
||||
def __init__(self, netG, netD,
|
||||
optimizerD: nn.Optimizer,
|
||||
clip_lower=-0.01, clip_upper=0.01, sens=1.0):
|
||||
super(DisTrainOneStepCell, self).__init__()
|
||||
self.weights_D = optimizerD.parameters
|
||||
self.clip_parameters = ClipParameter()
|
||||
self.optimizerD = optimizerD
|
||||
self.net = DisWithLossCell(netG, netD)
|
||||
self.net.set_train()
|
||||
self.net.set_grad()
|
||||
|
||||
self.reduce_flag = False
|
||||
self.op_cast = P.Cast()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.clip_lower = clip_lower
|
||||
self.clip_upper = clip_upper
|
||||
|
||||
# parallel process
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer_D = F.identity
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer_D = DistributedGradReducer(self.weights_D, mean, degree) # A distributed optimizer.
|
||||
|
||||
def construct(self, real, noise):
|
||||
""" construct """
|
||||
loss_D = self.net(real, noise)
|
||||
sens = P.Fill()(P.DType()(loss_D), P.Shape()(loss_D), self.sens)
|
||||
grads = self.grad(self.net, self.weights_D)(real, noise, sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer_D(grads)
|
||||
|
||||
upd = self.optimizerD(grads)
|
||||
weights_D_cliped = self.clip_parameters(self.weights_D, self.clip_lower, self.clip_upper)
|
||||
res = self.hyper_map(F.partial(_my_adam_opt), self.weights_D, weights_D_cliped)
|
||||
res = F.depend(upd, res)
|
||||
return F.depend(loss_D, res)
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
""" dataset """
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as c
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
def create_dataset(dataroot, dataset, batchSize, imageSize, repeat_num=1, workers=8, target='Ascend'):
|
||||
"""Create dataset"""
|
||||
rank_id = 0
|
||||
device_num = 1
|
||||
|
||||
# define map operations
|
||||
resize_op = c.Resize(imageSize)
|
||||
center_crop_op = c.CenterCrop(imageSize)
|
||||
normalize_op = c.Normalize(mean=(0.5*255, 0.5*255, 0.5*255), std=(0.5*255, 0.5*255, 0.5*255))
|
||||
hwc2chw_op = c.HWC2CHW()
|
||||
|
||||
if dataset == 'lsun':
|
||||
if device_num == 1:
|
||||
data_set = ds.ImageFolderDataset(dataroot, num_parallel_workers=workers, shuffle=True, decode=True)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataroot, num_parallel_workers=workers, shuffle=True, decode=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
transform = [resize_op, center_crop_op, normalize_op, hwc2chw_op]
|
||||
else:
|
||||
if device_num == 1:
|
||||
data_set = ds.Cifar10Dataset(dataroot, num_parallel_workers=workers, shuffle=True)
|
||||
else:
|
||||
data_set = ds.Cifar10Dataset(dataroot, num_parallel_workers=workers, shuffle=True, \
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
transform = [resize_op, normalize_op, hwc2chw_op]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
data_set = data_set.map(input_columns='image', operations=transform, num_parallel_workers=workers)
|
||||
data_set = data_set.map(input_columns='label', operations=type_cast_op, num_parallel_workers=workers)
|
||||
|
||||
data_set = data_set.batch(batchSize, drop_remainder=True)
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
return data_set
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
""" Dcgan model """
|
||||
import mindspore.nn as nn
|
||||
|
||||
class DcganD(nn.Cell):
|
||||
""" dcgan Decriminator """
|
||||
def __init__(self, isize, nz, nc, ndf, n_extra_layers=0):
|
||||
super(DcganD, self).__init__()
|
||||
assert isize % 16 == 0, "isize has to be a multiple of 16"
|
||||
|
||||
main = nn.SequentialCell()
|
||||
main.append(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, has_bias=False))
|
||||
main.append(nn.LeakyReLU(0.2))
|
||||
|
||||
csize, cndf = isize / 2, ndf
|
||||
|
||||
# Extra layers
|
||||
for _ in range(n_extra_layers):
|
||||
main.append(nn.Conv2d(cndf, cndf, 3, 1, 'pad', 1, has_bias=False))
|
||||
main.append(nn.BatchNorm2d(cndf))
|
||||
main.append(nn.LeakyReLU(0.2))
|
||||
|
||||
while csize > 4:
|
||||
in_feat = cndf
|
||||
out_feat = cndf * 2
|
||||
|
||||
main.append(nn.Conv2d(in_feat, out_feat, 4, 2, 'pad', 1, has_bias=False))
|
||||
main.append(nn.BatchNorm2d(out_feat))
|
||||
main.append(nn.LeakyReLU(0.2))
|
||||
|
||||
cndf = cndf * 2
|
||||
csize = csize / 2
|
||||
|
||||
# state size. K x 4 x 4
|
||||
main.append(nn.Conv2d(cndf, 1, 4, 1, 'pad', 0, has_bias=False))
|
||||
self.main = main
|
||||
|
||||
def construct(self, input1):
|
||||
"""construct"""
|
||||
output = self.main(input1)
|
||||
output = output.mean(0)
|
||||
return output.view(1)
|
||||
|
||||
|
||||
class DcganG(nn.Cell):
|
||||
""" dcgan generator """
|
||||
def __init__(self, isize, nz, nc, ngf, n_extra_layers=0):
|
||||
super(DcganG, self).__init__()
|
||||
assert isize % 16 == 0, "isize has to be a multiple of 16"
|
||||
|
||||
cngf, tisize = ngf // 2, 4
|
||||
while tisize != isize:
|
||||
cngf = cngf * 2
|
||||
tisize = tisize * 2
|
||||
|
||||
main = nn.SequentialCell()
|
||||
# input is Z, going into a convolution
|
||||
main.append(nn.Conv2dTranspose(nz, cngf, 4, 1, 'pad', 0, has_bias=False))
|
||||
main.append(nn.BatchNorm2d(cngf))
|
||||
main.append(nn.ReLU())
|
||||
|
||||
csize = 4
|
||||
while csize < isize // 2:
|
||||
main.append(nn.Conv2dTranspose(cngf, cngf // 2, 4, 2, 'pad', 1, has_bias=False))
|
||||
main.append(nn.BatchNorm2d(cngf // 2))
|
||||
main.append(nn.ReLU())
|
||||
|
||||
cngf = cngf // 2
|
||||
csize = csize * 2
|
||||
|
||||
# Extra layers
|
||||
for _ in range(n_extra_layers):
|
||||
main.append(nn.Conv2d(cngf, cngf, 3, 1, 'pad', 1, has_bias=False))
|
||||
main.append(nn.BatchNorm2d(cngf))
|
||||
main.append(nn.ReLU())
|
||||
|
||||
main.append(nn.Conv2dTranspose(cngf, nc, 4, 2, 'pad', 1, has_bias=False))
|
||||
main.append(nn.Tanh())
|
||||
self.main = main
|
||||
|
||||
def construct(self, input1):
|
||||
"""construct"""
|
||||
output = self.main(input1)
|
||||
return output
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
""" DCGAN with on Batchnorm model """
|
||||
import mindspore.nn as nn
|
||||
|
||||
class DcgannobnD(nn.Cell):
|
||||
""" DCGAN Descriminator with no Batchnorm layer """
|
||||
def __init__(self, isize, nz, nc, ndf, n_extra_layers=0):
|
||||
super(DcgannobnD, self).__init__()
|
||||
assert isize % 16 == 0, "isize has to be a multiple of 16"
|
||||
|
||||
main = nn.SequentialCell()
|
||||
# input is nc x isize x isize
|
||||
main.append(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, has_bias=False))
|
||||
main.append(nn.LeakyReLU(0.2))
|
||||
|
||||
csize, cndf = isize / 2, ndf
|
||||
|
||||
# Extra layers
|
||||
for _ in range(n_extra_layers):
|
||||
main.append(nn.Conv2d(cndf, cndf, 3, 1, 'pad', 1, has_bias=False))
|
||||
main.append(nn.LeakyReLU(0.2))
|
||||
|
||||
while csize > 4:
|
||||
in_feat = cndf
|
||||
out_feat = cndf * 2
|
||||
main.append(nn.Conv2d(in_feat, out_feat, 4, 2, 'pad', 1, has_bias=False))
|
||||
main.append(nn.LeakyReLU(0.2))
|
||||
|
||||
cndf = cndf * 2
|
||||
csize = csize / 2
|
||||
|
||||
# state size. K x 4 x 4
|
||||
main.append(nn.Conv2d(cndf, 1, 4, 1, 'pad', 0, has_bias=False))
|
||||
self.main = main
|
||||
|
||||
def construct(self, input1):
|
||||
"""construct"""
|
||||
output = self.main(input1)
|
||||
output = output.mean(0)
|
||||
return output.view(1)
|
||||
|
||||
|
||||
class DcgannobnG(nn.Cell):
|
||||
""" DCGAN Generator with no BatchNorm layer """
|
||||
def __init__(self, isize, nz, nc, ngf, n_extra_layers=0):
|
||||
super(DcgannobnG, self).__init__()
|
||||
assert isize % 16 == 0, "isize has to be a multiple of 16"
|
||||
|
||||
cngf, tisize = ngf // 2, 4
|
||||
while tisize != isize:
|
||||
cngf = cngf * 2
|
||||
tisize = tisize * 2
|
||||
|
||||
main = nn.SequentialCell()
|
||||
main.append(nn.Conv2dTranspose(nz, cngf, 4, 1, 'pad', 0, has_bias=False))
|
||||
main.append(nn.ReLU())
|
||||
|
||||
csize = 4
|
||||
while csize < isize // 2:
|
||||
main.append(nn.Conv2dTranspose(cngf, cngf // 2, 4, 2, 'pad', 1, has_bias=False))
|
||||
main.append(nn.ReLU())
|
||||
|
||||
cngf = cngf // 2
|
||||
csize = csize * 2
|
||||
|
||||
# Extra layers
|
||||
for _ in range(n_extra_layers):
|
||||
main.append(nn.Conv2d(cngf, cngf, 3, 1, 'pad', 1, has_bias=False))
|
||||
main.append(nn.ReLU())
|
||||
|
||||
main.append(nn.Conv2dTranspose(cngf, nc, 4, 2, 'pad', 1, has_bias=False))
|
||||
main.append(nn.Tanh())
|
||||
self.main = main
|
||||
|
||||
def construct(self, input1):
|
||||
"""construct"""
|
||||
output = self.main(input1)
|
||||
return output
|
|
@ -0,0 +1,232 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train WGAN"""
|
||||
import os
|
||||
import random
|
||||
import json
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.ops as ops
|
||||
from mindspore.common import initializer as init
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, save_checkpoint
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.dcgan_model import DcganG, DcganD
|
||||
from src.dcgannobn_model import DcgannobnG
|
||||
from src.cell import GenTrainOneStepCell, DisTrainOneStepCell
|
||||
from src.args import get_args
|
||||
|
||||
if __name__ == '__main__':
|
||||
args_opt = get_args('train')
|
||||
print(args_opt)
|
||||
|
||||
# init context
|
||||
target = args_opt.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target)
|
||||
|
||||
# whether train on modelarts or local server
|
||||
if not args_opt.is_modelarts:
|
||||
if args_opt.experiment is None:
|
||||
args_opt.experiment = 'samples'
|
||||
os.system('mkdir {0}'.format(args_opt.experiment))
|
||||
context.set_context(device_id=int(args_opt.device_id))
|
||||
dataset = create_dataset(args_opt.dataroot, args_opt.dataset, args_opt.batchSize, args_opt.imageSize, 1,
|
||||
args_opt.workers, target)
|
||||
|
||||
else:
|
||||
import moxing as mox
|
||||
if args_opt.experiment is None:
|
||||
args_opt.experiment = '/cache/train_output'
|
||||
os.system('mkdir {0}'.format(args_opt.experiment))
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
data_name = 'LSUN-bedroom.zip'
|
||||
local_data_url = '/cache/data_path/'
|
||||
mox.file.copy_parallel(src_url=args_opt.data_url, dst_url=local_data_url)
|
||||
zip_command = "unzip -o -q %s -d %s" % (local_data_url + data_name, local_data_url)
|
||||
os.system(zip_command)
|
||||
print("Unzip success!")
|
||||
|
||||
dataset = create_dataset(local_data_url, args_opt.dataset, args_opt.batchSize, args_opt.imageSize, 1,
|
||||
args_opt.workers, target)
|
||||
|
||||
|
||||
# fix seed
|
||||
args_opt.manualSeed = random.randint(1, 10000)
|
||||
print("Random Seed: ", args_opt.manualSeed)
|
||||
random.seed(args_opt.manualSeed)
|
||||
ds.config.set_seed(args_opt.manualSeed)
|
||||
|
||||
|
||||
# initialize hyperparameters
|
||||
nz = int(args_opt.nz)
|
||||
ngf = int(args_opt.ngf)
|
||||
ndf = int(args_opt.ndf)
|
||||
nc = int(args_opt.nc)
|
||||
n_extra_layers = int(args_opt.n_extra_layers)
|
||||
|
||||
# write out generator config to generate images together wth training checkpoints
|
||||
generator_config = {"imageSize": args_opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf,
|
||||
"n_extra_layers": n_extra_layers, "noBN": args_opt.noBN}
|
||||
|
||||
with open(os.path.join(args_opt.experiment, "generator_config.json"), 'w') as gcfg:
|
||||
gcfg.write(json.dumps(generator_config) + "\n")
|
||||
|
||||
def init_weight(net):
|
||||
"""initial net weight"""
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
|
||||
cell.weight.set_data(init.initializer(init.Normal(0.02), cell.weight.shape))
|
||||
elif isinstance(cell, nn.BatchNorm2d):
|
||||
cell.gamma.set_data(init.initializer(Tensor(np.random.normal(1, 0.02, cell.gamma.shape), \
|
||||
mstype.float32), cell.gamma.shape))
|
||||
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
|
||||
|
||||
|
||||
def save_image(img, img_path):
|
||||
"""save image"""
|
||||
mul = ops.Mul()
|
||||
add = ops.Add()
|
||||
if isinstance(img, Tensor):
|
||||
img = mul(img, 255 * 0.5)
|
||||
img = add(img, 255 * 0.5)
|
||||
|
||||
img = img.asnumpy().astype(np.uint8).transpose((0, 2, 3, 1))
|
||||
|
||||
elif not isinstance(img, np.ndarray):
|
||||
raise ValueError("img should be Tensor or numpy array, but get {}".format(type(img)))
|
||||
|
||||
IMAGE_SIZE = 64 # Image size
|
||||
IMAGE_ROW = 8 # Row num
|
||||
IMAGE_COLUMN = 8 # Column num
|
||||
PADDING = 2 # Interval of small pictures
|
||||
to_image = Image.new('RGB', (IMAGE_COLUMN * IMAGE_SIZE + PADDING * (IMAGE_COLUMN + 1),
|
||||
IMAGE_ROW * IMAGE_SIZE + PADDING * (IMAGE_ROW + 1))) # create a new picture
|
||||
# cycle
|
||||
ii = 0
|
||||
for y in range(1, IMAGE_ROW + 1):
|
||||
for x in range(1, IMAGE_COLUMN + 1):
|
||||
from_image = Image.fromarray(img[ii])
|
||||
to_image.paste(from_image, ((x - 1) * IMAGE_SIZE + PADDING * x, (y - 1) * IMAGE_SIZE + PADDING * y))
|
||||
ii = ii + 1
|
||||
|
||||
to_image.save(img_path) # save
|
||||
|
||||
|
||||
# define net----------------------------------------------------------------------------------------------
|
||||
# Generator
|
||||
if args_opt.noBN:
|
||||
netG = DcgannobnG(args_opt.imageSize, nz, nc, ngf, n_extra_layers)
|
||||
else:
|
||||
netG = DcganG(args_opt.imageSize, nz, nc, ngf, n_extra_layers)
|
||||
|
||||
# write out generator config to generate images together wth training checkpoints
|
||||
generator_config = {"imageSize": args_opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf,
|
||||
"n_extra_layers": n_extra_layers, "noBN": args_opt.noBN}
|
||||
with open(os.path.join(args_opt.experiment, "generator_config.json"), 'w') as gcfg:
|
||||
gcfg.write(json.dumps(generator_config) + "\n")
|
||||
|
||||
init_weight(netG)
|
||||
|
||||
if args_opt.netG != '': # load checkpoint if needed
|
||||
load_param_into_net(netG, load_checkpoint(args_opt.netG))
|
||||
print(netG)
|
||||
|
||||
netD = DcganD(args_opt.imageSize, nz, nc, ndf, n_extra_layers)
|
||||
init_weight(netD)
|
||||
|
||||
if args_opt.netD != '':
|
||||
load_param_into_net(netD, load_checkpoint(args_opt.netD))
|
||||
print(netD)
|
||||
|
||||
input1 = Tensor(np.zeros([args_opt.batchSize, 3, args_opt.imageSize, args_opt.imageSize]), dtype=mstype.float32)
|
||||
noise = Tensor(np.zeros([args_opt.batchSize, nz, 1, 1]), dtype=mstype.float32)
|
||||
fixed_noise = Tensor(np.random.normal(0, 1, size=[args_opt.batchSize, nz, 1, 1]), dtype=mstype.float32)
|
||||
|
||||
# setup optimizer
|
||||
if args_opt.adam:
|
||||
optimizerD = nn.Adam(netD.trainable_params(), learning_rate=args_opt.lrD, beta1=args_opt.beta1, beta2=.999)
|
||||
optimizerG = nn.Adam(netG.trainable_params(), learning_rate=args_opt.lrG, beta1=args_opt.beta1, beta2=.999)
|
||||
else:
|
||||
optimizerD = nn.RMSProp(netD.trainable_params(), learning_rate=args_opt.lrD, decay=0.99)
|
||||
optimizerG = nn.RMSProp(netG.trainable_params(), learning_rate=args_opt.lrG, decay=0.99)
|
||||
|
||||
netG_train = GenTrainOneStepCell(netG, netD, optimizerG)
|
||||
netD_train = DisTrainOneStepCell(netG, netD, optimizerD, args_opt.clamp_lower, args_opt.clamp_upper)
|
||||
|
||||
netG_train.set_train()
|
||||
netD_train.set_train()
|
||||
|
||||
gen_iterations = 0
|
||||
|
||||
# Train
|
||||
for epoch in range(args_opt.niter): # niter: the num of epoch
|
||||
data_iter = dataset.create_dict_iterator()
|
||||
length = dataset.get_dataset_size()
|
||||
i = 0
|
||||
while i < length:
|
||||
############################
|
||||
# (1) Update D network
|
||||
###########################
|
||||
for p in netD.trainable_params(): # reset requires_grad
|
||||
p.requires_grad = True # they are set to False below in netG update
|
||||
|
||||
# train the discriminator Diters times
|
||||
if gen_iterations < 25 or gen_iterations % 500 == 0:
|
||||
Diters = 100
|
||||
else:
|
||||
Diters = args_opt.Diters
|
||||
j = 0
|
||||
while j < Diters and i < length:
|
||||
j += 1
|
||||
|
||||
data = data_iter.__next__()
|
||||
i += 1
|
||||
|
||||
# train with real and fake
|
||||
real = data['image']
|
||||
noise = Tensor(np.random.normal(0, 1, size=[args_opt.batchSize, nz, 1, 1]), dtype=mstype.float32)
|
||||
loss_D = netD_train(real, noise)
|
||||
|
||||
############################
|
||||
# (2) Update G network
|
||||
###########################
|
||||
for p in netD.trainable_params():
|
||||
p.requires_grad = False # to avoid computation
|
||||
|
||||
noise = Tensor(np.random.normal(0, 1, size=[args_opt.batchSize, nz, 1, 1]), dtype=mstype.float32)
|
||||
loss_G = netG_train(noise)
|
||||
gen_iterations += 1
|
||||
|
||||
print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f'
|
||||
% (epoch, args_opt.niter, i, length, gen_iterations,
|
||||
loss_D.asnumpy(), loss_G.asnumpy()))
|
||||
|
||||
if gen_iterations % 500 == 0:
|
||||
fake = netG(fixed_noise)
|
||||
save_image(real, '{0}/real_samples.png'.format(args_opt.experiment))
|
||||
save_image(fake, '{0}/fake_samples_{1}.png'.format(args_opt.experiment, gen_iterations))
|
||||
|
||||
save_checkpoint(netD, '{0}/netD_epoch_{1}.ckpt'.format(args_opt.experiment, epoch))
|
||||
save_checkpoint(netG, '{0}/netG_epoch_{1}.ckpt'.format(args_opt.experiment, epoch))
|
||||
|
||||
if args_opt.is_modelarts:
|
||||
mox.file.copy_parallel(src_url='/cache/train_output', dst_url=args_opt.train_url)
|
||||
|
||||
print("Train success!")
|
Loading…
Reference in New Issue