forked from mindspore-Ecosystem/mindspore
!17363 STGAN-MindSpore
Merge pull request !17363 from Mingchung-Suen/STGAN
This commit is contained in:
commit
c055bc9b99
|
@ -0,0 +1,203 @@
|
|||
# Contents
|
||||
|
||||
- [Contents](#contents)
|
||||
- [STGAN Description](#stgan-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
## [STGAN Description](#contents)
|
||||
|
||||
STGAN was proposed in CVPR 2019, one of the facial attributes transfer networks using Generative Adversarial Networks (GANs). It introduces a new Selective Transfer Unit (STU) to get better facial attributes transfer than others.
|
||||
|
||||
[Paper](https://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_STGAN_A_Unified_Selective_Transfer_Network_for_Arbitrary_Image_Attribute_CVPR_2019_paper.pdf): Liu M, Ding Y, Xia M, et al. STGAN: A Unified Selective Transfer Network for Arbitrary Image
|
||||
Attribute Editing[C]. IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR).
|
||||
IEEE, 2019: 3668-3677.
|
||||
|
||||
## [Model Architecture](#contents)
|
||||
|
||||
STGAN composition consists of Generator, Discriminator and Selective Transfer Unit. Using Selective Transfer Unit can help networks keep more attributes in the long term of training.
|
||||
|
||||
## [Dataset](#contents)
|
||||
|
||||
In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
Dataset used: [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
|
||||
|
||||
- Dataset size:1011M,202,599 128*128 colorful images, marked as 40 attributes
|
||||
- Train:182,599 images
|
||||
- Test:18,800 images
|
||||
- Data format:binary files
|
||||
- Note:Data will be processed in celeba.py
|
||||
- Download the dataset, the directory structure is as follows:
|
||||
|
||||
```bash
|
||||
├── dataroot
|
||||
├── anno
|
||||
├── list_attr_celeba.txt
|
||||
├── image
|
||||
├── 000001.jpg
|
||||
├── ...
|
||||
```
|
||||
|
||||
## [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
## [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
```python
|
||||
# enter script dir, train STGAN
|
||||
sh scripts/run_standalone_train.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID]
|
||||
# distributed training
|
||||
sh scripts/run_distribute_train.sh [RANK_TABLE_FILE] [EXPERIMENT_NAME] [DATA_PATH]
|
||||
# enter script dir, evaluate STGAN
|
||||
sh scripts/run_eval.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
## [Script Description](#contents)
|
||||
|
||||
### [Script and Sample Code](#contents)
|
||||
|
||||
```bash
|
||||
├── cv
|
||||
├── STGAN
|
||||
├── README.md // descriptions about STGAN
|
||||
├── requirements.txt // package needed
|
||||
├── scripts
|
||||
│ ├──run_standalone_train.sh // train in ascend
|
||||
│ ├──run_eval.sh // evaluate in ascend
|
||||
│ ├──run_distribute_train.sh // distributed train in ascend
|
||||
├── src
|
||||
├── dataset
|
||||
├── datasets.py // creating dataset
|
||||
├── celeba.py // processing celeba dataset
|
||||
├── distributed_sampler.py // distributed sampler
|
||||
├── models
|
||||
├── base_model.py
|
||||
├── losses.py // loss models
|
||||
├── networks.py // basic models of STGAN
|
||||
├── stgan.py // executing procedure
|
||||
├── utils
|
||||
├── args.py // argument parser
|
||||
├── tools.py // simple tools
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
├── export.py // model-export script
|
||||
```
|
||||
|
||||
### [Script Parameters](#contents)
|
||||
|
||||
```python
|
||||
Major parameters in train.py and utils/args.py as follows:
|
||||
|
||||
--dataroot: The relative path from the current path to the train and evaluation datasets.
|
||||
--n_epochs: Total training epochs.
|
||||
--batch_size: Training batch size.
|
||||
--image_size: Image size used as input to the model.
|
||||
--device_target: Device where the code will be implemented. Optional value is "Ascend".
|
||||
```
|
||||
|
||||
### [Training Process](#contents)
|
||||
|
||||
#### Training
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```bash
|
||||
python train.py --dataroot ./dataset --experiment_name 128 > log 2>&1 &
|
||||
# or enter script dir, and run the script
|
||||
sh scripts/run_standalone_train.sh ./dataset 128 0
|
||||
# distributed training
|
||||
sh scripts/run_distribute_train.sh ./config/rank_table_8pcs.json 128 /data/dataset
|
||||
```
|
||||
|
||||
After training, the loss value will be achieved as follows:
|
||||
|
||||
```bash
|
||||
# grep "loss is " log
|
||||
epoch: 1 step: 1, loss is 2.2791853
|
||||
...
|
||||
epoch: 1 step: 1536, loss is 1.9366643
|
||||
epoch: 1 step: 1537, loss is 1.6983616
|
||||
epoch: 1 step: 1538, loss is 1.0221305
|
||||
...
|
||||
```
|
||||
|
||||
The model checkpoint will be saved in the output directory.
|
||||
|
||||
### [Evaluation Process](#contents)
|
||||
|
||||
#### Evaluation
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation.
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```bash
|
||||
python eval.py --dataroot ./dataset --experiment_name 128 > eval_log.txt 2>&1 &
|
||||
# or enter script dir, and run the script
|
||||
sh scripts/run_eval.sh ./dataset 128 0 ./ckpt/generator.ckpt
|
||||
```
|
||||
|
||||
You can view the results in the output directory, which contains a batch of result sample images.
|
||||
|
||||
### Model Export
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_path [CHECKPOINT_PATH] --platform [PLATFORM] --file_format[EXPORT_FORMAT]
|
||||
```
|
||||
|
||||
`EXPORT_FORMAT` should be "MINDIR"
|
||||
|
||||
## Model Description
|
||||
|
||||
### Performance
|
||||
|
||||
#### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
|
||||
| uploaded Date | 05/07/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | CelebA |
|
||||
| Training Parameters | epoch=100, batch_size = 128 |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | Loss |
|
||||
| Output | predict class |
|
||||
| Loss | 6.5523 |
|
||||
| Speed | 1pc: 400 ms/step; 8pcs: 143 ms/step |
|
||||
| Total time | 1pc: 41:36:07 |
|
||||
| Checkpoint for Fine tuning | 170.55M(.ckpt file) |
|
||||
| Scripts | [STGAN script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/STGAN) |
|
||||
|
||||
## [Model Description](#contents)
|
||||
|
||||
## [Description of Random Situation](#contents)
|
||||
|
||||
In dataset.py, we set the seed inside ```create_dataset``` function.
|
||||
|
||||
## [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" Model Test """
|
||||
import tqdm
|
||||
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.models import STGANModel
|
||||
from src.utils import get_args
|
||||
from src.dataset import CelebADataLoader
|
||||
|
||||
set_seed(1)
|
||||
|
||||
def test():
|
||||
""" test function """
|
||||
args = get_args("test")
|
||||
print('\n\n=============== start testing ===============\n\n')
|
||||
data_loader = CelebADataLoader(args.dataroot,
|
||||
mode=args.phase,
|
||||
selected_attrs=args.attrs,
|
||||
batch_size=1,
|
||||
image_size=args.image_size)
|
||||
iter_per_epoch = len(data_loader)
|
||||
args.dataset_size = iter_per_epoch
|
||||
model = STGANModel(args)
|
||||
|
||||
for _ in tqdm.trange(iter_per_epoch, desc='Test Loop'):
|
||||
data = next(data_loader.test_loader)
|
||||
model.test(data, data_loader.test_set.get_current_filename())
|
||||
|
||||
print('\n\n=============== finish testing ===============\n\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" Model Export """
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import export
|
||||
from src.models import STGANModel
|
||||
from src.utils import get_args
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args("test")
|
||||
context.set_context(mode=context.GRAPH_MODE, device_id=args.device_id)
|
||||
model = STGANModel(args)
|
||||
model.netG.set_train(True)
|
||||
input_shp = [16, 3, 128, 128]
|
||||
input_shp_2 = [16, 4]
|
||||
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
|
||||
input_array_2 = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp_2).astype(np.float32))
|
||||
G_file = f"{args.file_name}_model"
|
||||
export(model.netG, input_array, input_array_2, file_name=G_file, file_format=args.file_format)
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
numpy
|
||||
tqdm
|
|
@ -0,0 +1,54 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [EXPERIMENT_NAME] [DATA_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
RANK_TABLE_FILE=$(realpath $1)
|
||||
export RANK_TABLE_FILE
|
||||
export EXPERIMENT_NAME=$2
|
||||
export DATA_PATH=$3
|
||||
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
|
||||
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp ./train.py ./train_parallel$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
cd ./train_parallel$i ||exit
|
||||
env > env.log
|
||||
python train.py --device_num ${DEVICE_NUM} --experiment_name=$EXPERIMENT_NAME --dataroot=$DATA_PATH > log 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,31 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DATA_PATH=$1
|
||||
export EXPERIMENT_NAME=$2
|
||||
export DEVICE_ID=$3
|
||||
export CHECKPOINT_PATH=$4
|
||||
|
||||
python eval.py --dataroot=$DATA_PATH --experiment_name=$EXPERIMENT_NAME \
|
||||
--device_id=$DEVICE_ID --ckpt_path=$CHECKPOINT_PATH \
|
||||
--platform="Ascend" > eval_log 2>&1 &
|
|
@ -0,0 +1,29 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train.sh [DATA_PATH] [EXPERIMENT_NAME] [DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DATA_PATH=$1
|
||||
export EXPERIMENT_NAME=$2
|
||||
export DEVICE_ID=$3
|
||||
|
||||
python train.py --dataroot=$DATA_PATH --experiment_name=$EXPERIMENT_NAME \
|
||||
--device_id=$DEVICE_ID --platform="Ascend" > log 2>&1 &
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" dataset """
|
||||
from .celeba import CelebADataLoader
|
|
@ -0,0 +1,157 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" CelebA Dataset """
|
||||
import os
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import get_rank
|
||||
|
||||
from PIL import Image
|
||||
from .distributed_sampler import DistributedSampler
|
||||
from .datasets import make_dataset
|
||||
|
||||
|
||||
class CelebADataset:
|
||||
""" CelebA """
|
||||
def __init__(self, dir_path, mode, selected_attrs):
|
||||
self.items = make_dataset(dir_path, mode, selected_attrs)
|
||||
self.dir_path = dir_path
|
||||
self.mode = mode
|
||||
self.filename = ''
|
||||
|
||||
def __getitem__(self, index):
|
||||
filename, label = self.items[index]
|
||||
image = Image.open(os.path.join(self.dir_path, 'image', filename))
|
||||
image = np.array(image.convert('RGB'))
|
||||
label = np.array(label)
|
||||
if self.mode == 'test':
|
||||
self.filename = filename
|
||||
return image, label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def get_current_filename(self):
|
||||
return self.filename
|
||||
|
||||
|
||||
class CelebADataLoader:
|
||||
""" CelebADataLoader """
|
||||
def __init__(self,
|
||||
root,
|
||||
mode,
|
||||
selected_attrs,
|
||||
crop_size=None,
|
||||
image_size=128,
|
||||
batch_size=64,
|
||||
device_num=1):
|
||||
if mode not in ['train', 'test', 'val']:
|
||||
return
|
||||
|
||||
mean = [0.5 * 255] * 3
|
||||
std = [0.5 * 255] * 3
|
||||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
rank = 0
|
||||
if parallel_mode in [
|
||||
ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL,
|
||||
ParallelMode.AUTO_PARALLEL
|
||||
]:
|
||||
rank = get_rank()
|
||||
shuffle = True
|
||||
cores = multiprocessing.cpu_count()
|
||||
num_parallel_workers = min(16, int(cores / device_num))
|
||||
|
||||
if mode == 'train':
|
||||
dataset = CelebADataset(root, mode, selected_attrs)
|
||||
distributed_sampler = DistributedSampler(len(dataset),
|
||||
device_num,
|
||||
rank,
|
||||
shuffle=shuffle)
|
||||
self.dataset_size = int(len(distributed_sampler) / batch_size)
|
||||
|
||||
val_set = CelebADataset(root, 'val', selected_attrs)
|
||||
self.val_dataset_size = len(val_set)
|
||||
|
||||
transform = [
|
||||
C.Resize((image_size, image_size)),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
if crop_size is not None:
|
||||
transform.append(C.CenterCrop(crop_size))
|
||||
val_distributed_sampler = DistributedSampler(len(val_set),
|
||||
device_num,
|
||||
rank,
|
||||
shuffle=shuffle)
|
||||
val_dataset = de.GeneratorDataset(val_set,
|
||||
column_names=["image", "label"],
|
||||
sampler=val_distributed_sampler,
|
||||
num_parallel_workers=min(
|
||||
32, num_parallel_workers))
|
||||
val_dataset = val_dataset.map(operations=transform,
|
||||
input_columns=["image"],
|
||||
num_parallel_workers=min(
|
||||
32, num_parallel_workers))
|
||||
transform.insert(0, C.RandomHorizontalFlip())
|
||||
train_dataset = de.GeneratorDataset(
|
||||
dataset,
|
||||
column_names=["image", "label"],
|
||||
sampler=distributed_sampler,
|
||||
num_parallel_workers=min(32, num_parallel_workers))
|
||||
train_dataset = train_dataset.map(operations=transform,
|
||||
input_columns=["image"],
|
||||
num_parallel_workers=min(
|
||||
32, num_parallel_workers))
|
||||
train_dataset = train_dataset.batch(batch_size,
|
||||
drop_remainder=True)
|
||||
train_dataset = train_dataset.repeat(200)
|
||||
val_dataset = val_dataset.batch(batch_size, drop_remainder=True)
|
||||
val_dataset = val_dataset.repeat(200)
|
||||
self.train_loader = train_dataset.create_dict_iterator()
|
||||
self.val_loader = val_dataset.create_dict_iterator()
|
||||
else:
|
||||
dataset = CelebADataset(root, mode, selected_attrs)
|
||||
self.test_set = dataset
|
||||
self.dataset_size = int(len(dataset) / batch_size)
|
||||
distributed_sampler = DistributedSampler(len(dataset),
|
||||
device_num,
|
||||
rank,
|
||||
shuffle=shuffle)
|
||||
test_transform = [
|
||||
C.Resize((image_size, image_size)),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
test_dataset = de.GeneratorDataset(dataset,
|
||||
column_names=["image", "label"],
|
||||
sampler=distributed_sampler,
|
||||
num_parallel_workers=min(
|
||||
1, num_parallel_workers))
|
||||
test_dataset = test_dataset.map(operations=test_transform,
|
||||
input_columns=["image"],
|
||||
num_parallel_workers=min(
|
||||
32, num_parallel_workers))
|
||||
test_dataset = test_dataset.batch(batch_size, drop_remainder=True)
|
||||
test_dataset = test_dataset.repeat(1)
|
||||
|
||||
self.test_loader = test_dataset.create_dict_iterator()
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset_size
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""datasets"""
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
random.seed(1)
|
||||
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.tif', '.tiff']
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
"""if a file is a image"""
|
||||
return any(filename.lower().endswith(extension)
|
||||
for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def make_dataset(dir_path, mode, selected_attrs):
|
||||
""" make dataset """
|
||||
assert mode in ['train', 'val',
|
||||
'test'], "Mode [{}] is not supportable".format(mode)
|
||||
assert os.path.isdir(dir_path), '%s is not a valid directory' % dir_path
|
||||
lines = [
|
||||
line.rstrip() for line in open(
|
||||
os.path.join(dir_path, 'anno', 'list_attr_celeba.txt'), 'r')
|
||||
]
|
||||
all_attr_names = lines[1].split()
|
||||
attr2idx = {}
|
||||
idx2attr = {}
|
||||
for i, attr_name in enumerate(all_attr_names):
|
||||
attr2idx[attr_name] = i
|
||||
idx2attr[i] = attr_name
|
||||
|
||||
lines = lines[2:]
|
||||
if mode == 'train':
|
||||
lines = lines[:-20000] # train set contains 182599 images
|
||||
if mode == 'val':
|
||||
lines = lines[-20000:-18800] # val set contains 200 images
|
||||
if mode == 'test':
|
||||
lines = lines[-18800:] # test set contains 18800 images
|
||||
|
||||
items = []
|
||||
for i, line in enumerate(lines):
|
||||
split = line.split()
|
||||
filename = split[0]
|
||||
values = split[1:]
|
||||
label = []
|
||||
for attr_name in selected_attrs:
|
||||
idx = attr2idx[attr_name]
|
||||
label.append(np.float32(values[idx] == '1'))
|
||||
items.append([filename, label])
|
||||
return items
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset distributed sampler."""
|
||||
from __future__ import division
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
"""Distributed sampler."""
|
||||
def __init__(self,
|
||||
dataset_size,
|
||||
num_replicas=None,
|
||||
rank=None,
|
||||
shuffle=True):
|
||||
if num_replicas is None:
|
||||
print(
|
||||
"***********Setting world_size to 1 since it is not passed in ******************"
|
||||
)
|
||||
num_replicas = 1
|
||||
if rank is None:
|
||||
print(
|
||||
"***********Setting rank to 0 since it is not passed in ******************"
|
||||
)
|
||||
rank = 0
|
||||
self.dataset_size = dataset_size
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(
|
||||
math.ceil(dataset_size * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
if self.shuffle:
|
||||
indices = np.random.RandomState(seed=self.epoch).permutation(
|
||||
self.dataset_size)
|
||||
# np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
|
||||
indices = indices.tolist()
|
||||
self.epoch += 1
|
||||
# change to list type
|
||||
else:
|
||||
indices = list(range(self.dataset_size))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" models"""
|
||||
from .stgan import STGANModel
|
|
@ -0,0 +1,256 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" Base Model """
|
||||
import os
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
import mindspore.numpy as numpy
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from mindspore import load_checkpoint, load_param_into_net, save_checkpoint
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import Tensor
|
||||
from PIL import Image
|
||||
|
||||
from src.utils import mkdirs
|
||||
|
||||
|
||||
class BaseModel(ABC):
|
||||
""" BaseModel """
|
||||
def __init__(self, args):
|
||||
self.isTrain = args.isTrain
|
||||
self.save_dir = os.path.join(args.outputs_dir, args.experiment_name)
|
||||
self.model_names = []
|
||||
self.loss_names = []
|
||||
self.args = args
|
||||
self.optimizers = []
|
||||
self.current_iteration = 0
|
||||
self.netG = None
|
||||
self.netD = None
|
||||
|
||||
# continue train
|
||||
if self.isTrain:
|
||||
if self.args.continue_train:
|
||||
assert (os.path.exists(self.save_dir)
|
||||
), 'Checkpoint path not found at %s' % self.save_dir
|
||||
self.current_iteration = self.args.continue_iter
|
||||
else:
|
||||
if not os.path.exists(self.save_dir):
|
||||
mkdirs(self.save_dir)
|
||||
|
||||
# save config
|
||||
self.config_save_path = os.path.join(self.save_dir, 'config')
|
||||
if not os.path.exists(self.config_save_path):
|
||||
mkdirs(self.config_save_path)
|
||||
if self.isTrain:
|
||||
with open(os.path.join(self.config_save_path, 'train.conf'),
|
||||
'w') as f:
|
||||
f.write(json.dumps(vars(self.args)))
|
||||
if self.current_iteration == -1:
|
||||
with open(os.path.join(self.config_save_path, 'latest.conf'),
|
||||
'r') as f:
|
||||
self.current_iteration = int(f.read())
|
||||
|
||||
# sample save path
|
||||
if self.isTrain:
|
||||
self.sample_save_path = os.path.join(self.save_dir, 'sample')
|
||||
if not os.path.exists(self.sample_save_path):
|
||||
mkdirs(self.sample_save_path)
|
||||
|
||||
# test result save path
|
||||
if self.args.phase == 'test':
|
||||
self.test_result_save_path = os.path.join(self.save_dir, 'test')
|
||||
if not os.path.exists(self.test_result_save_path):
|
||||
mkdirs(self.test_result_save_path)
|
||||
|
||||
# train log save path
|
||||
if self.isTrain:
|
||||
self.train_log_path = os.path.join(self.save_dir, 'logs')
|
||||
if not os.path.exists(self.train_log_path):
|
||||
mkdirs(self.train_log_path)
|
||||
|
||||
@abstractmethod
|
||||
def set_input(self, input_data):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def optimize_parameters(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def test(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def eval(self):
|
||||
pass
|
||||
|
||||
def load_config(self):
|
||||
print('loading config from {}\n\n'.format(
|
||||
os.path.join(self.config_save_path, 'train.conf')))
|
||||
with open(os.path.join(self.config_save_path, 'train.conf'), 'r') as f:
|
||||
config = json.loads(f.read())
|
||||
print('config: ', config)
|
||||
return config
|
||||
|
||||
def save_networks(self):
|
||||
""" saving networks """
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
save_filename = '%s_%s.ckpt' % (self.current_iteration, name)
|
||||
save_filename_latest = 'latest_%s.ckpt' % name
|
||||
save_path = os.path.join(self.save_dir, 'ckpt')
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
save_path_latest = os.path.join(save_path,
|
||||
save_filename_latest)
|
||||
save_path = os.path.join(save_path, save_filename)
|
||||
net = getattr(self, 'net' + name)
|
||||
|
||||
print('saving the model to %s' % save_path)
|
||||
print('saving the model to %s' % save_path_latest)
|
||||
save_checkpoint(net, save_path)
|
||||
save_checkpoint(net, save_path_latest)
|
||||
|
||||
with open(os.path.join(self.config_save_path, 'latest.conf'),
|
||||
'w') as f:
|
||||
f.write("{}".format(self.current_iteration))
|
||||
|
||||
def load_networks(self, epoch='latest'):
|
||||
""" Load model checkpoint file
|
||||
|
||||
Parameters:
|
||||
epoch: epoch saved, default is latest
|
||||
"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
load_filename = '%s_%s.ckpt' % (epoch, name)
|
||||
load_path = os.path.join(self.save_dir, 'ckpt', load_filename)
|
||||
net = getattr(self, 'net' + name)
|
||||
|
||||
print('loading the model from %s' % load_path)
|
||||
if name == 'G':
|
||||
net.encoder.update_parameters_name(
|
||||
'network.network.netG.encoder.')
|
||||
net.decoder.update_parameters_name(
|
||||
'network.network.netG.decoder.')
|
||||
net.stu.update_parameters_name('network.network.netG.stu.')
|
||||
params = load_checkpoint(load_path, net, strict_load=True)
|
||||
load_param_into_net(net, params, strict_load=True)
|
||||
if epoch == 'latest':
|
||||
assert os.path.exists(
|
||||
os.path.join(self.config_save_path, 'latest.conf')
|
||||
), 'Missing iteration information of latest checkpoint file.'
|
||||
with open(os.path.join(self.config_save_path, 'latest.conf'),
|
||||
'r') as f:
|
||||
self.current_iteration = f.read()
|
||||
|
||||
def load_generator_from_path(self, path=None):
|
||||
""" Load generator checkpoint file from given path
|
||||
|
||||
Parameters:
|
||||
path: path of checkpoint file, required
|
||||
"""
|
||||
assert path is not None, 'Path of checkpoint can not be None'
|
||||
print('loading the model from %s' % path)
|
||||
net = getattr(self, 'netG')
|
||||
net.encoder.update_parameters_name('network.network.netG.encoder.')
|
||||
net.decoder.update_parameters_name('network.network.netG.decoder.')
|
||||
net.stu.update_parameters_name('network.network.netG.stu.')
|
||||
params = load_checkpoint(path, net, strict_load=True)
|
||||
load_param_into_net(net, params, strict_load=True)
|
||||
|
||||
def get_learning_rate(self):
|
||||
"""Learning rate generator."""
|
||||
lrs = [self.args.lr] * self.args.dataset_size * self.args.init_epoch
|
||||
lrs += [self.args.lr * 0.1] * self.args.dataset_size * (
|
||||
self.args.n_epochs - self.args.init_epoch)
|
||||
return Tensor(np.array(lrs).astype(
|
||||
np.float32))[self.current_iteration:]
|
||||
|
||||
def save_image(self, img, img_path):
|
||||
"""Save a numpy image to the disk
|
||||
|
||||
Parameters:
|
||||
img (numpy array / Tensor): image to save.
|
||||
image_path (str): the path of the image.
|
||||
"""
|
||||
if isinstance(img, Tensor):
|
||||
img = self.decode_image(img)
|
||||
elif not isinstance(img, np.ndarray):
|
||||
raise ValueError(
|
||||
"img should be Tensor or numpy array, but get {}".format(
|
||||
type(img)))
|
||||
img_pil = Image.fromarray(img)
|
||||
img_pil.save(img_path)
|
||||
|
||||
def decode_image(self, img):
|
||||
"""Decode a [1, C, H, W] Tensor to image numpy array."""
|
||||
mean = 0.5 * 255
|
||||
std = 0.5 * 255
|
||||
return (img.asnumpy() * std + mean).astype(np.uint8).transpose(
|
||||
(1, 2, 0))
|
||||
|
||||
def create_labels(self, c_org, selected_attrs=None):
|
||||
"""Generate target domain labels for debugging and testing."""
|
||||
hair_color_indices = []
|
||||
for i, attr_name in enumerate(selected_attrs):
|
||||
if attr_name in [
|
||||
'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'
|
||||
]:
|
||||
hair_color_indices.append(i)
|
||||
|
||||
c_trg_list = []
|
||||
for i in range(len(selected_attrs)):
|
||||
c_trg = numpy.copy(c_org)
|
||||
if i in hair_color_indices:
|
||||
c_trg[:, i] = Tensor(1, mstype.float32)
|
||||
for j in hair_color_indices:
|
||||
if j != i:
|
||||
c_trg[:, j] = Tensor(0, mstype.float32)
|
||||
else:
|
||||
c_trg[:, i] = Tensor((c_trg[:, i] == 0).asnumpy(),
|
||||
mstype.float32)
|
||||
|
||||
c_trg_list.append(c_trg)
|
||||
return c_trg_list
|
||||
|
||||
def create_test_label(self, c_org, selected_attrs=None):
|
||||
"""Generate target domain labels for debugging and testing."""
|
||||
hair_color_indices = []
|
||||
for i, attr_name in enumerate(selected_attrs):
|
||||
if attr_name in [
|
||||
'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'
|
||||
]:
|
||||
hair_color_indices.append(i)
|
||||
|
||||
c_trg = numpy.copy(c_org)
|
||||
for i in range(len(selected_attrs)):
|
||||
if i in hair_color_indices:
|
||||
c_trg[:, i] = 1.
|
||||
for j in hair_color_indices:
|
||||
if j != i:
|
||||
c_trg[:, j] = Tensor(0, mstype.float32)
|
||||
else:
|
||||
c_trg[:, i] = Tensor((c_trg[:, i] == 0).asnumpy(),
|
||||
mstype.float32)
|
||||
|
||||
return c_trg
|
||||
|
||||
def denorm(self, x):
|
||||
""" Denormalization """
|
||||
out = (x + 1) / 2
|
||||
return C.clip_by_value(out, 0, 1)
|
|
@ -0,0 +1,143 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" losses """
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
import mindspore.ops.operations as P
|
||||
|
||||
from mindspore import Tensor
|
||||
|
||||
from src.utils.tools import is_version_satisfied_1_2_0
|
||||
|
||||
|
||||
class GeneratorLoss(nn.Cell):
|
||||
""" GeneratorLoss """
|
||||
def __init__(self, netG, netD, args, mode):
|
||||
super(GeneratorLoss, self).__init__()
|
||||
self.args = args
|
||||
self.platform = args.platform
|
||||
self.use_stu = args.use_stu
|
||||
self.mode = mode
|
||||
self.netG = netG
|
||||
self.netD = netD
|
||||
if self.platform == 'Ascend':
|
||||
self.cls_loss = nn.BCEWithLogitsLoss(reduction='sum')
|
||||
else:
|
||||
self.cls_loss = ClassificationLoss()
|
||||
self.cyc_loss = P.ReduceMean()
|
||||
|
||||
self.lambda2 = Tensor(args.lambda2)
|
||||
self.lambda3 = Tensor(args.lambda3)
|
||||
self.attr_mode = args.attr_mode
|
||||
|
||||
def construct(self, real_x, c_org, c_trg, attr_diff):
|
||||
""" construct """
|
||||
fake_x = self.netG(real_x, attr_diff)
|
||||
rec_x = self.netG(real_x, c_org - c_org)
|
||||
loss_fake_G, loss_att_fake = self.netD(fake_x)
|
||||
|
||||
loss_adv_G = -loss_fake_G.mean()
|
||||
loss_cls_G = self.cls_loss(loss_att_fake,
|
||||
c_trg) / loss_att_fake.shape[0]
|
||||
loss_rec_G = (real_x - rec_x).abs().mean()
|
||||
|
||||
loss_G = loss_adv_G + self.lambda2 * loss_cls_G + self.lambda3 * loss_rec_G
|
||||
|
||||
return (fake_x, loss_G, loss_fake_G.mean(), loss_cls_G, loss_rec_G,
|
||||
loss_adv_G)
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Cell):
|
||||
""" DiscriminatorLoss """
|
||||
def __init__(self, netD, netG, args, mode):
|
||||
super(DiscriminatorLoss, self).__init__()
|
||||
self.mode = mode
|
||||
self.netD = netD
|
||||
self.netG = netG
|
||||
self.platform = args.platform
|
||||
self.gradient_penalty = WGANGPGradientPenalty(netD)
|
||||
if self.platform == 'Ascend' and is_version_satisfied_1_2_0(
|
||||
args.ms_version):
|
||||
self.cls_loss = nn.BCEWithLogitsLoss(reduction='sum')
|
||||
else:
|
||||
self.cls_loss = ClassificationLoss()
|
||||
self.cyc_loss = P.ReduceMean()
|
||||
|
||||
self.lambda_gp = Tensor(args.lambda_gp)
|
||||
self.lambda1 = Tensor(args.lambda1)
|
||||
self.thres_int = Tensor(args.thres_int, ms.float32)
|
||||
self.attr_mode = args.attr_mode
|
||||
|
||||
def construct(self, real_x, c_org, c_trg, attr_diff, alpha):
|
||||
""" construct """
|
||||
loss_real_D, loss_att_real = self.netD(real_x)
|
||||
loss_real_D = -loss_real_D.mean()
|
||||
loss_cls_D = self.cls_loss(loss_att_real,
|
||||
c_org) / loss_att_real.shape[0]
|
||||
|
||||
fake_x = self.netG(real_x, attr_diff)
|
||||
loss_fake_D, _ = self.netD(ops.functional.stop_gradient(fake_x))
|
||||
loss_fake_D = loss_fake_D.mean()
|
||||
|
||||
x_hat = (alpha * real_x + (1 - alpha) * fake_x)
|
||||
loss_gp_D = self.gradient_penalty(x_hat)
|
||||
loss_adv_D = loss_real_D + loss_fake_D + self.lambda_gp * loss_gp_D
|
||||
loss_D = self.lambda1 * loss_cls_D + loss_adv_D
|
||||
|
||||
return (loss_D, loss_real_D, loss_fake_D, loss_cls_D, loss_gp_D,
|
||||
loss_adv_D, attr_diff)
|
||||
|
||||
|
||||
class WGANGPGradientPenalty(nn.Cell):
|
||||
""" WGANGPGradientPenalty """
|
||||
def __init__(self, discriminator):
|
||||
super(WGANGPGradientPenalty, self).__init__()
|
||||
self.gradient_op = ops.GradOperation()
|
||||
self.reduce_sum = ops.ReduceSum()
|
||||
self.reduce_sum_keep_dim = ops.ReduceSum(keep_dims=True)
|
||||
self.sqrt = ops.Sqrt()
|
||||
self.discriminator = discriminator
|
||||
self.gradientWithInput = GradientWithInput(discriminator)
|
||||
|
||||
def construct(self, x_hat):
|
||||
gradient = self.gradient_op(self.gradientWithInput)(x_hat)
|
||||
gradient_1 = ops.reshape(gradient, (x_hat.shape[0], -1))
|
||||
gradient_1 = self.sqrt(self.reduce_sum(gradient_1**2, 1))
|
||||
gradient_penalty = ((gradient_1 - 1.0)**2).mean()
|
||||
return gradient_penalty
|
||||
|
||||
|
||||
class GradientWithInput(nn.Cell):
|
||||
def __init__(self, discriminator):
|
||||
super(GradientWithInput, self).__init__()
|
||||
self.reduce_sum = ops.ReduceSum()
|
||||
self.discriminator = discriminator
|
||||
|
||||
def construct(self, interpolates):
|
||||
decision_interpolate, _ = self.discriminator(interpolates)
|
||||
decision_interpolate = self.reduce_sum(decision_interpolate, 0)
|
||||
return decision_interpolate
|
||||
|
||||
|
||||
class ClassificationLoss(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ClassificationLoss, self).__init__()
|
||||
self.BCELoss = nn.BCELoss(reduction='sum')
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def construct(self, logit, target):
|
||||
logit = self.sigmoid(logit)
|
||||
return self.BCELoss(logit, target)
|
|
@ -0,0 +1,455 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" networks """
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.ops.functional as F
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.ops import constexpr
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.common import initializer as init
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
|
||||
|
||||
def init_weights(net, init_type='normal', init_gain=0.02):
|
||||
"""
|
||||
Initialize network weights.
|
||||
|
||||
Parameters:
|
||||
net (Cell): Network to be initialized
|
||||
init_type (str): The name of an initialization method: normal | xavier.
|
||||
init_gain (float): Gain factor for normal and xavier.
|
||||
|
||||
"""
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
|
||||
if init_type == 'normal':
|
||||
cell.weight.set_data(
|
||||
init.initializer(init.Normal(init_gain),
|
||||
cell.weight.shape))
|
||||
elif init_type == 'xavier':
|
||||
cell.weight.set_data(
|
||||
init.initializer(init.XavierUniform(init_gain),
|
||||
cell.weight.shape))
|
||||
elif init_type == 'KaimingUniform':
|
||||
cell.weight.set_data(
|
||||
init.initializer(init.HeUniform(init_gain),
|
||||
cell.weight.shape))
|
||||
elif init_type == 'constant':
|
||||
cell.weight.set_data(init.initializer(0.001,
|
||||
cell.weight.shape))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'initialization method [%s] is not implemented' %
|
||||
init_type)
|
||||
elif isinstance(cell, _GroupNorm):
|
||||
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
|
||||
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
|
||||
|
||||
|
||||
class ConvGRUCell(nn.Cell):
|
||||
""" Convolutional GRU Cell """
|
||||
def __init__(self, n_attrs, in_dim, out_dim, kernel_size=3, stu_norm='bn'):
|
||||
super(ConvGRUCell, self).__init__()
|
||||
self.concat = ops.Concat(axis=1)
|
||||
self.reshape = ops.Reshape()
|
||||
self.n_attrs = n_attrs
|
||||
|
||||
self.normalization = nn.BatchNorm2d(out_dim)
|
||||
if stu_norm == 'in':
|
||||
self.normalization = _GroupNorm(num_groups=out_dim,
|
||||
num_channels=out_dim)
|
||||
self.upsample = nn.Conv2dTranspose(in_dim * 2 + n_attrs,
|
||||
out_dim,
|
||||
4,
|
||||
2,
|
||||
padding=1,
|
||||
pad_mode='pad')
|
||||
self.reset_gate = nn.SequentialCell(
|
||||
nn.Conv2d(in_dim + out_dim,
|
||||
out_dim,
|
||||
kernel_size,
|
||||
1,
|
||||
padding=((kernel_size - 1) // 2),
|
||||
pad_mode='pad'), self.normalization, nn.Sigmoid())
|
||||
self.update_gate = nn.SequentialCell(
|
||||
nn.Conv2d(in_dim + out_dim,
|
||||
out_dim,
|
||||
kernel_size,
|
||||
1,
|
||||
padding=((kernel_size - 1) // 2),
|
||||
pad_mode='pad'), self.normalization, nn.Sigmoid())
|
||||
self.hidden = nn.SequentialCell(
|
||||
nn.Conv2d(in_dim + out_dim,
|
||||
out_dim,
|
||||
kernel_size,
|
||||
1,
|
||||
padding=((kernel_size - 1) // 2),
|
||||
pad_mode='pad'), self.normalization, nn.Tanh())
|
||||
|
||||
def construct(self, input_data, old_state, attr):
|
||||
""" construct """
|
||||
n, _, h, w = old_state.shape
|
||||
attr = self.reshape(attr, (n, self.n_attrs, 1, 1))
|
||||
tile = ops.Tile()
|
||||
attr = tile(attr, (1, 1, h, w))
|
||||
state_hat = self.upsample(self.concat((old_state, attr)))
|
||||
r = self.reset_gate(self.concat((input_data, state_hat)))
|
||||
z = self.update_gate(self.concat((input_data, state_hat)))
|
||||
new_state = r * state_hat
|
||||
hidden_info = self.hidden(self.concat((input_data, new_state)))
|
||||
output = (1 - z) * state_hat + z * hidden_info
|
||||
return output, new_state
|
||||
|
||||
|
||||
class Generator(nn.Cell):
|
||||
""" Generator """
|
||||
def __init__(self,
|
||||
attr_dim,
|
||||
enc_dim=64,
|
||||
dec_dim=64,
|
||||
enc_layers=5,
|
||||
dec_layers=5,
|
||||
shortcut_layers=2,
|
||||
stu_kernel_size=3,
|
||||
use_stu=True,
|
||||
one_more_conv=True,
|
||||
stu_norm='bn'):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
self.n_attrs = attr_dim
|
||||
self.enc_layers = enc_layers
|
||||
self.dec_layers = dec_layers
|
||||
self.shortcut_layers = min(shortcut_layers, enc_layers - 1,
|
||||
dec_layers - 1)
|
||||
self.use_stu = use_stu
|
||||
self.concat = ops.Concat(axis=1)
|
||||
|
||||
self.encoder = nn.CellList()
|
||||
in_channels = 3
|
||||
for i in range(self.enc_layers):
|
||||
self.encoder.append(
|
||||
nn.SequentialCell(
|
||||
nn.Conv2d(in_channels,
|
||||
enc_dim * 2**i,
|
||||
4,
|
||||
2,
|
||||
padding=1,
|
||||
pad_mode='pad'), nn.BatchNorm2d(enc_dim * 2**i),
|
||||
nn.LeakyReLU(alpha=0.2)))
|
||||
in_channels = enc_dim * 2**i
|
||||
|
||||
# Selective Transfer Unit (STU)
|
||||
if self.use_stu:
|
||||
self.stu = nn.CellList()
|
||||
for i in reversed(
|
||||
range(self.enc_layers - 1 - self.shortcut_layers,
|
||||
self.enc_layers - 1)):
|
||||
self.stu.append(
|
||||
ConvGRUCell(self.n_attrs, enc_dim * 2**i, enc_dim * 2**i,
|
||||
stu_kernel_size, stu_norm))
|
||||
|
||||
self.decoder = nn.CellList()
|
||||
for i in range(self.dec_layers):
|
||||
if i < self.dec_layers - 1:
|
||||
if i == 0:
|
||||
self.decoder.append(
|
||||
nn.SequentialCell(
|
||||
nn.Conv2dTranspose(
|
||||
dec_dim * 2**(self.dec_layers - 1) + attr_dim,
|
||||
dec_dim * 2**(self.dec_layers - 1),
|
||||
4,
|
||||
2,
|
||||
padding=1,
|
||||
pad_mode='pad'), nn.BatchNorm2d(in_channels),
|
||||
nn.ReLU()))
|
||||
elif i <= self.shortcut_layers:
|
||||
self.decoder.append(
|
||||
nn.SequentialCell(
|
||||
nn.Conv2dTranspose(
|
||||
dec_dim * 3 * 2**(self.dec_layers - 1 - i),
|
||||
dec_dim * 2**(self.dec_layers - 1 - i),
|
||||
4,
|
||||
2,
|
||||
padding=1,
|
||||
pad_mode='pad'),
|
||||
nn.BatchNorm2d(dec_dim *
|
||||
2**(self.dec_layers - 1 - i)),
|
||||
nn.ReLU()))
|
||||
else:
|
||||
self.decoder.append(
|
||||
nn.SequentialCell(
|
||||
nn.Conv2dTranspose(
|
||||
dec_dim * 2**(self.dec_layers - i),
|
||||
dec_dim * 2**(self.dec_layers - 1 - i),
|
||||
4,
|
||||
2,
|
||||
padding=1,
|
||||
pad_mode='pad'),
|
||||
nn.BatchNorm2d(dec_dim *
|
||||
2**(self.dec_layers - 1 - i)),
|
||||
nn.ReLU()))
|
||||
else:
|
||||
in_dim = dec_dim * 3 if self.shortcut_layers == self.dec_layers - 1 else dec_dim * 2
|
||||
if one_more_conv:
|
||||
self.decoder.append(
|
||||
nn.SequentialCell(
|
||||
nn.Conv2dTranspose(in_dim,
|
||||
dec_dim // 4,
|
||||
4,
|
||||
2,
|
||||
padding=1,
|
||||
pad_mode='pad'),
|
||||
nn.BatchNorm2d(dec_dim // 4), nn.ReLU(),
|
||||
nn.Conv2dTranspose(dec_dim // 4,
|
||||
3,
|
||||
3,
|
||||
1,
|
||||
padding=1,
|
||||
pad_mode='pad'), nn.Tanh()))
|
||||
else:
|
||||
self.decoder.append(
|
||||
nn.SequentialCell(
|
||||
nn.Conv2dTranspose(in_dim,
|
||||
3,
|
||||
4,
|
||||
2,
|
||||
padding=1,
|
||||
pad_mode='pad'), nn.Tanh()))
|
||||
|
||||
def construct(self, x, a):
|
||||
""" construct """
|
||||
# propagate encoder layers
|
||||
y = []
|
||||
x_ = x
|
||||
for layer in self.encoder:
|
||||
x_ = layer(x_)
|
||||
y.append(x_)
|
||||
|
||||
out = y[-1]
|
||||
reshape = ops.Reshape()
|
||||
(n, _, h, w) = out.shape
|
||||
attr = reshape(a, (n, self.n_attrs, 1, 1))
|
||||
tile = ops.Tile()
|
||||
attr = tile(attr, (1, 1, h, w))
|
||||
out = self.decoder[0](self.concat((out, attr)))
|
||||
stu_state = y[-1]
|
||||
|
||||
# propagate shortcut layers
|
||||
for i in range(1, self.shortcut_layers + 1):
|
||||
if self.use_stu:
|
||||
stu_out, stu_state = self.stu[i - 1](y[-(i + 1)], stu_state, a)
|
||||
out = self.concat((out, stu_out))
|
||||
out = self.decoder[i](out)
|
||||
else:
|
||||
out = self.concat((out, y[-(i + 1)]))
|
||||
out = self.decoder[i](out)
|
||||
|
||||
# propagate non-shortcut layers
|
||||
for i in range(self.shortcut_layers + 1, self.dec_layers):
|
||||
out = self.decoder[i](out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Discriminator(nn.Cell):
|
||||
""" Discriminator Cell """
|
||||
def __init__(self,
|
||||
image_size=128,
|
||||
attr_dim=10,
|
||||
conv_dim=64,
|
||||
fc_dim=1024,
|
||||
n_layers=5):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
layers = []
|
||||
in_channels = 3
|
||||
for i in range(n_layers):
|
||||
layers.append(
|
||||
nn.SequentialCell(
|
||||
nn.Conv2d(in_channels,
|
||||
conv_dim * 2**i,
|
||||
4,
|
||||
2,
|
||||
padding=1,
|
||||
pad_mode='pad'),
|
||||
_GroupNorm(num_groups=conv_dim * 2**i,
|
||||
num_channels=conv_dim * 2**i),
|
||||
nn.LeakyReLU(alpha=0.2)))
|
||||
in_channels = conv_dim * 2**i
|
||||
self.conv = nn.SequentialCell(*layers)
|
||||
feature_size = image_size // 2**n_layers
|
||||
self.fc_adv = nn.SequentialCell(
|
||||
nn.Flatten(),
|
||||
nn.Dense(conv_dim * 2**(n_layers - 1) * feature_size**2, fc_dim),
|
||||
nn.LeakyReLU(alpha=0.2), nn.Flatten(), nn.Dense(fc_dim, 1))
|
||||
self.fc_att = nn.SequentialCell(
|
||||
nn.Flatten(),
|
||||
nn.Dense(conv_dim * 2**(n_layers - 1) * feature_size**2, fc_dim),
|
||||
nn.LeakyReLU(alpha=0.2),
|
||||
nn.Flatten(),
|
||||
nn.Dense(fc_dim, attr_dim),
|
||||
)
|
||||
|
||||
def construct(self, x):
|
||||
y = self.conv(x)
|
||||
reshape = ops.Reshape()
|
||||
y = reshape(y, (y.shape[0], -1))
|
||||
logit_adv = self.fc_adv(y)
|
||||
logit_att = self.fc_att(y)
|
||||
return logit_adv, logit_att
|
||||
|
||||
|
||||
class _GroupNorm(nn.GroupNorm):
|
||||
""" Rewrite of original GroupNorm """
|
||||
def __init__(self,
|
||||
num_groups,
|
||||
num_channels,
|
||||
eps=1e-05,
|
||||
affine=True,
|
||||
gamma_init='ones',
|
||||
beta_init='zeros'):
|
||||
super().__init__(num_groups,
|
||||
num_channels,
|
||||
eps=1e-05,
|
||||
affine=True,
|
||||
gamma_init='ones',
|
||||
beta_init='zeros')
|
||||
self.pow = ops.Pow()
|
||||
|
||||
def _cal_output(self, x):
|
||||
"""calculate groupnorm output"""
|
||||
batch, channel, height, width = self.shape(x)
|
||||
_channel_check(channel, self.num_channels)
|
||||
x = self.reshape(x, (batch, self.num_groups, -1))
|
||||
mean = self.reduce_mean(x, 2)
|
||||
var = self.reduce_sum(self.square(x - mean),
|
||||
2) / (channel * height * width / self.num_groups)
|
||||
std = self.pow((var + self.eps), 0.5)
|
||||
x = (x - mean) / std
|
||||
x = self.reshape(x, (batch, channel, height, width))
|
||||
output = x * self.reshape(self.gamma, (-1, 1, 1)) + self.reshape(
|
||||
self.beta, (-1, 1, 1))
|
||||
return output
|
||||
|
||||
|
||||
@constexpr
|
||||
def _channel_check(channel, num_channel):
|
||||
if channel != num_channel:
|
||||
raise ValueError("the input channel is not equal with num_channel")
|
||||
|
||||
|
||||
class GeneratorWithLossCell(nn.Cell):
|
||||
""" GeneratorWithLossCell """
|
||||
def __init__(self, network, args):
|
||||
super(GeneratorWithLossCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.lambda2 = args.lambda2
|
||||
self.lambda3 = args.lambda3
|
||||
|
||||
def construct(self, x_real, c_org, c_trg, attr_diff):
|
||||
_, _, _, loss_cls_G, loss_rec_G, loss_adv_G = self.network(
|
||||
x_real, c_org, c_trg, attr_diff)
|
||||
return loss_adv_G + self.lambda2 * loss_cls_G + self.lambda3 * loss_rec_G
|
||||
|
||||
|
||||
class DiscriminatorWithLossCell(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(DiscriminatorWithLossCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
|
||||
def construct(self, x_real, c_org, c_trg, attr_diff, alpha):
|
||||
loss_D, _, _, _, _, _, _ = self.network(x_real, c_org, c_trg,
|
||||
attr_diff, alpha)
|
||||
return loss_D
|
||||
|
||||
|
||||
class TrainOneStepGenerator(nn.Cell):
|
||||
""" Training class of Generator """
|
||||
def __init__(self, loss_G_model, optimizer, args):
|
||||
super(TrainOneStepGenerator, self).__init__()
|
||||
self.optimizer = optimizer
|
||||
self.loss_G_model = loss_G_model
|
||||
self.loss_G_model.set_grad()
|
||||
self.loss_G_model.set_train()
|
||||
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.weights = optimizer.parameters
|
||||
self.network = GeneratorWithLossCell(loss_G_model, args)
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.grad_reducer = F.identity
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [
|
||||
ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL,
|
||||
ParallelMode.AUTO_PARALLEL
|
||||
]:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
if auto_parallel_context().get_device_num_is_set():
|
||||
degree = context.get_auto_parallel_context("device_num")
|
||||
else:
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = nn.DistributedGradReducer(
|
||||
self.weights, mean, degree)
|
||||
|
||||
def construct(self, real_x, c_org, c_trg, attr_diff, sens=1.0):
|
||||
fake_x, loss_G, loss_fake_G, loss_cls_G, loss_rec_G, loss_adv_G =\
|
||||
self.loss_G_model(real_x, c_org, c_trg, attr_diff)
|
||||
sens = P.Fill()(P.DType()(loss_G), P.Shape()(loss_G), sens)
|
||||
grads = self.grad(self.network, self.weights)(real_x, c_org, c_trg,
|
||||
attr_diff, sens)
|
||||
grads = self.grad_reducer(grads)
|
||||
return (ops.depend(loss_G, self.optimizer(grads)), fake_x, loss_G,
|
||||
loss_fake_G, loss_cls_G, loss_rec_G, loss_adv_G)
|
||||
|
||||
|
||||
class TrainOneStepDiscriminator(nn.Cell):
|
||||
""" Training class of Discriminator """
|
||||
def __init__(self, loss_D_model, optimizer):
|
||||
super(TrainOneStepDiscriminator, self).__init__()
|
||||
self.optimizer = optimizer
|
||||
self.loss_D_model = loss_D_model
|
||||
self.loss_D_model.set_grad()
|
||||
self.loss_D_model.set_train()
|
||||
self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.weights = optimizer.parameters
|
||||
self.network = DiscriminatorWithLossCell(loss_D_model)
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.grad_reducer = F.identity
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [
|
||||
ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL,
|
||||
ParallelMode.AUTO_PARALLEL
|
||||
]:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
if auto_parallel_context().get_device_num_is_set():
|
||||
degree = context.get_auto_parallel_context("device_num")
|
||||
else:
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = nn.DistributedGradReducer(
|
||||
self.weights, mean, degree)
|
||||
|
||||
def construct(self, real_x, c_org, c_trg, attr_diff, alpha, sens=1.0):
|
||||
loss_D, loss_real_D, loss_fake_D, loss_cls_D, loss_gp_D, loss_adv_D, attr_diff =\
|
||||
self.loss_D_model(real_x, c_org, c_trg, attr_diff, alpha)
|
||||
sens = P.Fill()(P.DType()(loss_D), P.Shape()(loss_D), sens)
|
||||
grads = self.grad(self.network, self.weights)(real_x, c_org, c_trg,
|
||||
attr_diff, alpha, sens)
|
||||
grads = self.grad_reducer(grads)
|
||||
return (ops.depend(loss_D, self.optimizer(grads)), loss_D, loss_real_D,
|
||||
loss_fake_D, loss_cls_D, loss_gp_D, loss_adv_D, attr_diff)
|
|
@ -0,0 +1,189 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" STGAN Models """
|
||||
import os
|
||||
import math
|
||||
import random
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor
|
||||
import numpy as np
|
||||
|
||||
from .networks import init_weights, Discriminator, Generator, TrainOneStepGenerator, TrainOneStepDiscriminator
|
||||
from .losses import GeneratorLoss, DiscriminatorLoss
|
||||
from .base_model import BaseModel
|
||||
|
||||
|
||||
class STGANModel(BaseModel):
|
||||
""" STGANModel """
|
||||
def __init__(self, args):
|
||||
""" """
|
||||
BaseModel.__init__(self, args)
|
||||
self.rand_int = ops.UniformInt()
|
||||
self.concat = ops.operations.Concat(axis=1)
|
||||
self.use_stu = args.use_stu
|
||||
self.args = args
|
||||
self.n_attrs = len(args.attrs)
|
||||
self.mode = args.mode
|
||||
self.loss_names = [
|
||||
'D', 'G', 'adv_D', 'cls_D', 'real_D', 'fake_D', 'gp_D', 'adv_G',
|
||||
'cls_G', 'rec_G'
|
||||
]
|
||||
|
||||
if self.isTrain:
|
||||
self.model_names = ['G', 'D']
|
||||
else:
|
||||
self.model_names = ['G']
|
||||
|
||||
self.netG = Generator(self.n_attrs,
|
||||
args.enc_dim,
|
||||
args.dec_dim,
|
||||
args.enc_layers,
|
||||
args.dec_layers,
|
||||
args.shortcut_layers,
|
||||
args.stu_kernel_size,
|
||||
use_stu=args.use_stu,
|
||||
one_more_conv=args.one_more_conv,
|
||||
stu_norm=args.stu_norm)
|
||||
print('Generator: ', self.netG)
|
||||
if self.isTrain:
|
||||
self.netD = Discriminator(args.image_size, self.n_attrs,
|
||||
args.dis_dim, args.dis_fc_dim,
|
||||
args.dis_layers)
|
||||
print('Discriminator: ', self.netD)
|
||||
|
||||
if self.args.continue_train:
|
||||
continue_iter = self.args.continue_iter if self.args.continue_iter != -1 else 'latest'
|
||||
self.load_networks(continue_iter)
|
||||
|
||||
if self.netG is not None:
|
||||
num_params = 0
|
||||
for p in self.netG.trainable_params():
|
||||
num_params += np.prod(p.shape)
|
||||
print(
|
||||
'\n\n\nGenerator trainable parameters: {}'.format(num_params))
|
||||
if self.netD is not None:
|
||||
num_params = 0
|
||||
for p in self.netD.trainable_params():
|
||||
num_params += np.prod(p.shape)
|
||||
print('Discriminator trainable parameters: {}\n\n\n'.format(
|
||||
num_params))
|
||||
|
||||
if self.isTrain:
|
||||
if not self.args.continue_train:
|
||||
init_weights(self.netG, 'KaimingUniform', math.sqrt(5))
|
||||
init_weights(self.netD, 'KaimingUniform', math.sqrt(5))
|
||||
self.loss_D_model = DiscriminatorLoss(self.netD, self.netG,
|
||||
self.args, self.mode)
|
||||
self.loss_G_model = GeneratorLoss(self.netG, self.netD, self.args,
|
||||
self.mode)
|
||||
self.optimizer_G = nn.Adam(self.netG.trainable_params(),
|
||||
self.get_learning_rate(),
|
||||
beta1=args.beta1,
|
||||
beta2=args.beta2)
|
||||
self.optimizer_D = nn.Adam(self.netD.trainable_params(),
|
||||
self.get_learning_rate(),
|
||||
beta1=args.beta1,
|
||||
beta2=args.beta2)
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
self.optimizers.append(self.optimizer_D)
|
||||
self.train_G = TrainOneStepGenerator(self.loss_G_model,
|
||||
self.optimizer_G, self.args)
|
||||
self.train_D = TrainOneStepDiscriminator(self.loss_D_model,
|
||||
self.optimizer_D)
|
||||
|
||||
self.train_G.set_train()
|
||||
self.train_D.set_train()
|
||||
|
||||
if self.args.phase == 'test':
|
||||
if self.args.ckpt_path is not None:
|
||||
self.load_generator_from_path(self.args.ckpt_path)
|
||||
else:
|
||||
self.load_networks('latest' if args.test == -1 else args.test)
|
||||
|
||||
def set_input(self, input_data):
|
||||
self.label_org = Tensor(input_data['label'], mstype.float32)
|
||||
(shape_0, _) = self.label_org.shape
|
||||
rand_idx = random.sample(range(0, shape_0), shape_0)
|
||||
self.label_trg = self.label_org[rand_idx]
|
||||
self.real_x = Tensor(input_data['image'], mstype.float32)
|
||||
|
||||
def test(self, data, filename=''):
|
||||
""" Test Function """
|
||||
data_image = data['image']
|
||||
data_label = data['label']
|
||||
c_trg = self.create_test_label(data_label, self.args.attrs)
|
||||
attr_diff = c_trg - data_label if self.args.attr_mode == 'diff' else c_trg
|
||||
attr_diff = attr_diff * self.args.thres_int
|
||||
fake_x = self.netG(data_image, attr_diff)
|
||||
self.save_image(
|
||||
fake_x[0],
|
||||
os.path.join(self.test_result_save_path, '{}'.format(filename)))
|
||||
|
||||
def optimize_parameters(self):
|
||||
""" Optimizing Model's Trainable Parameters
|
||||
"""
|
||||
attr_diff = self.label_trg - self.label_org if self.args.attr_mode == 'diff' else self.label_trg
|
||||
(h, w) = attr_diff.shape
|
||||
rand_attr = Tensor(np.random.rand(h, w), mstype.float32)
|
||||
attr_diff = attr_diff * rand_attr * (2 * self.args.thres_int)
|
||||
alpha = Tensor(np.random.randn(self.real_x.shape[0], 1, 1, 1),
|
||||
mstype.float32)
|
||||
# train D
|
||||
_, loss_D, loss_real_D, loss_fake_D, loss_cls_D, loss_gp_D, loss_adv_D, attr_diff =\
|
||||
self.train_D(self.real_x, self.label_org, self.label_trg, attr_diff, alpha)
|
||||
if self.current_iteration % self.args.n_critic == 0:
|
||||
# train G
|
||||
_, _, loss_G, loss_fake_G, loss_cls_G, loss_rec_G, loss_adv_G =\
|
||||
self.train_G(self.real_x, self.label_org, self.label_trg, attr_diff)
|
||||
# saving losses
|
||||
if (self.current_iteration / 5) % self.args.print_freq == 0:
|
||||
with open(os.path.join(self.train_log_path, 'loss.log'),
|
||||
'a+') as f:
|
||||
f.write('Iter: %s\n' % self.current_iteration)
|
||||
f.write(
|
||||
'loss D: %s, loss D_real: %s, loss D_fake: %s, loss D_gp: %s, loss D_adv: %s, loss D_cls: %s \n'
|
||||
% (loss_D, loss_real_D, loss_fake_D, loss_gp_D,
|
||||
loss_adv_D, loss_cls_D))
|
||||
f.write(
|
||||
'loss G: %s, loss G_rec: %s, loss G_fake: %s, loss G_adv: %s, loss G_cls: %s \n\n'
|
||||
% (loss_G, loss_rec_G, loss_fake_G, loss_adv_G,
|
||||
loss_cls_G))
|
||||
|
||||
def eval(self, data_loader):
|
||||
""" Eval function of STGAN
|
||||
"""
|
||||
val_loader = data_loader.val_loader
|
||||
concat_3d = ops.Concat(axis=3)
|
||||
concat_1d = ops.Concat(axis=1)
|
||||
data = next(val_loader)
|
||||
data_image = data['image']
|
||||
data_label = data['label']
|
||||
sample_list = self.create_labels(data_label, self.args.attrs)
|
||||
sample_list.insert(0, data_label)
|
||||
x_concat = data_image
|
||||
for c_trg_sample in sample_list:
|
||||
attr_diff = c_trg_sample - data_label if self.args.attr_mode == 'diff' else c_trg_sample
|
||||
attr_diff = attr_diff * self.args.thres_int
|
||||
fake_x = self.netG(data_image, attr_diff)
|
||||
x_concat = concat_3d((x_concat, fake_x))
|
||||
sample_result = x_concat[0]
|
||||
for i in range(1, x_concat.shape[0]):
|
||||
sample_result = concat_1d((sample_result, x_concat[i]))
|
||||
self.save_image(
|
||||
sample_result,
|
||||
os.path.join(self.sample_save_path,
|
||||
'samples_{}.jpg'.format(self.current_iteration + 1)))
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" utils """
|
||||
from .args import get_args
|
||||
from .tools import mkdirs
|
|
@ -0,0 +1,296 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""arguments"""
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import datetime
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init
|
||||
|
||||
def add_basic_parameters(parser):
|
||||
""" add basic parameters """
|
||||
parser.add_argument("--platform",
|
||||
type=str,
|
||||
default="Ascend",
|
||||
choices=("Ascend", "GPU", "CPU"),
|
||||
help="running platform, support Ascend, GPU and CPU")
|
||||
parser.add_argument("--device_id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="device id, default is 0")
|
||||
parser.add_argument('--device_num',
|
||||
type=int,
|
||||
default=1,
|
||||
help='device num, default is 1.')
|
||||
parser.add_argument('--ms_version',
|
||||
type=str,
|
||||
default='1.2.0',
|
||||
help="Mindspore's Version, default is 1.2.0")
|
||||
return parser
|
||||
|
||||
def add_model_parameters(parser):
|
||||
""" add model parameters """
|
||||
att_dict = {
|
||||
'5_o_Clock_Shadow': 0,
|
||||
'Arched_Eyebrows': 1,
|
||||
'Attractive': 2,
|
||||
'Bags_Under_Eyes': 3,
|
||||
'Bald': 4,
|
||||
'Bangs': 5,
|
||||
'Big_Lips': 6,
|
||||
'Big_Nose': 7,
|
||||
'Black_Hair': 8,
|
||||
'Blond_Hair': 9,
|
||||
'Blurry': 10,
|
||||
'Brown_Hair': 11,
|
||||
'Bushy_Eyebrows': 12,
|
||||
'Chubby': 13,
|
||||
'Double_Chin': 14,
|
||||
'Eyeglasses': 15,
|
||||
'Goatee': 16,
|
||||
'Gray_Hair': 17,
|
||||
'Heavy_Makeup': 18,
|
||||
'High_Cheekbones': 19,
|
||||
'Male': 20,
|
||||
'Mouth_Slightly_Open': 21,
|
||||
'Mustache': 22,
|
||||
'Narrow_Eyes': 23,
|
||||
'No_Beard': 24,
|
||||
'Oval_Face': 25,
|
||||
'Pale_Skin': 26,
|
||||
'Pointy_Nose': 27,
|
||||
'Receding_Hairline': 28,
|
||||
'Rosy_Cheeks': 29,
|
||||
'Sideburns': 30,
|
||||
'Smiling': 31,
|
||||
'Straight_Hair': 32,
|
||||
'Wavy_Hair': 33,
|
||||
'Wearing_Earrings': 34,
|
||||
'Wearing_Hat': 35,
|
||||
'Wearing_Lipstick': 36,
|
||||
'Wearing_Necklace': 37,
|
||||
'Wearing_Necktie': 38,
|
||||
'Young': 39
|
||||
}
|
||||
attr_default = ['Bangs', 'Blond_Hair', 'Mustache', 'Young']
|
||||
parser.add_argument("--attrs",
|
||||
default=attr_default,
|
||||
choices=att_dict,
|
||||
nargs='+',
|
||||
help='Attributes to modify by the model')
|
||||
parser.add_argument('--image_size',
|
||||
type=int,
|
||||
default=128,
|
||||
help='input image size')
|
||||
parser.add_argument(
|
||||
'--shortcut_layers',
|
||||
type=int,
|
||||
default=3,
|
||||
help='# of skip connections between the encoder and the decoder')
|
||||
parser.add_argument('--enc_dim', type=int, default=64)
|
||||
parser.add_argument('--dec_dim', type=int, default=64)
|
||||
parser.add_argument('--dis_dim', type=int, default=64)
|
||||
parser.add_argument('--dis_fc_dim',
|
||||
type=int,
|
||||
default=1024,
|
||||
help='# of discriminator fc channels')
|
||||
parser.add_argument('--enc_layers', type=int, default=5)
|
||||
parser.add_argument('--dec_layers', type=int, default=5)
|
||||
parser.add_argument('--dis_layers', type=int, default=5)
|
||||
# STGAN & STU
|
||||
parser.add_argument('--attr_mode',
|
||||
type=str,
|
||||
default='diff',
|
||||
choices=['diff', 'target'])
|
||||
parser.add_argument('--use_stu', type=bool, default=True)
|
||||
parser.add_argument('--stu_dim', type=int, default=64)
|
||||
parser.add_argument('--stu_kernel_size', type=int, default=3)
|
||||
parser.add_argument('--stu_norm',
|
||||
type=str,
|
||||
default='bn',
|
||||
choices=['bn', 'in'])
|
||||
parser.add_argument(
|
||||
'--stu_state',
|
||||
type=str,
|
||||
default='stu',
|
||||
choices=['stu', 'gru', 'direct'],
|
||||
help=
|
||||
'gru: gru arch.; stu: stu arch.; direct: directly pass the inner state to the outer layer'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--multi_inputs',
|
||||
type=int,
|
||||
default=1,
|
||||
help='# of hierarchical inputs (in the first several encoder layers')
|
||||
parser.add_argument(
|
||||
'--one_more_conv',
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[0, 1, 3],
|
||||
help='0: no further conv after the decoder; 1: conv(k=1); 3: conv(k=3)'
|
||||
)
|
||||
return parser
|
||||
|
||||
def add_train_parameters(parser):
|
||||
""" add train parameters """
|
||||
parser.add_argument('--mode',
|
||||
default='wgan',
|
||||
choices=['wgan', 'lsgan', 'dcgan'])
|
||||
parser.add_argument('--continue_train',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='Flag of continue train, default is false')
|
||||
parser.add_argument(
|
||||
'--continue_iter',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='Continue point of continue training, -1 means latest')
|
||||
parser.add_argument('--test_iter',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='Checkpoint of model testing, -1 means latest')
|
||||
parser.add_argument('--n_epochs',
|
||||
type=int,
|
||||
default=100,
|
||||
help='# of epochs')
|
||||
parser.add_argument('--n_critic',
|
||||
type=int,
|
||||
default=5,
|
||||
help='number of D updates per each G update')
|
||||
parser.add_argument('--max_epoch',
|
||||
type=int,
|
||||
default=100,
|
||||
help='# of epochs')
|
||||
parser.add_argument('--init_epoch',
|
||||
type=int,
|
||||
default=50,
|
||||
help='# of epochs with init lr.')
|
||||
parser.add_argument('--batch_size', type=int, default=64)
|
||||
parser.add_argument("--beta1",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Adam beta1, default is 0.5.")
|
||||
parser.add_argument("--beta2",
|
||||
type=float,
|
||||
default=0.999,
|
||||
help="Adam beta2, default is 0.999.")
|
||||
parser.add_argument("--lambda_gp",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Lambda gp, default is 10")
|
||||
parser.add_argument("--lambda1",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Lambda1, default is 1")
|
||||
parser.add_argument("--lambda2",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Lambda2, default is 10")
|
||||
parser.add_argument("--lambda3",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Lambda3, default is 100")
|
||||
parser.add_argument('--lr',
|
||||
type=float,
|
||||
default=0.0002,
|
||||
help='learning rate')
|
||||
parser.add_argument('--thres_int', type=float, default=0.5)
|
||||
parser.add_argument('--test_int', type=float, default=1.0)
|
||||
parser.add_argument('--n_sample',
|
||||
type=int,
|
||||
default=64,
|
||||
help='# of sample images')
|
||||
parser.add_argument('--print_freq',
|
||||
type=int,
|
||||
default=1,
|
||||
help='print log freq (per critic), default is 1')
|
||||
parser.add_argument(
|
||||
'--save_freq',
|
||||
type=int,
|
||||
default=5000,
|
||||
help='save model evary save_freq iters, 0 means to save evary epoch.')
|
||||
parser.add_argument(
|
||||
'--sample_freq',
|
||||
type=int,
|
||||
default=1000,
|
||||
help=
|
||||
'eval on validation set every sample_freq iters, 0 means to save evary epoch.'
|
||||
)
|
||||
return parser
|
||||
|
||||
def get_args(phase):
|
||||
"""get args"""
|
||||
parser = argparse.ArgumentParser(description="STGAN")
|
||||
# basic parameters
|
||||
parser = add_basic_parameters(parser)
|
||||
|
||||
#model parameters
|
||||
parser = add_model_parameters(parser)
|
||||
|
||||
# training
|
||||
parser = add_train_parameters(parser)
|
||||
|
||||
# others
|
||||
parser.add_argument('--use_cropped_img', action='store_true')
|
||||
default_experiment_name = datetime.datetime.now().strftime(
|
||||
"%Y.%m.%d-%H%M%S")
|
||||
parser.add_argument('--experiment_name', default=default_experiment_name)
|
||||
parser.add_argument('--num_ckpt', type=int, default=1)
|
||||
parser.add_argument('--clear', default=False, action='store_true')
|
||||
parser.add_argument('--save_graphs', type=ast.literal_eval, default=False, \
|
||||
help='whether save graphs, default is False.')
|
||||
parser.add_argument('--outputs_dir', type=str, default='./outputs', \
|
||||
help='models are saved here, default is ./outputs.')
|
||||
parser.add_argument("--dataroot", type=str, default='./dataset')
|
||||
parser.add_argument('--file_format', type=str, choices=['AIR', 'ONNX', 'MINDIR'], default='AIR', \
|
||||
help='file format')
|
||||
parser.add_argument('--file_name', type=str, default='STGAN', help='output file name prefix.')
|
||||
parser.add_argument('--ckpt_path', default=None, help='path of checkpoint file.')
|
||||
|
||||
args = parser.parse_args()
|
||||
if phase == 'test':
|
||||
assert args.experiment_name != default_experiment_name, "--experiment_name should be assigned in test mode"
|
||||
if args.continue_train:
|
||||
assert args.experiment_name != default_experiment_name, "--experiment_name should be assigned in continue"
|
||||
if args.device_num > 1 and args.platform != "CPU":
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=args.platform,
|
||||
save_graphs=args.save_graphs,
|
||||
device_id=int(os.environ["DEVICE_ID"]))
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(
|
||||
parallel_mode=ParallelMode.AUTO_PARALLEL,
|
||||
gradients_mean=True,
|
||||
device_num=args.device_num)
|
||||
init()
|
||||
args.rank = int(os.environ["DEVICE_ID"])
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=args.platform,
|
||||
save_graphs=args.save_graphs,
|
||||
device_id=args.device_id)
|
||||
args.rank = 0
|
||||
args.device_num = 1
|
||||
|
||||
args.n_epochs = min(args.max_epoch, args.n_epochs)
|
||||
args.n_epochs_decay = args.max_epoch - args.n_epochs
|
||||
if phase == 'train':
|
||||
args.isTrain = True
|
||||
else:
|
||||
args.isTrain = False
|
||||
args.phase = phase
|
||||
return args
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utils"""
|
||||
import os
|
||||
|
||||
|
||||
def mkdirs(paths):
|
||||
"""create empty directories if they don't exist
|
||||
|
||||
Parameters:
|
||||
paths (str list) -- a list of directory paths
|
||||
"""
|
||||
if isinstance(paths, list) and not isinstance(paths, str):
|
||||
for path in paths:
|
||||
mkdir(path)
|
||||
else:
|
||||
mkdir(paths)
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
"""create a single empty directory if it didn't exist
|
||||
|
||||
Parameters:
|
||||
path (str) -- a single directory path
|
||||
"""
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def is_version_satisfied_1_2_0(version):
|
||||
old_arr = '1.2.0'.split('.')
|
||||
new_arr = version.split('.')
|
||||
assert len(old_arr) == 3 and len(
|
||||
new_arr) == 3, 'version input must str like x.x.x'
|
||||
for i in range(3):
|
||||
if new_arr[i] < old_arr[i]:
|
||||
return False
|
||||
return True
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" STGAN TRAIN"""
|
||||
import tqdm
|
||||
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.models import STGANModel
|
||||
from src.utils import get_args
|
||||
from src.dataset import CelebADataLoader
|
||||
|
||||
set_seed(1)
|
||||
|
||||
def train():
|
||||
"""Train Function"""
|
||||
args = get_args("train")
|
||||
print(args)
|
||||
|
||||
print('\n\n=============== start training ===============\n\n')
|
||||
|
||||
# Get DataLoader
|
||||
data_loader = CelebADataLoader(args.dataroot,
|
||||
mode=args.phase,
|
||||
selected_attrs=args.attrs,
|
||||
batch_size=args.batch_size,
|
||||
image_size=args.image_size,
|
||||
device_num=args.device_num)
|
||||
iter_per_epoch = len(data_loader)
|
||||
args.dataset_size = iter_per_epoch
|
||||
|
||||
# Get STGAN MODEL
|
||||
model = STGANModel(args)
|
||||
it_count = 0
|
||||
|
||||
for _ in tqdm.trange(args.n_epochs, desc='Epoch Loop'):
|
||||
for _ in tqdm.trange(iter_per_epoch, desc='Inner Epoch Loop'):
|
||||
if model.current_iteration > it_count:
|
||||
it_count += 1
|
||||
continue
|
||||
try:
|
||||
# training model
|
||||
data = next(data_loader.train_loader)
|
||||
model.set_input(data)
|
||||
model.optimize_parameters()
|
||||
|
||||
# saving model
|
||||
if (it_count + 1) % args.save_freq == 0:
|
||||
model.save_networks()
|
||||
|
||||
# sampling
|
||||
if (it_count + 1) % args.sample_freq == 0:
|
||||
model.eval(data_loader)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info('You have entered CTRL+C.. Wait to finalize')
|
||||
model.save_networks()
|
||||
|
||||
it_count += 1
|
||||
model.current_iteration = it_count
|
||||
|
||||
model.save_networks()
|
||||
print('\n\n=============== finish training ===============\n\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
Loading…
Reference in New Issue