!18615 Add a new model - Conditional GAN

Merge pull request !18615 from Ucalan/ucalan/mindspore
This commit is contained in:
i-robot 2021-07-10 06:56:57 +00:00 committed by Gitee
commit 0d2eede453
14 changed files with 1052 additions and 0 deletions

View File

@ -0,0 +1,181 @@
# Contents
- [CGAN Description](#cgan-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Script Parameters](#training-script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Distributed Training](#distributed-training)
- [Training Result](#training-result)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Evaluation result](#evaluation-result)
- [Model Export](#model-export)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [CGAN Description](#contents)
Generative Adversarial Nets were recently introduced as a novel way to train generative models. In this work we introduce the conditional version of generative adversarial nets, which can be constructed by simply feeding the data, y, we wish to condition on to both the generator and discriminator. We show that this model can generate MNIST digits conditioned on class labels. We also illustrate how this model could be used to learn a multi-modal model, and provide preliminary examples of an application to image tagging in which we demonstrate how this approach can generate descriptive tags which are not part of training labels.
[Paper](https://arxiv.org/pdf/1411.1784.pdf): Conditional Generative Adversarial Nets.
# [Model Architecture](#contents)
Architecture guidelines for Conditional GANs
- Replace any pooling layers with strided convolutions (discriminator) and fractional-strided convolutions (generator).
- Use batchnorm in both the generator and the discriminator.
- Remove fully connected hidden layers for deeper architectures.
- Use ReLU activation in generator for all layers except for the output, which uses Tanh.
- Use LeakyReLU activation in the discriminator for all layers.
# [Dataset](#contents)
Train CGAN Dataset used: [MNIST](<http://yann.lecun.com/exdb/mnist/>)
- Dataset size52.4M60,000 28*28 in 10 classes
- Train60,000 images
- Test10,000 images
- Data formatbinary files
- NoteData will be processed in dataset.py
```text
└─data
└─MNIST_Data
└─train
```
# [Environment Requirements](#contents)
- Hardware Ascend
- Prepare hardware environment with Ascend processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```text
.
└─CGAN
├─README.md # README
├─requirements.txt # required modules
├─scripts # shell script
├─run_standalone_train.sh # training in standalone mode(1pcs)
├─run_distributed_train_ascend.sh # training in parallel mode(8 pcs)
└─run_eval_ascend.sh # evaluation
├─ src
├─dataset.py # dataset create
├─cell.py # network definition
├─ckpt_util.py # utility of checkpoint
├─model.py # discriminator & generator structure
├─ train.py # train cgan
├─ eval.py # eval cgan
├─ export.py # export mindir
```
## [Script Parameters](#contents)
### [Training Script Parameters](#contents)
```shell
# distributed training
bash run_distributed_train_ascend.sh /path/to/MNIST_Data/train /path/to/hccl_8p_01234567_127.0.0.1.json 8
# standalone training
bash run_standalone_train.sh /path/MNIST_Data/train 0
# evaluating
bash run_eval_ascend.sh /path/to/script/train_parallel/0/ckpt/G_50.ckpt 0
```
## [Training Process](#contents)
### [Training](#content)
- Run `run_standalone_train_ascend.sh` for non-distributed training of CGAN model.
```bash
# standalone training
bash run_standalone_train_ascend.sh /path/MNIST_Data/train 0
```
### [Distributed Training](#content)
- Run `run_distributed_train_ascend.sh` for distributed training of CGAN model.
```bash
bash run_distributed_train_ascend.sh /path/to/MNIST_Data/train /path/to/hccl_8p_01234567_127.0.0.1.json 8
```
- Notes
1. hccl.json which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it by using the [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
### [Training Result](#content)
Training result will be stored in `img_eval`.
## [Evaluation Process](#contents)
### [Evaluation](#content)
- Run `run_eval_ascend.sh` for evaluation.
```bash
# eval
bash run_eval_ascend.sh /path/to/script/train_parallel/0/ckpt/G_50.ckpt 0
```
### [Evaluation result](#content)
Evaluation result will be stored in the img_eval path. Under this, you can find generator result in result.png.
## Model Export
```bash
python export.py --ckpt_dir /path/to/train/ckpt/G_50.ckpt
```
# Model Description
## Performance
### Evaluation Performance
| Parameters | Ascend |
| -------------------------- | ------------------------------------------------------------------------------------------- |
| Model Version | V1 |
| Resource | CentOs 8.2; Ascend 910; CPU 2.60GHz, 192cores; Memory 755G |
| uploaded Date | 07/04/2021 (month/day/year) |
| MindSpore Version | 1.2.0 |
| Dataset | MNIST Dataset |
| Training Parameters | epoch=50, batch_size = 128 |
| Optimizer | Adam |
| Loss Function | BCELoss |
| Output | predict class |
| Loss | g_loss: 4.9693 d_loss: 0.1540 |
| Total time | 7.5 mins(8p) |
| Checkpoint for Fine tuning | 26.2M(.ckpt file) |
| Scripts | [cgan script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/CGAN) |
# [Description of Random Situation](#contents)
We use random seed in train.py and cell.py for weight initialization.
# [Model_Zoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,92 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval cgan"""
import os
import itertools
import argparse
import numpy as np
import matplotlib.pyplot as plt
from mindspore import Tensor
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.model import Generator
def preLauch():
"""parse the console argument"""
parser = argparse.ArgumentParser(description='MindSpore cgan training')
parser.add_argument('--device_id', type=int, default=0,
help='device id of Ascend (Default: 0)')
parser.add_argument('--ckpt_dir', type=str,
default='ckpt', help='checkpoint dir of CGAN')
parser.add_argument('--img_out', type=str,
default='img_eval', help='the dir of output img')
args = parser.parse_args()
context.set_context(device_id=args.device_id,
mode=context.GRAPH_MODE,
device_target="Ascend")
# if not exists 'img_out', make it
if not os.path.exists(args.img_out):
os.mkdir(args.img_out)
return args
def main():
# before training, we should set some arguments
args = preLauch()
# training argument
input_dim = 100
# create G Cell & D Cell
netG = Generator(input_dim)
latent_code_eval = Tensor(np.random.randn(200, input_dim), dtype=mstype.float32)
label_eval = np.zeros((200, 10))
for i in range(200):
j = i // 20
label_eval[i][j] = 1
label_eval = Tensor(label_eval, dtype=mstype.float32)
fig, ax = plt.subplots(10, 20, figsize=(10, 5))
for digit, num in itertools.product(range(10), range(20)):
ax[digit, num].get_xaxis().set_visible(False)
ax[digit, num].get_yaxis().set_visible(False)
param_G = load_checkpoint(args.ckpt_dir)
load_param_into_net(netG, param_G)
gen_imgs_eval = netG(latent_code_eval, label_eval)
for i in range(200):
if (i + 1) % 20 == 0:
print("process ========= {}/200".format(i+1))
digit = i // 20
num = i % 20
img = gen_imgs_eval[i].asnumpy().reshape((28, 28))
ax[digit, num].cla()
ax[digit, num].imshow(img * 127.5 + 127.5, cmap="gray")
label = 'eval result'
fig.text(0.5, 0.01, label, ha='center')
fig.tight_layout()
plt.subplots_adjust(wspace=0.01, hspace=0.01)
print("===========saving image===========")
plt.savefig("./img_eval/result.png")
print("===========success================")
if __name__ == '__main__':
main()

View File

@ -0,0 +1,61 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export"""
import argparse
import numpy as np
from mindspore import Tensor
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.model import Generator
def preLauch():
"""parse the console argument"""
parser = argparse.ArgumentParser(description='MindSpore cgan training')
parser.add_argument('--device_id', type=int, default=0,
help='device id of Ascend (Default: 0)')
parser.add_argument('--ckpt_dir', type=str,
default='ckpt', help='checkpoint dir of CGAN')
args = parser.parse_args()
context.set_context(device_id=args.device_id, mode=context.GRAPH_MODE, device_target="Ascend")
return args
def main():
# before training, we should set some arguments
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
args = preLauch()
# training argument
input_dim = 100
# create G Cell & D Cell
netG = Generator(input_dim)
latent_code_eval = Tensor(np.random.randn(200, input_dim), dtype=mstype.float32)
label_eval = np.zeros((200, 10))
for i in range(200):
j = i // 20
label_eval[i][j] = 1
label_eval = Tensor(label_eval, dtype=mstype.float32)
param_G = load_checkpoint(args.ckpt_dir)
load_param_into_net(netG, param_G)
netG.set_train(False)
export(netG, latent_code_eval, label_eval, file_name="CGAN", file_format="MINDIR")
print("CGAN exported")
if __name__ == '__main__':
main()

Binary file not shown.

View File

@ -0,0 +1,56 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_standalone_train_ascend.sh [dataset] [rank_table] [device_num]"
exit 1
fi
export DATASET=$1
export RANK_TABLE_FILE=$2
export DEVICE_NUM=$3
export RANK_SIZE=$DEVICE_NUM
export HCCL_CONNECT_TIMEOUT=600
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
# remove old train_parallel files
rm -rf ./train_parallel
mkdir ./train_parallel
echo "device count=$DEVICE_NUM"
i=0
while [ $i -lt ${DEVICE_NUM} ]; do
export DEVICE_ID=${i}
export RANK_ID=$((rank_start + i))
# mkdirs
mkdir ./train_parallel/$i
mkdir ./train_parallel/$i/src
# move files
cp ../*.py ./train_parallel/$i
cp ../src/*.py ./train_parallel/$i/src
# goto the training dirs of each training
cd ./train_parallel/$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
# input logs to env.log
env > env.log
python -u train.py --device_id=$i --distribute=True --ckpt_dir=./ckpt --dataset=$DATASET > log 2>&1 &
cd ../..
i=$((i + 1))
done

View File

@ -0,0 +1,26 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_eval_ascend.sh [checkpoint_path] [device_id]"
exit 1
fi
export CKPT=$1
export DEVICE_ID=$2
python -u ../eval.py --ckpt_dir=$CKPT --device_id=$DEVICE_ID > log 2>&1 &

View File

@ -0,0 +1,32 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_standalone_train_ascend.sh [dataset] [device_id]"
exit 1
fi
export DATASET=$1
export DEVICE_ID=$2
rm -rf ./train
mkdir ./train
cp ../*.py ./train
cp -r ../src ./train
cd ./train
python -u ./train.py --dataset=$DATASET --device_id=$DEVICE_ID > log 2>&1 &

View File

@ -0,0 +1,136 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""cell define"""
from mindspore import nn
import mindspore.ops.operations as P
import mindspore.ops.functional as F
import mindspore.ops.composite as C
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from mindspore.context import ParallelMode
from mindspore.ops import OnesLike, ZerosLike
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
class GenWithLossCell(nn.Cell):
"""GenWithLossCell"""
def __init__(self, netG, netD, auto_prefix=True):
super(GenWithLossCell, self).__init__(auto_prefix=auto_prefix)
self.netG = netG
self.netD = netD
self.loss_fn = nn.BCELoss(reduction="mean")
def construct(self, latent_code, label):
"""cgan construct"""
fake_data = self.netG(latent_code, label)
# loss
fake_out = self.netD(fake_data, label)
ones = OnesLike()(fake_out)
loss_G = self.loss_fn(fake_out, ones)
return loss_G
class DisWithLossCell(nn.Cell):
"""DisWithLossCell"""
def __init__(self, netG, netD, auto_prefix=True):
super(DisWithLossCell, self).__init__(auto_prefix=auto_prefix)
self.netG = netG
self.netD = netD
self.loss_fn = nn.BCELoss(reduction="mean")
def construct(self, real_data, latent_code, label):
"""construct"""
# fake_data
fake_data = self.netG(latent_code, label)
# fake_loss
fake_out = self.netD(fake_data, label)
zeros = ZerosLike()(fake_out)
fake_loss = self.loss_fn(fake_out, zeros)
# real loss
real_out = self.netD(real_data, label)
ones = OnesLike()(real_out)
real_loss = self.loss_fn(real_out, ones)
# d loss
loss_D = real_loss + fake_loss
return loss_D
class TrainOneStepCell(nn.Cell):
"""define TrainOneStepCell"""
def __init__(self,
netG,
netD,
optimizerG: nn.Optimizer,
optimizerD: nn.Optimizer,
sens=1.0,
auto_prefix=True):
super(TrainOneStepCell, self).__init__(auto_prefix=auto_prefix)
self.netG = netG
self.netG.set_grad()
self.netG.add_flags(defer_inline=True)
self.netD = netD
self.netD.set_grad()
self.netD.add_flags(defer_inline=True)
self.weights_G = optimizerG.parameters
self.optimizerG = optimizerG
self.weights_D = optimizerD.parameters
self.optimizerD = optimizerD
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer_G = F.identity
self.grad_reducer_D = F.identity
self.parallel_mode = _get_parallel_mode()
if self.parallel_mode in (ParallelMode.DATA_PARALLEL,
ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer_G = DistributedGradReducer(
self.weights_G, mean, degree)
self.grad_reducer_D = DistributedGradReducer(
self.weights_D, mean, degree)
def trainD(self, real_data, latent_code, label, loss):
"""trainD"""
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.netD, self.weights_D)(real_data, latent_code, label, sens)
grads = self.grad_reducer_D(grads)
return F.depend(loss, self.optimizerD(grads))
def trainG(self, latent_code, label, loss):
"""trainG"""
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.netG, self.weights_G)(latent_code, label, sens)
grads = self.grad_reducer_G(grads)
return F.depend(loss, self.optimizerG(grads))
def construct(self, real_data, latent_code, label):
"""construct"""
loss_D = self.netD(real_data, latent_code, label)
loss_G = self.netG(latent_code, label)
d_out = self.trainD(real_data, latent_code, label, loss_D)
g_out = self.trainG(latent_code, label, loss_G)
return d_out, g_out

View File

@ -0,0 +1,33 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ckpt_util"""
import os
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net
def save_ckpt(args, G, D, epoch):
# should remove old ckpt
save_checkpoint(G, os.path.join(args.ckpt_dir, f"G_{epoch}.ckpt"))
def load_ckpt(args, G, D, epoch):
if args.ckpt_dir is not None:
param_G = load_checkpoint(os.path.join(
args.ckpt_dir, f"G_{epoch}.ckpt"))
load_param_into_net(G, param_G)
if args.ckpt_dir is not None and D is not None:
param_D = load_checkpoint(os.path.join(
args.ckpt_dir, f"G_{epoch}.ckpt"))
load_param_into_net(D, param_D)

View File

@ -0,0 +1,79 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""create dataset"""
import os
import numpy as np
import mindspore.dataset as ds
from mindspore.common import dtype as mstype
import mindspore.dataset.transforms.c_transforms as CT
from mindspore.communication.management import get_rank, get_group_size
def create_dataset(data_path,
flatten_size,
batch_size,
repeat_size=1,
num_parallel_workers=1):
"""create_dataset"""
device_num, rank_id = _get_rank_info()
if device_num == 1:
mnist_ds = ds.MnistDataset(data_path)
else:
mnist_ds = ds.MnistDataset(data_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
type_cast_op = CT.TypeCast(mstype.float32)
onehot_op = CT.OneHot(num_classes=10)
mnist_ds = mnist_ds.map(input_columns="label",
operations=onehot_op,
num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="label",
operations=type_cast_op,
num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image",
operations=lambda x: ((x - 127.5) / 127.5).astype("float32"),
num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image",
operations=lambda x: (x.reshape((flatten_size,))),
num_parallel_workers=num_parallel_workers)
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds
def one_hot(num_classes=10, arr=None):
"""onehot process"""
if arr is not None:
arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
return np.eye(num_classes)[arr]
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = 1
rank_id = 0
return rank_size, rank_id

View File

@ -0,0 +1,141 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""model define"""
import math
from mindspore import nn
import mindspore.ops.operations as P
from mindspore.common import initializer as init
def init_weights(net, init_type='normal', init_gain=0.02):
"""
Initialize network weights.
Parameters:
net (Cell): Network to be initialized
init_type (str): The name of an initialization method: normal | xavier.
init_gain (float): Gain factor for normal and xavier.
"""
for _, cell in net.cells_and_names():
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
if init_type == 'normal':
cell.weight.set_data(init.initializer(
init.Normal(init_gain), cell.weight.shape))
elif init_type == 'xavier':
cell.weight.set_data(init.initializer(
init.XavierUniform(init_gain), cell.weight.shape))
elif init_type == 'KaimingUniform':
cell.weight.set_data(init.initializer(
init.HeUniform(init_gain), cell.weight.shape))
elif init_type == 'constant':
cell.weight.set_data(
init.initializer(0.001, cell.weight.shape))
else:
raise NotImplementedError(
'initialization method [%s] is not implemented' % init_type)
elif isinstance(cell, nn.GroupNorm):
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
class Generator(nn.Cell):
"""Generator"""
def __init__(self, input_dim, output_dim=1, input_size=28, class_num=10):
super(Generator, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.input_size = input_size
self.class_num = class_num
self.concat = P.Concat(1)
self.fc = nn.SequentialCell(
nn.Dense(self.input_dim + self.class_num, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Dense(1024, 128 * (self.input_size // 4)
* (self.input_size // 4)),
nn.BatchNorm1d(128 * (self.input_size // 4)
* (self.input_size // 4)),
nn.ReLU(),
)
self.deconv = nn.SequentialCell(
nn.Conv2dTranspose(128, 64, 4, 2, padding=0,
has_bias=True, pad_mode='same'),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2dTranspose(64, self.output_dim, 4, 2,
padding=0, has_bias=True, pad_mode='same'),
nn.Tanh(),
)
init_weights(self.deconv, 'KaimingUniform', math.sqrt(5))
def construct(self, input_param, label):
"""construct"""
x = self.concat((input_param, label))
x = self.fc(x)
x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
x = self.deconv(x)
return x
class Discriminator(nn.Cell):
"""Discriminator"""
def __init__(self, batch_size, input_dim=1, output_dim=1, input_size=28, class_num=10):
super(Discriminator, self).__init__()
self.batch_size = batch_size
self.input_dim = input_dim
self.output_dim = output_dim
self.input_size = input_size
self.class_num = class_num
self.concat = P.Concat(1)
self.ExpandDims = P.ExpandDims()
self.expand = P.BroadcastTo
self.conv = nn.SequentialCell(
nn.Conv2d(self.input_dim + self.class_num, 64, 4, 2,
padding=0, has_bias=True, pad_mode='same'),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 4, 2, padding=0,
has_bias=True, pad_mode='same'),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
)
self.fc = nn.SequentialCell(
nn.Dense(128 * (self.input_size // 4) *
(self.input_size // 4), 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2),
nn.Dense(1024, self.output_dim),
nn.Sigmoid(),
)
init_weights(self.conv, 'KaimingUniform', math.sqrt(5))
def construct(self, input_param, label):
"""construct"""
# expand_fill
label_fill = self.ExpandDims(label, 2)
label_fill = self.ExpandDims(label_fill, 3)
shape = (self.batch_size, 10, self.input_size, self.input_size)
label_fill = self.expand(shape)(label_fill)
# forward
x = self.concat((input_param, label_fill))
x = self.conv(x)
x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
x = self.fc(x)
return x

View File

@ -0,0 +1,140 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train"""
import os
import time
import argparse
import numpy as np
from mindspore import nn
from mindspore import Tensor
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.common import dtype as mstype
from mindspore.communication.management import init, get_group_size
import mindspore.ops as ops
from src.dataset import create_dataset
from src.ckpt_util import save_ckpt
from src.model import Generator, Discriminator
from src.cell import GenWithLossCell, DisWithLossCell, TrainOneStepCell
def preLauch():
"""parse the console argument"""
parser = argparse.ArgumentParser(description='MindSpore cgan training')
parser.add_argument("--distribute", type=bool, default=False,
help="Run distribute, default is false.")
parser.add_argument('--device_id', type=int, default=0,
help='device id of Ascend (Default: 0)')
parser.add_argument('--ckpt_dir', type=str,
default='ckpt', help='checkpoint dir of CGAN')
parser.add_argument('--dataset', type=str, default='data/MNIST_Data/train',
help='dataset dir (default data/MNISt_Data/train)')
args = parser.parse_args()
# if not exists 'imgs4', 'gif' or 'ckpt_dir', make it
if not os.path.exists(args.ckpt_dir):
os.mkdir(args.ckpt_dir)
# deal with the distribute analyze problem
if args.distribute:
device_id = args.device_id
context.set_context(save_graphs=False,
device_id=device_id,
device_target="Ascend",
mode=context.GRAPH_MODE)
init()
args.device_num = get_group_size()
context.set_auto_parallel_context(gradients_mean=True,
device_num=args.device_num,
parallel_mode=ParallelMode.DATA_PARALLEL)
else:
device_id = args.device_id
args.device_num = 1
context.set_context(save_graphs=False,
mode=context.GRAPH_MODE,
device_target="Ascend")
context.set_context(device_id=device_id)
return args
def main():
# before training, we should set some arguments
args = preLauch()
# training argument
batch_size = 128
input_dim = 100
epoch_start = 0
epoch_end = 51
lr = 0.001
dataset = create_dataset(args.dataset,
flatten_size=28 * 28,
batch_size=batch_size,
num_parallel_workers=args.device_num)
# create G Cell & D Cell
netG = Generator(input_dim)
netD = Discriminator(batch_size)
# create WithLossCell
netG_with_loss = GenWithLossCell(netG, netD)
netD_with_loss = DisWithLossCell(netG, netD)
# create optimizer cell
optimizerG = nn.Adam(netG.trainable_params(), lr)
optimizerD = nn.Adam(netD.trainable_params(), lr)
net_train = TrainOneStepCell(netG_with_loss,
netD_with_loss,
optimizerG,
optimizerD)
netG.set_train()
netD.set_train()
# latent_code_eval = Tensor(np.random.randn(
# 200, input_dim), dtype=mstype.float32)
# label_eval = np.zeros((200, 10))
# for i in range(200):
# j = i // 20
# label_eval[i][j] = 1
# label_eval = Tensor(label_eval, dtype=mstype.float32)
data_size = dataset.get_dataset_size()
print("data-size", data_size)
print("=========== start training ===========")
for epoch in range(epoch_start, epoch_end):
step = 0
start = time.time()
for data in dataset:
img = data[0]
label = data[1]
img = ops.Reshape()(img, (batch_size, 1, 28, 28))
latent_code = Tensor(np.random.randn(
batch_size, input_dim), dtype=mstype.float32)
dout, gout = net_train(img, latent_code, label)
step += 1
if step % data_size == 0:
end = time.time()
pref = (end-start)*1000 / data_size
print("epoch {}, {:.3f} ms per step, d_loss is {:.4f}, g_loss is {:.4f}".format(epoch,
pref, dout.asnumpy(),
gout.asnumpy()))
save_ckpt(args, netG, netD, epoch)
print("===========training success================")
if __name__ == '__main__':
main()

View File

@ -0,0 +1,25 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_NUM=1
export DEVICE_ID=0
export MODE='test'
echo "start training for device $DEVICE_ID"
env > env.log
python eval.py --run_distribute=0 --device_num=$DEVICE_NUM --device_id=$DEVICE_ID --mode=$MODE> log_eval.txt 2>&1 &
cd ..

View File

@ -0,0 +1,50 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [DISTRIBUTE] [RANK_TABLE_FILE]"
exit 1
fi
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
export RANK_SIZE=$1
DISTRIBUTE=$2
export RANK_TABLE_FILE=$3
for((i=0;i<RANK_SIZE;i++))
do
export DEVICE_ID=$i
rm -rf LOG$i
mkdir ./LOG$i
cp ./*.json ./LOG$i
cp ./*.py ./LOG$i
cp -r ./src ./LOG$i
cp -r ./scripts ./LOG$i
cd ./LOG$i || exit
export RANK_ID=$i
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
if [ $# == 3 ]
then
python train.py \
--run_distribute=$DISTRIBUTE \
--device_num=$RANK_SIZE \
--device_id=$DEVICE_ID > log.txt 2>&1 &
fi
cd ../
done