!17363 STGAN-MindSpore

Merge pull request !17363 from Mingchung-Suen/STGAN
This commit is contained in:
i-robot 2021-07-16 13:24:06 +00:00 committed by Gitee
commit c055bc9b99
21 changed files with 2206 additions and 0 deletions

View File

@ -0,0 +1,203 @@
# Contents
- [Contents](#contents)
- [STGAN Description](#stgan-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
## [STGAN Description](#contents)
STGAN was proposed in CVPR 2019, one of the facial attributes transfer networks using Generative Adversarial Networks (GANs). It introduces a new Selective Transfer Unit (STU) to get better facial attributes transfer than others.
[Paper](https://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_STGAN_A_Unified_Selective_Transfer_Network_for_Arbitrary_Image_Attribute_CVPR_2019_paper.pdf): Liu M, Ding Y, Xia M, et al. STGAN: A Unified Selective Transfer Network for Arbitrary Image
Attribute Editing[C]. IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR).
IEEE, 2019: 3668-3677.
## [Model Architecture](#contents)
STGAN composition consists of Generator, Discriminator and Selective Transfer Unit. Using Selective Transfer Unit can help networks keep more attributes in the long term of training.
## [Dataset](#contents)
In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
- Dataset size1011M202,599 128*128 colorful images, marked as 40 attributes
- Train182,599 images
- Test18,800 images
- Data formatbinary files
- NoteData will be processed in celeba.py
- Download the dataset, the directory structure is as follows:
```bash
├── dataroot
├── anno
├── list_attr_celeba.txt
├── image
├── 000001.jpg
├── ...
```
## [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
## [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
```python
# enter script dir, train STGAN
sh scripts/run_standalone_train.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID]
# distributed training
sh scripts/run_distribute_train.sh [RANK_TABLE_FILE] [EXPERIMENT_NAME] [DATA_PATH]
# enter script dir, evaluate STGAN
sh scripts/run_eval.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID] [CHECKPOINT_PATH]
```
## [Script Description](#contents)
### [Script and Sample Code](#contents)
```bash
├── cv
├── STGAN
├── README.md // descriptions about STGAN
├── requirements.txt // package needed
├── scripts
│ ├──run_standalone_train.sh // train in ascend
│ ├──run_eval.sh // evaluate in ascend
│ ├──run_distribute_train.sh // distributed train in ascend
├── src
├── dataset
├── datasets.py // creating dataset
├── celeba.py // processing celeba dataset
├── distributed_sampler.py // distributed sampler
├── models
├── base_model.py
├── losses.py // loss models
├── networks.py // basic models of STGAN
├── stgan.py // executing procedure
├── utils
├── args.py // argument parser
├── tools.py // simple tools
├── train.py // training script
├── eval.py // evaluation script
├── export.py // model-export script
```
### [Script Parameters](#contents)
```python
Major parameters in train.py and utils/args.py as follows:
--dataroot: The relative path from the current path to the train and evaluation datasets.
--n_epochs: Total training epochs.
--batch_size: Training batch size.
--image_size: Image size used as input to the model.
--device_target: Device where the code will be implemented. Optional value is "Ascend".
```
### [Training Process](#contents)
#### Training
- running on Ascend
```bash
python train.py --dataroot ./dataset --experiment_name 128 > log 2>&1 &
# or enter script dir, and run the script
sh scripts/run_standalone_train.sh ./dataset 128 0
# distributed training
sh scripts/run_distribute_train.sh ./config/rank_table_8pcs.json 128 /data/dataset
```
After training, the loss value will be achieved as follows:
```bash
# grep "loss is " log
epoch: 1 step: 1, loss is 2.2791853
...
epoch: 1 step: 1536, loss is 1.9366643
epoch: 1 step: 1537, loss is 1.6983616
epoch: 1 step: 1538, loss is 1.0221305
...
```
The model checkpoint will be saved in the output directory.
### [Evaluation Process](#contents)
#### Evaluation
Before running the command below, please check the checkpoint path used for evaluation.
- running on Ascend
```bash
python eval.py --dataroot ./dataset --experiment_name 128 > eval_log.txt 2>&1 &
# or enter script dir, and run the script
sh scripts/run_eval.sh ./dataset 128 0 ./ckpt/generator.ckpt
```
You can view the results in the output directory, which contains a batch of result sample images.
### Model Export
```shell
python export.py --ckpt_path [CHECKPOINT_PATH] --platform [PLATFORM] --file_format[EXPORT_FORMAT]
```
`EXPORT_FORMAT` should be "MINDIR"
## Model Description
### Performance
#### Evaluation Performance
| Parameters | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | V1 |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
| uploaded Date | 05/07/2021 (month/day/year) |
| MindSpore Version | 1.2.0 |
| Dataset | CelebA |
| Training Parameters | epoch=100, batch_size = 128 |
| Optimizer | Adam |
| Loss Function | Loss |
| Output | predict class |
| Loss | 6.5523 |
| Speed | 1pc: 400 ms/step; 8pcs: 143 ms/step |
| Total time | 1pc: 41:36:07 |
| Checkpoint for Fine tuning | 170.55M(.ckpt file) |
| Scripts | [STGAN script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/STGAN) |
## [Model Description](#contents)
## [Description of Random Situation](#contents)
In dataset.py, we set the seed inside ```create_dataset``` function.
## [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,47 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Model Test """
import tqdm
from mindspore.common import set_seed
from src.models import STGANModel
from src.utils import get_args
from src.dataset import CelebADataLoader
set_seed(1)
def test():
""" test function """
args = get_args("test")
print('\n\n=============== start testing ===============\n\n')
data_loader = CelebADataLoader(args.dataroot,
mode=args.phase,
selected_attrs=args.attrs,
batch_size=1,
image_size=args.image_size)
iter_per_epoch = len(data_loader)
args.dataset_size = iter_per_epoch
model = STGANModel(args)
for _ in tqdm.trange(iter_per_epoch, desc='Test Loop'):
data = next(data_loader.test_loader)
model.test(data, data_loader.test_set.get_current_filename())
print('\n\n=============== finish testing ===============\n\n')
if __name__ == '__main__':
test()

View File

@ -0,0 +1,33 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Model Export """
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import export
from src.models import STGANModel
from src.utils import get_args
if __name__ == '__main__':
args = get_args("test")
context.set_context(mode=context.GRAPH_MODE, device_id=args.device_id)
model = STGANModel(args)
model.netG.set_train(True)
input_shp = [16, 3, 128, 128]
input_shp_2 = [16, 4]
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
input_array_2 = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp_2).astype(np.float32))
G_file = f"{args.file_name}_model"
export(model.netG, input_array, input_array_2, file_name=G_file, file_format=args.file_format)

View File

@ -0,0 +1,2 @@
numpy
tqdm

View File

@ -0,0 +1,54 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [EXPERIMENT_NAME] [DATA_PATH]"
exit 1
fi
if [ ! -f $1 ]
then
echo "error: RANK_TABLE_FILE=$1 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
RANK_TABLE_FILE=$(realpath $1)
export RANK_TABLE_FILE
export EXPERIMENT_NAME=$2
export DATA_PATH=$3
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp -r ./src ./train_parallel$i
cp ./train.py ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
cd ./train_parallel$i ||exit
env > env.log
python train.py --device_num ${DEVICE_NUM} --experiment_name=$EXPERIMENT_NAME --dataroot=$DATA_PATH > log 2>&1 &
cd ..
done

View File

@ -0,0 +1,31 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ]
then
echo "Usage: sh run_eval.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID] [CHECKPOINT_PATH]"
exit 1
fi
export DATA_PATH=$1
export EXPERIMENT_NAME=$2
export DEVICE_ID=$3
export CHECKPOINT_PATH=$4
python eval.py --dataroot=$DATA_PATH --experiment_name=$EXPERIMENT_NAME \
--device_id=$DEVICE_ID --ckpt_path=$CHECKPOINT_PATH \
--platform="Ascend" > eval_log 2>&1 &

View File

@ -0,0 +1,29 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_standalone_train.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID]"
exit 1
fi
export DATA_PATH=$1
export EXPERIMENT_NAME=$2
export DEVICE_ID=$3
python train.py --dataroot=$DATA_PATH --experiment_name=$EXPERIMENT_NAME \
--device_id=$DEVICE_ID --platform="Ascend" > log 2>&1 &

View File

@ -0,0 +1,16 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" dataset """
from .celeba import CelebADataLoader

View File

@ -0,0 +1,157 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" CelebA Dataset """
import os
import multiprocessing
import numpy as np
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as C
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank
from PIL import Image
from .distributed_sampler import DistributedSampler
from .datasets import make_dataset
class CelebADataset:
""" CelebA """
def __init__(self, dir_path, mode, selected_attrs):
self.items = make_dataset(dir_path, mode, selected_attrs)
self.dir_path = dir_path
self.mode = mode
self.filename = ''
def __getitem__(self, index):
filename, label = self.items[index]
image = Image.open(os.path.join(self.dir_path, 'image', filename))
image = np.array(image.convert('RGB'))
label = np.array(label)
if self.mode == 'test':
self.filename = filename
return image, label
def __len__(self):
return len(self.items)
def get_current_filename(self):
return self.filename
class CelebADataLoader:
""" CelebADataLoader """
def __init__(self,
root,
mode,
selected_attrs,
crop_size=None,
image_size=128,
batch_size=64,
device_num=1):
if mode not in ['train', 'test', 'val']:
return
mean = [0.5 * 255] * 3
std = [0.5 * 255] * 3
parallel_mode = context.get_auto_parallel_context("parallel_mode")
rank = 0
if parallel_mode in [
ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL,
ParallelMode.AUTO_PARALLEL
]:
rank = get_rank()
shuffle = True
cores = multiprocessing.cpu_count()
num_parallel_workers = min(16, int(cores / device_num))
if mode == 'train':
dataset = CelebADataset(root, mode, selected_attrs)
distributed_sampler = DistributedSampler(len(dataset),
device_num,
rank,
shuffle=shuffle)
self.dataset_size = int(len(distributed_sampler) / batch_size)
val_set = CelebADataset(root, 'val', selected_attrs)
self.val_dataset_size = len(val_set)
transform = [
C.Resize((image_size, image_size)),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
if crop_size is not None:
transform.append(C.CenterCrop(crop_size))
val_distributed_sampler = DistributedSampler(len(val_set),
device_num,
rank,
shuffle=shuffle)
val_dataset = de.GeneratorDataset(val_set,
column_names=["image", "label"],
sampler=val_distributed_sampler,
num_parallel_workers=min(
32, num_parallel_workers))
val_dataset = val_dataset.map(operations=transform,
input_columns=["image"],
num_parallel_workers=min(
32, num_parallel_workers))
transform.insert(0, C.RandomHorizontalFlip())
train_dataset = de.GeneratorDataset(
dataset,
column_names=["image", "label"],
sampler=distributed_sampler,
num_parallel_workers=min(32, num_parallel_workers))
train_dataset = train_dataset.map(operations=transform,
input_columns=["image"],
num_parallel_workers=min(
32, num_parallel_workers))
train_dataset = train_dataset.batch(batch_size,
drop_remainder=True)
train_dataset = train_dataset.repeat(200)
val_dataset = val_dataset.batch(batch_size, drop_remainder=True)
val_dataset = val_dataset.repeat(200)
self.train_loader = train_dataset.create_dict_iterator()
self.val_loader = val_dataset.create_dict_iterator()
else:
dataset = CelebADataset(root, mode, selected_attrs)
self.test_set = dataset
self.dataset_size = int(len(dataset) / batch_size)
distributed_sampler = DistributedSampler(len(dataset),
device_num,
rank,
shuffle=shuffle)
test_transform = [
C.Resize((image_size, image_size)),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
test_dataset = de.GeneratorDataset(dataset,
column_names=["image", "label"],
sampler=distributed_sampler,
num_parallel_workers=min(
1, num_parallel_workers))
test_dataset = test_dataset.map(operations=test_transform,
input_columns=["image"],
num_parallel_workers=min(
32, num_parallel_workers))
test_dataset = test_dataset.batch(batch_size, drop_remainder=True)
test_dataset = test_dataset.repeat(1)
self.test_loader = test_dataset.create_dict_iterator()
def __len__(self):
return self.dataset_size

View File

@ -0,0 +1,64 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""datasets"""
import os
import random
import numpy as np
random.seed(1)
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.tif', '.tiff']
def is_image_file(filename):
"""if a file is a image"""
return any(filename.lower().endswith(extension)
for extension in IMG_EXTENSIONS)
def make_dataset(dir_path, mode, selected_attrs):
""" make dataset """
assert mode in ['train', 'val',
'test'], "Mode [{}] is not supportable".format(mode)
assert os.path.isdir(dir_path), '%s is not a valid directory' % dir_path
lines = [
line.rstrip() for line in open(
os.path.join(dir_path, 'anno', 'list_attr_celeba.txt'), 'r')
]
all_attr_names = lines[1].split()
attr2idx = {}
idx2attr = {}
for i, attr_name in enumerate(all_attr_names):
attr2idx[attr_name] = i
idx2attr[i] = attr_name
lines = lines[2:]
if mode == 'train':
lines = lines[:-20000] # train set contains 182599 images
if mode == 'val':
lines = lines[-20000:-18800] # val set contains 200 images
if mode == 'test':
lines = lines[-18800:] # test set contains 18800 images
items = []
for i, line in enumerate(lines):
split = line.split()
filename = split[0]
values = split[1:]
label = []
for attr_name in selected_attrs:
idx = attr2idx[attr_name]
label.append(np.float32(values[idx] == '1'))
items.append([filename, label])
return items

View File

@ -0,0 +1,70 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset distributed sampler."""
from __future__ import division
import math
import numpy as np
class DistributedSampler:
"""Distributed sampler."""
def __init__(self,
dataset_size,
num_replicas=None,
rank=None,
shuffle=True):
if num_replicas is None:
print(
"***********Setting world_size to 1 since it is not passed in ******************"
)
num_replicas = 1
if rank is None:
print(
"***********Setting rank to 0 since it is not passed in ******************"
)
rank = 0
self.dataset_size = dataset_size
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(
math.ceil(dataset_size * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
indices = np.random.RandomState(seed=self.epoch).permutation(
self.dataset_size)
# np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
indices = indices.tolist()
self.epoch += 1
# change to list type
else:
indices = list(range(self.dataset_size))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples

View File

@ -0,0 +1,16 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" models"""
from .stgan import STGANModel

View File

@ -0,0 +1,256 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Base Model """
import os
import json
from abc import ABC, abstractmethod
import numpy as np
import mindspore.numpy as numpy
import mindspore.common.dtype as mstype
from mindspore import load_checkpoint, load_param_into_net, save_checkpoint
from mindspore.ops import composite as C
from mindspore import Tensor
from PIL import Image
from src.utils import mkdirs
class BaseModel(ABC):
""" BaseModel """
def __init__(self, args):
self.isTrain = args.isTrain
self.save_dir = os.path.join(args.outputs_dir, args.experiment_name)
self.model_names = []
self.loss_names = []
self.args = args
self.optimizers = []
self.current_iteration = 0
self.netG = None
self.netD = None
# continue train
if self.isTrain:
if self.args.continue_train:
assert (os.path.exists(self.save_dir)
), 'Checkpoint path not found at %s' % self.save_dir
self.current_iteration = self.args.continue_iter
else:
if not os.path.exists(self.save_dir):
mkdirs(self.save_dir)
# save config
self.config_save_path = os.path.join(self.save_dir, 'config')
if not os.path.exists(self.config_save_path):
mkdirs(self.config_save_path)
if self.isTrain:
with open(os.path.join(self.config_save_path, 'train.conf'),
'w') as f:
f.write(json.dumps(vars(self.args)))
if self.current_iteration == -1:
with open(os.path.join(self.config_save_path, 'latest.conf'),
'r') as f:
self.current_iteration = int(f.read())
# sample save path
if self.isTrain:
self.sample_save_path = os.path.join(self.save_dir, 'sample')
if not os.path.exists(self.sample_save_path):
mkdirs(self.sample_save_path)
# test result save path
if self.args.phase == 'test':
self.test_result_save_path = os.path.join(self.save_dir, 'test')
if not os.path.exists(self.test_result_save_path):
mkdirs(self.test_result_save_path)
# train log save path
if self.isTrain:
self.train_log_path = os.path.join(self.save_dir, 'logs')
if not os.path.exists(self.train_log_path):
mkdirs(self.train_log_path)
@abstractmethod
def set_input(self, input_data):
pass
@abstractmethod
def optimize_parameters(self):
pass
@abstractmethod
def test(self):
pass
@abstractmethod
def eval(self):
pass
def load_config(self):
print('loading config from {}\n\n'.format(
os.path.join(self.config_save_path, 'train.conf')))
with open(os.path.join(self.config_save_path, 'train.conf'), 'r') as f:
config = json.loads(f.read())
print('config: ', config)
return config
def save_networks(self):
""" saving networks """
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_%s.ckpt' % (self.current_iteration, name)
save_filename_latest = 'latest_%s.ckpt' % name
save_path = os.path.join(self.save_dir, 'ckpt')
if not os.path.exists(save_path):
os.makedirs(save_path)
save_path_latest = os.path.join(save_path,
save_filename_latest)
save_path = os.path.join(save_path, save_filename)
net = getattr(self, 'net' + name)
print('saving the model to %s' % save_path)
print('saving the model to %s' % save_path_latest)
save_checkpoint(net, save_path)
save_checkpoint(net, save_path_latest)
with open(os.path.join(self.config_save_path, 'latest.conf'),
'w') as f:
f.write("{}".format(self.current_iteration))
def load_networks(self, epoch='latest'):
""" Load model checkpoint file
Parameters:
epoch: epoch saved, default is latest
"""
for name in self.model_names:
if isinstance(name, str):
load_filename = '%s_%s.ckpt' % (epoch, name)
load_path = os.path.join(self.save_dir, 'ckpt', load_filename)
net = getattr(self, 'net' + name)
print('loading the model from %s' % load_path)
if name == 'G':
net.encoder.update_parameters_name(
'network.network.netG.encoder.')
net.decoder.update_parameters_name(
'network.network.netG.decoder.')
net.stu.update_parameters_name('network.network.netG.stu.')
params = load_checkpoint(load_path, net, strict_load=True)
load_param_into_net(net, params, strict_load=True)
if epoch == 'latest':
assert os.path.exists(
os.path.join(self.config_save_path, 'latest.conf')
), 'Missing iteration information of latest checkpoint file.'
with open(os.path.join(self.config_save_path, 'latest.conf'),
'r') as f:
self.current_iteration = f.read()
def load_generator_from_path(self, path=None):
""" Load generator checkpoint file from given path
Parameters:
path: path of checkpoint file, required
"""
assert path is not None, 'Path of checkpoint can not be None'
print('loading the model from %s' % path)
net = getattr(self, 'netG')
net.encoder.update_parameters_name('network.network.netG.encoder.')
net.decoder.update_parameters_name('network.network.netG.decoder.')
net.stu.update_parameters_name('network.network.netG.stu.')
params = load_checkpoint(path, net, strict_load=True)
load_param_into_net(net, params, strict_load=True)
def get_learning_rate(self):
"""Learning rate generator."""
lrs = [self.args.lr] * self.args.dataset_size * self.args.init_epoch
lrs += [self.args.lr * 0.1] * self.args.dataset_size * (
self.args.n_epochs - self.args.init_epoch)
return Tensor(np.array(lrs).astype(
np.float32))[self.current_iteration:]
def save_image(self, img, img_path):
"""Save a numpy image to the disk
Parameters:
img (numpy array / Tensor): image to save.
image_path (str): the path of the image.
"""
if isinstance(img, Tensor):
img = self.decode_image(img)
elif not isinstance(img, np.ndarray):
raise ValueError(
"img should be Tensor or numpy array, but get {}".format(
type(img)))
img_pil = Image.fromarray(img)
img_pil.save(img_path)
def decode_image(self, img):
"""Decode a [1, C, H, W] Tensor to image numpy array."""
mean = 0.5 * 255
std = 0.5 * 255
return (img.asnumpy() * std + mean).astype(np.uint8).transpose(
(1, 2, 0))
def create_labels(self, c_org, selected_attrs=None):
"""Generate target domain labels for debugging and testing."""
hair_color_indices = []
for i, attr_name in enumerate(selected_attrs):
if attr_name in [
'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'
]:
hair_color_indices.append(i)
c_trg_list = []
for i in range(len(selected_attrs)):
c_trg = numpy.copy(c_org)
if i in hair_color_indices:
c_trg[:, i] = Tensor(1, mstype.float32)
for j in hair_color_indices:
if j != i:
c_trg[:, j] = Tensor(0, mstype.float32)
else:
c_trg[:, i] = Tensor((c_trg[:, i] == 0).asnumpy(),
mstype.float32)
c_trg_list.append(c_trg)
return c_trg_list
def create_test_label(self, c_org, selected_attrs=None):
"""Generate target domain labels for debugging and testing."""
hair_color_indices = []
for i, attr_name in enumerate(selected_attrs):
if attr_name in [
'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'
]:
hair_color_indices.append(i)
c_trg = numpy.copy(c_org)
for i in range(len(selected_attrs)):
if i in hair_color_indices:
c_trg[:, i] = 1.
for j in hair_color_indices:
if j != i:
c_trg[:, j] = Tensor(0, mstype.float32)
else:
c_trg[:, i] = Tensor((c_trg[:, i] == 0).asnumpy(),
mstype.float32)
return c_trg
def denorm(self, x):
""" Denormalization """
out = (x + 1) / 2
return C.clip_by_value(out, 0, 1)

View File

@ -0,0 +1,143 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" losses """
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.ops.operations as P
from mindspore import Tensor
from src.utils.tools import is_version_satisfied_1_2_0
class GeneratorLoss(nn.Cell):
""" GeneratorLoss """
def __init__(self, netG, netD, args, mode):
super(GeneratorLoss, self).__init__()
self.args = args
self.platform = args.platform
self.use_stu = args.use_stu
self.mode = mode
self.netG = netG
self.netD = netD
if self.platform == 'Ascend':
self.cls_loss = nn.BCEWithLogitsLoss(reduction='sum')
else:
self.cls_loss = ClassificationLoss()
self.cyc_loss = P.ReduceMean()
self.lambda2 = Tensor(args.lambda2)
self.lambda3 = Tensor(args.lambda3)
self.attr_mode = args.attr_mode
def construct(self, real_x, c_org, c_trg, attr_diff):
""" construct """
fake_x = self.netG(real_x, attr_diff)
rec_x = self.netG(real_x, c_org - c_org)
loss_fake_G, loss_att_fake = self.netD(fake_x)
loss_adv_G = -loss_fake_G.mean()
loss_cls_G = self.cls_loss(loss_att_fake,
c_trg) / loss_att_fake.shape[0]
loss_rec_G = (real_x - rec_x).abs().mean()
loss_G = loss_adv_G + self.lambda2 * loss_cls_G + self.lambda3 * loss_rec_G
return (fake_x, loss_G, loss_fake_G.mean(), loss_cls_G, loss_rec_G,
loss_adv_G)
class DiscriminatorLoss(nn.Cell):
""" DiscriminatorLoss """
def __init__(self, netD, netG, args, mode):
super(DiscriminatorLoss, self).__init__()
self.mode = mode
self.netD = netD
self.netG = netG
self.platform = args.platform
self.gradient_penalty = WGANGPGradientPenalty(netD)
if self.platform == 'Ascend' and is_version_satisfied_1_2_0(
args.ms_version):
self.cls_loss = nn.BCEWithLogitsLoss(reduction='sum')
else:
self.cls_loss = ClassificationLoss()
self.cyc_loss = P.ReduceMean()
self.lambda_gp = Tensor(args.lambda_gp)
self.lambda1 = Tensor(args.lambda1)
self.thres_int = Tensor(args.thres_int, ms.float32)
self.attr_mode = args.attr_mode
def construct(self, real_x, c_org, c_trg, attr_diff, alpha):
""" construct """
loss_real_D, loss_att_real = self.netD(real_x)
loss_real_D = -loss_real_D.mean()
loss_cls_D = self.cls_loss(loss_att_real,
c_org) / loss_att_real.shape[0]
fake_x = self.netG(real_x, attr_diff)
loss_fake_D, _ = self.netD(ops.functional.stop_gradient(fake_x))
loss_fake_D = loss_fake_D.mean()
x_hat = (alpha * real_x + (1 - alpha) * fake_x)
loss_gp_D = self.gradient_penalty(x_hat)
loss_adv_D = loss_real_D + loss_fake_D + self.lambda_gp * loss_gp_D
loss_D = self.lambda1 * loss_cls_D + loss_adv_D
return (loss_D, loss_real_D, loss_fake_D, loss_cls_D, loss_gp_D,
loss_adv_D, attr_diff)
class WGANGPGradientPenalty(nn.Cell):
""" WGANGPGradientPenalty """
def __init__(self, discriminator):
super(WGANGPGradientPenalty, self).__init__()
self.gradient_op = ops.GradOperation()
self.reduce_sum = ops.ReduceSum()
self.reduce_sum_keep_dim = ops.ReduceSum(keep_dims=True)
self.sqrt = ops.Sqrt()
self.discriminator = discriminator
self.gradientWithInput = GradientWithInput(discriminator)
def construct(self, x_hat):
gradient = self.gradient_op(self.gradientWithInput)(x_hat)
gradient_1 = ops.reshape(gradient, (x_hat.shape[0], -1))
gradient_1 = self.sqrt(self.reduce_sum(gradient_1**2, 1))
gradient_penalty = ((gradient_1 - 1.0)**2).mean()
return gradient_penalty
class GradientWithInput(nn.Cell):
def __init__(self, discriminator):
super(GradientWithInput, self).__init__()
self.reduce_sum = ops.ReduceSum()
self.discriminator = discriminator
def construct(self, interpolates):
decision_interpolate, _ = self.discriminator(interpolates)
decision_interpolate = self.reduce_sum(decision_interpolate, 0)
return decision_interpolate
class ClassificationLoss(nn.Cell):
def __init__(self):
super(ClassificationLoss, self).__init__()
self.BCELoss = nn.BCELoss(reduction='sum')
self.sigmoid = nn.Sigmoid()
def construct(self, logit, target):
logit = self.sigmoid(logit)
return self.BCELoss(logit, target)

View File

@ -0,0 +1,455 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" networks """
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.ops.operations as P
import mindspore.ops.functional as F
from mindspore import context
from mindspore.ops import constexpr
from mindspore.context import ParallelMode
from mindspore.common import initializer as init
from mindspore.communication.management import get_group_size
from mindspore.parallel._auto_parallel_context import auto_parallel_context
def init_weights(net, init_type='normal', init_gain=0.02):
"""
Initialize network weights.
Parameters:
net (Cell): Network to be initialized
init_type (str): The name of an initialization method: normal | xavier.
init_gain (float): Gain factor for normal and xavier.
"""
for _, cell in net.cells_and_names():
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
if init_type == 'normal':
cell.weight.set_data(
init.initializer(init.Normal(init_gain),
cell.weight.shape))
elif init_type == 'xavier':
cell.weight.set_data(
init.initializer(init.XavierUniform(init_gain),
cell.weight.shape))
elif init_type == 'KaimingUniform':
cell.weight.set_data(
init.initializer(init.HeUniform(init_gain),
cell.weight.shape))
elif init_type == 'constant':
cell.weight.set_data(init.initializer(0.001,
cell.weight.shape))
else:
raise NotImplementedError(
'initialization method [%s] is not implemented' %
init_type)
elif isinstance(cell, _GroupNorm):
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
class ConvGRUCell(nn.Cell):
""" Convolutional GRU Cell """
def __init__(self, n_attrs, in_dim, out_dim, kernel_size=3, stu_norm='bn'):
super(ConvGRUCell, self).__init__()
self.concat = ops.Concat(axis=1)
self.reshape = ops.Reshape()
self.n_attrs = n_attrs
self.normalization = nn.BatchNorm2d(out_dim)
if stu_norm == 'in':
self.normalization = _GroupNorm(num_groups=out_dim,
num_channels=out_dim)
self.upsample = nn.Conv2dTranspose(in_dim * 2 + n_attrs,
out_dim,
4,
2,
padding=1,
pad_mode='pad')
self.reset_gate = nn.SequentialCell(
nn.Conv2d(in_dim + out_dim,
out_dim,
kernel_size,
1,
padding=((kernel_size - 1) // 2),
pad_mode='pad'), self.normalization, nn.Sigmoid())
self.update_gate = nn.SequentialCell(
nn.Conv2d(in_dim + out_dim,
out_dim,
kernel_size,
1,
padding=((kernel_size - 1) // 2),
pad_mode='pad'), self.normalization, nn.Sigmoid())
self.hidden = nn.SequentialCell(
nn.Conv2d(in_dim + out_dim,
out_dim,
kernel_size,
1,
padding=((kernel_size - 1) // 2),
pad_mode='pad'), self.normalization, nn.Tanh())
def construct(self, input_data, old_state, attr):
""" construct """
n, _, h, w = old_state.shape
attr = self.reshape(attr, (n, self.n_attrs, 1, 1))
tile = ops.Tile()
attr = tile(attr, (1, 1, h, w))
state_hat = self.upsample(self.concat((old_state, attr)))
r = self.reset_gate(self.concat((input_data, state_hat)))
z = self.update_gate(self.concat((input_data, state_hat)))
new_state = r * state_hat
hidden_info = self.hidden(self.concat((input_data, new_state)))
output = (1 - z) * state_hat + z * hidden_info
return output, new_state
class Generator(nn.Cell):
""" Generator """
def __init__(self,
attr_dim,
enc_dim=64,
dec_dim=64,
enc_layers=5,
dec_layers=5,
shortcut_layers=2,
stu_kernel_size=3,
use_stu=True,
one_more_conv=True,
stu_norm='bn'):
super(Generator, self).__init__()
self.n_attrs = attr_dim
self.enc_layers = enc_layers
self.dec_layers = dec_layers
self.shortcut_layers = min(shortcut_layers, enc_layers - 1,
dec_layers - 1)
self.use_stu = use_stu
self.concat = ops.Concat(axis=1)
self.encoder = nn.CellList()
in_channels = 3
for i in range(self.enc_layers):
self.encoder.append(
nn.SequentialCell(
nn.Conv2d(in_channels,
enc_dim * 2**i,
4,
2,
padding=1,
pad_mode='pad'), nn.BatchNorm2d(enc_dim * 2**i),
nn.LeakyReLU(alpha=0.2)))
in_channels = enc_dim * 2**i
# Selective Transfer Unit (STU)
if self.use_stu:
self.stu = nn.CellList()
for i in reversed(
range(self.enc_layers - 1 - self.shortcut_layers,
self.enc_layers - 1)):
self.stu.append(
ConvGRUCell(self.n_attrs, enc_dim * 2**i, enc_dim * 2**i,
stu_kernel_size, stu_norm))
self.decoder = nn.CellList()
for i in range(self.dec_layers):
if i < self.dec_layers - 1:
if i == 0:
self.decoder.append(
nn.SequentialCell(
nn.Conv2dTranspose(
dec_dim * 2**(self.dec_layers - 1) + attr_dim,
dec_dim * 2**(self.dec_layers - 1),
4,
2,
padding=1,
pad_mode='pad'), nn.BatchNorm2d(in_channels),
nn.ReLU()))
elif i <= self.shortcut_layers:
self.decoder.append(
nn.SequentialCell(
nn.Conv2dTranspose(
dec_dim * 3 * 2**(self.dec_layers - 1 - i),
dec_dim * 2**(self.dec_layers - 1 - i),
4,
2,
padding=1,
pad_mode='pad'),
nn.BatchNorm2d(dec_dim *
2**(self.dec_layers - 1 - i)),
nn.ReLU()))
else:
self.decoder.append(
nn.SequentialCell(
nn.Conv2dTranspose(
dec_dim * 2**(self.dec_layers - i),
dec_dim * 2**(self.dec_layers - 1 - i),
4,
2,
padding=1,
pad_mode='pad'),
nn.BatchNorm2d(dec_dim *
2**(self.dec_layers - 1 - i)),
nn.ReLU()))
else:
in_dim = dec_dim * 3 if self.shortcut_layers == self.dec_layers - 1 else dec_dim * 2
if one_more_conv:
self.decoder.append(
nn.SequentialCell(
nn.Conv2dTranspose(in_dim,
dec_dim // 4,
4,
2,
padding=1,
pad_mode='pad'),
nn.BatchNorm2d(dec_dim // 4), nn.ReLU(),
nn.Conv2dTranspose(dec_dim // 4,
3,
3,
1,
padding=1,
pad_mode='pad'), nn.Tanh()))
else:
self.decoder.append(
nn.SequentialCell(
nn.Conv2dTranspose(in_dim,
3,
4,
2,
padding=1,
pad_mode='pad'), nn.Tanh()))
def construct(self, x, a):
""" construct """
# propagate encoder layers
y = []
x_ = x
for layer in self.encoder:
x_ = layer(x_)
y.append(x_)
out = y[-1]
reshape = ops.Reshape()
(n, _, h, w) = out.shape
attr = reshape(a, (n, self.n_attrs, 1, 1))
tile = ops.Tile()
attr = tile(attr, (1, 1, h, w))
out = self.decoder[0](self.concat((out, attr)))
stu_state = y[-1]
# propagate shortcut layers
for i in range(1, self.shortcut_layers + 1):
if self.use_stu:
stu_out, stu_state = self.stu[i - 1](y[-(i + 1)], stu_state, a)
out = self.concat((out, stu_out))
out = self.decoder[i](out)
else:
out = self.concat((out, y[-(i + 1)]))
out = self.decoder[i](out)
# propagate non-shortcut layers
for i in range(self.shortcut_layers + 1, self.dec_layers):
out = self.decoder[i](out)
return out
class Discriminator(nn.Cell):
""" Discriminator Cell """
def __init__(self,
image_size=128,
attr_dim=10,
conv_dim=64,
fc_dim=1024,
n_layers=5):
super(Discriminator, self).__init__()
layers = []
in_channels = 3
for i in range(n_layers):
layers.append(
nn.SequentialCell(
nn.Conv2d(in_channels,
conv_dim * 2**i,
4,
2,
padding=1,
pad_mode='pad'),
_GroupNorm(num_groups=conv_dim * 2**i,
num_channels=conv_dim * 2**i),
nn.LeakyReLU(alpha=0.2)))
in_channels = conv_dim * 2**i
self.conv = nn.SequentialCell(*layers)
feature_size = image_size // 2**n_layers
self.fc_adv = nn.SequentialCell(
nn.Flatten(),
nn.Dense(conv_dim * 2**(n_layers - 1) * feature_size**2, fc_dim),
nn.LeakyReLU(alpha=0.2), nn.Flatten(), nn.Dense(fc_dim, 1))
self.fc_att = nn.SequentialCell(
nn.Flatten(),
nn.Dense(conv_dim * 2**(n_layers - 1) * feature_size**2, fc_dim),
nn.LeakyReLU(alpha=0.2),
nn.Flatten(),
nn.Dense(fc_dim, attr_dim),
)
def construct(self, x):
y = self.conv(x)
reshape = ops.Reshape()
y = reshape(y, (y.shape[0], -1))
logit_adv = self.fc_adv(y)
logit_att = self.fc_att(y)
return logit_adv, logit_att
class _GroupNorm(nn.GroupNorm):
""" Rewrite of original GroupNorm """
def __init__(self,
num_groups,
num_channels,
eps=1e-05,
affine=True,
gamma_init='ones',
beta_init='zeros'):
super().__init__(num_groups,
num_channels,
eps=1e-05,
affine=True,
gamma_init='ones',
beta_init='zeros')
self.pow = ops.Pow()
def _cal_output(self, x):
"""calculate groupnorm output"""
batch, channel, height, width = self.shape(x)
_channel_check(channel, self.num_channels)
x = self.reshape(x, (batch, self.num_groups, -1))
mean = self.reduce_mean(x, 2)
var = self.reduce_sum(self.square(x - mean),
2) / (channel * height * width / self.num_groups)
std = self.pow((var + self.eps), 0.5)
x = (x - mean) / std
x = self.reshape(x, (batch, channel, height, width))
output = x * self.reshape(self.gamma, (-1, 1, 1)) + self.reshape(
self.beta, (-1, 1, 1))
return output
@constexpr
def _channel_check(channel, num_channel):
if channel != num_channel:
raise ValueError("the input channel is not equal with num_channel")
class GeneratorWithLossCell(nn.Cell):
""" GeneratorWithLossCell """
def __init__(self, network, args):
super(GeneratorWithLossCell, self).__init__(auto_prefix=False)
self.network = network
self.lambda2 = args.lambda2
self.lambda3 = args.lambda3
def construct(self, x_real, c_org, c_trg, attr_diff):
_, _, _, loss_cls_G, loss_rec_G, loss_adv_G = self.network(
x_real, c_org, c_trg, attr_diff)
return loss_adv_G + self.lambda2 * loss_cls_G + self.lambda3 * loss_rec_G
class DiscriminatorWithLossCell(nn.Cell):
def __init__(self, network):
super(DiscriminatorWithLossCell, self).__init__(auto_prefix=False)
self.network = network
def construct(self, x_real, c_org, c_trg, attr_diff, alpha):
loss_D, _, _, _, _, _, _ = self.network(x_real, c_org, c_trg,
attr_diff, alpha)
return loss_D
class TrainOneStepGenerator(nn.Cell):
""" Training class of Generator """
def __init__(self, loss_G_model, optimizer, args):
super(TrainOneStepGenerator, self).__init__()
self.optimizer = optimizer
self.loss_G_model = loss_G_model
self.loss_G_model.set_grad()
self.loss_G_model.set_train()
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
self.weights = optimizer.parameters
self.network = GeneratorWithLossCell(loss_G_model, args)
self.network.add_flags(defer_inline=True)
self.grad_reducer = F.identity
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [
ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL,
ParallelMode.AUTO_PARALLEL
]:
mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num")
else:
degree = get_group_size()
self.grad_reducer = nn.DistributedGradReducer(
self.weights, mean, degree)
def construct(self, real_x, c_org, c_trg, attr_diff, sens=1.0):
fake_x, loss_G, loss_fake_G, loss_cls_G, loss_rec_G, loss_adv_G =\
self.loss_G_model(real_x, c_org, c_trg, attr_diff)
sens = P.Fill()(P.DType()(loss_G), P.Shape()(loss_G), sens)
grads = self.grad(self.network, self.weights)(real_x, c_org, c_trg,
attr_diff, sens)
grads = self.grad_reducer(grads)
return (ops.depend(loss_G, self.optimizer(grads)), fake_x, loss_G,
loss_fake_G, loss_cls_G, loss_rec_G, loss_adv_G)
class TrainOneStepDiscriminator(nn.Cell):
""" Training class of Discriminator """
def __init__(self, loss_D_model, optimizer):
super(TrainOneStepDiscriminator, self).__init__()
self.optimizer = optimizer
self.loss_D_model = loss_D_model
self.loss_D_model.set_grad()
self.loss_D_model.set_train()
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
self.weights = optimizer.parameters
self.network = DiscriminatorWithLossCell(loss_D_model)
self.network.add_flags(defer_inline=True)
self.grad_reducer = F.identity
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [
ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL,
ParallelMode.AUTO_PARALLEL
]:
mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num")
else:
degree = get_group_size()
self.grad_reducer = nn.DistributedGradReducer(
self.weights, mean, degree)
def construct(self, real_x, c_org, c_trg, attr_diff, alpha, sens=1.0):
loss_D, loss_real_D, loss_fake_D, loss_cls_D, loss_gp_D, loss_adv_D, attr_diff =\
self.loss_D_model(real_x, c_org, c_trg, attr_diff, alpha)
sens = P.Fill()(P.DType()(loss_D), P.Shape()(loss_D), sens)
grads = self.grad(self.network, self.weights)(real_x, c_org, c_trg,
attr_diff, alpha, sens)
grads = self.grad_reducer(grads)
return (ops.depend(loss_D, self.optimizer(grads)), loss_D, loss_real_D,
loss_fake_D, loss_cls_D, loss_gp_D, loss_adv_D, attr_diff)

View File

@ -0,0 +1,189 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" STGAN Models """
import os
import math
import random
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.common.dtype as mstype
from mindspore import Tensor
import numpy as np
from .networks import init_weights, Discriminator, Generator, TrainOneStepGenerator, TrainOneStepDiscriminator
from .losses import GeneratorLoss, DiscriminatorLoss
from .base_model import BaseModel
class STGANModel(BaseModel):
""" STGANModel """
def __init__(self, args):
""" """
BaseModel.__init__(self, args)
self.rand_int = ops.UniformInt()
self.concat = ops.operations.Concat(axis=1)
self.use_stu = args.use_stu
self.args = args
self.n_attrs = len(args.attrs)
self.mode = args.mode
self.loss_names = [
'D', 'G', 'adv_D', 'cls_D', 'real_D', 'fake_D', 'gp_D', 'adv_G',
'cls_G', 'rec_G'
]
if self.isTrain:
self.model_names = ['G', 'D']
else:
self.model_names = ['G']
self.netG = Generator(self.n_attrs,
args.enc_dim,
args.dec_dim,
args.enc_layers,
args.dec_layers,
args.shortcut_layers,
args.stu_kernel_size,
use_stu=args.use_stu,
one_more_conv=args.one_more_conv,
stu_norm=args.stu_norm)
print('Generator: ', self.netG)
if self.isTrain:
self.netD = Discriminator(args.image_size, self.n_attrs,
args.dis_dim, args.dis_fc_dim,
args.dis_layers)
print('Discriminator: ', self.netD)
if self.args.continue_train:
continue_iter = self.args.continue_iter if self.args.continue_iter != -1 else 'latest'
self.load_networks(continue_iter)
if self.netG is not None:
num_params = 0
for p in self.netG.trainable_params():
num_params += np.prod(p.shape)
print(
'\n\n\nGenerator trainable parameters: {}'.format(num_params))
if self.netD is not None:
num_params = 0
for p in self.netD.trainable_params():
num_params += np.prod(p.shape)
print('Discriminator trainable parameters: {}\n\n\n'.format(
num_params))
if self.isTrain:
if not self.args.continue_train:
init_weights(self.netG, 'KaimingUniform', math.sqrt(5))
init_weights(self.netD, 'KaimingUniform', math.sqrt(5))
self.loss_D_model = DiscriminatorLoss(self.netD, self.netG,
self.args, self.mode)
self.loss_G_model = GeneratorLoss(self.netG, self.netD, self.args,
self.mode)
self.optimizer_G = nn.Adam(self.netG.trainable_params(),
self.get_learning_rate(),
beta1=args.beta1,
beta2=args.beta2)
self.optimizer_D = nn.Adam(self.netD.trainable_params(),
self.get_learning_rate(),
beta1=args.beta1,
beta2=args.beta2)
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
self.train_G = TrainOneStepGenerator(self.loss_G_model,
self.optimizer_G, self.args)
self.train_D = TrainOneStepDiscriminator(self.loss_D_model,
self.optimizer_D)
self.train_G.set_train()
self.train_D.set_train()
if self.args.phase == 'test':
if self.args.ckpt_path is not None:
self.load_generator_from_path(self.args.ckpt_path)
else:
self.load_networks('latest' if args.test == -1 else args.test)
def set_input(self, input_data):
self.label_org = Tensor(input_data['label'], mstype.float32)
(shape_0, _) = self.label_org.shape
rand_idx = random.sample(range(0, shape_0), shape_0)
self.label_trg = self.label_org[rand_idx]
self.real_x = Tensor(input_data['image'], mstype.float32)
def test(self, data, filename=''):
""" Test Function """
data_image = data['image']
data_label = data['label']
c_trg = self.create_test_label(data_label, self.args.attrs)
attr_diff = c_trg - data_label if self.args.attr_mode == 'diff' else c_trg
attr_diff = attr_diff * self.args.thres_int
fake_x = self.netG(data_image, attr_diff)
self.save_image(
fake_x[0],
os.path.join(self.test_result_save_path, '{}'.format(filename)))
def optimize_parameters(self):
""" Optimizing Model's Trainable Parameters
"""
attr_diff = self.label_trg - self.label_org if self.args.attr_mode == 'diff' else self.label_trg
(h, w) = attr_diff.shape
rand_attr = Tensor(np.random.rand(h, w), mstype.float32)
attr_diff = attr_diff * rand_attr * (2 * self.args.thres_int)
alpha = Tensor(np.random.randn(self.real_x.shape[0], 1, 1, 1),
mstype.float32)
# train D
_, loss_D, loss_real_D, loss_fake_D, loss_cls_D, loss_gp_D, loss_adv_D, attr_diff =\
self.train_D(self.real_x, self.label_org, self.label_trg, attr_diff, alpha)
if self.current_iteration % self.args.n_critic == 0:
# train G
_, _, loss_G, loss_fake_G, loss_cls_G, loss_rec_G, loss_adv_G =\
self.train_G(self.real_x, self.label_org, self.label_trg, attr_diff)
# saving losses
if (self.current_iteration / 5) % self.args.print_freq == 0:
with open(os.path.join(self.train_log_path, 'loss.log'),
'a+') as f:
f.write('Iter: %s\n' % self.current_iteration)
f.write(
'loss D: %s, loss D_real: %s, loss D_fake: %s, loss D_gp: %s, loss D_adv: %s, loss D_cls: %s \n'
% (loss_D, loss_real_D, loss_fake_D, loss_gp_D,
loss_adv_D, loss_cls_D))
f.write(
'loss G: %s, loss G_rec: %s, loss G_fake: %s, loss G_adv: %s, loss G_cls: %s \n\n'
% (loss_G, loss_rec_G, loss_fake_G, loss_adv_G,
loss_cls_G))
def eval(self, data_loader):
""" Eval function of STGAN
"""
val_loader = data_loader.val_loader
concat_3d = ops.Concat(axis=3)
concat_1d = ops.Concat(axis=1)
data = next(val_loader)
data_image = data['image']
data_label = data['label']
sample_list = self.create_labels(data_label, self.args.attrs)
sample_list.insert(0, data_label)
x_concat = data_image
for c_trg_sample in sample_list:
attr_diff = c_trg_sample - data_label if self.args.attr_mode == 'diff' else c_trg_sample
attr_diff = attr_diff * self.args.thres_int
fake_x = self.netG(data_image, attr_diff)
x_concat = concat_3d((x_concat, fake_x))
sample_result = x_concat[0]
for i in range(1, x_concat.shape[0]):
sample_result = concat_1d((sample_result, x_concat[i]))
self.save_image(
sample_result,
os.path.join(self.sample_save_path,
'samples_{}.jpg'.format(self.current_iteration + 1)))

View File

@ -0,0 +1,17 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" utils """
from .args import get_args
from .tools import mkdirs

View File

@ -0,0 +1,296 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""arguments"""
import os
import argparse
import ast
import datetime
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.communication.management import init
def add_basic_parameters(parser):
""" add basic parameters """
parser.add_argument("--platform",
type=str,
default="Ascend",
choices=("Ascend", "GPU", "CPU"),
help="running platform, support Ascend, GPU and CPU")
parser.add_argument("--device_id",
type=int,
default=0,
help="device id, default is 0")
parser.add_argument('--device_num',
type=int,
default=1,
help='device num, default is 1.')
parser.add_argument('--ms_version',
type=str,
default='1.2.0',
help="Mindspore's Version, default is 1.2.0")
return parser
def add_model_parameters(parser):
""" add model parameters """
att_dict = {
'5_o_Clock_Shadow': 0,
'Arched_Eyebrows': 1,
'Attractive': 2,
'Bags_Under_Eyes': 3,
'Bald': 4,
'Bangs': 5,
'Big_Lips': 6,
'Big_Nose': 7,
'Black_Hair': 8,
'Blond_Hair': 9,
'Blurry': 10,
'Brown_Hair': 11,
'Bushy_Eyebrows': 12,
'Chubby': 13,
'Double_Chin': 14,
'Eyeglasses': 15,
'Goatee': 16,
'Gray_Hair': 17,
'Heavy_Makeup': 18,
'High_Cheekbones': 19,
'Male': 20,
'Mouth_Slightly_Open': 21,
'Mustache': 22,
'Narrow_Eyes': 23,
'No_Beard': 24,
'Oval_Face': 25,
'Pale_Skin': 26,
'Pointy_Nose': 27,
'Receding_Hairline': 28,
'Rosy_Cheeks': 29,
'Sideburns': 30,
'Smiling': 31,
'Straight_Hair': 32,
'Wavy_Hair': 33,
'Wearing_Earrings': 34,
'Wearing_Hat': 35,
'Wearing_Lipstick': 36,
'Wearing_Necklace': 37,
'Wearing_Necktie': 38,
'Young': 39
}
attr_default = ['Bangs', 'Blond_Hair', 'Mustache', 'Young']
parser.add_argument("--attrs",
default=attr_default,
choices=att_dict,
nargs='+',
help='Attributes to modify by the model')
parser.add_argument('--image_size',
type=int,
default=128,
help='input image size')
parser.add_argument(
'--shortcut_layers',
type=int,
default=3,
help='# of skip connections between the encoder and the decoder')
parser.add_argument('--enc_dim', type=int, default=64)
parser.add_argument('--dec_dim', type=int, default=64)
parser.add_argument('--dis_dim', type=int, default=64)
parser.add_argument('--dis_fc_dim',
type=int,
default=1024,
help='# of discriminator fc channels')
parser.add_argument('--enc_layers', type=int, default=5)
parser.add_argument('--dec_layers', type=int, default=5)
parser.add_argument('--dis_layers', type=int, default=5)
# STGAN & STU
parser.add_argument('--attr_mode',
type=str,
default='diff',
choices=['diff', 'target'])
parser.add_argument('--use_stu', type=bool, default=True)
parser.add_argument('--stu_dim', type=int, default=64)
parser.add_argument('--stu_kernel_size', type=int, default=3)
parser.add_argument('--stu_norm',
type=str,
default='bn',
choices=['bn', 'in'])
parser.add_argument(
'--stu_state',
type=str,
default='stu',
choices=['stu', 'gru', 'direct'],
help=
'gru: gru arch.; stu: stu arch.; direct: directly pass the inner state to the outer layer'
)
parser.add_argument(
'--multi_inputs',
type=int,
default=1,
help='# of hierarchical inputs (in the first several encoder layers')
parser.add_argument(
'--one_more_conv',
type=int,
default=1,
choices=[0, 1, 3],
help='0: no further conv after the decoder; 1: conv(k=1); 3: conv(k=3)'
)
return parser
def add_train_parameters(parser):
""" add train parameters """
parser.add_argument('--mode',
default='wgan',
choices=['wgan', 'lsgan', 'dcgan'])
parser.add_argument('--continue_train',
type=bool,
default=False,
help='Flag of continue train, default is false')
parser.add_argument(
'--continue_iter',
type=int,
default=-1,
help='Continue point of continue training, -1 means latest')
parser.add_argument('--test_iter',
type=int,
default=-1,
help='Checkpoint of model testing, -1 means latest')
parser.add_argument('--n_epochs',
type=int,
default=100,
help='# of epochs')
parser.add_argument('--n_critic',
type=int,
default=5,
help='number of D updates per each G update')
parser.add_argument('--max_epoch',
type=int,
default=100,
help='# of epochs')
parser.add_argument('--init_epoch',
type=int,
default=50,
help='# of epochs with init lr.')
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument("--beta1",
type=float,
default=0.5,
help="Adam beta1, default is 0.5.")
parser.add_argument("--beta2",
type=float,
default=0.999,
help="Adam beta2, default is 0.999.")
parser.add_argument("--lambda_gp",
type=int,
default=10,
help="Lambda gp, default is 10")
parser.add_argument("--lambda1",
type=int,
default=1,
help="Lambda1, default is 1")
parser.add_argument("--lambda2",
type=int,
default=10,
help="Lambda2, default is 10")
parser.add_argument("--lambda3",
type=int,
default=100,
help="Lambda3, default is 100")
parser.add_argument('--lr',
type=float,
default=0.0002,
help='learning rate')
parser.add_argument('--thres_int', type=float, default=0.5)
parser.add_argument('--test_int', type=float, default=1.0)
parser.add_argument('--n_sample',
type=int,
default=64,
help='# of sample images')
parser.add_argument('--print_freq',
type=int,
default=1,
help='print log freq (per critic), default is 1')
parser.add_argument(
'--save_freq',
type=int,
default=5000,
help='save model evary save_freq iters, 0 means to save evary epoch.')
parser.add_argument(
'--sample_freq',
type=int,
default=1000,
help=
'eval on validation set every sample_freq iters, 0 means to save evary epoch.'
)
return parser
def get_args(phase):
"""get args"""
parser = argparse.ArgumentParser(description="STGAN")
# basic parameters
parser = add_basic_parameters(parser)
#model parameters
parser = add_model_parameters(parser)
# training
parser = add_train_parameters(parser)
# others
parser.add_argument('--use_cropped_img', action='store_true')
default_experiment_name = datetime.datetime.now().strftime(
"%Y.%m.%d-%H%M%S")
parser.add_argument('--experiment_name', default=default_experiment_name)
parser.add_argument('--num_ckpt', type=int, default=1)
parser.add_argument('--clear', default=False, action='store_true')
parser.add_argument('--save_graphs', type=ast.literal_eval, default=False, \
help='whether save graphs, default is False.')
parser.add_argument('--outputs_dir', type=str, default='./outputs', \
help='models are saved here, default is ./outputs.')
parser.add_argument("--dataroot", type=str, default='./dataset')
parser.add_argument('--file_format', type=str, choices=['AIR', 'ONNX', 'MINDIR'], default='AIR', \
help='file format')
parser.add_argument('--file_name', type=str, default='STGAN', help='output file name prefix.')
parser.add_argument('--ckpt_path', default=None, help='path of checkpoint file.')
args = parser.parse_args()
if phase == 'test':
assert args.experiment_name != default_experiment_name, "--experiment_name should be assigned in test mode"
if args.continue_train:
assert args.experiment_name != default_experiment_name, "--experiment_name should be assigned in continue"
if args.device_num > 1 and args.platform != "CPU":
context.set_context(mode=context.GRAPH_MODE,
device_target=args.platform,
save_graphs=args.save_graphs,
device_id=int(os.environ["DEVICE_ID"]))
context.reset_auto_parallel_context()
context.set_auto_parallel_context(
parallel_mode=ParallelMode.AUTO_PARALLEL,
gradients_mean=True,
device_num=args.device_num)
init()
args.rank = int(os.environ["DEVICE_ID"])
else:
context.set_context(mode=context.GRAPH_MODE,
device_target=args.platform,
save_graphs=args.save_graphs,
device_id=args.device_id)
args.rank = 0
args.device_num = 1
args.n_epochs = min(args.max_epoch, args.n_epochs)
args.n_epochs_decay = args.max_epoch - args.n_epochs
if phase == 'train':
args.isTrain = True
else:
args.isTrain = False
args.phase = phase
return args

View File

@ -0,0 +1,50 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils"""
import os
def mkdirs(paths):
"""create empty directories if they don't exist
Parameters:
paths (str list) -- a list of directory paths
"""
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
"""create a single empty directory if it didn't exist
Parameters:
path (str) -- a single directory path
"""
if not os.path.exists(path):
os.makedirs(path)
def is_version_satisfied_1_2_0(version):
old_arr = '1.2.0'.split('.')
new_arr = version.split('.')
assert len(old_arr) == 3 and len(
new_arr) == 3, 'version input must str like x.x.x'
for i in range(3):
if new_arr[i] < old_arr[i]:
return False
return True

View File

@ -0,0 +1,78 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" STGAN TRAIN"""
import tqdm
from mindspore.common import set_seed
from src.models import STGANModel
from src.utils import get_args
from src.dataset import CelebADataLoader
set_seed(1)
def train():
"""Train Function"""
args = get_args("train")
print(args)
print('\n\n=============== start training ===============\n\n')
# Get DataLoader
data_loader = CelebADataLoader(args.dataroot,
mode=args.phase,
selected_attrs=args.attrs,
batch_size=args.batch_size,
image_size=args.image_size,
device_num=args.device_num)
iter_per_epoch = len(data_loader)
args.dataset_size = iter_per_epoch
# Get STGAN MODEL
model = STGANModel(args)
it_count = 0
for _ in tqdm.trange(args.n_epochs, desc='Epoch Loop'):
for _ in tqdm.trange(iter_per_epoch, desc='Inner Epoch Loop'):
if model.current_iteration > it_count:
it_count += 1
continue
try:
# training model
data = next(data_loader.train_loader)
model.set_input(data)
model.optimize_parameters()
# saving model
if (it_count + 1) % args.save_freq == 0:
model.save_networks()
# sampling
if (it_count + 1) % args.sample_freq == 0:
model.eval(data_loader)
except KeyboardInterrupt:
logger.info('You have entered CTRL+C.. Wait to finalize')
model.save_networks()
it_count += 1
model.current_iteration = it_count
model.save_networks()
print('\n\n=============== finish training ===============\n\n')
if __name__ == '__main__':
train()