forked from mindspore-Ecosystem/mindspore
commit
26ecdc1fc4
|
@ -0,0 +1,201 @@
|
|||
# Contents
|
||||
|
||||
- [DCGAN Description](#DCGAN-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [DCGAN Description](#contents)
|
||||
|
||||
The deep convolutional generative adversarial networks (DCGANs) first introduced CNN into the GAN structure, and the strong feature extraction ability of convolution layer was used to improve the generation effect of GAN.
|
||||
|
||||
[Paper](https://arxiv.org/pdf/1511.06434.pdf): Radford A, Metz L, Chintala S. Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks[J]. Computer ence, 2015.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
Architecture guidelines for stable Deep Convolutional GANs
|
||||
|
||||
- Replace any pooling layers with strided convolutions (discriminator) and fractional-strided convolutions (generator).
|
||||
- Use batchnorm in both the generator and the discriminator.
|
||||
- Remove fully connected hidden layers for deeper architectures.
|
||||
- Use ReLU activation in generator for all layers except for the output, which uses Tanh.
|
||||
- Use LeakyReLU activation in the discriminator for all layers.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Train DCGAN Dataset used: [Imagenet-1k](<http://www.image-net.org/index>)
|
||||
|
||||
- Dataset size: ~125G, 1.2W colorful images in 1000 classes
|
||||
- Train: 120G, 1.2W images
|
||||
- Test: 5G, 50000 images
|
||||
- Data format: RGB images.
|
||||
- Note: Data will be processed in src/dataset.py
|
||||
|
||||
```path
|
||||
|
||||
└─imagenet_original
|
||||
└─train
|
||||
```
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware Ascend
|
||||
- Prepare hardware environment with Ascend processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```shell
|
||||
.
|
||||
└─dcgan
|
||||
├─README.md # README
|
||||
├─scripts # shell script
|
||||
├─run_standalone_train.sh # training in standalone mode(1pcs)
|
||||
├─run_distribute_train.sh # training in parallel mode(8 pcs)
|
||||
└─run_eval.sh # evaluation
|
||||
├─ src
|
||||
├─dataset.py // dataset create
|
||||
├─cell.py // network definition
|
||||
├─dcgan.py // dcgan structure
|
||||
├─discriminator.py // discriminator structure
|
||||
├─generator.py // generator structure
|
||||
├─config.py // config
|
||||
├─ train.py // train dcgan
|
||||
├─ eval.py // eval dcgan
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
### [Training Script Parameters](#contents)
|
||||
|
||||
```shell
|
||||
# distributed training
|
||||
Usage: bash run_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [SAVE_PATH]
|
||||
|
||||
# standalone training
|
||||
Usage: bash run_standalone_train.sh [DATASET_PATH] [SAVE_PATH]
|
||||
```
|
||||
|
||||
### [Parameters Configuration](#contents)
|
||||
|
||||
```txt
|
||||
"img_width": 32, # width of the input images
|
||||
"img_height": 32, # height of the input images
|
||||
'num_classes': 1000,
|
||||
'epoch_size': 20,
|
||||
'batch_size': 128,
|
||||
'latent_size': 100,
|
||||
'feature_size': 64,
|
||||
'channel_size': 3,
|
||||
'image_height': 32,
|
||||
'image_width': 32,
|
||||
'learning_rate': 0.0002,
|
||||
'beta1': 0.5
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
- Set options in `config.py`, including learning rate, output filename and network hyperparameters. Click [here](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
|
||||
|
||||
### [Training](#content)
|
||||
|
||||
- Run `run_standalone_train.sh` for non-distributed training of DCGAN model.
|
||||
|
||||
```bash
|
||||
# standalone training
|
||||
run_standalone_train.sh [DATASET_PATH] [SAVE_PATH]
|
||||
```
|
||||
|
||||
### [Distributed Training](#content)
|
||||
|
||||
- Run `run_distribute_train.sh` for distributed training of DCGAN model.
|
||||
|
||||
```bash
|
||||
run_distribute.sh [RANK_TABLE_FILE] [DATASET_PATH] [SAVE_PATH]
|
||||
```
|
||||
|
||||
- Notes
|
||||
1. hccl.json which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it by using the [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
|
||||
### [Training Result](#content)
|
||||
|
||||
Training result will be stored in save_path. You can find checkpoint file.
|
||||
|
||||
```bash
|
||||
# standalone training result(1p)
|
||||
Date time: 2021-04-13 13:55:39 epoch: 0 / 20 step: 0 / 10010 Dloss: 2.2297878 Gloss: 1.1530013
|
||||
Date time: 2021-04-13 13:56:01 epoch: 0 / 20 step: 50 / 10010 Dloss: 0.21959287 Gloss: 20.064941
|
||||
Date time: 2021-04-13 13:56:22 epoch: 0 / 20 step: 100 / 10010 Dloss: 0.18872623 Gloss: 5.872738
|
||||
Date time: 2021-04-13 13:56:44 epoch: 0 / 20 step: 150 / 10010 Dloss: 0.53905165 Gloss: 4.477289
|
||||
Date time: 2021-04-13 13:57:07 epoch: 0 / 20 step: 200 / 10010 Dloss: 0.47870708 Gloss: 2.2019134
|
||||
Date time: 2021-04-13 13:57:28 epoch: 0 / 20 step: 250 / 10010 Dloss: 0.3929835 Gloss: 1.8170083
|
||||
```
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### [Evaluation](#content)
|
||||
|
||||
- Run `run_eval.sh` for evaluation.
|
||||
|
||||
```bash
|
||||
# infer
|
||||
sh run_eval.sh [IMG_URL] [CKPT_URL]
|
||||
```
|
||||
|
||||
### [Evaluation result](#content)
|
||||
|
||||
Evaluation result will be stored in the img_url path. Under this, you can find generator result in generate.png.
|
||||
|
||||
## Model Export
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --device_target [DEVICE_TARGET] --file_format[EXPORT_FORMAT]
|
||||
```
|
||||
|
||||
`EXPORT_FORMAT` should be "MINDIR"
|
||||
|
||||
# Model Description
|
||||
|
||||
## Performance
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
|
||||
| uploaded Date | 16/04/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.1.1 |
|
||||
| Dataset | ImageNet2012 |
|
||||
| Training Parameters | epoch=20, batch_size = 128 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | BCELoss |
|
||||
| Output | predict class |
|
||||
| Loss | 10.9852 |
|
||||
| Speed | 1pc: 420 ms/step; 8pcs: 143 ms/step |
|
||||
| Total time | 1pc: 24.32 hours |
|
||||
| Checkpoint for Fine tuning | 79.05M(.ckpt file) |
|
||||
| Scripts | [dcgan script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/dcgan) |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
We use random seed in train.py and cell.py for weight initialization.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dcgan eval"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor, nn, load_checkpoint
|
||||
|
||||
from src.config import dcgan_imagenet_cfg as cfg
|
||||
from src.generator import Generator
|
||||
from src.discriminator import Discriminator
|
||||
from src.cell import WithLossCellD, WithLossCellG
|
||||
from src.dcgan import DCGAN
|
||||
|
||||
|
||||
def save_imgs(gen_imgs, img_url):
|
||||
"""save_imgs function"""
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
for i in range(gen_imgs.shape[0]):
|
||||
plt.subplot(4, 4, i + 1)
|
||||
gen_imgs[i] = gen_imgs[i] * 127.5 + 127.5
|
||||
perm = (1, 2, 0)
|
||||
show_imgs = np.transpose(gen_imgs[i], perm)
|
||||
sdf = show_imgs.astype(int)
|
||||
plt.imshow(sdf)
|
||||
plt.axis("off")
|
||||
plt.savefig(img_url + "/generate.png")
|
||||
|
||||
|
||||
def load_dcgan(ckpt_url):
|
||||
"""load_dcgan function"""
|
||||
netD = Discriminator()
|
||||
netG = Generator()
|
||||
|
||||
criterion = nn.BCELoss(reduction='mean')
|
||||
|
||||
netD_with_criterion = WithLossCellD(netD, netG, criterion)
|
||||
netG_with_criterion = WithLossCellG(netD, netG, criterion)
|
||||
|
||||
optimizerD = nn.Adam(netD.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
|
||||
optimizerG = nn.Adam(netG.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
|
||||
|
||||
myTrainOneStepCellForD = nn.TrainOneStepCell(netD_with_criterion, optimizerD)
|
||||
myTrainOneStepCellForG = nn.TrainOneStepCell(netG_with_criterion, optimizerG)
|
||||
|
||||
dcgan = DCGAN(myTrainOneStepCellForD, myTrainOneStepCellForG)
|
||||
load_checkpoint(ckpt_url, dcgan)
|
||||
netG_trained = dcgan.myTrainOneStepCellForG.network.netG
|
||||
return netG_trained
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='MindSpore dcgan training')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend (Default: 0)')
|
||||
parser.add_argument('--img_url', type=str, default=None, help='img save path')
|
||||
parser.add_argument('--ckpt_url', type=str, default=None, help='checkpoint load path')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
fixed_noise = Tensor(np.random.normal(size=(16, cfg.latent_size, 1, 1)).astype("float32"))
|
||||
|
||||
net_G = load_dcgan(args.ckpt_url)
|
||||
fake = net_G(fixed_noise)
|
||||
print("================saving images================")
|
||||
save_imgs(fake.asnumpy(), args.img_url)
|
||||
print("================success================")
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export checkpoint file into air, onnx, mindir models"""
|
||||
import argparse
|
||||
import ast
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import nn
|
||||
|
||||
from src.cell import WithLossCellD, WithLossCellG
|
||||
from src.dcgan import DCGAN
|
||||
from src.discriminator import Discriminator
|
||||
from src.generator import Generator
|
||||
from src.config import dcgan_imagenet_cfg as cfg
|
||||
|
||||
parser = argparse.ArgumentParser(description='ntsnet export')
|
||||
parser.add_argument("--run_modelart", type=ast.literal_eval, default=False, help="Run on modelArt, default is false.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=128, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file name.")
|
||||
parser.add_argument('--data_url', default=None, help='Directory contains CUB_200_2011 dataset.')
|
||||
parser.add_argument('--train_url', default=None, help='Directory contains checkpoint file')
|
||||
parser.add_argument("--file_name", type=str, default="ntsnet", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, default="MINDIR", help="file format")
|
||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||
choices=['Ascend', 'GPU', 'CPU'], help='device target (default: Ascend)')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
netD = Discriminator()
|
||||
netG = Generator()
|
||||
|
||||
criterion = nn.BCELoss(reduction='mean')
|
||||
|
||||
netD_with_criterion = WithLossCellD(netD, netG, criterion)
|
||||
netG_with_criterion = WithLossCellG(netD, netG, criterion)
|
||||
|
||||
optimizerD = nn.Adam(netD.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
|
||||
optimizerG = nn.Adam(netG.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
|
||||
|
||||
myTrainOneStepCellForD = nn.TrainOneStepCell(netD_with_criterion, optimizerD)
|
||||
myTrainOneStepCellForG = nn.TrainOneStepCell(netG_with_criterion, optimizerG)
|
||||
|
||||
net = DCGAN(myTrainOneStepCellForD, myTrainOneStepCellForG)
|
||||
param_dict = load_checkpoint(os.path.join(args.train_url, args.ckpt_file))
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
# inputs = Tensor(np.random.rand(args.batch_size, 3, 448, 448), mstype.float32)
|
||||
real_data = Tensor(np.random.rand(args.batch_size, 3, 32, 32), mstype.float32)
|
||||
latent_code = Tensor(np.random.rand(args.batch_size, 100, 1, 1), mstype.float32)
|
||||
inputs = [real_data, latent_code]
|
||||
export(net, *inputs, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
from mindspore import nn
|
||||
|
||||
from src.cell import WithLossCellD, WithLossCellG
|
||||
from src.dcgan import DCGAN
|
||||
from src.discriminator import Discriminator
|
||||
from src.generator import Generator
|
||||
from src.config import dcgan_imagenet_cfg as cfg
|
||||
|
||||
|
||||
def create_network(name):
|
||||
"""create_network function"""
|
||||
if name == "dcgan":
|
||||
netD = Discriminator()
|
||||
netG = Generator()
|
||||
|
||||
criterion = nn.BCELoss(reduction='mean')
|
||||
|
||||
netD_with_criterion = WithLossCellD(netD, netG, criterion)
|
||||
netG_with_criterion = WithLossCellG(netD, netG, criterion)
|
||||
|
||||
optimizerD = nn.Adam(netD.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
|
||||
optimizerG = nn.Adam(netG.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
|
||||
|
||||
myTrainOneStepCellForD = nn.TrainOneStepCell(netD_with_criterion, optimizerD)
|
||||
myTrainOneStepCellForG = nn.TrainOneStepCell(netG_with_criterion, optimizerG)
|
||||
|
||||
dcgan = DCGAN(myTrainOneStepCellForD, myTrainOneStepCellForG)
|
||||
return dcgan
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,90 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [SAVE_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
PATH3=$(get_real_path $3)
|
||||
|
||||
echo $PATH1
|
||||
echo $PATH2
|
||||
echo $PATH3
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH3 ]
|
||||
then
|
||||
echo "error: SAVE_PATH=$PATH3 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export HCCL_CONNECT_TIMEOUT=600
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
echo 3 > /proc/sys/vm/drop_caches
|
||||
|
||||
cpus=`cat /proc/cpuinfo| grep "processor"| wc -l`
|
||||
avg=`expr $cpus \/ $DEVICE_NUM`
|
||||
gap=`expr $avg \- 1`
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
start=`expr $i \* $avg`
|
||||
end=`expr $start \+ $gap`
|
||||
cmdopt=$start"-"$end
|
||||
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
taskset -c $cmdopt python -u train.py --device_id=$i --run_distribute=True \
|
||||
--dataset_path=$PATH2 --save_path=$PATH3 &> log &
|
||||
cd ..
|
||||
done
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [IMG_URL] [CKPT_URL]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: IMG_URL=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: CKPT_URL=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python -u eval.py --device_id=$DEVICE_ID --img_url=$PATH1 --ckpt_url=$PATH2 &> log &
|
||||
cd ..
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: bash run_standalone_train.sh [DATASET_PATH] [SAVE_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
PATH2=$(get_real_path $2)
|
||||
echo $PATH2
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]
|
||||
then
|
||||
echo "error: SAVE_PATH=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python -u train.py --device_id=$DEVICE_ID --dataset_path=$PATH1 --save_path=$PATH2 &> log &
|
||||
cd ..
|
|
@ -0,0 +1,243 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dcgan cell"""
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore import nn, ops, context
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import Initializer, _assignment
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
|
||||
from mindspore.train._utils import _make_directory
|
||||
from mindspore.train.serialization import _save_graph, save_checkpoint
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.callback._callback import set_cur_net
|
||||
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _cur_dir, CheckpointConfig, CheckpointManager, \
|
||||
_chg_ckpt_file_name_if_same_exist
|
||||
|
||||
|
||||
class Reshape(nn.Cell):
|
||||
def __init__(self, shape, auto_prefix=True):
|
||||
super().__init__(auto_prefix=auto_prefix)
|
||||
self.shape = shape
|
||||
|
||||
def construct(self, x):
|
||||
return ops.operations.Reshape()(x, self.shape)
|
||||
|
||||
|
||||
class Normal(Initializer):
|
||||
def __init__(self, mean=0.0, sigma=0.01):
|
||||
super(Normal, self).__init__()
|
||||
self.sigma = sigma
|
||||
self.mean = mean
|
||||
|
||||
def _initialize(self, arr):
|
||||
np.random.seed(999)
|
||||
arr_normal = np.random.normal(self.mean, self.sigma, arr.shape)
|
||||
_assignment(arr, arr_normal)
|
||||
|
||||
|
||||
class ModelCheckpoint(Callback):
|
||||
"""
|
||||
The checkpoint callback class.
|
||||
|
||||
It is called to combine with train process and save the model and network parameters after traning.
|
||||
|
||||
Args:
|
||||
prefix (str): The prefix name of checkpoint files. Default: "CKP".
|
||||
directory (str): The path of the folder which will be saved in the checkpoint file. Default: None.
|
||||
config (CheckpointConfig): Checkpoint strategy configuration. Default: None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the prefix is invalid.
|
||||
TypeError: If the config is not CheckpointConfig type.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix='CKP', directory=None, config=None):
|
||||
super(ModelCheckpoint, self).__init__()
|
||||
self._latest_ckpt_file_name = ""
|
||||
self._init_time = time.time()
|
||||
self._last_time = time.time()
|
||||
self._last_time_for_keep = time.time()
|
||||
self._last_triggered_step = 0
|
||||
|
||||
if _check_file_name_prefix(prefix):
|
||||
self._prefix = prefix
|
||||
else:
|
||||
raise ValueError("Prefix {} for checkpoint file name invalid, "
|
||||
"please check and correct it and then continue.".format(prefix))
|
||||
|
||||
if directory is not None:
|
||||
self._directory = _make_directory(directory)
|
||||
else:
|
||||
self._directory = _cur_dir
|
||||
|
||||
if config is None:
|
||||
self._config = CheckpointConfig()
|
||||
else:
|
||||
if not isinstance(config, CheckpointConfig):
|
||||
raise TypeError("config should be CheckpointConfig type.")
|
||||
self._config = config
|
||||
|
||||
# get existing checkpoint files
|
||||
self._manager = CheckpointManager()
|
||||
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
|
||||
self._graph_saved = False
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Save the checkpoint at the end of step.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the train running.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
# save graph (only once)
|
||||
if not self._graph_saved:
|
||||
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
|
||||
_save_graph(cb_params.train_network, graph_file_name)
|
||||
self._graph_saved = True
|
||||
self.save_ckpt(cb_params)
|
||||
|
||||
def end(self, run_context):
|
||||
"""
|
||||
Save the last checkpoint after training finished.
|
||||
|
||||
Args:
|
||||
run_context (RunContext): Context of the train running.
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
_to_save_last_ckpt = True
|
||||
self.save_ckpt(cb_params, _to_save_last_ckpt)
|
||||
|
||||
thread_list = threading.enumerate()
|
||||
if len(thread_list) > 1:
|
||||
for thread in thread_list:
|
||||
if thread.getName() == "asyn_save_ckpt":
|
||||
thread.join()
|
||||
|
||||
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
||||
destroy_allgather_cell()
|
||||
|
||||
def _check_save_ckpt(self, cb_params, force_to_save):
|
||||
"""Check whether save checkpoint files or not."""
|
||||
if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
|
||||
if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
|
||||
or force_to_save is True:
|
||||
return True
|
||||
elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
|
||||
self._cur_time = time.time()
|
||||
if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True:
|
||||
self._last_time = self._cur_time
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def save_ckpt(self, cb_params, force_to_save=False):
|
||||
"""Save checkpoint files."""
|
||||
if cb_params.cur_step_num == self._last_triggered_step:
|
||||
return
|
||||
|
||||
save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
|
||||
step_num_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
|
||||
if save_ckpt:
|
||||
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
|
||||
+ str(step_num_in_epoch) + ".ckpt"
|
||||
if _is_role_pserver():
|
||||
cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file
|
||||
# update checkpoint file list.
|
||||
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
|
||||
# keep checkpoint files number equal max number.
|
||||
if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
|
||||
self._manager.remove_oldest_ckpoint_file()
|
||||
elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
|
||||
self._cur_time_for_keep = time.time()
|
||||
if (self._cur_time_for_keep - self._last_time_for_keep) \
|
||||
< self._config.keep_checkpoint_per_n_minutes * 60:
|
||||
self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
|
||||
self._cur_time_for_keep)
|
||||
|
||||
# generate the new checkpoint file and rename it.
|
||||
cur_file = os.path.join(self._directory, cur_ckpoint_file)
|
||||
self._last_time_for_keep = time.time()
|
||||
self._last_triggered_step = cb_params.cur_step_num
|
||||
|
||||
if context.get_context("enable_ge"):
|
||||
set_cur_net(cb_params.train_network)
|
||||
cb_params.train_network.exec_checkpoint_graph()
|
||||
|
||||
save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
|
||||
self._config.async_save)
|
||||
|
||||
self._latest_ckpt_file_name = cur_file
|
||||
|
||||
@property
|
||||
def latest_ckpt_file_name(self):
|
||||
"""Return the latest checkpoint path and file name."""
|
||||
return self._latest_ckpt_file_name
|
||||
|
||||
|
||||
class WithLossCellD(nn.Cell):
|
||||
"""class WithLossCellD"""
|
||||
def __init__(self, netD, netG, loss_fn):
|
||||
super(WithLossCellD, self).__init__(auto_prefix=True)
|
||||
self.netD = netD
|
||||
self.netG = netG
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
def construct(self, real_data, latent_code):
|
||||
"""class WithLossCellD construct"""
|
||||
ones = ops.Ones()
|
||||
zeros = ops.Zeros()
|
||||
|
||||
out1 = self.netD(real_data)
|
||||
label1 = ones(out1.shape, mstype.float32)
|
||||
loss1 = self.loss_fn(out1, label1)
|
||||
|
||||
fake_data = self.netG(latent_code)
|
||||
fake_data = F.stop_gradient(fake_data)
|
||||
out2 = self.netD(fake_data)
|
||||
label2 = zeros(out2.shape, mstype.float32)
|
||||
loss2 = self.loss_fn(out2, label2)
|
||||
return loss1 + loss2
|
||||
|
||||
@property
|
||||
def backbone_network(self):
|
||||
"""class WithLossCellD backbone_network"""
|
||||
return self.netD
|
||||
|
||||
|
||||
class WithLossCellG(nn.Cell):
|
||||
"""class WithLossCellG"""
|
||||
def __init__(self, netD, netG, loss_fn):
|
||||
super(WithLossCellG, self).__init__(auto_prefix=True)
|
||||
self.netD = netD
|
||||
self.netG = netG
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
def construct(self, latent_code):
|
||||
ones = ops.Ones()
|
||||
fake_data = self.netG(latent_code)
|
||||
out = self.netD(fake_data)
|
||||
label = ones(out.shape, mstype.float32)
|
||||
loss = self.loss_fn(out, label)
|
||||
return loss
|
||||
|
||||
@property
|
||||
def backbone_network(self):
|
||||
return self.netG
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
dcgan_imagenet_cfg = edict({
|
||||
'num_classes': 1000,
|
||||
'epoch_size': 20,
|
||||
'batch_size': 128,
|
||||
'latent_size': 100,
|
||||
'feature_size': 64,
|
||||
'channel_size': 3,
|
||||
'image_height': 32,
|
||||
'image_width': 32,
|
||||
'learning_rate': 0.0002,
|
||||
'beta1': 0.5
|
||||
})
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dcgan dataset"""
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as vision
|
||||
from src.config import dcgan_imagenet_cfg
|
||||
|
||||
|
||||
def create_dataset_imagenet(dataset_path, num_parallel_workers=None):
|
||||
"""
|
||||
create a train or eval imagenet2012 dataset for dcgan
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
|
||||
device_num, rank_id = _get_rank_info()
|
||||
|
||||
if device_num == 1:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers)
|
||||
else:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
assert dcgan_imagenet_cfg.image_height == dcgan_imagenet_cfg.image_width, "image_height not equal image_width"
|
||||
image_size = dcgan_imagenet_cfg.image_height
|
||||
|
||||
# define map operations
|
||||
transform_img = [
|
||||
vision.Decode(),
|
||||
vision.Resize(image_size),
|
||||
vision.CenterCrop(image_size),
|
||||
vision.HWC2CHW()
|
||||
]
|
||||
|
||||
data_set = data_set.map(input_columns="image", num_parallel_workers=num_parallel_workers, operations=transform_img,
|
||||
output_columns="image")
|
||||
data_set = data_set.map(input_columns="image", num_parallel_workers=num_parallel_workers,
|
||||
operations=lambda x: ((x - 127.5) / 127.5).astype("float32"))
|
||||
data_set = data_set.map(
|
||||
input_columns="image",
|
||||
operations=lambda x: (
|
||||
x,
|
||||
np.random.normal(size=(dcgan_imagenet_cfg.latent_size, 1, 1)).astype("float32")
|
||||
),
|
||||
output_columns=["image", "latent_code"],
|
||||
column_order=["image", "latent_code"],
|
||||
num_parallel_workers=num_parallel_workers
|
||||
)
|
||||
|
||||
data_set = data_set.batch(dcgan_imagenet_cfg.batch_size)
|
||||
|
||||
return data_set
|
||||
|
||||
|
||||
def _get_rank_info():
|
||||
"""
|
||||
get rank size and rank id
|
||||
"""
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
|
||||
if rank_size > 1:
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
rank_size = get_group_size()
|
||||
rank_id = get_rank()
|
||||
else:
|
||||
rank_size = rank_id = None
|
||||
return rank_size, rank_id
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dcgan"""
|
||||
from mindspore import nn
|
||||
|
||||
|
||||
class DCGAN(nn.Cell):
|
||||
"""dcgan class"""
|
||||
def __init__(self, myTrainOneStepCellForD, myTrainOneStepCellForG):
|
||||
super(DCGAN, self).__init__(auto_prefix=True)
|
||||
self.myTrainOneStepCellForD = myTrainOneStepCellForD
|
||||
self.myTrainOneStepCellForG = myTrainOneStepCellForG
|
||||
|
||||
def construct(self, real_data, latent_code):
|
||||
output_D = self.myTrainOneStepCellForD(real_data, latent_code).view(-1)
|
||||
netD_loss = output_D.mean()
|
||||
output_G = self.myTrainOneStepCellForG(latent_code).view(-1)
|
||||
netG_loss = output_G.mean()
|
||||
return netD_loss, netG_loss
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dcgan discriminator"""
|
||||
from mindspore import nn
|
||||
|
||||
from src.cell import Normal
|
||||
from src.config import dcgan_imagenet_cfg as cfg
|
||||
|
||||
|
||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="pad"):
|
||||
weight_init = Normal(mean=0, sigma=0.02)
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
weight_init=weight_init, has_bias=False, pad_mode=pad_mode)
|
||||
|
||||
|
||||
def bm(num_features):
|
||||
gamma_init = Normal(mean=1, sigma=0.02)
|
||||
return nn.BatchNorm2d(num_features=num_features, gamma_init=gamma_init)
|
||||
|
||||
|
||||
class Discriminator(nn.Cell):
|
||||
"""
|
||||
DCGAN Discriminator
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Discriminator, self).__init__()
|
||||
self.discriminator = nn.SequentialCell()
|
||||
# input is 3 x 32 x 32
|
||||
self.discriminator.append(conv(cfg.channel_size, cfg.feature_size * 2, 4, 2, 1))
|
||||
self.discriminator.append(nn.LeakyReLU(0.2))
|
||||
# state size. 128 x 16 x 16
|
||||
self.discriminator.append(conv(cfg.feature_size * 2, cfg.feature_size * 4, 4, 2, 1))
|
||||
self.discriminator.append(bm(cfg.feature_size * 4))
|
||||
self.discriminator.append(nn.LeakyReLU(0.2))
|
||||
# state size. 256 x 8 x 8
|
||||
self.discriminator.append(conv(cfg.feature_size * 4, cfg.feature_size * 8, 4, 2, 1))
|
||||
self.discriminator.append(bm(cfg.feature_size * 8))
|
||||
self.discriminator.append(nn.LeakyReLU(0.2))
|
||||
# state size. 512 x 4 x 4
|
||||
self.discriminator.append(conv(cfg.feature_size * 8, 1, 4, 1))
|
||||
self.discriminator.append(nn.Sigmoid())
|
||||
|
||||
def construct(self, x):
|
||||
return self.discriminator(x)
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""dcgan generator"""
|
||||
from mindspore import nn
|
||||
|
||||
from src.cell import Normal
|
||||
from src.config import dcgan_imagenet_cfg as cfg
|
||||
|
||||
|
||||
def convt(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="pad"):
|
||||
weight_init = Normal(mean=0, sigma=0.02)
|
||||
return nn.Conv2dTranspose(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
weight_init=weight_init, has_bias=False, pad_mode=pad_mode)
|
||||
|
||||
|
||||
def bm(num_features):
|
||||
gamma_init = Normal(mean=1, sigma=0.02)
|
||||
return nn.BatchNorm2d(num_features=num_features, gamma_init=gamma_init)
|
||||
|
||||
|
||||
class Generator(nn.Cell):
|
||||
"""
|
||||
DCGAN Generator
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Generator, self).__init__()
|
||||
self.generator = nn.SequentialCell()
|
||||
# input is Z, going into a convolution
|
||||
self.generator.append(convt(cfg.latent_size, cfg.feature_size * 8, 4, 1, 0))
|
||||
self.generator.append(bm(cfg.feature_size * 8))
|
||||
self.generator.append(nn.ReLU())
|
||||
# state size. 512 x 4 x 4
|
||||
self.generator.append(convt(cfg.feature_size * 8, cfg.feature_size * 4, 4, 2, 1))
|
||||
self.generator.append(bm(cfg.feature_size * 4))
|
||||
self.generator.append(nn.ReLU())
|
||||
# state size. 256 x 8 x 8
|
||||
self.generator.append(convt(cfg.feature_size * 4, cfg.feature_size * 2, 4, 2, 1))
|
||||
self.generator.append(bm(cfg.feature_size * 2))
|
||||
self.generator.append(nn.ReLU())
|
||||
# state size. 128 x 16 x 16
|
||||
self.generator.append(convt(cfg.feature_size * 2, cfg.channel_size, 4, 2, 1))
|
||||
self.generator.append(nn.Tanh())
|
||||
# state size. 3 x 32 x 32
|
||||
|
||||
def construct(self, x):
|
||||
return self.generator(x)
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""train DCGAN and get checkpoint files."""
|
||||
import argparse
|
||||
import ast
|
||||
import os
|
||||
import datetime
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import nn, Tensor
|
||||
from mindspore.train.callback import CheckpointConfig, _InternalCallbackParam
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_group_size
|
||||
from src.dataset import create_dataset_imagenet
|
||||
from src.config import dcgan_imagenet_cfg as cfg
|
||||
from src.generator import Generator
|
||||
from src.discriminator import Discriminator
|
||||
from src.cell import WithLossCellD, WithLossCellG, ModelCheckpoint
|
||||
from src.dcgan import DCGAN
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='MindSpore dcgan training')
|
||||
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
|
||||
help="Run distribute, default is false.")
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend (Default: 0)')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='dataset path')
|
||||
parser.add_argument('--save_path', type=str, default=None, help='checkpoint save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.run_distribute:
|
||||
device_id = args.device_id
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
context.set_context(device_id=device_id)
|
||||
init()
|
||||
device_num = get_group_size()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
device_id = args.device_id
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
# Load Dataset
|
||||
ds = create_dataset_imagenet(os.path.join(args.dataset_path), num_parallel_workers=2)
|
||||
|
||||
steps_per_epoch = ds.get_dataset_size()
|
||||
|
||||
# Define Network
|
||||
netD = Discriminator()
|
||||
netG = Generator()
|
||||
|
||||
criterion = nn.BCELoss(reduction='mean')
|
||||
|
||||
netD_with_criterion = WithLossCellD(netD, netG, criterion)
|
||||
netG_with_criterion = WithLossCellG(netD, netG, criterion)
|
||||
|
||||
optimizerD = nn.Adam(netD.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
|
||||
optimizerG = nn.Adam(netG.trainable_params(), learning_rate=cfg.learning_rate, beta1=cfg.beta1)
|
||||
|
||||
myTrainOneStepCellForD = nn.TrainOneStepCell(netD_with_criterion, optimizerD)
|
||||
myTrainOneStepCellForG = nn.TrainOneStepCell(netG_with_criterion, optimizerG)
|
||||
|
||||
dcgan = DCGAN(myTrainOneStepCellForD, myTrainOneStepCellForG)
|
||||
dcgan.set_train()
|
||||
|
||||
# checkpoint save
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
|
||||
keep_checkpoint_max=cfg.epoch_size)
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.save_path, prefix='dcgan')
|
||||
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = dcgan
|
||||
cb_params.batch_num = steps_per_epoch
|
||||
cb_params.epoch_num = cfg.epoch_size
|
||||
# For each epoch
|
||||
cb_params.cur_epoch_num = 0
|
||||
cb_params.cur_step_num = 0
|
||||
|
||||
np.random.seed(1)
|
||||
fixed_noise = Tensor(np.random.normal(size=(16, cfg.latent_size, 1, 1)).astype("float32"))
|
||||
|
||||
data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=cfg.epoch_size)
|
||||
G_losses = []
|
||||
D_losses = []
|
||||
# Start Training Loop
|
||||
print("Starting Training Loop...")
|
||||
for epoch in range(cfg.epoch_size):
|
||||
# For each batch in the dataloader
|
||||
for i, data in enumerate(data_loader):
|
||||
real_data = Tensor(data['image'])
|
||||
latent_code = Tensor(data["latent_code"])
|
||||
netD_loss, netG_loss = dcgan(real_data, latent_code)
|
||||
if i % 50 == 0:
|
||||
print("Date time: ", datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "\tepoch: ", epoch, "/",
|
||||
cfg.epoch_size, "\tstep: ", i, "/", steps_per_epoch, "\tDloss: ", netD_loss, "\tGloss: ",
|
||||
netG_loss)
|
||||
D_losses.append(netD_loss.asnumpy())
|
||||
G_losses.append(netG_loss.asnumpy())
|
||||
cb_params.cur_step_num = cb_params.cur_step_num + 1
|
||||
cb_params.cur_epoch_num = cb_params.cur_epoch_num + 1
|
||||
print("================saving model===================")
|
||||
if args.device_id == 0 or not args.run_distribute:
|
||||
ckpt_cb.save_ckpt(cb_params, True)
|
||||
print("================success================")
|
Loading…
Reference in New Issue