!15406 dcgan master pr

Merge pull request !15406 from 周洪飞/master
This commit is contained in:
i-robot 2021-08-12 13:22:36 +00:00 committed by Gitee
commit 26ecdc1fc4
14 changed files with 1246 additions and 0 deletions

View File

@ -0,0 +1,201 @@
# Contents
- [DCGAN Description](#DCGAN-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [DCGAN Description](#contents)
The deep convolutional generative adversarial networks (DCGANs) first introduced CNN into the GAN structure, and the strong feature extraction ability of convolution layer was used to improve the generation effect of GAN.
[Paper](https://arxiv.org/pdf/1511.06434.pdf): Radford A, Metz L, Chintala S. Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks[J]. Computer ence, 2015.
# [Model Architecture](#contents)
Architecture guidelines for stable Deep Convolutional GANs
- Replace any pooling layers with strided convolutions (discriminator) and fractional-strided convolutions (generator).
- Use batchnorm in both the generator and the discriminator.
- Remove fully connected hidden layers for deeper architectures.
- Use ReLU activation in generator for all layers except for the output, which uses Tanh.
- Use LeakyReLU activation in the discriminator for all layers.
# [Dataset](#contents)
Train DCGAN Dataset used: [Imagenet-1k](<http://www.image-net.org/index>)
- Dataset size: ~125G, 1.2W colorful images in 1000 classes
- Train: 120G, 1.2W images
- Test: 5G, 50000 images
- Data format: RGB images.
- Note: Data will be processed in src/dataset.py
```path
└─imagenet_original
└─train
```
# [Environment Requirements](#contents)
- Hardware Ascend
- Prepare hardware environment with Ascend processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```shell
.
└─dcgan
├─README.md # README
├─scripts # shell script
├─run_standalone_train.sh # training in standalone mode(1pcs)
├─run_distribute_train.sh # training in parallel mode(8 pcs)
└─run_eval.sh # evaluation
├─ src
├─dataset.py // dataset create
├─cell.py // network definition
├─dcgan.py // dcgan structure
├─discriminator.py // discriminator structure
├─generator.py // generator structure
├─config.py // config
├─ train.py // train dcgan
├─ eval.py // eval dcgan
```
## [Script Parameters](#contents)
### [Training Script Parameters](#contents)
```shell
# distributed training
Usage: bash run_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [SAVE_PATH]
# standalone training
Usage: bash run_standalone_train.sh [DATASET_PATH] [SAVE_PATH]
```
### [Parameters Configuration](#contents)
```txt
"img_width": 32, # width of the input images
"img_height": 32, # height of the input images
'num_classes': 1000,
'epoch_size': 20,
'batch_size': 128,
'latent_size': 100,
'feature_size': 64,
'channel_size': 3,
'image_height': 32,
'image_width': 32,
'learning_rate': 0.0002,
'beta1': 0.5
```
## [Training Process](#contents)
- Set options in `config.py`, including learning rate, output filename and network hyperparameters. Click [here](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
### [Training](#content)
- Run `run_standalone_train.sh` for non-distributed training of DCGAN model.
```bash
# standalone training
run_standalone_train.sh [DATASET_PATH] [SAVE_PATH]
```
### [Distributed Training](#content)
- Run `run_distribute_train.sh` for distributed training of DCGAN model.
```bash
run_distribute.sh [RANK_TABLE_FILE] [DATASET_PATH] [SAVE_PATH]
```
- Notes
1. hccl.json which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it by using the [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
### [Training Result](#content)
Training result will be stored in save_path. You can find checkpoint file.
```bash
# standalone training result(1p)
Date time: 2021-04-13 13:55:39 epoch: 0 / 20 step: 0 / 10010 Dloss: 2.2297878 Gloss: 1.1530013
Date time: 2021-04-13 13:56:01 epoch: 0 / 20 step: 50 / 10010 Dloss: 0.21959287 Gloss: 20.064941
Date time: 2021-04-13 13:56:22 epoch: 0 / 20 step: 100 / 10010 Dloss: 0.18872623 Gloss: 5.872738
Date time: 2021-04-13 13:56:44 epoch: 0 / 20 step: 150 / 10010 Dloss: 0.53905165 Gloss: 4.477289
Date time: 2021-04-13 13:57:07 epoch: 0 / 20 step: 200 / 10010 Dloss: 0.47870708 Gloss: 2.2019134
Date time: 2021-04-13 13:57:28 epoch: 0 / 20 step: 250 / 10010 Dloss: 0.3929835 Gloss: 1.8170083
```
## [Evaluation Process](#contents)
### [Evaluation](#content)
- Run `run_eval.sh` for evaluation.
```bash
# infer
sh run_eval.sh [IMG_URL] [CKPT_URL]
```
### [Evaluation result](#content)
Evaluation result will be stored in the img_url path. Under this, you can find generator result in generate.png.
## Model Export
```shell
python export.py --ckpt_file [CKPT_PATH] --device_target [DEVICE_TARGET] --file_format[EXPORT_FORMAT]
```
`EXPORT_FORMAT` should be "MINDIR"
# Model Description
## Performance
### Evaluation Performance
| Parameters | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | V1 |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
| uploaded Date | 16/04/2021 (month/day/year) |
| MindSpore Version | 1.1.1 |
| Dataset | ImageNet2012 |
| Training Parameters | epoch=20, batch_size = 128 |
| Optimizer | Adam |
| Loss Function | BCELoss |
| Output | predict class |
| Loss | 10.9852 |
| Speed | 1pc: 420 ms/step; 8pcs: 143 ms/step |
| Total time | 1pc: 24.32 hours |
| Checkpoint for Fine tuning | 79.05M(.ckpt file) |
| Scripts | [dcgan script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/dcgan) |
# [Description of Random Situation](#contents)
We use random seed in train.py and cell.py for weight initialization.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,81 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dcgan eval"""
import argparse
import numpy as np
from mindspore import context, Tensor, nn, load_checkpoint
from src.config import dcgan_imagenet_cfg as cfg
from src.generator import Generator
from src.discriminator import Discriminator
from src.cell import WithLossCellD, WithLossCellG
from src.dcgan import DCGAN
def save_imgs(gen_imgs, img_url):
"""save_imgs function"""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
for i in range(gen_imgs.shape[0]):
plt.subplot(4, 4, i + 1)
gen_imgs[i] = gen_imgs[i] * 127.5 + 127.5
perm = (1, 2, 0)
show_imgs = np.transpose(gen_imgs[i], perm)
sdf = show_imgs.astype(int)
plt.imshow(sdf)
plt.axis("off")
plt.savefig(img_url + "/generate.png")
def load_dcgan(ckpt_url):
"""load_dcgan function"""
netD = Discriminator()
netG = Generator()
criterion = nn.BCELoss(reduction='mean')
netD_with_criterion = WithLossCellD(netD, netG, criterion)
netG_with_criterion = WithLossCellG(netD, netG, criterion)
optimizerD = nn.Adam(netD.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
optimizerG = nn.Adam(netG.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
myTrainOneStepCellForD = nn.TrainOneStepCell(netD_with_criterion, optimizerD)
myTrainOneStepCellForG = nn.TrainOneStepCell(netG_with_criterion, optimizerG)
dcgan = DCGAN(myTrainOneStepCellForD, myTrainOneStepCellForG)
load_checkpoint(ckpt_url, dcgan)
netG_trained = dcgan.myTrainOneStepCellForG.network.netG
return netG_trained
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='MindSpore dcgan training')
parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend (Default: 0)')
parser.add_argument('--img_url', type=str, default=None, help='img save path')
parser.add_argument('--ckpt_url', type=str, default=None, help='checkpoint load path')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=args.device_id)
fixed_noise = Tensor(np.random.normal(size=(16, cfg.latent_size, 1, 1)).astype("float32"))
net_G = load_dcgan(args.ckpt_url)
fake = net_G(fixed_noise)
print("================saving images================")
save_imgs(fake.asnumpy(), args.img_url)
print("================success================")

View File

@ -0,0 +1,72 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export checkpoint file into air, onnx, mindir models"""
import argparse
import ast
import os
import numpy as np
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
import mindspore.common.dtype as mstype
from mindspore import nn
from src.cell import WithLossCellD, WithLossCellG
from src.dcgan import DCGAN
from src.discriminator import Discriminator
from src.generator import Generator
from src.config import dcgan_imagenet_cfg as cfg
parser = argparse.ArgumentParser(description='ntsnet export')
parser.add_argument("--run_modelart", type=ast.literal_eval, default=False, help="Run on modelArt, default is false.")
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=128, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file name.")
parser.add_argument('--data_url', default=None, help='Directory contains CUB_200_2011 dataset.')
parser.add_argument('--train_url', default=None, help='Directory contains checkpoint file')
parser.add_argument("--file_name", type=str, default="ntsnet", help="output file name.")
parser.add_argument("--file_format", type=str, default="MINDIR", help="file format")
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'], help='device target (default: Ascend)')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == '__main__':
netD = Discriminator()
netG = Generator()
criterion = nn.BCELoss(reduction='mean')
netD_with_criterion = WithLossCellD(netD, netG, criterion)
netG_with_criterion = WithLossCellG(netD, netG, criterion)
optimizerD = nn.Adam(netD.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
optimizerG = nn.Adam(netG.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
myTrainOneStepCellForD = nn.TrainOneStepCell(netD_with_criterion, optimizerD)
myTrainOneStepCellForG = nn.TrainOneStepCell(netG_with_criterion, optimizerG)
net = DCGAN(myTrainOneStepCellForD, myTrainOneStepCellForG)
param_dict = load_checkpoint(os.path.join(args.train_url, args.ckpt_file))
load_param_into_net(net, param_dict)
net.set_train(False)
# inputs = Tensor(np.random.rand(args.batch_size, 3, 448, 448), mstype.float32)
real_data = Tensor(np.random.rand(args.batch_size, 3, 32, 32), mstype.float32)
latent_code = Tensor(np.random.rand(args.batch_size, 100, 1, 1), mstype.float32)
inputs = [real_data, latent_code]
export(net, *inputs, file_name=args.file_name, file_format=args.file_format)

View File

@ -0,0 +1,44 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from mindspore import nn
from src.cell import WithLossCellD, WithLossCellG
from src.dcgan import DCGAN
from src.discriminator import Discriminator
from src.generator import Generator
from src.config import dcgan_imagenet_cfg as cfg
def create_network(name):
"""create_network function"""
if name == "dcgan":
netD = Discriminator()
netG = Generator()
criterion = nn.BCELoss(reduction='mean')
netD_with_criterion = WithLossCellD(netD, netG, criterion)
netG_with_criterion = WithLossCellG(netD, netG, criterion)
optimizerD = nn.Adam(netD.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
optimizerG = nn.Adam(netG.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
myTrainOneStepCellForD = nn.TrainOneStepCell(netD_with_criterion, optimizerD)
myTrainOneStepCellForG = nn.TrainOneStepCell(netG_with_criterion, optimizerG)
dcgan = DCGAN(myTrainOneStepCellForD, myTrainOneStepCellForG)
return dcgan
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,90 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [SAVE_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
PATH3=$(get_real_path $3)
echo $PATH1
echo $PATH2
echo $PATH3
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -d $PATH2 ]
then
echo "error: DATASET_PATH=$PATH2 is not a directory"
exit 1
fi
if [ ! -d $PATH3 ]
then
echo "error: SAVE_PATH=$PATH3 is not a directory"
exit 1
fi
ulimit -u unlimited
export HCCL_CONNECT_TIMEOUT=600
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
echo 3 > /proc/sys/vm/drop_caches
cpus=`cat /proc/cpuinfo| grep "processor"| wc -l`
avg=`expr $cpus \/ $DEVICE_NUM`
gap=`expr $avg \- 1`
for((i=0; i<${DEVICE_NUM}; i++))
do
start=`expr $i \* $avg`
end=`expr $start \+ $gap`
cmdopt=$start"-"$end
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
taskset -c $cmdopt python -u train.py --device_id=$i --run_distribute=True \
--dataset_path=$PATH2 --save_path=$PATH3 &> log &
cd ..
done

View File

@ -0,0 +1,65 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_eval.sh [IMG_URL] [CKPT_URL]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -d $PATH1 ]
then
echo "error: IMG_URL=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CKPT_URL=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation for device $DEVICE_ID"
python -u eval.py --device_id=$DEVICE_ID --img_url=$PATH1 --ckpt_url=$PATH2 &> log &
cd ..

View File

@ -0,0 +1,65 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: bash run_standalone_train.sh [DATASET_PATH] [SAVE_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
PATH2=$(get_real_path $2)
echo $PATH2
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -d $PATH2 ]
then
echo "error: SAVE_PATH=$PATH2 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python -u train.py --device_id=$DEVICE_ID --dataset_path=$PATH1 --save_path=$PATH2 &> log &
cd ..

View File

@ -0,0 +1,243 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dcgan cell"""
import os
import threading
import time
import numpy as np
from mindspore import nn, ops, context
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore.common.initializer import Initializer, _assignment
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
from mindspore.train._utils import _make_directory
from mindspore.train.serialization import _save_graph, save_checkpoint
from mindspore.train.callback import Callback
from mindspore.train.callback._callback import set_cur_net
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _cur_dir, CheckpointConfig, CheckpointManager, \
_chg_ckpt_file_name_if_same_exist
class Reshape(nn.Cell):
def __init__(self, shape, auto_prefix=True):
super().__init__(auto_prefix=auto_prefix)
self.shape = shape
def construct(self, x):
return ops.operations.Reshape()(x, self.shape)
class Normal(Initializer):
def __init__(self, mean=0.0, sigma=0.01):
super(Normal, self).__init__()
self.sigma = sigma
self.mean = mean
def _initialize(self, arr):
np.random.seed(999)
arr_normal = np.random.normal(self.mean, self.sigma, arr.shape)
_assignment(arr, arr_normal)
class ModelCheckpoint(Callback):
"""
The checkpoint callback class.
It is called to combine with train process and save the model and network parameters after traning.
Args:
prefix (str): The prefix name of checkpoint files. Default: "CKP".
directory (str): The path of the folder which will be saved in the checkpoint file. Default: None.
config (CheckpointConfig): Checkpoint strategy configuration. Default: None.
Raises:
ValueError: If the prefix is invalid.
TypeError: If the config is not CheckpointConfig type.
"""
def __init__(self, prefix='CKP', directory=None, config=None):
super(ModelCheckpoint, self).__init__()
self._latest_ckpt_file_name = ""
self._init_time = time.time()
self._last_time = time.time()
self._last_time_for_keep = time.time()
self._last_triggered_step = 0
if _check_file_name_prefix(prefix):
self._prefix = prefix
else:
raise ValueError("Prefix {} for checkpoint file name invalid, "
"please check and correct it and then continue.".format(prefix))
if directory is not None:
self._directory = _make_directory(directory)
else:
self._directory = _cur_dir
if config is None:
self._config = CheckpointConfig()
else:
if not isinstance(config, CheckpointConfig):
raise TypeError("config should be CheckpointConfig type.")
self._config = config
# get existing checkpoint files
self._manager = CheckpointManager()
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
self._graph_saved = False
def step_end(self, run_context):
"""
Save the checkpoint at the end of step.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params = run_context.original_args()
# save graph (only once)
if not self._graph_saved:
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
_save_graph(cb_params.train_network, graph_file_name)
self._graph_saved = True
self.save_ckpt(cb_params)
def end(self, run_context):
"""
Save the last checkpoint after training finished.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params = run_context.original_args()
_to_save_last_ckpt = True
self.save_ckpt(cb_params, _to_save_last_ckpt)
thread_list = threading.enumerate()
if len(thread_list) > 1:
for thread in thread_list:
if thread.getName() == "asyn_save_ckpt":
thread.join()
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
destroy_allgather_cell()
def _check_save_ckpt(self, cb_params, force_to_save):
"""Check whether save checkpoint files or not."""
if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
or force_to_save is True:
return True
elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
self._cur_time = time.time()
if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True:
self._last_time = self._cur_time
return True
return False
def save_ckpt(self, cb_params, force_to_save=False):
"""Save checkpoint files."""
if cb_params.cur_step_num == self._last_triggered_step:
return
save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
step_num_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if save_ckpt:
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
+ str(step_num_in_epoch) + ".ckpt"
if _is_role_pserver():
cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file
# update checkpoint file list.
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
# keep checkpoint files number equal max number.
if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
self._manager.remove_oldest_ckpoint_file()
elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
self._cur_time_for_keep = time.time()
if (self._cur_time_for_keep - self._last_time_for_keep) \
< self._config.keep_checkpoint_per_n_minutes * 60:
self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
self._cur_time_for_keep)
# generate the new checkpoint file and rename it.
cur_file = os.path.join(self._directory, cur_ckpoint_file)
self._last_time_for_keep = time.time()
self._last_triggered_step = cb_params.cur_step_num
if context.get_context("enable_ge"):
set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph()
save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
self._config.async_save)
self._latest_ckpt_file_name = cur_file
@property
def latest_ckpt_file_name(self):
"""Return the latest checkpoint path and file name."""
return self._latest_ckpt_file_name
class WithLossCellD(nn.Cell):
"""class WithLossCellD"""
def __init__(self, netD, netG, loss_fn):
super(WithLossCellD, self).__init__(auto_prefix=True)
self.netD = netD
self.netG = netG
self.loss_fn = loss_fn
def construct(self, real_data, latent_code):
"""class WithLossCellD construct"""
ones = ops.Ones()
zeros = ops.Zeros()
out1 = self.netD(real_data)
label1 = ones(out1.shape, mstype.float32)
loss1 = self.loss_fn(out1, label1)
fake_data = self.netG(latent_code)
fake_data = F.stop_gradient(fake_data)
out2 = self.netD(fake_data)
label2 = zeros(out2.shape, mstype.float32)
loss2 = self.loss_fn(out2, label2)
return loss1 + loss2
@property
def backbone_network(self):
"""class WithLossCellD backbone_network"""
return self.netD
class WithLossCellG(nn.Cell):
"""class WithLossCellG"""
def __init__(self, netD, netG, loss_fn):
super(WithLossCellG, self).__init__(auto_prefix=True)
self.netD = netD
self.netG = netG
self.loss_fn = loss_fn
def construct(self, latent_code):
ones = ops.Ones()
fake_data = self.netG(latent_code)
out = self.netD(fake_data)
label = ones(out.shape, mstype.float32)
loss = self.loss_fn(out, label)
return loss
@property
def backbone_network(self):
return self.netG

View File

@ -0,0 +1,32 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py
"""
from easydict import EasyDict as edict
dcgan_imagenet_cfg = edict({
'num_classes': 1000,
'epoch_size': 20,
'batch_size': 128,
'latent_size': 100,
'feature_size': 64,
'channel_size': 3,
'image_height': 32,
'image_width': 32,
'learning_rate': 0.0002,
'beta1': 0.5
})

View File

@ -0,0 +1,86 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dcgan dataset"""
import os
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as vision
from src.config import dcgan_imagenet_cfg
def create_dataset_imagenet(dataset_path, num_parallel_workers=None):
"""
create a train or eval imagenet2012 dataset for dcgan
Args:
dataset_path(string): the path of dataset.
Returns:
dataset
"""
device_num, rank_id = _get_rank_info()
if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers,
num_shards=device_num, shard_id=rank_id)
assert dcgan_imagenet_cfg.image_height == dcgan_imagenet_cfg.image_width, "image_height not equal image_width"
image_size = dcgan_imagenet_cfg.image_height
# define map operations
transform_img = [
vision.Decode(),
vision.Resize(image_size),
vision.CenterCrop(image_size),
vision.HWC2CHW()
]
data_set = data_set.map(input_columns="image", num_parallel_workers=num_parallel_workers, operations=transform_img,
output_columns="image")
data_set = data_set.map(input_columns="image", num_parallel_workers=num_parallel_workers,
operations=lambda x: ((x - 127.5) / 127.5).astype("float32"))
data_set = data_set.map(
input_columns="image",
operations=lambda x: (
x,
np.random.normal(size=(dcgan_imagenet_cfg.latent_size, 1, 1)).astype("float32")
),
output_columns=["image", "latent_code"],
column_order=["image", "latent_code"],
num_parallel_workers=num_parallel_workers
)
data_set = data_set.batch(dcgan_imagenet_cfg.batch_size)
return data_set
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
from mindspore.communication.management import get_rank, get_group_size
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = rank_id = None
return rank_size, rank_id

View File

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dcgan"""
from mindspore import nn
class DCGAN(nn.Cell):
"""dcgan class"""
def __init__(self, myTrainOneStepCellForD, myTrainOneStepCellForG):
super(DCGAN, self).__init__(auto_prefix=True)
self.myTrainOneStepCellForD = myTrainOneStepCellForD
self.myTrainOneStepCellForG = myTrainOneStepCellForG
def construct(self, real_data, latent_code):
output_D = self.myTrainOneStepCellForD(real_data, latent_code).view(-1)
netD_loss = output_D.mean()
output_G = self.myTrainOneStepCellForG(latent_code).view(-1)
netG_loss = output_G.mean()
return netD_loss, netG_loss

View File

@ -0,0 +1,58 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dcgan discriminator"""
from mindspore import nn
from src.cell import Normal
from src.config import dcgan_imagenet_cfg as cfg
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="pad"):
weight_init = Normal(mean=0, sigma=0.02)
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight_init, has_bias=False, pad_mode=pad_mode)
def bm(num_features):
gamma_init = Normal(mean=1, sigma=0.02)
return nn.BatchNorm2d(num_features=num_features, gamma_init=gamma_init)
class Discriminator(nn.Cell):
"""
DCGAN Discriminator
"""
def __init__(self):
super(Discriminator, self).__init__()
self.discriminator = nn.SequentialCell()
# input is 3 x 32 x 32
self.discriminator.append(conv(cfg.channel_size, cfg.feature_size * 2, 4, 2, 1))
self.discriminator.append(nn.LeakyReLU(0.2))
# state size. 128 x 16 x 16
self.discriminator.append(conv(cfg.feature_size * 2, cfg.feature_size * 4, 4, 2, 1))
self.discriminator.append(bm(cfg.feature_size * 4))
self.discriminator.append(nn.LeakyReLU(0.2))
# state size. 256 x 8 x 8
self.discriminator.append(conv(cfg.feature_size * 4, cfg.feature_size * 8, 4, 2, 1))
self.discriminator.append(bm(cfg.feature_size * 8))
self.discriminator.append(nn.LeakyReLU(0.2))
# state size. 512 x 4 x 4
self.discriminator.append(conv(cfg.feature_size * 8, 1, 4, 1))
self.discriminator.append(nn.Sigmoid())
def construct(self, x):
return self.discriminator(x)

View File

@ -0,0 +1,60 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dcgan generator"""
from mindspore import nn
from src.cell import Normal
from src.config import dcgan_imagenet_cfg as cfg
def convt(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="pad"):
weight_init = Normal(mean=0, sigma=0.02)
return nn.Conv2dTranspose(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight_init, has_bias=False, pad_mode=pad_mode)
def bm(num_features):
gamma_init = Normal(mean=1, sigma=0.02)
return nn.BatchNorm2d(num_features=num_features, gamma_init=gamma_init)
class Generator(nn.Cell):
"""
DCGAN Generator
"""
def __init__(self):
super(Generator, self).__init__()
self.generator = nn.SequentialCell()
# input is Z, going into a convolution
self.generator.append(convt(cfg.latent_size, cfg.feature_size * 8, 4, 1, 0))
self.generator.append(bm(cfg.feature_size * 8))
self.generator.append(nn.ReLU())
# state size. 512 x 4 x 4
self.generator.append(convt(cfg.feature_size * 8, cfg.feature_size * 4, 4, 2, 1))
self.generator.append(bm(cfg.feature_size * 4))
self.generator.append(nn.ReLU())
# state size. 256 x 8 x 8
self.generator.append(convt(cfg.feature_size * 4, cfg.feature_size * 2, 4, 2, 1))
self.generator.append(bm(cfg.feature_size * 2))
self.generator.append(nn.ReLU())
# state size. 128 x 16 x 16
self.generator.append(convt(cfg.feature_size * 2, cfg.channel_size, 4, 2, 1))
self.generator.append(nn.Tanh())
# state size. 3 x 32 x 32
def construct(self, x):
return self.generator(x)

View File

@ -0,0 +1,118 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train DCGAN and get checkpoint files."""
import argparse
import ast
import os
import datetime
import numpy as np
from mindspore import context
from mindspore import nn, Tensor
from mindspore.train.callback import CheckpointConfig, _InternalCallbackParam
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_group_size
from src.dataset import create_dataset_imagenet
from src.config import dcgan_imagenet_cfg as cfg
from src.generator import Generator
from src.discriminator import Discriminator
from src.cell import WithLossCellD, WithLossCellG, ModelCheckpoint
from src.dcgan import DCGAN
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='MindSpore dcgan training')
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
help="Run distribute, default is false.")
parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend (Default: 0)')
parser.add_argument('--dataset_path', type=str, default=None, help='dataset path')
parser.add_argument('--save_path', type=str, default=None, help='checkpoint save path')
args = parser.parse_args()
if args.run_distribute:
device_id = args.device_id
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(device_id=device_id)
init()
device_num = get_group_size()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
device_id = args.device_id
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(device_id=device_id)
# Load Dataset
ds = create_dataset_imagenet(os.path.join(args.dataset_path), num_parallel_workers=2)
steps_per_epoch = ds.get_dataset_size()
# Define Network
netD = Discriminator()
netG = Generator()
criterion = nn.BCELoss(reduction='mean')
netD_with_criterion = WithLossCellD(netD, netG, criterion)
netG_with_criterion = WithLossCellG(netD, netG, criterion)
optimizerD = nn.Adam(netD.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
optimizerG = nn.Adam(netG.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
myTrainOneStepCellForD = nn.TrainOneStepCell(netD_with_criterion, optimizerD)
myTrainOneStepCellForG = nn.TrainOneStepCell(netG_with_criterion, optimizerG)
dcgan = DCGAN(myTrainOneStepCellForD, myTrainOneStepCellForG)
dcgan.set_train()
# checkpoint save
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
keep_checkpoint_max=cfg.epoch_size)
ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.save_path, prefix='dcgan')
cb_params = _InternalCallbackParam()
cb_params.train_network = dcgan
cb_params.batch_num = steps_per_epoch
cb_params.epoch_num = cfg.epoch_size
# For each epoch
cb_params.cur_epoch_num = 0
cb_params.cur_step_num = 0
np.random.seed(1)
fixed_noise = Tensor(np.random.normal(size=(16, cfg.latent_size, 1, 1)).astype("float32"))
data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=cfg.epoch_size)
G_losses = []
D_losses = []
# Start Training Loop
print("Starting Training Loop...")
for epoch in range(cfg.epoch_size):
# For each batch in the dataloader
for i, data in enumerate(data_loader):
real_data = Tensor(data['image'])
latent_code = Tensor(data["latent_code"])
netD_loss, netG_loss = dcgan(real_data, latent_code)
if i % 50 == 0:
print("Date time: ", datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "\tepoch: ", epoch, "/",
cfg.epoch_size, "\tstep: ", i, "/", steps_per_epoch, "\tDloss: ", netD_loss, "\tGloss: ",
netG_loss)
D_losses.append(netD_loss.asnumpy())
G_losses.append(netG_loss.asnumpy())
cb_params.cur_step_num = cb_params.cur_step_num + 1
cb_params.cur_epoch_num = cb_params.cur_epoch_num + 1
print("================saving model===================")
if args.device_id == 0 or not args.run_distribute:
ckpt_cb.save_ckpt(cb_params, True)
print("================success================")