add new model: wgan

This commit is contained in:
jiang 2021-05-27 20:28:34 +08:00
parent 5ed742a893
commit 283672992c
15 changed files with 1217 additions and 0 deletions

View File

@ -0,0 +1,269 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [WGAN描述](#WGAN描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [单机训练](#单机训练)
- [推理过程](#推理过程)
- [推理](#推理)
- [模型描述](#模型描述)
- [性能](#性能)
- [训练性能](#训练性能)
- [第一种情况选用标准卷积DCGAN的生成器结构](#第一种情况选用标准卷积DCGAN的生成器结构)
- [第二种情况选用没有BatchNorm的卷积DCGAN的生成器结构](#第二种情况选用没有BatchNorm的卷积DCGAN的生成器结构)
- [推理性能](#推理性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# WGAN描述
WGAN(Wasserstein GAN的简称)是一种基于Wasserstein距离的生成对抗网络(GAN)包括生成器网络和判别器网络它通过改进原始GAN的算法流程彻底解决了GAN训练不稳定的问题确保了生成样本的多样性并且训练过程中终于有一个像交叉熵、准确率这样的数值来指示训练的进程即-loss_D这个数值越小代表GAN训练得越好代表生成器产生的图像质量越高。
[论文](https://arxiv.org/abs/1701.07875)Martin Arjovsky, Soumith Chintala, Léon Bottou. "Wasserstein GAN"*In International Conference on Machine Learning(ICML 2017).
# 模型架构
WGAN网络包含两部分生成器网络和判别器网络。判别器网络采用卷积DCGAN的架构即多层二维卷积相连。生成器网络分别采用卷积DCGAN生成器结构、没有BatchNorm的卷积DCGAN生成器结构。输入数据包括真实图片数据和噪声数据真实图片resize到64*64噪声数据随机生成。
# 数据集
[LSUN-Bedrooms](<http://dl.yf.io/lsun/scenes/bedroom_train_lmdb.zip>)
- 数据集大小42.8G
- 训练集42.8G共3033044张图像。
- 注:对于生成对抗网络,推理部分是传入噪声数据生成图片,故无需使用测试集数据。
- 数据格式原始数据格式为lmdb格式需要使用LSUN官网格式转换脚本把lmdb数据export所有图片并将Bedrooms这一类图片放到同一文件夹下。
- 注LSUN数据集官网的数据格式转换脚本地址(https://github.com/fyu/lsun)
# 环境要求
- 硬件Ascend
- 使用Ascend来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- Ascend处理器环境运行
```python
# 运行单机训练示例(包括以下两种情况):
sh run_train.sh [DATASET] [DATAROOT] [DEVICE_ID] [NOBN]
# 第一种情况选用标准卷积DCGAN的生成器结构
sh run_train.sh [DATASET] [DATAROOT] [DEVICE_ID] False
# 第二种情况选用没有BatchNorm的卷积DCGAN的生成器结构
sh run_train.sh [DATASET] [DATAROOT] [DEVICE_ID] True
# 运行评估示例
sh run_eval.sh [DEVICE_ID] [CONFIG_PATH] [CKPT_FILE_PATH] [OUTPUT_DIR] [NIMAGES]
```
# 脚本说明
## 脚本及样例代码
```bash
├── model_zoo
├── README.md // 所有模型相关说明
├── WGAN
├── README.md // WGAN相关说明
├── scripts
│ ├── run_train.sh // 单机到Ascend处理器的shell脚本
│ ├── run_eval.sh // Ascend评估的shell脚本
├── src
│ ├── dataset.py // 创建数据集及数据预处理
│ ├── dcgan_model.py // WGAN架构,标准的DCGAN架构
│ ├── dcgannobn_model.py // WGAN架构没有BatchNorm的DCGAN架构
│ ├── args.py // 参数配置文件
│ ├── cell.py // 模型单步训练文件
├── train.py // 训练脚本
├── eval.py // 评估脚本
├── export.py // 将checkpoint文件导出到mindir下
```
## 脚本参数
在args.py中可以同时配置训练参数、评估参数及模型导出参数。
```python
# common_config
'device_target': 'Ascend', # 运行设备
'device_id': 0, # 用于训练或评估数据集的设备ID
# train_config
'dataset': 'lsun', # 数据集名称
'dataroot': None, # 数据集路径,必须输入,不能为空
'workers': 8, # 数据加载线程数
'batchSize': 64, # 批处理大小
'imageSize': 64, # 图片尺寸大小
'nc': 3, # 传入图片的通道数
'nz': 100, # 初始噪声向量大小
'ndf': 64, # 判别器网络基础特征数目
'ngf': 64, # 生成器网络基础特征数目
'niter': 25, # 网络训练的epoch数
'lrD': 0.00005, # 判别器初始学习率
'lrG': 0.00005, # 生成器初始学习率
'netG': '', # 恢复训练的生成器的ckpt文件路径
'netD': '', # 恢复训练的判别器的ckpt文件路径
'clamp_lower': -0.01, # 将优化器参数限定在某一范围的下界
'clamp_upper': 0.01, # 将优化器参数限定在某一范围的上界
'Diters': 5, # 每训练一次生成器需要训练判别器的次数
'noBN': False, # 卷积生成器网络中是否使用BatchNorm默认是使用
'n_extra_layers': 0, # 生成器和判别器网络中附加层的数目默认是0
'experiment': None, # 保存模型和生成图片的路径,若不指定,则使用默认路径
'adam': False, # 是否使用Adam优化器默认是不使用使用的是RMSprop优化器
# eval_config
'config': None, # 训练生成的生成器的配置文件.json文件路径必须指定
'ckpt_file': None, # 训练时保存的生成器的权重文件.ckpt的路径必须指定
'output_dir': None, # 生成图片的输出路径,必须指定
'nimages': 1, # 生成图片的数量默认是1
# export_config
'config': None, # 训练生成的生成器的配置文件.json文件路径必须指定
'ckpt_file': None, # 训练时保存的生成器的权重文件.ckpt的路径必须指定
'file_name': 'WGAN', # 输出文件名字的前缀,默认是'WGAN'
'file_format': 'AIR', # 模型输出格式,可选["AIR", "ONNX", "MINDIR"],默认是'AIR'
'nimages': 1, # 生成图片的数量默认是1
```
更多配置细节请参考脚本`args.py`。
## 训练过程
### 单机训练
- Ascend处理器环境运行
```bash
sh run_train.sh [DATASET] [DATAROOT] [DEVICE_ID] [NOBN]
```
第一种情况选用标准卷积DCGAN的生成器结构
```bash
sh run_train.sh [DATASET] [DATAROOT] [DEVICE_ID] False
```
第二种情况选用没有BatchNorm的卷积DCGAN的生成器结构
```bash
sh run_train.sh [DATASET] [DATAROOT] [DEVICE_ID] True
```
上述python命令将在后台运行您可以通过train.log文件查看结果。
训练结束后,您可在存储的文件夹(默认是./samples下找到生成的图片、检查点文件和.json文件。采用以下方式得到损失值
```bash
[0/25][2300/47391][23] Loss_D: -1.555344 Loss_G: 0.761238
[0/25][2400/47391][24] Loss_D: -1.557617 Loss_G: 0.762344
...
```
## 推理过程
### 推理
- 在Ascend环境下评估
在运行以下命令之前请检查用于推理的检查点和json文件路径并设置输出图片的路径。
```bash
sh run_eval.sh [DEVICE_ID] [CONFIG_PATH] [CKPT_FILE_PATH] [OUTPUT_DIR] [NIMAGES]
```
上述python命令将在后台运行您可以通过eval/eval.log文件查看日志信息在输出图片的路径下查看生成的图片。
# 模型描述
## 性能
### 训练性能
#### 第一种情况选用标准卷积DCGAN的生成器结构
| 参数 | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| 资源 | Ascend 910 CPU 2.60GHz192核内存755G |
| 上传日期 | 2021-05-14 |
| MindSpore版本 | 1.2.0-alpha |
| 数据集 | LSUN-Bedrooms |
| 训练参数 | max_epoch=25, batch_size=64, lr_init=0.00005 |
| 优化器 | RMSProp |
| 损失函数 | 自定义损失函数 |
| 输出 | 生成的图片 |
| 速度 | 单卡190毫秒/步 |
| 总时长 | 单卡12小时10分钟 |
| 参数(M) | 6.57 |
| 微调检查点 | 13.98M (.ckpt文件) |
| 推理模型 | 14.00M (.mindir文件) |
| 脚本 | [WGAN脚本](https://gitee.com/mindspore/mindspore/tree/r1.2/model_zoo/research/cv/WGAN) |
生成图片效果如下:
![GenSample1](imgs/WGAN_1.png "第一种情况生成的图片样本")
#### 第二种情况选用没有BatchNorm的卷积DCGAN的生成器结构
| 参数 | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| 资源 | Ascend 910 CPU 2.60GHz192核内存755G |
| 上传日期 | 2021-05-14 |
| MindSpore版本 | 1.2.0-alpha |
| 数据集 | LSUN-Bedrooms |
| 训练参数 | max_epoch=25, batch_size=64, lr_init=0.00005 |
| 优化器 | RMSProp |
| 损失函数 | 自定义损失函数 |
| 输出 | 生成的图片 |
| 速度 | 单卡180毫秒/步 |
| 总时长 | 单卡11小时40分钟 |
| 参数(M) | 6.45 |
| 微调检查点 | 13.98M (.ckpt文件) |
| 推理模型 | 14.00M (.mindir文件) |
| 脚本 | [WGAN脚本](https://gitee.com/mindspore/mindspore/tree/r1.2/model_zoo/research/cv/WGAN) |
生成图片效果如下:
![GenSample2](imgs/WGAN_2.png "第二种情况生成的图片样本")
### 推理性能
#### 推理
| 参数 | Ascend |
| ------------------- | --------------------------- |
| 资源 | Ascend 910 |
| 上传日期 | 2021-05-14 |
| MindSpore 版本 | 1.2.0-alpha |
| 数据集 | LSUN-Bedrooms |
| batch_size | 1 |
| 输出 | 生成的图片 |
# 随机情况说明
在train.py中我们设置了随机种子。
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,72 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test WGAN """
import os
import json
import mindspore.common.dtype as mstype
import mindspore.ops as ops
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore import context
import numpy as np
from PIL import Image
from src.dcgan_model import DcganG
from src.dcgannobn_model import DcgannobnG
from src.args import get_args
if __name__ == "__main__":
args_opt = get_args('eval')
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id)
with open(args_opt.config, 'r') as gencfg:
generator_config = json.loads(gencfg.read())
imageSize = generator_config["imageSize"]
nz = generator_config["nz"]
nc = generator_config["nc"]
ngf = generator_config["ngf"]
noBN = generator_config["noBN"]
n_extra_layers = generator_config["n_extra_layers"]
# generator
if noBN:
netG = DcgannobnG(imageSize, nz, nc, ngf, n_extra_layers)
else:
netG = DcganG(imageSize, nz, nc, ngf, n_extra_layers)
# load weights
load_param_into_net(netG, load_checkpoint(args_opt.ckpt_file))
# initialize noise
fixed_noise = Tensor(np.random.normal(size=[args_opt.nimages, nz, 1, 1]), dtype=mstype.float32)
fake = netG(fixed_noise)
mul = ops.Mul()
add = ops.Add()
reshape = ops.Reshape()
fake = mul(fake, 0.5*255)
fake = add(fake, 0.5*255)
for i in range(args_opt.nimages):
img_pil = reshape(fake[i, ...], (1, nc, imageSize, imageSize))
img_pil = img_pil.asnumpy()[0].astype(np.uint8).transpose((1, 2, 0))
img_pil = Image.fromarray(img_pil)
img_pil.save(os.path.join(args_opt.output_dir, "generated_%02d.png" % i))
print("Generate images success!")

View File

@ -0,0 +1,55 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx, mindir models#################
python export.py
"""
import json
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
from src.args import get_args
if __name__ == '__main__':
args_opt = get_args('export')
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id)
with open(args_opt.config, 'r') as gencfg:
generator_config = json.loads(gencfg.read())
imageSize = generator_config["imageSize"]
nz = generator_config["nz"]
nc = generator_config["nc"]
ngf = generator_config["ngf"]
noBN = generator_config["noBN"]
n_extra_layers = generator_config["n_extra_layers"]
# generator
if noBN:
netG = DcgannobnG(imageSize, nz, nc, ngf, n_extra_layers)
else:
netG = DcganG(imageSize, nz, nc, ngf, n_extra_layers)
# load weights
load_param_into_net(netG, load_checkpoint(args_opt.ckpt_file))
# initialize noise
fixed_noise = Tensor(np.random.normal(size=[args_opt.nimages, nz, 1, 1]), dtype=mstype.float32)
export(netG, fixed_noise, file_name=args_opt.file_name, file_format=args_opt.file_format)

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 531 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 536 KiB

View File

@ -0,0 +1,3 @@
mindspore
numpy
Pillow

View File

@ -0,0 +1,49 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash bash run_eval.sh device_id config ckpt_file output_dir nimages"
echo "For example: bash run_eval.sh DEVICE_ID CONFIG_PATH CKPT_FILE_PATH OUTPUT_DIR NIMAGES"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
EXEC_PATH=$(pwd)
echo "$EXEC_PATH"
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
cd ../
rm -rf eval
mkdir eval
cd ./eval
mkdir src
cd ../
cp ./*.py ./eval
cp ./src/*.py ./eval/src
cd ./eval
env > env0.log
echo "train begin."
python eval.py --device_id $1 --config $2 --ckpt_file $3 --output_dir $4 --nimages $5 > ./eval.log 2>&1 &
if [ $? -eq 0 ];then
echo "eval success"
else
echo "eval failed"
exit 2
fi
echo "finish"
cd ../

View File

@ -0,0 +1,49 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash bash run_train.sh dataset dataroot device_id noBN"
echo "For example: bash run_train.sh lsun /opt_data/xidian_wks/lsun 3 False"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
EXEC_PATH=$(pwd)
echo "$EXEC_PATH"
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
cd ../
rm -rf train
mkdir train
cd ./train
mkdir src
cd ../
cp ./*.py ./train
cp ./src/*.py ./train/src
cd ./train
env > env0.log
echo "train begin."
python train.py --dataset $1 --dataroot $2 --device_id $3 --noBN $4 > ./train.log 2>&1 &
if [ $? -eq 0 ];then
echo "training success"
else
echo "training failed"
exit 2
fi
echo "finish"
cd ../

View File

@ -0,0 +1,69 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""get args"""
import ast
import argparse
def get_args(phase):
"""Define the common options that are used in training."""
parser = argparse.ArgumentParser(description='WGAN')
parser.add_argument('--device_target', default='Ascend', help='enables npu')
parser.add_argument('--device_id', type=int, default=0)
if phase == 'train':
parser.add_argument('--dataset', default='lsun', help='cifar10 | lsun')
parser.add_argument('--dataroot', default=None, help='path to dataset')
parser.add_argument('--is_modelarts', type=ast.literal_eval, default=False, help='train in Modelarts or not')
parser.add_argument('--data_url', default=None, help='Location of data.')
parser.add_argument('--train_url', default=None, help='Location of training outputs.')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=8)
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
parser.add_argument('--imageSize', type=int, default=64, help='the height/width of the input image to network')
parser.add_argument('--nc', type=int, default=3, help='input image channels')
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--lrD', type=float, default=0.00005, help='learning rate for Critic, default=0.00005')
parser.add_argument('--lrG', type=float, default=0.00005, help='learning rate for Generator, default=0.00005')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--clamp_lower', type=float, default=-0.01)
parser.add_argument('--clamp_upper', type=float, default=0.01)
parser.add_argument('--Diters', type=int, default=5, help='number of D iters per each G iter')
parser.add_argument('--noBN', type=ast.literal_eval, default=False, help='use batchnorm or not (for DCGAN)')
parser.add_argument('--n_extra_layers', type=int, default=0, help='Number of extra layers on gen and disc')
parser.add_argument('--experiment', default=None, help='Where to store samples and models')
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)')
elif phase == 'export':
parser.add_argument('--config', required=True, type=str, help='path to generator config .json file')
parser.add_argument('--ckpt_file', type=str, required=True, help="Checkpoint file path.")
parser.add_argument('--file_name', type=str, default="WGAN", help="output file name prefix.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], \
default='AIR', help='file format')
parser.add_argument('--nimages', required=True, type=int, help="number of images to generate", default=1)
elif phase == 'eval':
parser.add_argument('--config', required=True, type=str, help='path to generator config .json file')
parser.add_argument('--ckpt_file', required=True, type=str, help='path to generator weights .ckpt file')
parser.add_argument('--output_dir', required=True, type=str, help="path to to output directory")
parser.add_argument('--nimages', required=True, type=int, help="number of images to generate", default=1)
args_opt = parser.parse_args()
return args_opt

View File

@ -0,0 +1,170 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Train one step """
import mindspore.nn as nn
import mindspore.ops.composite as C
import mindspore.ops.operations as P
import mindspore.ops.functional as F
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
_get_parallel_mode)
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
class GenWithLossCell(nn.Cell):
"""Generator with loss(wrapped)"""
def __init__(self, netG, netD):
super(GenWithLossCell, self).__init__()
self.netG = netG
self.netD = netD
def construct(self, noise):
"""construct"""
fake = self.netG(noise)
errG = self.netD(fake)
loss_G = errG
return loss_G
class DisWithLossCell(nn.Cell):
""" Discriminator with loss(wrapped) """
def __init__(self, netG, netD):
super(DisWithLossCell, self).__init__()
self.netG = netG
self.netD = netD
def construct(self, real, noise):
"""construct"""
errD_real = self.netD(real)
fake = self.netG(noise)
errD_fake = self.netD(fake)
loss_D = errD_real - errD_fake
return loss_D
class ClipParameter(nn.Cell):
""" Clip the parameter """
def __init__(self):
super(ClipParameter, self).__init__()
self.cast = P.Cast()
self.dtype = P.DType()
def construct(self, params, clip_lower, clip_upper):
"""construct"""
new_params = ()
for param in params:
dt = self.dtype(param)
t = C.clip_by_value(param, self.cast(F.tuple_to_array((clip_lower,)), dt),
self.cast(F.tuple_to_array((clip_upper,)), dt))
new_params = new_params + (t,)
return new_params
class GenTrainOneStepCell(nn.Cell):
""" Generator TrainOneStepCell """
def __init__(self, netG, netD,
optimizerG: nn.Optimizer,
sens=1.0):
super(GenTrainOneStepCell, self).__init__()
self.netD = netD
self.netD.set_train(False)
self.netD.set_grad(False)
self.weights_G = optimizerG.parameters
self.optimizerG = optimizerG
self.net = GenWithLossCell(netG, netD)
self.net.set_train()
self.net.set_grad()
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
# parallel process
self.reducer_flag = False
self.grad_reducer_G = F.identity
self.parallel_mode = _get_parallel_mode()
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer_G = DistributedGradReducer(self.weights_G, mean, degree) # A distributed optimizer.
def construct(self, noise):
""" construct """
loss_G = self.net(noise)
sens = P.Fill()(P.DType()(loss_G), P.Shape()(loss_G), self.sens)
grads = self.grad(self.net, self.weights_G)(noise, sens)
if self.reducer_flag:
grads = self.grad_reducer_G(grads)
return F.depend(loss_G, self.optimizerG(grads))
_my_adam_opt = C.MultitypeFuncGraph("_my_adam_opt")
@_my_adam_opt.register("Tensor", "Tensor")
def _update_run_op(param, param_clipped):
param_clipped = F.depend(param_clipped, F.assign(param, param_clipped))
return param_clipped
class DisTrainOneStepCell(nn.Cell):
""" Discriminator TrainOneStepCell """
def __init__(self, netG, netD,
optimizerD: nn.Optimizer,
clip_lower=-0.01, clip_upper=0.01, sens=1.0):
super(DisTrainOneStepCell, self).__init__()
self.weights_D = optimizerD.parameters
self.clip_parameters = ClipParameter()
self.optimizerD = optimizerD
self.net = DisWithLossCell(netG, netD)
self.net.set_train()
self.net.set_grad()
self.reduce_flag = False
self.op_cast = P.Cast()
self.hyper_map = C.HyperMap()
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.clip_lower = clip_lower
self.clip_upper = clip_upper
# parallel process
self.reducer_flag = False
self.grad_reducer_D = F.identity
self.parallel_mode = _get_parallel_mode()
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer_D = DistributedGradReducer(self.weights_D, mean, degree) # A distributed optimizer.
def construct(self, real, noise):
""" construct """
loss_D = self.net(real, noise)
sens = P.Fill()(P.DType()(loss_D), P.Shape()(loss_D), self.sens)
grads = self.grad(self.net, self.weights_D)(real, noise, sens)
if self.reducer_flag:
grads = self.grad_reducer_D(grads)
upd = self.optimizerD(grads)
weights_D_cliped = self.clip_parameters(self.weights_D, self.clip_lower, self.clip_upper)
res = self.hyper_map(F.partial(_my_adam_opt), self.weights_D, weights_D_cliped)
res = F.depend(upd, res)
return F.depend(loss_D, res)

View File

@ -0,0 +1,59 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" dataset """
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.common.dtype as mstype
def create_dataset(dataroot, dataset, batchSize, imageSize, repeat_num=1, workers=8, target='Ascend'):
"""Create dataset"""
rank_id = 0
device_num = 1
# define map operations
resize_op = c.Resize(imageSize)
center_crop_op = c.CenterCrop(imageSize)
normalize_op = c.Normalize(mean=(0.5*255, 0.5*255, 0.5*255), std=(0.5*255, 0.5*255, 0.5*255))
hwc2chw_op = c.HWC2CHW()
if dataset == 'lsun':
if device_num == 1:
data_set = ds.ImageFolderDataset(dataroot, num_parallel_workers=workers, shuffle=True, decode=True)
else:
data_set = ds.ImageFolderDataset(dataroot, num_parallel_workers=workers, shuffle=True, decode=True,
num_shards=device_num, shard_id=rank_id)
transform = [resize_op, center_crop_op, normalize_op, hwc2chw_op]
else:
if device_num == 1:
data_set = ds.Cifar10Dataset(dataroot, num_parallel_workers=workers, shuffle=True)
else:
data_set = ds.Cifar10Dataset(dataroot, num_parallel_workers=workers, shuffle=True, \
num_shards=device_num, shard_id=rank_id)
transform = [resize_op, normalize_op, hwc2chw_op]
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(input_columns='image', operations=transform, num_parallel_workers=workers)
data_set = data_set.map(input_columns='label', operations=type_cast_op, num_parallel_workers=workers)
data_set = data_set.batch(batchSize, drop_remainder=True)
data_set = data_set.repeat(repeat_num)
return data_set

View File

@ -0,0 +1,98 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Dcgan model """
import mindspore.nn as nn
class DcganD(nn.Cell):
""" dcgan Decriminator """
def __init__(self, isize, nz, nc, ndf, n_extra_layers=0):
super(DcganD, self).__init__()
assert isize % 16 == 0, "isize has to be a multiple of 16"
main = nn.SequentialCell()
main.append(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, has_bias=False))
main.append(nn.LeakyReLU(0.2))
csize, cndf = isize / 2, ndf
# Extra layers
for _ in range(n_extra_layers):
main.append(nn.Conv2d(cndf, cndf, 3, 1, 'pad', 1, has_bias=False))
main.append(nn.BatchNorm2d(cndf))
main.append(nn.LeakyReLU(0.2))
while csize > 4:
in_feat = cndf
out_feat = cndf * 2
main.append(nn.Conv2d(in_feat, out_feat, 4, 2, 'pad', 1, has_bias=False))
main.append(nn.BatchNorm2d(out_feat))
main.append(nn.LeakyReLU(0.2))
cndf = cndf * 2
csize = csize / 2
# state size. K x 4 x 4
main.append(nn.Conv2d(cndf, 1, 4, 1, 'pad', 0, has_bias=False))
self.main = main
def construct(self, input1):
"""construct"""
output = self.main(input1)
output = output.mean(0)
return output.view(1)
class DcganG(nn.Cell):
""" dcgan generator """
def __init__(self, isize, nz, nc, ngf, n_extra_layers=0):
super(DcganG, self).__init__()
assert isize % 16 == 0, "isize has to be a multiple of 16"
cngf, tisize = ngf // 2, 4
while tisize != isize:
cngf = cngf * 2
tisize = tisize * 2
main = nn.SequentialCell()
# input is Z, going into a convolution
main.append(nn.Conv2dTranspose(nz, cngf, 4, 1, 'pad', 0, has_bias=False))
main.append(nn.BatchNorm2d(cngf))
main.append(nn.ReLU())
csize = 4
while csize < isize // 2:
main.append(nn.Conv2dTranspose(cngf, cngf // 2, 4, 2, 'pad', 1, has_bias=False))
main.append(nn.BatchNorm2d(cngf // 2))
main.append(nn.ReLU())
cngf = cngf // 2
csize = csize * 2
# Extra layers
for _ in range(n_extra_layers):
main.append(nn.Conv2d(cngf, cngf, 3, 1, 'pad', 1, has_bias=False))
main.append(nn.BatchNorm2d(cngf))
main.append(nn.ReLU())
main.append(nn.Conv2dTranspose(cngf, nc, 4, 2, 'pad', 1, has_bias=False))
main.append(nn.Tanh())
self.main = main
def construct(self, input1):
"""construct"""
output = self.main(input1)
return output

View File

@ -0,0 +1,92 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" DCGAN with on Batchnorm model """
import mindspore.nn as nn
class DcgannobnD(nn.Cell):
""" DCGAN Descriminator with no Batchnorm layer """
def __init__(self, isize, nz, nc, ndf, n_extra_layers=0):
super(DcgannobnD, self).__init__()
assert isize % 16 == 0, "isize has to be a multiple of 16"
main = nn.SequentialCell()
# input is nc x isize x isize
main.append(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, has_bias=False))
main.append(nn.LeakyReLU(0.2))
csize, cndf = isize / 2, ndf
# Extra layers
for _ in range(n_extra_layers):
main.append(nn.Conv2d(cndf, cndf, 3, 1, 'pad', 1, has_bias=False))
main.append(nn.LeakyReLU(0.2))
while csize > 4:
in_feat = cndf
out_feat = cndf * 2
main.append(nn.Conv2d(in_feat, out_feat, 4, 2, 'pad', 1, has_bias=False))
main.append(nn.LeakyReLU(0.2))
cndf = cndf * 2
csize = csize / 2
# state size. K x 4 x 4
main.append(nn.Conv2d(cndf, 1, 4, 1, 'pad', 0, has_bias=False))
self.main = main
def construct(self, input1):
"""construct"""
output = self.main(input1)
output = output.mean(0)
return output.view(1)
class DcgannobnG(nn.Cell):
""" DCGAN Generator with no BatchNorm layer """
def __init__(self, isize, nz, nc, ngf, n_extra_layers=0):
super(DcgannobnG, self).__init__()
assert isize % 16 == 0, "isize has to be a multiple of 16"
cngf, tisize = ngf // 2, 4
while tisize != isize:
cngf = cngf * 2
tisize = tisize * 2
main = nn.SequentialCell()
main.append(nn.Conv2dTranspose(nz, cngf, 4, 1, 'pad', 0, has_bias=False))
main.append(nn.ReLU())
csize = 4
while csize < isize // 2:
main.append(nn.Conv2dTranspose(cngf, cngf // 2, 4, 2, 'pad', 1, has_bias=False))
main.append(nn.ReLU())
cngf = cngf // 2
csize = csize * 2
# Extra layers
for _ in range(n_extra_layers):
main.append(nn.Conv2d(cngf, cngf, 3, 1, 'pad', 1, has_bias=False))
main.append(nn.ReLU())
main.append(nn.Conv2dTranspose(cngf, nc, 4, 2, 'pad', 1, has_bias=False))
main.append(nn.Tanh())
self.main = main
def construct(self, input1):
"""construct"""
output = self.main(input1)
return output

View File

@ -0,0 +1,232 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train WGAN"""
import os
import random
import json
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.dataset as ds
import mindspore.ops as ops
from mindspore.common import initializer as init
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, save_checkpoint
from PIL import Image
import numpy as np
from src.dataset import create_dataset
from src.dcgan_model import DcganG, DcganD
from src.dcgannobn_model import DcgannobnG
from src.cell import GenTrainOneStepCell, DisTrainOneStepCell
from src.args import get_args
if __name__ == '__main__':
args_opt = get_args('train')
print(args_opt)
# init context
target = args_opt.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=target)
# whether train on modelarts or local server
if not args_opt.is_modelarts:
if args_opt.experiment is None:
args_opt.experiment = 'samples'
os.system('mkdir {0}'.format(args_opt.experiment))
context.set_context(device_id=int(args_opt.device_id))
dataset = create_dataset(args_opt.dataroot, args_opt.dataset, args_opt.batchSize, args_opt.imageSize, 1,
args_opt.workers, target)
else:
import moxing as mox
if args_opt.experiment is None:
args_opt.experiment = '/cache/train_output'
os.system('mkdir {0}'.format(args_opt.experiment))
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
data_name = 'LSUN-bedroom.zip'
local_data_url = '/cache/data_path/'
mox.file.copy_parallel(src_url=args_opt.data_url, dst_url=local_data_url)
zip_command = "unzip -o -q %s -d %s" % (local_data_url + data_name, local_data_url)
os.system(zip_command)
print("Unzip success!")
dataset = create_dataset(local_data_url, args_opt.dataset, args_opt.batchSize, args_opt.imageSize, 1,
args_opt.workers, target)
# fix seed
args_opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", args_opt.manualSeed)
random.seed(args_opt.manualSeed)
ds.config.set_seed(args_opt.manualSeed)
# initialize hyperparameters
nz = int(args_opt.nz)
ngf = int(args_opt.ngf)
ndf = int(args_opt.ndf)
nc = int(args_opt.nc)
n_extra_layers = int(args_opt.n_extra_layers)
# write out generator config to generate images together wth training checkpoints
generator_config = {"imageSize": args_opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf,
"n_extra_layers": n_extra_layers, "noBN": args_opt.noBN}
with open(os.path.join(args_opt.experiment, "generator_config.json"), 'w') as gcfg:
gcfg.write(json.dumps(generator_config) + "\n")
def init_weight(net):
"""initial net weight"""
for _, cell in net.cells_and_names():
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
cell.weight.set_data(init.initializer(init.Normal(0.02), cell.weight.shape))
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(init.initializer(Tensor(np.random.normal(1, 0.02, cell.gamma.shape), \
mstype.float32), cell.gamma.shape))
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
def save_image(img, img_path):
"""save image"""
mul = ops.Mul()
add = ops.Add()
if isinstance(img, Tensor):
img = mul(img, 255 * 0.5)
img = add(img, 255 * 0.5)
img = img.asnumpy().astype(np.uint8).transpose((0, 2, 3, 1))
elif not isinstance(img, np.ndarray):
raise ValueError("img should be Tensor or numpy array, but get {}".format(type(img)))
IMAGE_SIZE = 64 # Image size
IMAGE_ROW = 8 # Row num
IMAGE_COLUMN = 8 # Column num
PADDING = 2 # Interval of small pictures
to_image = Image.new('RGB', (IMAGE_COLUMN * IMAGE_SIZE + PADDING * (IMAGE_COLUMN + 1),
IMAGE_ROW * IMAGE_SIZE + PADDING * (IMAGE_ROW + 1))) # create a new picture
# cycle
ii = 0
for y in range(1, IMAGE_ROW + 1):
for x in range(1, IMAGE_COLUMN + 1):
from_image = Image.fromarray(img[ii])
to_image.paste(from_image, ((x - 1) * IMAGE_SIZE + PADDING * x, (y - 1) * IMAGE_SIZE + PADDING * y))
ii = ii + 1
to_image.save(img_path) # save
# define net----------------------------------------------------------------------------------------------
# Generator
if args_opt.noBN:
netG = DcgannobnG(args_opt.imageSize, nz, nc, ngf, n_extra_layers)
else:
netG = DcganG(args_opt.imageSize, nz, nc, ngf, n_extra_layers)
# write out generator config to generate images together wth training checkpoints
generator_config = {"imageSize": args_opt.imageSize, "nz": nz, "nc": nc, "ngf": ngf,
"n_extra_layers": n_extra_layers, "noBN": args_opt.noBN}
with open(os.path.join(args_opt.experiment, "generator_config.json"), 'w') as gcfg:
gcfg.write(json.dumps(generator_config) + "\n")
init_weight(netG)
if args_opt.netG != '': # load checkpoint if needed
load_param_into_net(netG, load_checkpoint(args_opt.netG))
print(netG)
netD = DcganD(args_opt.imageSize, nz, nc, ndf, n_extra_layers)
init_weight(netD)
if args_opt.netD != '':
load_param_into_net(netD, load_checkpoint(args_opt.netD))
print(netD)
input1 = Tensor(np.zeros([args_opt.batchSize, 3, args_opt.imageSize, args_opt.imageSize]), dtype=mstype.float32)
noise = Tensor(np.zeros([args_opt.batchSize, nz, 1, 1]), dtype=mstype.float32)
fixed_noise = Tensor(np.random.normal(0, 1, size=[args_opt.batchSize, nz, 1, 1]), dtype=mstype.float32)
# setup optimizer
if args_opt.adam:
optimizerD = nn.Adam(netD.trainable_params(), learning_rate=args_opt.lrD, beta1=args_opt.beta1, beta2=.999)
optimizerG = nn.Adam(netG.trainable_params(), learning_rate=args_opt.lrG, beta1=args_opt.beta1, beta2=.999)
else:
optimizerD = nn.RMSProp(netD.trainable_params(), learning_rate=args_opt.lrD, decay=0.99)
optimizerG = nn.RMSProp(netG.trainable_params(), learning_rate=args_opt.lrG, decay=0.99)
netG_train = GenTrainOneStepCell(netG, netD, optimizerG)
netD_train = DisTrainOneStepCell(netG, netD, optimizerD, args_opt.clamp_lower, args_opt.clamp_upper)
netG_train.set_train()
netD_train.set_train()
gen_iterations = 0
# Train
for epoch in range(args_opt.niter): # niter: the num of epoch
data_iter = dataset.create_dict_iterator()
length = dataset.get_dataset_size()
i = 0
while i < length:
############################
# (1) Update D network
###########################
for p in netD.trainable_params(): # reset requires_grad
p.requires_grad = True # they are set to False below in netG update
# train the discriminator Diters times
if gen_iterations < 25 or gen_iterations % 500 == 0:
Diters = 100
else:
Diters = args_opt.Diters
j = 0
while j < Diters and i < length:
j += 1
data = data_iter.__next__()
i += 1
# train with real and fake
real = data['image']
noise = Tensor(np.random.normal(0, 1, size=[args_opt.batchSize, nz, 1, 1]), dtype=mstype.float32)
loss_D = netD_train(real, noise)
############################
# (2) Update G network
###########################
for p in netD.trainable_params():
p.requires_grad = False # to avoid computation
noise = Tensor(np.random.normal(0, 1, size=[args_opt.batchSize, nz, 1, 1]), dtype=mstype.float32)
loss_G = netG_train(noise)
gen_iterations += 1
print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f'
% (epoch, args_opt.niter, i, length, gen_iterations,
loss_D.asnumpy(), loss_G.asnumpy()))
if gen_iterations % 500 == 0:
fake = netG(fixed_noise)
save_image(real, '{0}/real_samples.png'.format(args_opt.experiment))
save_image(fake, '{0}/fake_samples_{1}.png'.format(args_opt.experiment, gen_iterations))
save_checkpoint(netD, '{0}/netD_epoch_{1}.ckpt'.format(args_opt.experiment, epoch))
save_checkpoint(netG, '{0}/netG_epoch_{1}.ckpt'.format(args_opt.experiment, epoch))
if args_opt.is_modelarts:
mox.file.copy_parallel(src_url='/cache/train_output', dst_url=args_opt.train_url)
print("Train success!")