forked from mindspore-Ecosystem/mindspore
!18615 Add a new model - Conditional GAN
Merge pull request !18615 from Ucalan/ucalan/mindspore
This commit is contained in:
commit
0d2eede453
|
@ -0,0 +1,181 @@
|
|||
# Contents
|
||||
|
||||
- [CGAN Description](#cgan-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Script Parameters](#training-script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Training Result](#training-result)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Evaluation result](#evaluation-result)
|
||||
- [Model Export](#model-export)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [CGAN Description](#contents)
|
||||
|
||||
Generative Adversarial Nets were recently introduced as a novel way to train generative models. In this work we introduce the conditional version of generative adversarial nets, which can be constructed by simply feeding the data, y, we wish to condition on to both the generator and discriminator. We show that this model can generate MNIST digits conditioned on class labels. We also illustrate how this model could be used to learn a multi-modal model, and provide preliminary examples of an application to image tagging in which we demonstrate how this approach can generate descriptive tags which are not part of training labels.
|
||||
|
||||
[Paper](https://arxiv.org/pdf/1411.1784.pdf): Conditional Generative Adversarial Nets.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
Architecture guidelines for Conditional GANs
|
||||
|
||||
- Replace any pooling layers with strided convolutions (discriminator) and fractional-strided convolutions (generator).
|
||||
- Use batchnorm in both the generator and the discriminator.
|
||||
- Remove fully connected hidden layers for deeper architectures.
|
||||
- Use ReLU activation in generator for all layers except for the output, which uses Tanh.
|
||||
- Use LeakyReLU activation in the discriminator for all layers.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Train CGAN Dataset used: [MNIST](<http://yann.lecun.com/exdb/mnist/>)
|
||||
|
||||
- Dataset size:52.4M,60,000 28*28 in 10 classes
|
||||
- Train:60,000 images
|
||||
- Test:10,000 images
|
||||
- Data format:binary files
|
||||
- Note:Data will be processed in dataset.py
|
||||
|
||||
```text
|
||||
|
||||
└─data
|
||||
└─MNIST_Data
|
||||
└─train
|
||||
```
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware Ascend
|
||||
- Prepare hardware environment with Ascend processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```text
|
||||
.
|
||||
└─CGAN
|
||||
├─README.md # README
|
||||
├─requirements.txt # required modules
|
||||
├─scripts # shell script
|
||||
├─run_standalone_train.sh # training in standalone mode(1pcs)
|
||||
├─run_distributed_train_ascend.sh # training in parallel mode(8 pcs)
|
||||
└─run_eval_ascend.sh # evaluation
|
||||
├─ src
|
||||
├─dataset.py # dataset create
|
||||
├─cell.py # network definition
|
||||
├─ckpt_util.py # utility of checkpoint
|
||||
├─model.py # discriminator & generator structure
|
||||
├─ train.py # train cgan
|
||||
├─ eval.py # eval cgan
|
||||
├─ export.py # export mindir
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
### [Training Script Parameters](#contents)
|
||||
|
||||
```shell
|
||||
# distributed training
|
||||
bash run_distributed_train_ascend.sh /path/to/MNIST_Data/train /path/to/hccl_8p_01234567_127.0.0.1.json 8
|
||||
|
||||
# standalone training
|
||||
bash run_standalone_train.sh /path/MNIST_Data/train 0
|
||||
|
||||
# evaluating
|
||||
bash run_eval_ascend.sh /path/to/script/train_parallel/0/ckpt/G_50.ckpt 0
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### [Training](#content)
|
||||
|
||||
- Run `run_standalone_train_ascend.sh` for non-distributed training of CGAN model.
|
||||
|
||||
```bash
|
||||
# standalone training
|
||||
bash run_standalone_train_ascend.sh /path/MNIST_Data/train 0
|
||||
```
|
||||
|
||||
### [Distributed Training](#content)
|
||||
|
||||
- Run `run_distributed_train_ascend.sh` for distributed training of CGAN model.
|
||||
|
||||
```bash
|
||||
bash run_distributed_train_ascend.sh /path/to/MNIST_Data/train /path/to/hccl_8p_01234567_127.0.0.1.json 8
|
||||
```
|
||||
|
||||
- Notes
|
||||
1. hccl.json which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it by using the [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
|
||||
### [Training Result](#content)
|
||||
|
||||
Training result will be stored in `img_eval`.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### [Evaluation](#content)
|
||||
|
||||
- Run `run_eval_ascend.sh` for evaluation.
|
||||
|
||||
```bash
|
||||
# eval
|
||||
bash run_eval_ascend.sh /path/to/script/train_parallel/0/ckpt/G_50.ckpt 0
|
||||
```
|
||||
|
||||
### [Evaluation result](#content)
|
||||
|
||||
Evaluation result will be stored in the img_eval path. Under this, you can find generator result in result.png.
|
||||
|
||||
## Model Export
|
||||
|
||||
```bash
|
||||
python export.py --ckpt_dir /path/to/train/ckpt/G_50.ckpt
|
||||
```
|
||||
|
||||
# Model Description
|
||||
|
||||
## Performance
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ------------------------------------------------------------------------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | CentOs 8.2; Ascend 910; CPU 2.60GHz, 192cores; Memory 755G |
|
||||
| uploaded Date | 07/04/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | MNIST Dataset |
|
||||
| Training Parameters | epoch=50, batch_size = 128 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | BCELoss |
|
||||
| Output | predict class |
|
||||
| Loss | g_loss: 4.9693 d_loss: 0.1540 |
|
||||
| Total time | 7.5 mins(8p) |
|
||||
| Checkpoint for Fine tuning | 26.2M(.ckpt file) |
|
||||
| Scripts | [cgan script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/CGAN) |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
We use random seed in train.py and cell.py for weight initialization.
|
||||
|
||||
# [Model_Zoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""eval cgan"""
|
||||
import os
|
||||
import itertools
|
||||
import argparse
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.model import Generator
|
||||
|
||||
def preLauch():
|
||||
"""parse the console argument"""
|
||||
parser = argparse.ArgumentParser(description='MindSpore cgan training')
|
||||
parser.add_argument('--device_id', type=int, default=0,
|
||||
help='device id of Ascend (Default: 0)')
|
||||
parser.add_argument('--ckpt_dir', type=str,
|
||||
default='ckpt', help='checkpoint dir of CGAN')
|
||||
parser.add_argument('--img_out', type=str,
|
||||
default='img_eval', help='the dir of output img')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(device_id=args.device_id,
|
||||
mode=context.GRAPH_MODE,
|
||||
device_target="Ascend")
|
||||
# if not exists 'img_out', make it
|
||||
if not os.path.exists(args.img_out):
|
||||
os.mkdir(args.img_out)
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
# before training, we should set some arguments
|
||||
args = preLauch()
|
||||
|
||||
# training argument
|
||||
input_dim = 100
|
||||
|
||||
# create G Cell & D Cell
|
||||
netG = Generator(input_dim)
|
||||
|
||||
latent_code_eval = Tensor(np.random.randn(200, input_dim), dtype=mstype.float32)
|
||||
|
||||
label_eval = np.zeros((200, 10))
|
||||
for i in range(200):
|
||||
j = i // 20
|
||||
label_eval[i][j] = 1
|
||||
label_eval = Tensor(label_eval, dtype=mstype.float32)
|
||||
|
||||
fig, ax = plt.subplots(10, 20, figsize=(10, 5))
|
||||
for digit, num in itertools.product(range(10), range(20)):
|
||||
ax[digit, num].get_xaxis().set_visible(False)
|
||||
ax[digit, num].get_yaxis().set_visible(False)
|
||||
|
||||
param_G = load_checkpoint(args.ckpt_dir)
|
||||
load_param_into_net(netG, param_G)
|
||||
gen_imgs_eval = netG(latent_code_eval, label_eval)
|
||||
for i in range(200):
|
||||
if (i + 1) % 20 == 0:
|
||||
print("process ========= {}/200".format(i+1))
|
||||
digit = i // 20
|
||||
num = i % 20
|
||||
img = gen_imgs_eval[i].asnumpy().reshape((28, 28))
|
||||
ax[digit, num].cla()
|
||||
ax[digit, num].imshow(img * 127.5 + 127.5, cmap="gray")
|
||||
|
||||
label = 'eval result'
|
||||
fig.text(0.5, 0.01, label, ha='center')
|
||||
fig.tight_layout()
|
||||
plt.subplots_adjust(wspace=0.01, hspace=0.01)
|
||||
print("===========saving image===========")
|
||||
plt.savefig("./img_eval/result.png")
|
||||
print("===========success================")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from src.model import Generator
|
||||
|
||||
|
||||
def preLauch():
|
||||
"""parse the console argument"""
|
||||
parser = argparse.ArgumentParser(description='MindSpore cgan training')
|
||||
parser.add_argument('--device_id', type=int, default=0,
|
||||
help='device id of Ascend (Default: 0)')
|
||||
parser.add_argument('--ckpt_dir', type=str,
|
||||
default='ckpt', help='checkpoint dir of CGAN')
|
||||
args = parser.parse_args()
|
||||
context.set_context(device_id=args.device_id, mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
return args
|
||||
|
||||
def main():
|
||||
# before training, we should set some arguments
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
args = preLauch()
|
||||
|
||||
# training argument
|
||||
input_dim = 100
|
||||
# create G Cell & D Cell
|
||||
netG = Generator(input_dim)
|
||||
|
||||
latent_code_eval = Tensor(np.random.randn(200, input_dim), dtype=mstype.float32)
|
||||
|
||||
label_eval = np.zeros((200, 10))
|
||||
for i in range(200):
|
||||
j = i // 20
|
||||
label_eval[i][j] = 1
|
||||
label_eval = Tensor(label_eval, dtype=mstype.float32)
|
||||
|
||||
param_G = load_checkpoint(args.ckpt_dir)
|
||||
load_param_into_net(netG, param_G)
|
||||
netG.set_train(False)
|
||||
export(netG, latent_code_eval, label_eval, file_name="CGAN", file_format="MINDIR")
|
||||
print("CGAN exported")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,56 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train_ascend.sh [dataset] [rank_table] [device_num]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DATASET=$1
|
||||
export RANK_TABLE_FILE=$2
|
||||
export DEVICE_NUM=$3
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export HCCL_CONNECT_TIMEOUT=600
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
|
||||
# remove old train_parallel files
|
||||
rm -rf ./train_parallel
|
||||
mkdir ./train_parallel
|
||||
echo "device count=$DEVICE_NUM"
|
||||
|
||||
i=0
|
||||
while [ $i -lt ${DEVICE_NUM} ]; do
|
||||
export DEVICE_ID=${i}
|
||||
export RANK_ID=$((rank_start + i))
|
||||
|
||||
# mkdirs
|
||||
mkdir ./train_parallel/$i
|
||||
mkdir ./train_parallel/$i/src
|
||||
|
||||
# move files
|
||||
cp ../*.py ./train_parallel/$i
|
||||
cp ../src/*.py ./train_parallel/$i/src
|
||||
|
||||
# goto the training dirs of each training
|
||||
cd ./train_parallel/$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
# input logs to env.log
|
||||
env > env.log
|
||||
python -u train.py --device_id=$i --distribute=True --ckpt_dir=./ckpt --dataset=$DATASET > log 2>&1 &
|
||||
cd ../..
|
||||
i=$((i + 1))
|
||||
done
|
|
@ -0,0 +1,26 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_eval_ascend.sh [checkpoint_path] [device_id]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export CKPT=$1
|
||||
export DEVICE_ID=$2
|
||||
|
||||
python -u ../eval.py --ckpt_dir=$CKPT --device_id=$DEVICE_ID > log 2>&1 &
|
|
@ -0,0 +1,32 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train_ascend.sh [dataset] [device_id]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DATASET=$1
|
||||
export DEVICE_ID=$2
|
||||
|
||||
rm -rf ./train
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train
|
||||
python -u ./train.py --dataset=$DATASET --device_id=$DEVICE_ID > log 2>&1 &
|
||||
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""cell define"""
|
||||
from mindspore import nn
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.ops.functional as F
|
||||
import mindspore.ops.composite as C
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.ops import OnesLike, ZerosLike
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
||||
|
||||
class GenWithLossCell(nn.Cell):
|
||||
"""GenWithLossCell"""
|
||||
def __init__(self, netG, netD, auto_prefix=True):
|
||||
super(GenWithLossCell, self).__init__(auto_prefix=auto_prefix)
|
||||
self.netG = netG
|
||||
self.netD = netD
|
||||
self.loss_fn = nn.BCELoss(reduction="mean")
|
||||
|
||||
|
||||
def construct(self, latent_code, label):
|
||||
"""cgan construct"""
|
||||
fake_data = self.netG(latent_code, label)
|
||||
|
||||
# loss
|
||||
fake_out = self.netD(fake_data, label)
|
||||
ones = OnesLike()(fake_out)
|
||||
loss_G = self.loss_fn(fake_out, ones)
|
||||
|
||||
return loss_G
|
||||
|
||||
|
||||
class DisWithLossCell(nn.Cell):
|
||||
"""DisWithLossCell"""
|
||||
def __init__(self, netG, netD, auto_prefix=True):
|
||||
super(DisWithLossCell, self).__init__(auto_prefix=auto_prefix)
|
||||
self.netG = netG
|
||||
self.netD = netD
|
||||
self.loss_fn = nn.BCELoss(reduction="mean")
|
||||
|
||||
def construct(self, real_data, latent_code, label):
|
||||
"""construct"""
|
||||
# fake_data
|
||||
fake_data = self.netG(latent_code, label)
|
||||
# fake_loss
|
||||
fake_out = self.netD(fake_data, label)
|
||||
zeros = ZerosLike()(fake_out)
|
||||
fake_loss = self.loss_fn(fake_out, zeros)
|
||||
# real loss
|
||||
real_out = self.netD(real_data, label)
|
||||
ones = OnesLike()(real_out)
|
||||
real_loss = self.loss_fn(real_out, ones)
|
||||
|
||||
# d loss
|
||||
loss_D = real_loss + fake_loss
|
||||
|
||||
return loss_D
|
||||
|
||||
|
||||
class TrainOneStepCell(nn.Cell):
|
||||
"""define TrainOneStepCell"""
|
||||
def __init__(self,
|
||||
netG,
|
||||
netD,
|
||||
optimizerG: nn.Optimizer,
|
||||
optimizerD: nn.Optimizer,
|
||||
sens=1.0,
|
||||
auto_prefix=True):
|
||||
|
||||
super(TrainOneStepCell, self).__init__(auto_prefix=auto_prefix)
|
||||
self.netG = netG
|
||||
self.netG.set_grad()
|
||||
self.netG.add_flags(defer_inline=True)
|
||||
|
||||
self.netD = netD
|
||||
self.netD.set_grad()
|
||||
self.netD.add_flags(defer_inline=True)
|
||||
|
||||
self.weights_G = optimizerG.parameters
|
||||
self.optimizerG = optimizerG
|
||||
self.weights_D = optimizerD.parameters
|
||||
self.optimizerD = optimizerD
|
||||
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer_G = F.identity
|
||||
self.grad_reducer_D = F.identity
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
if self.parallel_mode in (ParallelMode.DATA_PARALLEL,
|
||||
ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer_G = DistributedGradReducer(
|
||||
self.weights_G, mean, degree)
|
||||
self.grad_reducer_D = DistributedGradReducer(
|
||||
self.weights_D, mean, degree)
|
||||
|
||||
def trainD(self, real_data, latent_code, label, loss):
|
||||
"""trainD"""
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.netD, self.weights_D)(real_data, latent_code, label, sens)
|
||||
grads = self.grad_reducer_D(grads)
|
||||
return F.depend(loss, self.optimizerD(grads))
|
||||
|
||||
def trainG(self, latent_code, label, loss):
|
||||
"""trainG"""
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.netG, self.weights_G)(latent_code, label, sens)
|
||||
grads = self.grad_reducer_G(grads)
|
||||
return F.depend(loss, self.optimizerG(grads))
|
||||
|
||||
def construct(self, real_data, latent_code, label):
|
||||
"""construct"""
|
||||
loss_D = self.netD(real_data, latent_code, label)
|
||||
loss_G = self.netG(latent_code, label)
|
||||
d_out = self.trainD(real_data, latent_code, label, loss_D)
|
||||
g_out = self.trainG(latent_code, label, loss_G)
|
||||
return d_out, g_out
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ckpt_util"""
|
||||
import os
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net
|
||||
|
||||
|
||||
def save_ckpt(args, G, D, epoch):
|
||||
# should remove old ckpt
|
||||
save_checkpoint(G, os.path.join(args.ckpt_dir, f"G_{epoch}.ckpt"))
|
||||
|
||||
|
||||
def load_ckpt(args, G, D, epoch):
|
||||
if args.ckpt_dir is not None:
|
||||
param_G = load_checkpoint(os.path.join(
|
||||
args.ckpt_dir, f"G_{epoch}.ckpt"))
|
||||
load_param_into_net(G, param_G)
|
||||
if args.ckpt_dir is not None and D is not None:
|
||||
param_D = load_checkpoint(os.path.join(
|
||||
args.ckpt_dir, f"G_{epoch}.ckpt"))
|
||||
load_param_into_net(D, param_D)
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""create dataset"""
|
||||
import os
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.dataset.transforms.c_transforms as CT
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
|
||||
|
||||
def create_dataset(data_path,
|
||||
flatten_size,
|
||||
batch_size,
|
||||
repeat_size=1,
|
||||
num_parallel_workers=1):
|
||||
"""create_dataset"""
|
||||
device_num, rank_id = _get_rank_info()
|
||||
|
||||
if device_num == 1:
|
||||
mnist_ds = ds.MnistDataset(data_path)
|
||||
else:
|
||||
mnist_ds = ds.MnistDataset(data_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
type_cast_op = CT.TypeCast(mstype.float32)
|
||||
onehot_op = CT.OneHot(num_classes=10)
|
||||
|
||||
mnist_ds = mnist_ds.map(input_columns="label",
|
||||
operations=onehot_op,
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(input_columns="label",
|
||||
operations=type_cast_op,
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(input_columns="image",
|
||||
operations=lambda x: ((x - 127.5) / 127.5).astype("float32"),
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(input_columns="image",
|
||||
operations=lambda x: (x.reshape((flatten_size,))),
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
buffer_size = 10000
|
||||
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
|
||||
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
|
||||
mnist_ds = mnist_ds.repeat(repeat_size)
|
||||
|
||||
return mnist_ds
|
||||
|
||||
|
||||
def one_hot(num_classes=10, arr=None):
|
||||
"""onehot process"""
|
||||
if arr is not None:
|
||||
arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
return np.eye(num_classes)[arr]
|
||||
|
||||
def _get_rank_info():
|
||||
"""
|
||||
get rank size and rank id
|
||||
"""
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
|
||||
if rank_size > 1:
|
||||
rank_size = get_group_size()
|
||||
rank_id = get_rank()
|
||||
else:
|
||||
rank_size = 1
|
||||
rank_id = 0
|
||||
|
||||
return rank_size, rank_id
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""model define"""
|
||||
import math
|
||||
from mindspore import nn
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore.common import initializer as init
|
||||
|
||||
|
||||
def init_weights(net, init_type='normal', init_gain=0.02):
|
||||
"""
|
||||
Initialize network weights.
|
||||
|
||||
Parameters:
|
||||
net (Cell): Network to be initialized
|
||||
init_type (str): The name of an initialization method: normal | xavier.
|
||||
init_gain (float): Gain factor for normal and xavier.
|
||||
|
||||
"""
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
|
||||
if init_type == 'normal':
|
||||
cell.weight.set_data(init.initializer(
|
||||
init.Normal(init_gain), cell.weight.shape))
|
||||
elif init_type == 'xavier':
|
||||
cell.weight.set_data(init.initializer(
|
||||
init.XavierUniform(init_gain), cell.weight.shape))
|
||||
elif init_type == 'KaimingUniform':
|
||||
cell.weight.set_data(init.initializer(
|
||||
init.HeUniform(init_gain), cell.weight.shape))
|
||||
elif init_type == 'constant':
|
||||
cell.weight.set_data(
|
||||
init.initializer(0.001, cell.weight.shape))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'initialization method [%s] is not implemented' % init_type)
|
||||
elif isinstance(cell, nn.GroupNorm):
|
||||
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
|
||||
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
|
||||
|
||||
|
||||
class Generator(nn.Cell):
|
||||
"""Generator"""
|
||||
def __init__(self, input_dim, output_dim=1, input_size=28, class_num=10):
|
||||
super(Generator, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.input_size = input_size
|
||||
self.class_num = class_num
|
||||
self.concat = P.Concat(1)
|
||||
|
||||
self.fc = nn.SequentialCell(
|
||||
nn.Dense(self.input_dim + self.class_num, 1024),
|
||||
nn.BatchNorm1d(1024),
|
||||
nn.ReLU(),
|
||||
nn.Dense(1024, 128 * (self.input_size // 4)
|
||||
* (self.input_size // 4)),
|
||||
nn.BatchNorm1d(128 * (self.input_size // 4)
|
||||
* (self.input_size // 4)),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.deconv = nn.SequentialCell(
|
||||
nn.Conv2dTranspose(128, 64, 4, 2, padding=0,
|
||||
has_bias=True, pad_mode='same'),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.Conv2dTranspose(64, self.output_dim, 4, 2,
|
||||
padding=0, has_bias=True, pad_mode='same'),
|
||||
nn.Tanh(),
|
||||
)
|
||||
init_weights(self.deconv, 'KaimingUniform', math.sqrt(5))
|
||||
|
||||
def construct(self, input_param, label):
|
||||
"""construct"""
|
||||
x = self.concat((input_param, label))
|
||||
x = self.fc(x)
|
||||
x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
|
||||
x = self.deconv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Discriminator(nn.Cell):
|
||||
"""Discriminator"""
|
||||
def __init__(self, batch_size, input_dim=1, output_dim=1, input_size=28, class_num=10):
|
||||
super(Discriminator, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.input_size = input_size
|
||||
self.class_num = class_num
|
||||
self.concat = P.Concat(1)
|
||||
self.ExpandDims = P.ExpandDims()
|
||||
self.expand = P.BroadcastTo
|
||||
|
||||
self.conv = nn.SequentialCell(
|
||||
nn.Conv2d(self.input_dim + self.class_num, 64, 4, 2,
|
||||
padding=0, has_bias=True, pad_mode='same'),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Conv2d(64, 128, 4, 2, padding=0,
|
||||
has_bias=True, pad_mode='same'),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.LeakyReLU(0.2),
|
||||
)
|
||||
self.fc = nn.SequentialCell(
|
||||
nn.Dense(128 * (self.input_size // 4) *
|
||||
(self.input_size // 4), 1024),
|
||||
nn.BatchNorm1d(1024),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Dense(1024, self.output_dim),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
init_weights(self.conv, 'KaimingUniform', math.sqrt(5))
|
||||
|
||||
def construct(self, input_param, label):
|
||||
"""construct"""
|
||||
# expand_fill
|
||||
label_fill = self.ExpandDims(label, 2)
|
||||
label_fill = self.ExpandDims(label_fill, 3)
|
||||
shape = (self.batch_size, 10, self.input_size, self.input_size)
|
||||
label_fill = self.expand(shape)(label_fill)
|
||||
|
||||
# forward
|
||||
x = self.concat((input_param, label_fill))
|
||||
x = self.conv(x)
|
||||
x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train"""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.communication.management import init, get_group_size
|
||||
import mindspore.ops as ops
|
||||
from src.dataset import create_dataset
|
||||
from src.ckpt_util import save_ckpt
|
||||
from src.model import Generator, Discriminator
|
||||
from src.cell import GenWithLossCell, DisWithLossCell, TrainOneStepCell
|
||||
|
||||
|
||||
def preLauch():
|
||||
"""parse the console argument"""
|
||||
parser = argparse.ArgumentParser(description='MindSpore cgan training')
|
||||
parser.add_argument("--distribute", type=bool, default=False,
|
||||
help="Run distribute, default is false.")
|
||||
parser.add_argument('--device_id', type=int, default=0,
|
||||
help='device id of Ascend (Default: 0)')
|
||||
parser.add_argument('--ckpt_dir', type=str,
|
||||
default='ckpt', help='checkpoint dir of CGAN')
|
||||
parser.add_argument('--dataset', type=str, default='data/MNIST_Data/train',
|
||||
help='dataset dir (default data/MNISt_Data/train)')
|
||||
args = parser.parse_args()
|
||||
|
||||
# if not exists 'imgs4', 'gif' or 'ckpt_dir', make it
|
||||
if not os.path.exists(args.ckpt_dir):
|
||||
os.mkdir(args.ckpt_dir)
|
||||
# deal with the distribute analyze problem
|
||||
if args.distribute:
|
||||
device_id = args.device_id
|
||||
context.set_context(save_graphs=False,
|
||||
device_id=device_id,
|
||||
device_target="Ascend",
|
||||
mode=context.GRAPH_MODE)
|
||||
init()
|
||||
args.device_num = get_group_size()
|
||||
context.set_auto_parallel_context(gradients_mean=True,
|
||||
device_num=args.device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL)
|
||||
else:
|
||||
device_id = args.device_id
|
||||
args.device_num = 1
|
||||
context.set_context(save_graphs=False,
|
||||
mode=context.GRAPH_MODE,
|
||||
device_target="Ascend")
|
||||
context.set_context(device_id=device_id)
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
# before training, we should set some arguments
|
||||
args = preLauch()
|
||||
|
||||
# training argument
|
||||
batch_size = 128
|
||||
input_dim = 100
|
||||
epoch_start = 0
|
||||
epoch_end = 51
|
||||
lr = 0.001
|
||||
|
||||
dataset = create_dataset(args.dataset,
|
||||
flatten_size=28 * 28,
|
||||
batch_size=batch_size,
|
||||
num_parallel_workers=args.device_num)
|
||||
|
||||
# create G Cell & D Cell
|
||||
netG = Generator(input_dim)
|
||||
netD = Discriminator(batch_size)
|
||||
# create WithLossCell
|
||||
netG_with_loss = GenWithLossCell(netG, netD)
|
||||
netD_with_loss = DisWithLossCell(netG, netD)
|
||||
# create optimizer cell
|
||||
optimizerG = nn.Adam(netG.trainable_params(), lr)
|
||||
optimizerD = nn.Adam(netD.trainable_params(), lr)
|
||||
|
||||
net_train = TrainOneStepCell(netG_with_loss,
|
||||
netD_with_loss,
|
||||
optimizerG,
|
||||
optimizerD)
|
||||
|
||||
netG.set_train()
|
||||
netD.set_train()
|
||||
|
||||
# latent_code_eval = Tensor(np.random.randn(
|
||||
# 200, input_dim), dtype=mstype.float32)
|
||||
|
||||
# label_eval = np.zeros((200, 10))
|
||||
# for i in range(200):
|
||||
# j = i // 20
|
||||
# label_eval[i][j] = 1
|
||||
# label_eval = Tensor(label_eval, dtype=mstype.float32)
|
||||
|
||||
data_size = dataset.get_dataset_size()
|
||||
print("data-size", data_size)
|
||||
print("=========== start training ===========")
|
||||
for epoch in range(epoch_start, epoch_end):
|
||||
step = 0
|
||||
start = time.time()
|
||||
for data in dataset:
|
||||
img = data[0]
|
||||
label = data[1]
|
||||
img = ops.Reshape()(img, (batch_size, 1, 28, 28))
|
||||
latent_code = Tensor(np.random.randn(
|
||||
batch_size, input_dim), dtype=mstype.float32)
|
||||
dout, gout = net_train(img, latent_code, label)
|
||||
step += 1
|
||||
|
||||
if step % data_size == 0:
|
||||
end = time.time()
|
||||
pref = (end-start)*1000 / data_size
|
||||
print("epoch {}, {:.3f} ms per step, d_loss is {:.4f}, g_loss is {:.4f}".format(epoch,
|
||||
pref, dout.asnumpy(),
|
||||
gout.asnumpy()))
|
||||
|
||||
save_ckpt(args, netG, netD, epoch)
|
||||
print("===========training success================")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,25 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export MODE='test'
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python eval.py --run_distribute=0 --device_num=$DEVICE_NUM --device_id=$DEVICE_ID --mode=$MODE> log_eval.txt 2>&1 &
|
||||
|
||||
cd ..
|
|
@ -0,0 +1,50 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [DISTRIBUTE] [RANK_TABLE_FILE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
export RANK_SIZE=$1
|
||||
DISTRIBUTE=$2
|
||||
export RANK_TABLE_FILE=$3
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cp ./*.json ./LOG$i
|
||||
cp ./*.py ./LOG$i
|
||||
cp -r ./src ./LOG$i
|
||||
cp -r ./scripts ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
export RANK_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
python train.py \
|
||||
--run_distribute=$DISTRIBUTE \
|
||||
--device_num=$RANK_SIZE \
|
||||
--device_id=$DEVICE_ID > log.txt 2>&1 &
|
||||
fi
|
||||
cd ../
|
||||
done
|
Loading…
Reference in New Issue